kenobi-windows 0.2.0

A Windows Negotiate implementation.
use std::{ffi::c_void, marker::PhantomData, sync::Arc};

use windows::Win32::{
    Foundation::{
        SEC_E_INTERNAL_ERROR, SEC_E_INVALID_HANDLE, SEC_E_INVALID_TOKEN, SEC_E_LOGON_DENIED,
        SEC_E_NO_AUTHENTICATING_AUTHORITY, SEC_E_OK, SEC_E_UNSUPPORTED_FUNCTION, SEC_I_CONTINUE_NEEDED,
    },
    Security::Authentication::Identity::{
        ASC_REQ_CONFIDENTIALITY, ASC_REQ_DELEGATE, ASC_REQ_FLAGS, ASC_REQ_INTEGRITY, ASC_REQ_MUTUAL_AUTH,
        AcceptSecurityContext, SEC_CHANNEL_BINDINGS, SECBUFFER_CHANNEL_BINDINGS, SECBUFFER_TOKEN, SECBUFFER_VERSION,
        SECURITY_NATIVE_DREP, SecBuffer, SecBufferDesc,
    },
};

use kenobi_core::{cred::usage::InboundUsable, flags::CapabilityFlags};

use crate::{
    buffer::NonResizableVec,
    context::ContextHandle,
    cred::Credentials,
    server::typestate::{DelegationPolicy, EncryptionPolicy, SigningPolicy},
    sign_encrypt::{Altered, Plaintext, Signature, WrapError},
};

mod builder;
mod error;
mod typestate;

pub use builder::ServerBuilder;
pub use error::AcceptContextError;
use kenobi_core::typestate::{
    Delegation, Encryption, MaybeDelegation, MaybeEncryption, MaybeSigning, NoDelegation, NoEncryption, NoSigning,
    Signing,
};

pub struct ServerContext<Usage, S, E, D> {
    cred: Arc<Credentials<Usage>>,
    context: ContextHandle,
    attributes: u32,
    /// should never be resized
    token_buffer: NonResizableVec,
    _enc: PhantomData<(D, E, S)>,
}
impl<Usage: InboundUsable, S, E, D> ServerContext<Usage, S, E, D>
where
    S: SigningPolicy,
    E: EncryptionPolicy,
    D: DelegationPolicy,
{
    pub fn initialize(cred: Arc<Credentials<Usage>>, first_token: &[u8]) -> Result<StepOut<Usage>, AcceptContextError> {
        step(
            cred,
            None,
            CapabilityFlags::MUTUAL_AUTH | CapabilityFlags::INTEGRITY | CapabilityFlags::CONFIDENTIALITY,
            0,
            NonResizableVec::new(),
            None,
            first_token,
        )
    }
}
impl<Usage, S, E, D> ServerContext<Usage, S, E, D> {
    pub fn last_token(&self) -> Option<&[u8]> {
        (!self.token_buffer.is_empty()).then_some(&self.token_buffer)
    }
}
impl<Usage, E, D> ServerContext<Usage, Signing, E, D> {
    pub fn sign(&self, message: &[u8]) -> Result<Signature, WrapError> {
        self.context.wrap_sign(message).map_err(WrapError)
    }
    pub fn unwrap(&self, message: &[u8]) -> Result<Plaintext, Altered> {
        self.context.unwrap(message)
    }
}
impl<Usage, S, E> ServerContext<Usage, S, E, MaybeDelegation> {
    #[allow(clippy::type_complexity)]
    pub fn check_delegation(
        self,
    ) -> Result<ServerContext<Usage, S, E, Delegation>, ServerContext<Usage, S, E, NoDelegation>> {
        if self.attributes & <MaybeDelegation as typestate::delegation::Sealed>::REQUEST_FLAGS.0 != 0 {
            Ok(self.convert_policy())
        } else {
            Err(self.convert_policy())
        }
    }
}
impl<Usage, E, D> ServerContext<Usage, MaybeSigning, E, D> {
    #[allow(clippy::type_complexity)]
    pub fn check_signing(self) -> Result<ServerContext<Usage, Signing, E, D>, ServerContext<Usage, NoSigning, E, D>> {
        if self.attributes & <MaybeSigning as typestate::sign::Sealed>::REQUEST_FLAGS.0 != 0 {
            Ok(self.convert_policy())
        } else {
            Err(self.convert_policy())
        }
    }
}
impl<Usage, S, D> ServerContext<Usage, S, MaybeEncryption, D> {
    #[allow(clippy::type_complexity)]
    pub fn check_encryption(
        self,
    ) -> Result<ServerContext<Usage, S, Encryption, D>, ServerContext<Usage, S, NoEncryption, D>> {
        if self.attributes & <MaybeEncryption as typestate::encrypt::Sealed>::REQUEST_FLAGS.0 != 0 {
            Ok(self.convert_policy())
        } else {
            Err(self.convert_policy())
        }
    }
}
impl<Usage, S1, E1, D1> ServerContext<Usage, S1, E1, D1> {
    fn convert_policy<S2, E2, D2>(self) -> ServerContext<Usage, S2, E2, D2> {
        let ServerContext {
            cred,
            context,
            attributes,
            token_buffer,
            ..
        } = self;
        ServerContext {
            cred,
            context,
            attributes,
            token_buffer,
            _enc: PhantomData,
        }
    }
}

