use kenobi_core::cred::usage::OutboundUsable;
use kenobi_core::flags::CapabilityFlags;
use kenobi_core::typestate::{
Encryption, MaybeDelegation, MaybeEncryption, MaybeSigning, NoDelegation, NoEncryption, NoSigning, Signing,
};
use std::sync::Arc;
use std::{ffi::c_void, marker::PhantomData};
use windows::Win32::Security::Authentication::Identity::{
ISC_REQ_CONFIDENTIALITY, ISC_REQ_DELEGATE, ISC_REQ_INTEGRITY, ISC_REQ_NO_INTEGRITY,
};
use windows::Win32::{
Foundation::{
SEC_E_INTERNAL_ERROR, SEC_E_INVALID_HANDLE, SEC_E_INVALID_TOKEN, SEC_E_LOGON_DENIED,
SEC_E_NO_AUTHENTICATING_AUTHORITY, SEC_E_NO_CREDENTIALS, SEC_E_OK, SEC_E_TARGET_UNKNOWN,
SEC_E_UNSUPPORTED_FUNCTION, SEC_E_WRONG_PRINCIPAL, SEC_I_COMPLETE_AND_CONTINUE, SEC_I_COMPLETE_NEEDED,
SEC_I_CONTINUE_NEEDED,
},
Security::Authentication::Identity::{
ISC_REQ_FLAGS, ISC_REQ_MUTUAL_AUTH, ISC_RET_MUTUAL_AUTH, InitializeSecurityContextW, QueryContextAttributesW,
SEC_CHANNEL_BINDINGS, SECBUFFER_CHANNEL_BINDINGS, SECBUFFER_TOKEN, SECBUFFER_VERSION, SECPKG_ATTR_SESSION_KEY,
SECURITY_NATIVE_DREP, SecBuffer, SecBufferDesc, SecPkgContext_SessionKey,
},
};
mod builder;
mod error;
mod typestate;
use crate::sign_encrypt::WrapError;
use crate::{
buffer::NonResizableVec,
context::{ContextHandle, SessionKey},
cred::Credentials,
sign_encrypt::{Altered, Encrypted, Plaintext, Signature},
};
pub use builder::ClientBuilder;
pub use error::InitializeContextError;
pub use typestate::{DelegationPolicy, EncryptionPolicy, SigningPolicy};
pub struct ClientContext<Usage, S = NoSigning, E = NoEncryption, D = NoDelegation> {
attributes: u32,
cred: Arc<Credentials<Usage>>,
context: ContextHandle,
token_buffer: NonResizableVec,
_enc: PhantomData<(S, E, D)>,
}
impl<Usage, S, E, D> ClientContext<Usage, S, E, D> {
pub fn is_mutually_authenticated(&self) -> bool {
self.attributes & ISC_RET_MUTUAL_AUTH != 0
}
pub fn attributes(&self) -> u32 {
self.attributes
}
pub fn last_token(&self) -> Option<&[u8]> {
(!self.token_buffer.is_empty()).then_some(self.token_buffer.as_slice())
}
pub fn get_session_key(&self) -> windows_result::Result<SessionKey> {
let mut key = SecPkgContext_SessionKey::default();
unsafe {
QueryContextAttributesW(
self.context.as_ptr(),
SECPKG_ATTR_SESSION_KEY,
std::ptr::from_mut(&mut key) as *mut c_void,
)
}?;
unsafe { Ok(SessionKey::new(key)) }
}
}
impl<Usage, E, D> ClientContext<Usage, Signing, E, D> {
pub fn sign(&self, message: &[u8]) -> Result<Signature, WrapError> {
self.context.wrap_sign(message).map_err(WrapError)
}
pub fn unwrap(&self, message: &[u8]) -> Result<Plaintext, Altered> {
self.context.unwrap(message)
}
}
impl<Usage, D> ClientContext<Usage, Signing, Encryption, D> {
pub fn encrypt(&self, message: &[u8]) -> Result<Encrypted, WrapError> {
self.context.wrap_encrypt(message).map_err(WrapError)
}
}
impl<Usage: OutboundUsable> ClientContext<Usage, NoSigning, NoEncryption, NoDelegation> {
pub fn new_from_cred(
cred: Arc<Credentials<Usage>>,
target_principal: Option<&str>,
) -> Result<StepOut<Usage>, InitializeContextError> {
ClientBuilder::new_from_credentials(cred, target_principal).initialize()
}
}
type CheckSignResult<Usage, E, D> = Result<ClientContext<Usage, Signing, E, D>, ClientContext<Usage, NoSigning, E, D>>;
impl<Usage, E, D> ClientContext<Usage, MaybeSigning, E, D> {
pub fn check_signing(self) -> CheckSignResult<Usage, E, D> {
if <MaybeSigning as typestate::signing::Sealed>::requirements_met_manual(self.attributes) {
Ok(self.convert_policy())
} else {
Err(self.convert_policy())
}
}
}
type CheckEncryptionResult<Usage, S, D> =
Result<ClientContext<Usage, S, Encryption, D>, ClientContext<Usage, S, NoEncryption, D>>;
impl<Usage, S, D> ClientContext<Usage, S, MaybeEncryption, D> {
pub fn check_encryption(self) -> CheckEncryptionResult<Usage, S, D> {
if <MaybeEncryption as typestate::encryption::Sealed>::requirements_met_manual(self.attributes) {
Ok(self.convert_policy())
} else {
Err(self.convert_policy())
}
}
}
impl<Usage, S1, E1, D1> ClientContext<Usage, S1, E1, D1> {
fn convert_policy<S2, E2, D2>(self) -> ClientContext<Usage, S2, E2, D2> {
let ClientContext {
attributes,
cred,
context,
token_buffer,
..
} = self;
ClientContext {
cred,
context,
attributes,
token_buffer,
_enc: PhantomData,
}
}
}
pub struct PendingClientContext<Usage> {
target_spn: Option<Box<[u16]>>,
cred: Arc<Credentials<Usage>>,
context: ContextHandle,
flags: CapabilityFlags,
token_buffer: NonResizableVec,
attributes: u32,
}
impl<Usage: OutboundUsable> PendingClientContext<Usage> {
pub fn step(self, token: &[u8]) -> Result<StepOut<Usage>, InitializeContextError> {
step(
self.cred,
self.target_spn,
Some(self.context),
self.flags,
self.attributes,
self.token_buffer,
None,
Some(token),
)
}
}
impl<Usage> PendingClientContext<Usage> {
pub fn next_token(&self) -> &[u8] {
assert!(
!self.token_buffer.is_empty(),
"Pending client context returned no token to transmit"
);
self.token_buffer.as_slice()
}
}
#[allow(clippy::too_many_arguments)]
fn step<Usage: OutboundUsable>(
cred: Arc<Credentials<Usage>>,
target_spn: Option<Box<[u16]>>,
context: Option<ContextHandle>,
flags: CapabilityFlags,
mut attributes: u32,
mut token_buffer: NonResizableVec,
channel_bindings: Option<&[u8]>,
in_token: Option<&[u8]>,
) -> Result<StepOut<Usage>, InitializeContextError> {
token_buffer.resize_max();
let mut out_token_buffer = token_buffer.sec_buffer(SECBUFFER_TOKEN);
let mut out_token_buffer_desc = SecBufferDesc {
ulVersion: SECBUFFER_VERSION,
cBuffers: 1,
pBuffers: &mut out_token_buffer,
};
let in_token_buf = in_token
.map(|token| {
let cb_buffer = token
.len()
.try_into()
.map_err(|_| InitializeContextError::InvalidToken)?;
Ok(SecBuffer {
cbBuffer: cb_buffer,
BufferType: SECBUFFER_TOKEN,
pvBuffer: token.as_ptr() as *mut c_void,
})
})
.transpose()?;
let mut buffers = vec![];
buffers.extend(in_token_buf);
let mut channel_binding_buffer = channel_bindings.map(|cb| {
let scb = SEC_CHANNEL_BINDINGS {
dwApplicationDataOffset: 32,
cbApplicationDataLength: cb.len() as u32,
..Default::default()
};
let mut buffer = vec![0u8; 32 + cb.len()];
unsafe {
std::ptr::write(buffer.as_mut_ptr() as *mut SEC_CHANNEL_BINDINGS, scb);
}
buffer[32..].copy_from_slice(cb);
buffer
});
buffers.extend(channel_binding_buffer.as_mut().map(|cb| SecBuffer {
cbBuffer: cb.len() as u32,
BufferType: SECBUFFER_CHANNEL_BINDINGS,
pvBuffer: cb.as_mut_ptr() as *mut c_void,
}));
let in_token_buf_desc = match buffers.as_mut_slice() {
[] => None,
v => Some(SecBufferDesc {
ulVersion: SECBUFFER_VERSION,
cBuffers: v.len() as u32,
pBuffers: v.as_mut_ptr(),
}),
};
let mut context = context.map(ContextHandle::leak);
let opt_sec_handle = context.as_ref().map(std::ptr::from_ref);
let out_sec_handle = context.get_or_insert_default();
let hres = unsafe {
InitializeSecurityContextW(
Some(cred.as_ref().as_raw_handle()),
opt_sec_handle,
target_spn.as_ref().map(|b| b.as_ptr()),
convert_flags(flags),
0,
SECURITY_NATIVE_DREP,
in_token_buf_desc.as_ref().map(std::ptr::from_ref),
0,
Some(out_sec_handle),
Some(&mut out_token_buffer_desc),
&mut attributes,
None,
)
};
token_buffer.set_length(out_token_buffer.cbBuffer);
match hres {
SEC_E_OK => {
let context = unsafe { ContextHandle::pick_up(*out_sec_handle) };
Ok(StepOut::Completed(ClientContext {
attributes,
cred,
context,
token_buffer,
_enc: PhantomData,
}))
}
SEC_I_COMPLETE_AND_CONTINUE | SEC_I_COMPLETE_NEEDED => {
panic!("CompleteAuthToken is not supported by Negotiate")
}
SEC_I_CONTINUE_NEEDED => {
let context = unsafe { ContextHandle::pick_up(*out_sec_handle) };
Ok(StepOut::Pending(PendingClientContext {
target_spn,
cred,
context,
flags,
token_buffer,
attributes,
}))
}
SEC_E_INTERNAL_ERROR => Err(InitializeContextError::Internal),
SEC_E_INVALID_HANDLE => Err(InitializeContextError::InvalidHandle),
SEC_E_INVALID_TOKEN => Err(InitializeContextError::InvalidToken),
SEC_E_LOGON_DENIED => Err(InitializeContextError::Denied),
SEC_E_NO_CREDENTIALS => todo!("constrained delegation"),
SEC_E_NO_AUTHENTICATING_AUTHORITY => Err(InitializeContextError::NoAuthority),
SEC_E_TARGET_UNKNOWN => Err(InitializeContextError::TargetUnknown),
SEC_E_UNSUPPORTED_FUNCTION => panic!("unsupported function"),
SEC_E_WRONG_PRINCIPAL => Err(InitializeContextError::WrongPrincipal),
e => todo!("unknown error code: {e:?} ({})", e.message()),
}
}
fn convert_flags(flags: CapabilityFlags) -> ISC_REQ_FLAGS {
let mut out_flags = ISC_REQ_FLAGS(0);
if flags.contains_all(CapabilityFlags::MUTUAL_AUTH) {
out_flags |= ISC_REQ_MUTUAL_AUTH;
}
if flags.contains_all(CapabilityFlags::INTEGRITY) {
out_flags |= ISC_REQ_INTEGRITY
} else {
out_flags |= ISC_REQ_NO_INTEGRITY
};
if flags.contains_all(CapabilityFlags::CONFIDENTIALITY) {
out_flags |= ISC_REQ_CONFIDENTIALITY
}
if flags.contains_all(CapabilityFlags::DELEGATE) {
out_flags |= ISC_REQ_DELEGATE
}
out_flags
}
pub enum StepOut<Usage> {
Pending(PendingClientContext<Usage>),
Completed(ClientContext<Usage, MaybeSigning, MaybeEncryption, MaybeDelegation>),
}