use std::sync::Mutex;
use sspi::{
AuthIdentity, BufferType, ClientRequestFlags, CredentialUse, Credentials, CredentialsBuffers,
DataRepresentation, Negotiate, NegotiateConfig, SecurityBuffer, SecurityStatus, Sspi, SspiImpl,
Username, ntlm::NtlmConfig,
};
use crate::error::AuthError;
use crate::provider::{AuthData, AuthMethod, AuthProvider};
pub struct SspiAuth {
spn: String,
credentials: Option<(String, String)>,
context: Mutex<SspiContext>,
}
struct SspiContext {
negotiate: Negotiate,
creds_handle: Option<CredentialsBuffers>,
complete: bool,
}
fn create_negotiate_config() -> NegotiateConfig {
NegotiateConfig::new(
Box::new(NtlmConfig::default()),
Some("kerberos,ntlm".to_string()),
String::new(),
)
}
impl SspiAuth {
pub fn new(hostname: &str, port: u16) -> Result<Self, AuthError> {
let spn = format!("MSSQLSvc/{hostname}:{port}");
let negotiate = Negotiate::new_client(create_negotiate_config())
.map_err(|e| AuthError::Sspi(format!("Failed to create Negotiate context: {e}")))?;
Ok(Self {
spn,
credentials: None,
context: Mutex::new(SspiContext {
negotiate,
creds_handle: None,
complete: false,
}),
})
}
pub fn with_credentials(
hostname: &str,
port: u16,
username: impl Into<String>,
password: impl Into<String>,
) -> Result<Self, AuthError> {
let spn = format!("MSSQLSvc/{hostname}:{port}");
let negotiate = Negotiate::new_client(create_negotiate_config())
.map_err(|e| AuthError::Sspi(format!("Failed to create Negotiate context: {e}")))?;
Ok(Self {
spn,
credentials: Some((username.into(), password.into())),
context: Mutex::new(SspiContext {
negotiate,
creds_handle: None,
complete: false,
}),
})
}
pub fn with_spn(spn: impl Into<String>) -> Result<Self, AuthError> {
let negotiate = Negotiate::new_client(create_negotiate_config())
.map_err(|e| AuthError::Sspi(format!("Failed to create Negotiate context: {e}")))?;
Ok(Self {
spn: spn.into(),
credentials: None,
context: Mutex::new(SspiContext {
negotiate,
creds_handle: None,
complete: false,
}),
})
}
pub fn initialize(&self) -> Result<Vec<u8>, AuthError> {
let mut ctx = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
let credentials = if let Some((ref username, ref password)) = self.credentials {
let parsed_user = Username::parse(username)
.map_err(|e| AuthError::Sspi(format!("Invalid username format: {e}")))?;
let identity = AuthIdentity {
username: parsed_user,
password: password.clone().into(),
};
Some(Credentials::from(identity))
} else {
None
};
let creds_result = {
let mut builder = ctx
.negotiate
.acquire_credentials_handle()
.with_credential_use(CredentialUse::Outbound);
if let Some(ref creds) = credentials {
builder = builder.with_auth_data(creds);
}
builder
.execute(&mut ctx.negotiate)
.map_err(|e| AuthError::Sspi(format!("Failed to acquire credentials: {e}")))?
};
ctx.creds_handle = creds_result.credentials_handle;
let mut creds = ctx.creds_handle.take();
let mut output_buffer = vec![SecurityBuffer::new(Vec::new(), BufferType::Token)];
let spn = self.spn.clone();
let mut builder = ctx
.negotiate
.initialize_security_context()
.with_credentials_handle(&mut creds)
.with_context_requirements(
ClientRequestFlags::MUTUAL_AUTH
| ClientRequestFlags::REPLAY_DETECT
| ClientRequestFlags::SEQUENCE_DETECT,
)
.with_target_data_representation(DataRepresentation::Native)
.with_target_name(&spn)
.with_output(&mut output_buffer);
let init_result = ctx
.negotiate
.initialize_security_context_impl(&mut builder)
.map_err(|e| AuthError::Sspi(format!("Failed to initialize context: {e}")))?
.resolve_to_result()
.map_err(|e| AuthError::Sspi(format!("Failed to resolve context: {e}")))?;
ctx.creds_handle = creds;
match init_result.status {
SecurityStatus::Ok | SecurityStatus::ContinueNeeded => {
if init_result.status == SecurityStatus::Ok {
ctx.complete = true;
}
let token = output_buffer
.into_iter()
.find(|b| b.buffer_type.buffer_type == BufferType::Token)
.map(|b| b.buffer)
.unwrap_or_default();
Ok(token)
}
status => Err(AuthError::Sspi(format!(
"Unexpected status during initialization: {status:?}"
))),
}
}
pub fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
let mut ctx = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
if ctx.complete {
return Ok(None);
}
if ctx.creds_handle.is_none() {
return Err(AuthError::Sspi(
"Context not initialized - call initialize() first".into(),
));
}
let mut input_buffer = vec![SecurityBuffer::new(
server_token.to_vec(),
BufferType::Token,
)];
let mut output_buffer = vec![SecurityBuffer::new(Vec::new(), BufferType::Token)];
let spn = self.spn.clone();
let mut creds = ctx.creds_handle.take();
let mut builder = ctx
.negotiate
.initialize_security_context()
.with_credentials_handle(&mut creds)
.with_context_requirements(
ClientRequestFlags::MUTUAL_AUTH
| ClientRequestFlags::REPLAY_DETECT
| ClientRequestFlags::SEQUENCE_DETECT,
)
.with_target_data_representation(DataRepresentation::Native)
.with_target_name(&spn)
.with_input(&mut input_buffer)
.with_output(&mut output_buffer);
let result = ctx
.negotiate
.initialize_security_context_impl(&mut builder)
.map_err(|e| AuthError::Sspi(format!("SSPI step failed: {e}")))?
.resolve_to_result()
.map_err(|e| AuthError::Sspi(format!("Failed to resolve step result: {e}")))?;
ctx.creds_handle = creds;
match result.status {
SecurityStatus::Ok => {
ctx.complete = true;
let token = output_buffer
.into_iter()
.find(|b| {
b.buffer_type.buffer_type == BufferType::Token && !b.buffer.is_empty()
})
.map(|b| b.buffer);
Ok(token)
}
SecurityStatus::ContinueNeeded => {
let token = output_buffer
.into_iter()
.find(|b| b.buffer_type.buffer_type == BufferType::Token)
.map(|b| b.buffer)
.unwrap_or_default();
Ok(Some(token))
}
status => Err(AuthError::Sspi(format!(
"Unexpected status during step: {status:?}"
))),
}
}
pub fn is_complete(&self) -> bool {
self.context.lock().map(|ctx| ctx.complete).unwrap_or(false)
}
#[must_use]
pub fn spn(&self) -> &str {
&self.spn
}
}
impl std::fmt::Debug for SspiAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SspiAuth")
.field("spn", &self.spn)
.field("has_explicit_credentials", &self.credentials.is_some())
.field("complete", &self.is_complete())
.finish()
}
}
impl crate::negotiator::SspiNegotiator for SspiAuth {
fn initialize(&self) -> Result<Vec<u8>, AuthError> {
SspiAuth::initialize(self)
}
fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
SspiAuth::step(self, server_token)
}
fn is_complete(&self) -> bool {
SspiAuth::is_complete(self)
}
}
impl AuthProvider for SspiAuth {
fn method(&self) -> AuthMethod {
AuthMethod::Integrated
}
fn authenticate(&self) -> Result<AuthData, AuthError> {
let blob = self.initialize()?;
Ok(AuthData::Sspi { blob })
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_spn_format() {
let auth = SspiAuth::new("sqlserver.example.com", 1433).unwrap();
assert_eq!(auth.spn(), "MSSQLSvc/sqlserver.example.com:1433");
}
#[test]
fn test_custom_spn() {
let auth = SspiAuth::with_spn("MSSQLSvc/cluster.example.com:1433").unwrap();
assert_eq!(auth.spn(), "MSSQLSvc/cluster.example.com:1433");
}
#[test]
fn test_debug_output() {
let auth = SspiAuth::new("test.example.com", 1433).unwrap();
let debug = format!("{auth:?}");
assert!(debug.contains("SspiAuth"));
assert!(debug.contains("test.example.com"));
}
#[test]
fn test_is_complete_initially_false() {
let auth = SspiAuth::new("test.example.com", 1433).unwrap();
assert!(!auth.is_complete());
}
#[test]
fn test_with_credentials() {
let auth = SspiAuth::with_credentials("test.example.com", 1433, "DOMAIN\\user", "password")
.unwrap();
let debug = format!("{auth:?}");
assert!(debug.contains("has_explicit_credentials: true"));
}
}