pub struct PendingServerContext<Usage> {
    cred: Arc<Credentials<Usage>>,
    context: ContextHandle,
    flags: CapabilityFlags,
    attributes: u32,
    token_buffer: NonResizableVec,
}
impl<Usage> PendingServerContext<Usage> {
    pub fn next_token(&self) -> &[u8] {
        assert!(!self.token_buffer.is_empty());
        &self.token_buffer
    }
}
impl<Usage: InboundUsable> PendingServerContext<Usage> {
    pub fn step(self, token: &[u8]) -> Result<StepOut<Usage>, AcceptContextError> {
        step(
            self.cred,
            Some(self.context),
            self.flags,
            self.attributes,
            self.token_buffer,
            None,
            token,
        )
    }
}

fn step<Usage: InboundUsable>(
    cred: Arc<Credentials<Usage>>,
    mut context: Option<ContextHandle>,
    flags: CapabilityFlags,
    mut attributes: u32,
    mut token_buffer: NonResizableVec,
    channel_bindings: Option<&[u8]>,
    in_token: &[u8],
) -> Result<StepOut<Usage>, AcceptContextError> {
    token_buffer.resize_max();

    let mut out_token_buffer = token_buffer.sec_buffer(SECBUFFER_TOKEN);
    let mut out_token_buffer_desc = SecBufferDesc {
        ulVersion: SECBUFFER_VERSION,
        cBuffers: 1,
        pBuffers: &mut out_token_buffer,
    };

    let mut buffers = vec![SecBuffer {
        cbBuffer: in_token.len() as u32,
        BufferType: SECBUFFER_TOKEN,
        pvBuffer: in_token.as_ptr() as *mut c_void,
    }];

    // Add channel binding data
    let mut channel_binding_buffer = channel_bindings.map(|cb| {
        let scb = SEC_CHANNEL_BINDINGS {
            dwApplicationDataOffset: 32,
            cbApplicationDataLength: cb.len() as u32,
            ..Default::default()
        };
        let mut buffer = vec![0u8; 32 + cb.len()];
        unsafe {
            std::ptr::write(buffer.as_mut_ptr() as *mut SEC_CHANNEL_BINDINGS, scb);
        }
        buffer[32..].copy_from_slice(cb);
        buffer
    });
    buffers.extend(channel_binding_buffer.as_mut().map(|cb| SecBuffer {
        cbBuffer: cb.len() as u32,
        BufferType: SECBUFFER_CHANNEL_BINDINGS,
        pvBuffer: cb.as_mut_ptr() as *mut c_void,
    }));

    let in_buf_desc = SecBufferDesc {
        ulVersion: SECBUFFER_VERSION,
        cBuffers: buffers.len() as u32,
        pBuffers: buffers.as_mut_ptr(),
    };
    let old_context_ptr = context.as_ref().map(|c| c.as_ptr());
    let hres = unsafe {
        AcceptSecurityContext(
            Some(cred.as_ref().as_raw_handle()),
            old_context_ptr,
            Some(&in_buf_desc),
            convert_flags(flags),
            SECURITY_NATIVE_DREP,
            Some(context.as_mut().map(|c| c.as_mut_ptr()).unwrap_or_default()),
            Some(&mut out_token_buffer_desc),
            &mut attributes,
            None,
        )
    };
    token_buffer.set_length(out_token_buffer.cbBuffer);
    match hres {
        SEC_E_OK => {
            let context = context.expect("get_or_inserted before");
            // Flag checks
            Ok(StepOut::Completed(ServerContext {
                cred,
                context,
                attributes,
                token_buffer,
                _enc: PhantomData,
            }))
        }
        SEC_I_CONTINUE_NEEDED => {
            let context = context.expect("get_or_inserted before");
            Ok(StepOut::Pending(PendingServerContext {
                cred,
                context,
                flags,
                attributes,
                token_buffer,
            }))
        }
        SEC_E_INTERNAL_ERROR => Err(AcceptContextError::Internal),
        SEC_E_INVALID_HANDLE => Err(AcceptContextError::InvalidHandle),
        SEC_E_INVALID_TOKEN => Err(AcceptContextError::InvalidToken),
        SEC_E_LOGON_DENIED => Err(AcceptContextError::Denied),
        SEC_E_NO_AUTHENTICATING_AUTHORITY => Err(AcceptContextError::NoAuthority),
        SEC_E_UNSUPPORTED_FUNCTION => unreachable!("only applicable from Schannel SSP"),
        e => todo!("unknown error code: {e:?} (\"{}\")", e.message()),
    }
}

fn convert_flags(flag: CapabilityFlags) -> ASC_REQ_FLAGS {
    let mut out_flags = ASC_REQ_FLAGS(0);
    if flag.contains_all(CapabilityFlags::MUTUAL_AUTH) {
        out_flags |= ASC_REQ_MUTUAL_AUTH;
    }
    if flag.contains_all(CapabilityFlags::INTEGRITY) {
        out_flags |= ASC_REQ_INTEGRITY
    };
    if flag.contains_all(CapabilityFlags::CONFIDENTIALITY) {
        out_flags |= ASC_REQ_CONFIDENTIALITY
    }
    if flag.contains_all(CapabilityFlags::DELEGATE) {
        out_flags |= ASC_REQ_DELEGATE
    }
    out_flags
}

pub enum StepOut<Usage> {
    Pending(PendingServerContext<Usage>),
    Completed(ServerContext<Usage, MaybeSigning, MaybeEncryption, MaybeDelegation>),
}