#![allow(dead_code)]
use crate::error::CredSspError;
const TAG_SEQUENCE: u8 = 0x30;
const TAG_OCTET_STRING: u8 = 0x04;
const TAG_INTEGER: u8 = 0x02;
const TAG_ENUM: u8 = 0x0A;
const TAG_OID: u8 = 0x06;
const TAG_BIT_STRING: u8 = 0x03;
const SPNEGO_OID: &[u8] = &[0x06, 0x06, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x02];
const NTLM_OID: &[u8] = &[
0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a,
];
#[derive(Debug, Clone)]
pub(crate) struct TsRequest {
pub version: u32,
pub nego_token: Option<Vec<u8>>,
pub auth_info: Option<Vec<u8>>,
pub pub_key_auth: Option<Vec<u8>>,
pub error_code: Option<u32>,
pub client_nonce: Option<Vec<u8>>,
}
fn encode_length(len: usize) -> Vec<u8> {
if len < 0x80 {
vec![len as u8]
} else if len < 0x100 {
vec![0x81, len as u8]
} else if len < 0x10000 {
vec![0x82, (len >> 8) as u8, len as u8]
} else {
vec![0x83, (len >> 16) as u8, (len >> 8) as u8, len as u8]
}
}
fn encode_tlv(tag: u8, contents: &[u8]) -> Vec<u8> {
let mut out = vec![tag];
out.extend_from_slice(&encode_length(contents.len()));
out.extend_from_slice(contents);
out
}
fn encode_sequence(contents: &[u8]) -> Vec<u8> {
encode_tlv(TAG_SEQUENCE, contents)
}
fn encode_context_tag(tag: u8, contents: &[u8]) -> Vec<u8> {
encode_tlv(0xA0 | tag, contents)
}
fn encode_octet_string(data: &[u8]) -> Vec<u8> {
encode_tlv(TAG_OCTET_STRING, data)
}
fn encode_integer_value(value: u32) -> Vec<u8> {
if value == 0 {
return encode_tlv(TAG_INTEGER, &[0]);
}
let bytes = value.to_be_bytes();
let start = bytes.iter().position(|&b| b != 0).unwrap_or(3);
if bytes[start] & 0x80 != 0 {
let mut content = vec![0x00];
content.extend_from_slice(&bytes[start..]);
encode_tlv(TAG_INTEGER, &content)
} else {
encode_tlv(TAG_INTEGER, &bytes[start..])
}
}
pub(crate) fn encode_ts_request(
version: u32,
nego_token: Option<&[u8]>,
pub_key_auth: Option<&[u8]>,
auth_info: Option<&[u8]>,
client_nonce: Option<&[u8]>,
) -> Vec<u8> {
let mut contents = Vec::new();
contents.extend_from_slice(&encode_context_tag(0, &encode_integer_value(version)));
if let Some(token) = nego_token {
let inner = encode_context_tag(0, &encode_octet_string(token));
let nego_seq = encode_sequence(&inner);
let nego_data = encode_sequence(&nego_seq);
contents.extend_from_slice(&encode_context_tag(1, &nego_data));
}
if let Some(info) = auth_info {
contents.extend_from_slice(&encode_context_tag(2, &encode_octet_string(info)));
}
if let Some(auth) = pub_key_auth {
contents.extend_from_slice(&encode_context_tag(3, &encode_octet_string(auth)));
}
if let Some(nonce) = client_nonce {
contents.extend_from_slice(&encode_context_tag(5, &encode_octet_string(nonce)));
}
encode_sequence(&contents)
}
pub(crate) fn encode_ts_credentials(domain: &str, username: &str, password: &str) -> Vec<u8> {
let domain_bytes = crate::ntlm::crypto::to_utf16le(domain);
let user_bytes = crate::ntlm::crypto::to_utf16le(username);
let pass_bytes = crate::ntlm::crypto::to_utf16le(password);
let mut pwd_contents = Vec::new();
pwd_contents.extend_from_slice(&encode_context_tag(0, &encode_octet_string(&domain_bytes)));
pwd_contents.extend_from_slice(&encode_context_tag(1, &encode_octet_string(&user_bytes)));
pwd_contents.extend_from_slice(&encode_context_tag(2, &encode_octet_string(&pass_bytes)));
let ts_password_creds = encode_sequence(&pwd_contents);
let mut cred_contents = Vec::new();
cred_contents.extend_from_slice(&encode_context_tag(0, &encode_integer_value(1)));
cred_contents.extend_from_slice(&encode_context_tag(
1,
&encode_octet_string(&ts_password_creds),
));
encode_sequence(&cred_contents)
}
pub(crate) fn encode_spnego_init(ntlm_token: &[u8]) -> Vec<u8> {
let mech_types = encode_context_tag(0, &encode_sequence(NTLM_OID));
let mech_token = encode_context_tag(2, &encode_octet_string(ntlm_token));
let mut neg_init_contents = Vec::new();
neg_init_contents.extend_from_slice(&mech_types);
neg_init_contents.extend_from_slice(&mech_token);
let neg_token_init = encode_context_tag(0, &encode_sequence(&neg_init_contents));
let mut gss_contents = Vec::new();
gss_contents.extend_from_slice(SPNEGO_OID);
gss_contents.extend_from_slice(&neg_token_init);
encode_tlv(0x60, &gss_contents)
}
pub(crate) fn encode_spnego_response(ntlm_token: &[u8], mech_list_mic: Option<&[u8]>) -> Vec<u8> {
let neg_state = encode_context_tag(0, &[0x0a, 0x01, 0x01]);
let resp_token = encode_context_tag(2, &encode_octet_string(ntlm_token));
let mut contents = Vec::new();
contents.extend_from_slice(&neg_state);
contents.extend_from_slice(&resp_token);
if let Some(mic) = mech_list_mic {
contents.extend_from_slice(&encode_context_tag(3, &encode_octet_string(mic)));
}
encode_context_tag(1, &encode_sequence(&contents))
}
pub(crate) const MECH_TYPE_LIST_NTLM: &[u8] = &[
0x30, 0x0c, 0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a,
];
fn decode_length(data: &[u8]) -> Result<(usize, usize), CredSspError> {
if data.is_empty() {
return Err(CredSspError::Asn1Decode("empty length".into()));
}
if data[0] < 0x80 {
Ok((data[0] as usize, 1))
} else {
let num_bytes = (data[0] & 0x7F) as usize;
if num_bytes == 0 || num_bytes > 3 || data.len() < 1 + num_bytes {
return Err(CredSspError::Asn1Decode("invalid length encoding".into()));
}
let mut len = 0usize;
for i in 0..num_bytes {
len = (len << 8) | data[1 + i] as usize;
}
Ok((len, 1 + num_bytes))
}
}
fn read_tlv(data: &[u8]) -> Result<(u8, &[u8], usize), CredSspError> {
if data.is_empty() {
return Err(CredSspError::Asn1Decode("unexpected end of data".into()));
}
let tag = data[0];
let (len, len_bytes) = decode_length(&data[1..])?;
let start = 1 + len_bytes;
let end = start + len;
if end > data.len() {
return Err(CredSspError::Asn1Decode(format!(
"TLV length {len} exceeds data ({})",
data.len() - start
)));
}
Ok((tag, &data[start..end], end))
}
fn find_context_tag(data: &[u8], tag: u8) -> Option<&[u8]> {
let target = 0xA0 | tag;
let mut pos = 0;
while pos < data.len() {
if let Ok((t, val, consumed)) = read_tlv(&data[pos..]) {
if t == target {
return Some(val);
}
pos += consumed;
} else {
break;
}
}
None
}
fn decode_octet_string(data: &[u8]) -> Result<&[u8], CredSspError> {
let (tag, val, _) = read_tlv(data)?;
if tag != TAG_OCTET_STRING {
return Err(CredSspError::Asn1Decode(format!(
"expected OCTET STRING (0x04), got 0x{tag:02x}"
)));
}
Ok(val)
}
fn decode_integer(data: &[u8]) -> Result<u32, CredSspError> {
let (tag, val, _) = read_tlv(data)?;
if tag != TAG_INTEGER {
return Err(CredSspError::Asn1Decode(format!(
"expected INTEGER (0x02), got 0x{tag:02x}"
)));
}
let mut result = 0u32;
for &b in val {
result = (result << 8) | b as u32;
}
Ok(result)
}
pub(crate) fn decode_ts_request(data: &[u8]) -> Result<TsRequest, CredSspError> {
let (tag, seq_data, _) = read_tlv(data)?;
if tag != TAG_SEQUENCE {
return Err(CredSspError::Asn1Decode(
"TSRequest: expected SEQUENCE".into(),
));
}
let version = find_context_tag(seq_data, 0)
.map(decode_integer)
.transpose()?
.unwrap_or(2);
let nego_token = find_context_tag(seq_data, 1).and_then(|nego_data| {
let (_, seq_of, _) = read_tlv(nego_data).ok()?;
let (_, inner_seq, _) = read_tlv(seq_of).ok()?;
let token_data = find_context_tag(inner_seq, 0)?;
Some(decode_octet_string(token_data).ok()?.to_vec())
});
let auth_info = find_context_tag(seq_data, 2)
.map(|d| decode_octet_string(d).map(|v| v.to_vec()))
.transpose()?;
let pub_key_auth = find_context_tag(seq_data, 3)
.map(|d| decode_octet_string(d).map(|v| v.to_vec()))
.transpose()?;
let error_code = find_context_tag(seq_data, 4)
.map(decode_integer)
.transpose()?;
let client_nonce = find_context_tag(seq_data, 5)
.map(|d| decode_octet_string(d).map(|v| v.to_vec()))
.transpose()?;
Ok(TsRequest {
version,
nego_token,
auth_info,
pub_key_auth,
error_code,
client_nonce,
})
}
pub(crate) fn decode_spnego_token(data: &[u8]) -> Result<Vec<u8>, CredSspError> {
let (tag, contents, _) = read_tlv(data)?;
if tag == 0x60 {
let oid_tlv =
read_tlv(contents).map_err(|_| CredSspError::Asn1Decode("bad OID in SPNEGO".into()))?;
let after_oid = &contents[oid_tlv.2..];
if let Some(init_data) = find_context_tag(after_oid, 0) {
let (_, seq_data, _) = read_tlv(init_data)?;
if let Some(token_data) = find_context_tag(seq_data, 2) {
return Ok(decode_octet_string(token_data)?.to_vec());
}
}
Err(CredSspError::Asn1Decode(
"no mechToken in NegTokenInit".into(),
))
} else if tag == 0xA1 {
let (_, seq_data, _) = read_tlv(contents)?;
if let Some(token_data) = find_context_tag(seq_data, 2) {
return Ok(decode_octet_string(token_data)?.to_vec());
}
Err(CredSspError::Asn1Decode(
"no responseToken in NegTokenResp".into(),
))
} else {
Err(CredSspError::Asn1Decode(format!(
"unexpected SPNEGO tag: 0x{tag:02x}"
)))
}
}
pub(crate) fn extract_subject_public_key(cert_der: &[u8]) -> Result<Vec<u8>, CredSspError> {
let (_, cert_seq, _) = read_tlv(cert_der)?;
let (_, tbs_seq, _) = read_tlv(cert_seq)?;
let mut pos = 0;
let mut field_idx = 0;
while pos < tbs_seq.len() {
let (tag, val, consumed) = read_tlv(&tbs_seq[pos..])?;
if tag == TAG_SEQUENCE && field_idx >= 5 {
let mut inner_pos = 0;
while inner_pos < val.len() {
let (inner_tag, inner_val, inner_consumed) = read_tlv(&val[inner_pos..])?;
if inner_tag == TAG_BIT_STRING && inner_val.len() > 1 {
return Ok(inner_val[1..].to_vec());
}
inner_pos += inner_consumed;
}
}
pos += consumed;
field_idx += 1;
}
Err(CredSspError::Asn1Decode(
"SubjectPublicKey not found in certificate".into(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ts_request_encode_decode_roundtrip() {
let encoded = encode_ts_request(6, Some(b"ntlm_token"), None, None, Some(&[0xAA; 32]));
let decoded = decode_ts_request(&encoded).unwrap();
assert_eq!(decoded.version, 6);
assert_eq!(decoded.nego_token, Some(b"ntlm_token".to_vec()));
assert!(decoded.auth_info.is_none());
assert!(decoded.pub_key_auth.is_none());
assert_eq!(decoded.client_nonce, Some(vec![0xAA; 32]));
}
#[test]
fn ts_request_with_pub_key_auth() {
let encoded = encode_ts_request(6, None, Some(b"encrypted_hash"), None, None);
let decoded = decode_ts_request(&encoded).unwrap();
assert_eq!(decoded.version, 6);
assert!(decoded.nego_token.is_none());
assert_eq!(decoded.pub_key_auth, Some(b"encrypted_hash".to_vec()));
}
#[test]
fn ts_request_with_auth_info() {
let encoded = encode_ts_request(6, None, None, Some(b"encrypted_creds"), None);
let decoded = decode_ts_request(&encoded).unwrap();
assert_eq!(decoded.auth_info, Some(b"encrypted_creds".to_vec()));
}
#[test]
fn ts_credentials_encoding() {
let creds = encode_ts_credentials("DOMAIN", "user", "pass");
assert_eq!(creds[0], TAG_SEQUENCE);
assert!(creds.len() > 20);
}
#[test]
fn spnego_init_wraps_ntlm() {
let type1 = b"NTLMSSP\x00\x01\x00\x00\x00";
let wrapped = encode_spnego_init(type1);
assert_eq!(wrapped[0], 0x60);
assert!(wrapped.windows(SPNEGO_OID.len()).any(|w| w == SPNEGO_OID));
}
#[test]
fn spnego_init_roundtrip() {
let ntlm_token = b"test_ntlm_type1_token";
let wrapped = encode_spnego_init(ntlm_token);
let unwrapped = decode_spnego_token(&wrapped).unwrap();
assert_eq!(unwrapped, ntlm_token);
}
#[test]
fn spnego_response_roundtrip() {
let ntlm_token = b"test_ntlm_type3_token";
let wrapped = encode_spnego_response(ntlm_token, None);
let unwrapped = decode_spnego_token(&wrapped).unwrap();
assert_eq!(unwrapped, ntlm_token);
}
#[test]
fn encode_length_short() {
assert_eq!(encode_length(0), vec![0]);
assert_eq!(encode_length(127), vec![127]);
}
#[test]
fn encode_length_medium() {
assert_eq!(encode_length(128), vec![0x81, 128]);
assert_eq!(encode_length(255), vec![0x81, 255]);
}
#[test]
fn encode_length_long() {
assert_eq!(encode_length(256), vec![0x82, 1, 0]);
assert_eq!(encode_length(65535), vec![0x82, 255, 255]);
}
#[test]
fn integer_encoding() {
let zero = encode_integer_value(0);
assert_eq!(zero, vec![TAG_INTEGER, 1, 0]);
let six = encode_integer_value(6);
assert_eq!(six, vec![TAG_INTEGER, 1, 6]);
let big = encode_integer_value(256);
assert_eq!(big, vec![TAG_INTEGER, 2, 1, 0]);
}
#[test]
fn encode_length_very_long() {
let l = encode_length(70000);
assert_eq!(l[0], 0x83);
assert_eq!(l.len(), 4);
}
#[test]
fn integer_encoding_high_bit() {
let enc = encode_integer_value(128);
assert_eq!(enc, vec![TAG_INTEGER, 2, 0x00, 0x80]);
}
#[test]
fn decode_length_error_empty() {
assert!(decode_length(&[]).is_err());
}
#[test]
fn decode_length_error_truncated() {
assert!(decode_length(&[0x82, 0x01]).is_err());
}
#[test]
fn decode_length_error_zero_num_bytes() {
assert!(decode_length(&[0x80]).is_err());
}
#[test]
fn read_tlv_error_empty() {
assert!(read_tlv(&[]).is_err());
}
#[test]
fn read_tlv_error_truncated_value() {
assert!(read_tlv(&[0x30, 10, 0x00, 0x00]).is_err());
}
#[test]
fn decode_octet_string_wrong_tag() {
let data = encode_integer_value(42);
assert!(decode_octet_string(&data).is_err());
}
#[test]
fn decode_integer_wrong_tag() {
let data = encode_octet_string(b"nope");
assert!(decode_integer(&data).is_err());
}
#[test]
fn decode_ts_request_not_sequence() {
let data = encode_octet_string(b"bad");
assert!(decode_ts_request(&data).is_err());
}
#[test]
fn decode_spnego_token_bad_tag() {
let data = encode_octet_string(b"not spnego");
assert!(decode_spnego_token(&data).is_err());
}
#[test]
fn extract_subject_public_key_from_self_signed_cert() {
let pub_key_bytes = vec![0x00, 0x30, 0x0d]; let bit_string = encode_tlv(TAG_BIT_STRING, &pub_key_bytes);
let algo = encode_sequence(&[TAG_OID, 3, 0x2a, 0x86, 0x48]); let mut spki_contents = Vec::new();
spki_contents.extend_from_slice(&algo);
spki_contents.extend_from_slice(&bit_string);
let spki = encode_sequence(&spki_contents);
let version = encode_context_tag(0, &encode_integer_value(2));
let serial = encode_integer_value(1);
let sig_algo = encode_sequence(&[TAG_OID, 3, 0x2a, 0x86, 0x48]);
let issuer = encode_sequence(&[]);
let validity = encode_sequence(&[]);
let subject = encode_sequence(&[]);
let mut tbs = Vec::new();
tbs.extend_from_slice(&version);
tbs.extend_from_slice(&serial);
tbs.extend_from_slice(&sig_algo);
tbs.extend_from_slice(&issuer);
tbs.extend_from_slice(&validity);
tbs.extend_from_slice(&subject);
tbs.extend_from_slice(&spki);
let tbs_seq = encode_sequence(&tbs);
let sig_algo2 = encode_sequence(&[TAG_OID, 3, 0x2a, 0x86, 0x48]);
let sig_val = encode_tlv(TAG_BIT_STRING, &[0x00, 0xFF]);
let mut cert_contents = Vec::new();
cert_contents.extend_from_slice(&tbs_seq);
cert_contents.extend_from_slice(&sig_algo2);
cert_contents.extend_from_slice(&sig_val);
let cert = encode_sequence(&cert_contents);
let result = extract_subject_public_key(&cert).unwrap();
assert_eq!(result, vec![0x30, 0x0d]); }
#[test]
fn extract_subject_public_key_bad_cert() {
assert!(extract_subject_public_key(&[0x30, 0x00]).is_err());
}
#[test]
fn find_context_tag_not_found() {
let data = encode_context_tag(0, &encode_integer_value(1));
assert!(find_context_tag(&data, 5).is_none());
}
#[test]
fn ts_request_all_fields() {
let encoded = encode_ts_request(
6,
Some(b"nego"),
Some(b"pubkey"),
Some(b"creds"),
Some(&[0xBB; 32]),
);
let decoded = decode_ts_request(&encoded).unwrap();
assert_eq!(decoded.version, 6);
assert_eq!(decoded.nego_token, Some(b"nego".to_vec()));
assert_eq!(decoded.pub_key_auth, Some(b"pubkey".to_vec()));
assert_eq!(decoded.auth_info, Some(b"creds".to_vec()));
assert_eq!(decoded.client_nonce, Some(vec![0xBB; 32]));
}
#[test]
fn encode_length_boundary_255_uses_one_byte_form() {
let l = encode_length(255);
assert_eq!(l, vec![0x81, 255]);
let l = encode_length(256);
assert_eq!(l, vec![0x82, 1, 0]);
}
#[test]
fn encode_context_tag_uses_or_not_xor() {
let encoded = encode_context_tag(3, &[0x42]);
assert_eq!(encoded[0], 0xA3); assert_eq!(encoded[1], 1); assert_eq!(encoded[2], 0x42); }
}