use std::sync::Mutex;
use libgssapi::{
context::{ClientCtx, CtxFlags},
credential::{Cred, CredUsage},
name::Name,
oid::{GSS_MECH_KRB5, GSS_NT_HOSTBASED_SERVICE, OidSet},
};
use crate::error::AuthError;
use crate::provider::{AuthData, AuthMethod, AuthProvider};
const GSS_MECH_SPNEGO: libgssapi::oid::Oid = libgssapi::oid::Oid::from_slice(&[
0x2b, 0x06, 0x01, 0x05, 0x05, 0x02, ]);
pub struct IntegratedAuth {
spn: String,
context: Mutex<Option<ClientCtx>>,
complete: Mutex<bool>,
}
impl IntegratedAuth {
#[must_use]
pub fn new(hostname: &str, port: u16) -> Self {
let spn = format!("MSSQLSvc/{hostname}:{port}");
Self {
spn,
context: Mutex::new(None),
complete: Mutex::new(false),
}
}
#[must_use]
pub fn with_spn(spn: impl Into<String>) -> Self {
Self {
spn: spn.into(),
context: Mutex::new(None),
complete: Mutex::new(false),
}
}
fn create_service_name(&self) -> Result<Name, AuthError> {
Name::new(self.spn.as_bytes(), Some(&GSS_NT_HOSTBASED_SERVICE))
.map_err(|e| AuthError::Sspi(format!("Failed to create service name: {e}")))
}
pub fn initialize(&self) -> Result<Vec<u8>, AuthError> {
let service_name = self.create_service_name()?;
let mut mechs =
OidSet::new().map_err(|e| AuthError::Sspi(format!("Failed to create OID set: {e}")))?;
mechs
.add(&GSS_MECH_SPNEGO)
.map_err(|e| AuthError::Sspi(format!("Failed to add SPNEGO mechanism: {e}")))?;
mechs
.add(&GSS_MECH_KRB5)
.map_err(|e| AuthError::Sspi(format!("Failed to add Kerberos mechanism: {e}")))?;
let cred = Cred::acquire(None, None, CredUsage::Initiate, Some(&mechs))
.map_err(|e| AuthError::Sspi(format!("Failed to acquire credentials: {e}")))?;
let mut ctx = ClientCtx::new(
Some(cred),
service_name,
CtxFlags::GSS_C_MUTUAL_FLAG | CtxFlags::GSS_C_REPLAY_FLAG,
Some(&GSS_MECH_SPNEGO),
);
let token = ctx
.step(None, None)
.map_err(|e| AuthError::Sspi(format!("Failed to initialize context: {e}")))?
.ok_or_else(|| {
AuthError::Sspi("No initial token generated (context already complete?)".into())
})?;
let mut context_guard = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
*context_guard = Some(ctx);
Ok(token.to_vec())
}
pub fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
let mut context_guard = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
let ctx = context_guard.as_mut().ok_or_else(|| {
AuthError::Sspi("Context not initialized - call initialize() first".into())
})?;
match ctx.step(Some(server_token), None) {
Ok(Some(token)) => Ok(Some(token.to_vec())),
Ok(None) => {
let mut complete_guard = self
.complete
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire complete lock".into()))?;
*complete_guard = true;
Ok(None)
}
Err(e) => Err(AuthError::Sspi(format!("GSSAPI step failed: {e}"))),
}
}
pub fn is_complete(&self) -> bool {
self.complete.lock().map(|guard| *guard).unwrap_or(false)
}
pub fn negotiated_mechanism(&self) -> Option<String> {
self.context.lock().ok().and_then(|guard| {
guard.as_ref().map(|_ctx| {
"SPNEGO/Kerberos".to_string()
})
})
}
}
impl std::fmt::Debug for IntegratedAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IntegratedAuth")
.field("complete", &self.is_complete())
.finish_non_exhaustive()
}
}
impl crate::negotiator::SspiNegotiator for IntegratedAuth {
fn initialize(&self) -> Result<Vec<u8>, AuthError> {
IntegratedAuth::initialize(self)
}
fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
IntegratedAuth::step(self, server_token)
}
fn is_complete(&self) -> bool {
IntegratedAuth::is_complete(self)
}
}
impl AuthProvider for IntegratedAuth {
fn method(&self) -> AuthMethod {
AuthMethod::Integrated
}
fn authenticate(&self) -> Result<AuthData, AuthError> {
let blob = self.initialize()?;
Ok(AuthData::Sspi { blob })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_name_format() {
let auth = IntegratedAuth::new("sqlserver.example.com", 1433);
assert_eq!(auth.spn, "MSSQLSvc/sqlserver.example.com:1433");
}
#[test]
fn test_custom_spn() {
let auth = IntegratedAuth::with_spn("MSSQLSvc/cluster.example.com:1433");
assert_eq!(auth.spn, "MSSQLSvc/cluster.example.com:1433");
}
#[test]
fn test_debug_output() {
let auth = IntegratedAuth::new("test.example.com", 1433);
let debug = format!("{auth:?}");
assert!(debug.contains("IntegratedAuth"));
}
#[test]
fn test_is_complete_initially_false() {
let auth = IntegratedAuth::new("test.example.com", 1433);
assert!(!auth.is_complete());
}
}