use std::sync::Mutex;
use windows::Win32::Security::Authentication::Identity::{
AcquireCredentialsHandleW, DeleteSecurityContext, FreeCredentialsHandle, ISC_REQ_CONNECTION,
ISC_REQ_MUTUAL_AUTH, ISC_REQ_REPLAY_DETECT, ISC_REQ_SEQUENCE_DETECT,
InitializeSecurityContextW, SECPKG_CRED_OUTBOUND, SecBuffer, SecBufferDesc,
};
use windows::Win32::Security::Credentials::SecHandle;
use windows::core::{HRESULT, PCWSTR};
use crate::error::AuthError;
const SECBUFFER_TOKEN: u32 = 2;
const SECBUFFER_VERSION: u32 = 0;
const SECURITY_NATIVE_DREP: u32 = 0x10;
const MAX_TOKEN_SIZE: usize = 16_384;
const SEC_E_OK: HRESULT = HRESULT(0);
const SEC_I_CONTINUE_NEEDED: HRESULT = HRESULT(0x0009_0312_u32 as i32);
pub struct NativeSspiAuth {
spn: String,
context: Mutex<NativeSspiContext>,
}
struct NativeSspiContext {
cred_handle: SecHandle,
ctx_handle: SecHandle,
has_context: bool,
complete: bool,
}
unsafe impl Send for NativeSspiContext {}
impl NativeSspiAuth {
pub fn new(hostname: &str, port: u16) -> Result<Self, AuthError> {
let spn = format!("MSSQLSvc/{hostname}:{port}");
let mut cred_handle = SecHandle::default();
let mut expiry: i64 = 0;
let package: Vec<u16> = "Negotiate\0".encode_utf16().collect();
let result = unsafe {
AcquireCredentialsHandleW(
None, PCWSTR(package.as_ptr()), SECPKG_CRED_OUTBOUND, None, None, None, None, &mut cred_handle, Some(&mut expiry), )
};
if let Err(e) = result {
return Err(AuthError::Sspi(format!(
"Failed to acquire Windows credentials: {e}"
)));
}
tracing::debug!(
spn = %spn,
"Acquired native Windows SSPI credentials for current user"
);
Ok(Self {
spn,
context: Mutex::new(NativeSspiContext {
cred_handle,
ctx_handle: SecHandle::default(),
has_context: false,
complete: false,
}),
})
}
pub fn initialize(&self) -> Result<Vec<u8>, AuthError> {
let mut ctx = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
let spn_wide: Vec<u16> = format!("{}\0", self.spn).encode_utf16().collect();
let mut out_buf = vec![0u8; MAX_TOKEN_SIZE];
let mut out_sec_buf = SecBuffer {
cbBuffer: out_buf.len() as u32,
BufferType: SECBUFFER_TOKEN,
pvBuffer: out_buf.as_mut_ptr().cast(),
};
let mut out_buf_desc = SecBufferDesc {
ulVersion: SECBUFFER_VERSION,
cBuffers: 1,
pBuffers: &mut out_sec_buf,
};
let mut context_attrs: u32 = 0;
let mut expiry: i64 = 0;
let context_req = ISC_REQ_MUTUAL_AUTH
| ISC_REQ_REPLAY_DETECT
| ISC_REQ_SEQUENCE_DETECT
| ISC_REQ_CONNECTION;
let hr = unsafe {
InitializeSecurityContextW(
Some(&ctx.cred_handle), None, Some(PCWSTR(spn_wide.as_ptr()).as_ptr()), context_req, 0, SECURITY_NATIVE_DREP, None, 0, Some(&mut ctx.ctx_handle), Some(&mut out_buf_desc), &mut context_attrs, Some(&mut expiry), )
};
if hr == SEC_E_OK || hr == SEC_I_CONTINUE_NEEDED {
ctx.has_context = true;
if hr == SEC_E_OK {
ctx.complete = true;
}
let token_len = out_sec_buf.cbBuffer as usize;
let token = out_buf[..token_len].to_vec();
tracing::debug!(
token_len,
continue_needed = (hr == SEC_I_CONTINUE_NEEDED),
"SSPI initialization produced token"
);
Ok(token)
} else {
Err(AuthError::Sspi(format!(
"InitializeSecurityContext failed: HRESULT 0x{:08X}",
hr.0 as u32
)))
}
}
pub fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
let mut ctx = self
.context
.lock()
.map_err(|_| AuthError::Sspi("Failed to acquire context lock".into()))?;
if ctx.complete {
return Ok(None);
}
if !ctx.has_context {
return Err(AuthError::Sspi(
"Context not initialized - call initialize() first".into(),
));
}
let spn_wide: Vec<u16> = format!("{}\0", self.spn).encode_utf16().collect();
let mut in_buf = server_token.to_vec();
let mut in_sec_buf = SecBuffer {
cbBuffer: in_buf.len() as u32,
BufferType: SECBUFFER_TOKEN,
pvBuffer: in_buf.as_mut_ptr().cast(),
};
let in_buf_desc = SecBufferDesc {
ulVersion: SECBUFFER_VERSION,
cBuffers: 1,
pBuffers: &mut in_sec_buf,
};
let mut out_buf = vec![0u8; MAX_TOKEN_SIZE];
let mut out_sec_buf = SecBuffer {
cbBuffer: out_buf.len() as u32,
BufferType: SECBUFFER_TOKEN,
pvBuffer: out_buf.as_mut_ptr().cast(),
};
let mut out_buf_desc = SecBufferDesc {
ulVersion: SECBUFFER_VERSION,
cBuffers: 1,
pBuffers: &mut out_sec_buf,
};
let mut context_attrs: u32 = 0;
let mut expiry: i64 = 0;
let context_req = ISC_REQ_MUTUAL_AUTH
| ISC_REQ_REPLAY_DETECT
| ISC_REQ_SEQUENCE_DETECT
| ISC_REQ_CONNECTION;
let hr = unsafe {
InitializeSecurityContextW(
Some(&ctx.cred_handle), Some(&ctx.ctx_handle), Some(PCWSTR(spn_wide.as_ptr()).as_ptr()), context_req, 0, SECURITY_NATIVE_DREP, Some(&in_buf_desc), 0, Some(&mut ctx.ctx_handle), Some(&mut out_buf_desc), &mut context_attrs, Some(&mut expiry), )
};
match hr {
hr if hr == SEC_E_OK => {
ctx.complete = true;
let token_len = out_sec_buf.cbBuffer as usize;
if token_len > 0 {
let token = out_buf[..token_len].to_vec();
tracing::debug!(token_len, "SSPI step complete with final token");
Ok(Some(token))
} else {
tracing::debug!("SSPI step complete, no final token");
Ok(None)
}
}
hr if hr == SEC_I_CONTINUE_NEEDED => {
let token_len = out_sec_buf.cbBuffer as usize;
let token = out_buf[..token_len].to_vec();
tracing::debug!(token_len, "SSPI step needs continuation");
Ok(Some(token))
}
_ => Err(AuthError::Sspi(format!(
"SSPI step failed: HRESULT 0x{:08X}",
hr.0 as u32
))),
}
}
pub fn is_complete(&self) -> bool {
self.context.lock().map(|ctx| ctx.complete).unwrap_or(false)
}
#[must_use]
pub fn spn(&self) -> &str {
&self.spn
}
}
impl std::fmt::Debug for NativeSspiAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NativeSspiAuth")
.field("spn", &self.spn)
.field("complete", &self.is_complete())
.finish()
}
}
impl Drop for NativeSspiContext {
fn drop(&mut self) {
unsafe {
if self.has_context {
let _ = DeleteSecurityContext(&self.ctx_handle);
}
let _ = FreeCredentialsHandle(&self.cred_handle);
}
}
}
impl crate::negotiator::SspiNegotiator for NativeSspiAuth {
fn initialize(&self) -> Result<Vec<u8>, AuthError> {
NativeSspiAuth::initialize(self)
}
fn step(&self, server_token: &[u8]) -> Result<Option<Vec<u8>>, AuthError> {
NativeSspiAuth::step(self, server_token)
}
fn is_complete(&self) -> bool {
NativeSspiAuth::is_complete(self)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_spn_format() {
let auth = NativeSspiAuth::new("sqlserver.example.com", 1433).unwrap();
assert_eq!(auth.spn(), "MSSQLSvc/sqlserver.example.com:1433");
}
#[test]
fn test_debug_output() {
let auth = NativeSspiAuth::new("test.example.com", 1433).unwrap();
let debug = format!("{auth:?}");
assert!(debug.contains("NativeSspiAuth"));
assert!(debug.contains("test.example.com"));
}
#[test]
fn test_is_complete_initially_false() {
let auth = NativeSspiAuth::new("test.example.com", 1433).unwrap();
assert!(!auth.is_complete());
}
#[test]
fn test_initialize_produces_token() {
let auth = NativeSspiAuth::new("localhost", 1433).unwrap();
let token = auth.initialize().unwrap();
assert!(!token.is_empty(), "Initial SSPI token should not be empty");
assert_eq!(
token[0], 0x60,
"Token should start with SPNEGO APPLICATION tag"
);
}
}