use crate::error::Result;
use crate::msg::header::Header;
use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::flags::{Capabilities, SecurityMode};
use crate::types::Dialect;
use crate::Error;
pub const NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY: u16 = 0x0001;
pub const NEGOTIATE_CONTEXT_ENCRYPTION: u16 = 0x0002;
pub const NEGOTIATE_CONTEXT_COMPRESSION: u16 = 0x0003;
pub const NEGOTIATE_CONTEXT_SIGNING: u16 = 0x0008;
pub const HASH_ALGORITHM_SHA512: u16 = 0x0001;
pub const CIPHER_AES_128_CCM: u16 = 0x0001;
pub const CIPHER_AES_128_GCM: u16 = 0x0002;
pub const CIPHER_AES_256_CCM: u16 = 0x0003;
pub const CIPHER_AES_256_GCM: u16 = 0x0004;
pub const SIGNING_HMAC_SHA256: u16 = 0x0000;
pub const SIGNING_AES_CMAC: u16 = 0x0001;
pub const SIGNING_AES_GMAC: u16 = 0x0002;
pub const COMPRESSION_NONE: u16 = 0x0000;
pub const COMPRESSION_LZNT1: u16 = 0x0001;
pub const COMPRESSION_LZ77: u16 = 0x0002;
pub const COMPRESSION_LZ77_HUFFMAN: u16 = 0x0003;
pub const COMPRESSION_PATTERN_V1: u16 = 0x0004;
pub const COMPRESSION_LZ4: u16 = 0x0005;
pub const COMPRESSION_FLAG_NONE: u32 = 0x0000_0000;
pub const COMPRESSION_FLAG_CHAINED: u32 = 0x0000_0001;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NegotiateContext {
PreauthIntegrity {
hash_algorithms: Vec<u16>,
salt: Vec<u8>,
},
Encryption {
ciphers: Vec<u16>,
},
Compression {
flags: u32,
algorithms: Vec<u16>,
},
Signing {
algorithms: Vec<u16>,
},
Unknown {
context_type: u16,
data: Vec<u8>,
},
}
fn pack_context_data(ctx: &NegotiateContext, cursor: &mut WriteCursor) {
match ctx {
NegotiateContext::PreauthIntegrity {
hash_algorithms,
salt,
} => {
cursor.write_u16_le(hash_algorithms.len() as u16);
cursor.write_u16_le(salt.len() as u16);
for &alg in hash_algorithms {
cursor.write_u16_le(alg);
}
cursor.write_bytes(salt);
}
NegotiateContext::Encryption { ciphers } => {
cursor.write_u16_le(ciphers.len() as u16);
for &c in ciphers {
cursor.write_u16_le(c);
}
}
NegotiateContext::Compression { flags, algorithms } => {
cursor.write_u16_le(algorithms.len() as u16);
cursor.write_u16_le(0);
cursor.write_u32_le(*flags);
for &a in algorithms {
cursor.write_u16_le(a);
}
}
NegotiateContext::Signing { algorithms } => {
cursor.write_u16_le(algorithms.len() as u16);
for &a in algorithms {
cursor.write_u16_le(a);
}
}
NegotiateContext::Unknown { data, .. } => {
cursor.write_bytes(data);
}
}
}
fn context_type_id(ctx: &NegotiateContext) -> u16 {
match ctx {
NegotiateContext::PreauthIntegrity { .. } => NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY,
NegotiateContext::Encryption { .. } => NEGOTIATE_CONTEXT_ENCRYPTION,
NegotiateContext::Compression { .. } => NEGOTIATE_CONTEXT_COMPRESSION,
NegotiateContext::Signing { .. } => NEGOTIATE_CONTEXT_SIGNING,
NegotiateContext::Unknown { context_type, .. } => *context_type,
}
}
fn context_data_len(ctx: &NegotiateContext) -> usize {
match ctx {
NegotiateContext::PreauthIntegrity {
hash_algorithms,
salt,
} => 2 + 2 + hash_algorithms.len() * 2 + salt.len(),
NegotiateContext::Encryption { ciphers } => 2 + ciphers.len() * 2,
NegotiateContext::Compression { algorithms, .. } => 2 + 2 + 4 + algorithms.len() * 2,
NegotiateContext::Signing { algorithms } => 2 + algorithms.len() * 2,
NegotiateContext::Unknown { data, .. } => data.len(),
}
}
fn pack_negotiate_contexts(contexts: &[NegotiateContext], cursor: &mut WriteCursor) {
for (i, ctx) in contexts.iter().enumerate() {
if i > 0 {
cursor.align_to(8);
}
cursor.write_u16_le(context_type_id(ctx));
cursor.write_u16_le(context_data_len(ctx) as u16);
cursor.write_u32_le(0);
pack_context_data(ctx, cursor);
}
}
fn unpack_negotiate_context(cursor: &mut ReadCursor<'_>) -> Result<NegotiateContext> {
let context_type = cursor.read_u16_le()?;
let data_length = cursor.read_u16_le()? as usize;
let _reserved = cursor.read_u32_le()?;
match context_type {
NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY => {
let hash_count = cursor.read_u16_le()? as usize;
let salt_length = cursor.read_u16_le()? as usize;
let mut hash_algorithms = Vec::with_capacity(hash_count);
for _ in 0..hash_count {
hash_algorithms.push(cursor.read_u16_le()?);
}
let salt = cursor.read_bytes_bounded(salt_length)?.to_vec();
Ok(NegotiateContext::PreauthIntegrity {
hash_algorithms,
salt,
})
}
NEGOTIATE_CONTEXT_ENCRYPTION => {
let cipher_count = cursor.read_u16_le()? as usize;
let mut ciphers = Vec::with_capacity(cipher_count);
for _ in 0..cipher_count {
ciphers.push(cursor.read_u16_le()?);
}
Ok(NegotiateContext::Encryption { ciphers })
}
NEGOTIATE_CONTEXT_COMPRESSION => {
let alg_count = cursor.read_u16_le()? as usize;
let _padding = cursor.read_u16_le()?;
let flags = cursor.read_u32_le()?;
let mut algorithms = Vec::with_capacity(alg_count);
for _ in 0..alg_count {
algorithms.push(cursor.read_u16_le()?);
}
Ok(NegotiateContext::Compression { flags, algorithms })
}
NEGOTIATE_CONTEXT_SIGNING => {
let alg_count = cursor.read_u16_le()? as usize;
let mut algorithms = Vec::with_capacity(alg_count);
for _ in 0..alg_count {
algorithms.push(cursor.read_u16_le()?);
}
Ok(NegotiateContext::Signing { algorithms })
}
_ => {
let data = cursor.read_bytes_bounded(data_length)?.to_vec();
Ok(NegotiateContext::Unknown { context_type, data })
}
}
}
fn unpack_negotiate_contexts(
cursor: &mut ReadCursor<'_>,
count: usize,
) -> Result<Vec<NegotiateContext>> {
let mut contexts = Vec::with_capacity(count);
for i in 0..count {
if i > 0 {
let pos = cursor.position();
let remainder = pos % 8;
if remainder != 0 {
cursor.skip(8 - remainder)?;
}
}
contexts.push(unpack_negotiate_context(cursor)?);
}
Ok(contexts)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NegotiateRequest {
pub security_mode: SecurityMode,
pub capabilities: Capabilities,
pub client_guid: Guid,
pub dialects: Vec<Dialect>,
pub negotiate_contexts: Vec<NegotiateContext>,
}
impl NegotiateRequest {
pub const STRUCTURE_SIZE: u16 = 36;
fn has_smb311(&self) -> bool {
self.dialects.contains(&Dialect::Smb3_1_1)
}
}
impl Pack for NegotiateRequest {
fn pack(&self, cursor: &mut WriteCursor) {
let start = cursor.position();
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u16_le(self.dialects.len() as u16);
cursor.write_u16_le(self.security_mode.bits());
cursor.write_u16_le(0);
cursor.write_u32_le(self.capabilities.bits());
self.client_guid.pack(cursor);
if self.has_smb311() {
let ctx_offset_pos = cursor.position();
cursor.write_u32_le(0); cursor.write_u16_le(self.negotiate_contexts.len() as u16);
cursor.write_u16_le(0);
for &d in &self.dialects {
cursor.write_u16_le(d.into());
}
let abs_pos = Header::SIZE + (cursor.position() - start);
let remainder = abs_pos % 8;
if remainder != 0 {
cursor.write_zeros(8 - remainder);
}
let ctx_offset = Header::SIZE + (cursor.position() - start);
cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32);
pack_negotiate_contexts(&self.negotiate_contexts, cursor);
} else {
cursor.write_u64_le(0);
for &d in &self.dialects {
cursor.write_u16_le(d.into());
}
}
}
}
impl Unpack for NegotiateRequest {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let start = cursor.position();
let structure_size = cursor.read_u16_le()?;
if structure_size != Self::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid NegotiateRequest structure size: expected {}, got {}",
Self::STRUCTURE_SIZE,
structure_size
)));
}
let dialect_count = cursor.read_u16_le()? as usize;
let security_mode = SecurityMode::new(cursor.read_u16_le()?);
let _reserved = cursor.read_u16_le()?;
let capabilities = Capabilities::new(cursor.read_u32_le()?);
let client_guid = Guid::unpack(cursor)?;
let raw_8 = cursor.read_bytes(8)?;
let mut dialects = Vec::with_capacity(dialect_count);
for _ in 0..dialect_count {
let d = cursor.read_u16_le()?;
dialects.push(
Dialect::try_from(d)
.map_err(|_| Error::invalid_data(format!("invalid dialect: 0x{:04X}", d)))?,
);
}
let has_311 = dialects.contains(&Dialect::Smb3_1_1);
let negotiate_contexts = if has_311 {
let ctx_offset = u32::from_le_bytes([raw_8[0], raw_8[1], raw_8[2], raw_8[3]]) as usize;
let ctx_count = u16::from_le_bytes([raw_8[4], raw_8[5]]) as usize;
let current_abs = Header::SIZE + (cursor.position() - start);
if ctx_offset > current_abs {
cursor.skip(ctx_offset - current_abs)?;
}
unpack_negotiate_contexts(cursor, ctx_count)?
} else {
Vec::new()
};
Ok(NegotiateRequest {
security_mode,
capabilities,
client_guid,
dialects,
negotiate_contexts,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NegotiateResponse {
pub security_mode: SecurityMode,
pub dialect_revision: Dialect,
pub server_guid: Guid,
pub capabilities: Capabilities,
pub max_transact_size: u32,
pub max_read_size: u32,
pub max_write_size: u32,
pub system_time: u64,
pub server_start_time: u64,
pub security_buffer: Vec<u8>,
pub negotiate_contexts: Vec<NegotiateContext>,
}
impl NegotiateResponse {
pub const STRUCTURE_SIZE: u16 = 65;
fn is_smb311(&self) -> bool {
self.dialect_revision == Dialect::Smb3_1_1
}
}
impl Pack for NegotiateResponse {
fn pack(&self, cursor: &mut WriteCursor) {
let start = cursor.position();
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u16_le(self.security_mode.bits());
cursor.write_u16_le(self.dialect_revision.into());
if self.is_smb311() {
cursor.write_u16_le(self.negotiate_contexts.len() as u16);
} else {
cursor.write_u16_le(0);
}
self.server_guid.pack(cursor);
cursor.write_u32_le(self.capabilities.bits());
cursor.write_u32_le(self.max_transact_size);
cursor.write_u32_le(self.max_read_size);
cursor.write_u32_le(self.max_write_size);
cursor.write_u64_le(self.system_time);
cursor.write_u64_le(self.server_start_time);
let sec_buf_offset = (Header::SIZE + 64) as u16;
cursor.write_u16_le(sec_buf_offset);
cursor.write_u16_le(self.security_buffer.len() as u16);
let ctx_offset_pos = cursor.position();
cursor.write_u32_le(0);
cursor.write_bytes(&self.security_buffer);
if self.is_smb311() && !self.negotiate_contexts.is_empty() {
let abs_pos = Header::SIZE + (cursor.position() - start);
let remainder = abs_pos % 8;
if remainder != 0 {
cursor.write_zeros(8 - remainder);
}
let ctx_offset = Header::SIZE + (cursor.position() - start);
cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32);
pack_negotiate_contexts(&self.negotiate_contexts, cursor);
}
}
}
impl Unpack for NegotiateResponse {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let start = cursor.position();
let structure_size = cursor.read_u16_le()?;
if structure_size != Self::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid NegotiateResponse structure size: expected {}, got {}",
Self::STRUCTURE_SIZE,
structure_size
)));
}
let security_mode = SecurityMode::new(cursor.read_u16_le()?);
let dialect_raw = cursor.read_u16_le()?;
let dialect_revision = Dialect::try_from(dialect_raw).map_err(|_| {
Error::invalid_data(format!("invalid dialect revision: 0x{:04X}", dialect_raw))
})?;
let negotiate_context_count = cursor.read_u16_le()? as usize;
let server_guid = Guid::unpack(cursor)?;
let capabilities = Capabilities::new(cursor.read_u32_le()?);
let max_transact_size = cursor.read_u32_le()?;
let max_read_size = cursor.read_u32_le()?;
let max_write_size = cursor.read_u32_le()?;
let system_time = cursor.read_u64_le()?;
let server_start_time = cursor.read_u64_le()?;
let _sec_buf_offset = cursor.read_u16_le()?;
let sec_buf_length = cursor.read_u16_le()? as usize;
let negotiate_context_offset = cursor.read_u32_le()? as usize;
let security_buffer = if sec_buf_length > 0 {
cursor.read_bytes_bounded(sec_buf_length)?.to_vec()
} else {
Vec::new()
};
let negotiate_contexts =
if dialect_revision == Dialect::Smb3_1_1 && negotiate_context_count > 0 {
let current_abs = Header::SIZE + (cursor.position() - start);
if negotiate_context_offset > current_abs {
cursor.skip(negotiate_context_offset - current_abs)?;
}
unpack_negotiate_contexts(cursor, negotiate_context_count)?
} else {
Vec::new()
};
Ok(NegotiateResponse {
security_mode,
dialect_revision,
server_guid,
capabilities,
max_transact_size,
max_read_size,
max_write_size,
system_time,
server_start_time,
security_buffer,
negotiate_contexts,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_guid() -> Guid {
Guid {
data1: 0x6BA7B810,
data2: 0x9DAD,
data3: 0x11D1,
data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8],
}
}
#[test]
fn negotiate_request_roundtrip_without_contexts() {
let original = NegotiateRequest {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LARGE_MTU),
client_guid: sample_guid(),
dialects: vec![Dialect::Smb2_0_2, Dialect::Smb2_1, Dialect::Smb3_0],
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateRequest::unpack(&mut r).unwrap();
assert_eq!(decoded.security_mode.bits(), original.security_mode.bits());
assert_eq!(decoded.capabilities.bits(), original.capabilities.bits());
assert_eq!(decoded.client_guid, original.client_guid);
assert_eq!(decoded.dialects, original.dialects);
assert!(decoded.negotiate_contexts.is_empty());
}
#[test]
fn negotiate_request_roundtrip_with_contexts() {
let original = NegotiateRequest {
security_mode: SecurityMode::new(
SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED,
),
capabilities: Capabilities::new(
Capabilities::DFS
| Capabilities::LEASING
| Capabilities::LARGE_MTU
| Capabilities::ENCRYPTION,
),
client_guid: sample_guid(),
dialects: vec![
Dialect::Smb2_0_2,
Dialect::Smb2_1,
Dialect::Smb3_0,
Dialect::Smb3_0_2,
Dialect::Smb3_1_1,
],
negotiate_contexts: vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xDE, 0xAD, 0xBE, 0xEF],
},
NegotiateContext::Encryption {
ciphers: vec![CIPHER_AES_128_GCM, CIPHER_AES_128_CCM],
},
NegotiateContext::Signing {
algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC],
},
NegotiateContext::Compression {
flags: COMPRESSION_FLAG_CHAINED,
algorithms: vec![COMPRESSION_LZ77, COMPRESSION_LZNT1],
},
],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateRequest::unpack(&mut r).unwrap();
assert_eq!(decoded.security_mode.bits(), original.security_mode.bits());
assert_eq!(decoded.capabilities.bits(), original.capabilities.bits());
assert_eq!(decoded.client_guid, original.client_guid);
assert_eq!(decoded.dialects, original.dialects);
assert_eq!(decoded.negotiate_contexts.len(), 4);
assert_eq!(decoded.negotiate_contexts, original.negotiate_contexts);
}
#[test]
fn negotiate_request_structure_size_field() {
let req = NegotiateRequest {
security_mode: SecurityMode::default(),
capabilities: Capabilities::default(),
client_guid: Guid::ZERO,
dialects: vec![Dialect::Smb2_0_2],
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
req.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 36);
}
#[test]
fn negotiate_request_wrong_structure_size() {
let mut buf = [0u8; 48];
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = NegotiateRequest::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("structure size"), "error was: {err}");
}
#[test]
fn negotiate_request_single_dialect() {
let original = NegotiateRequest {
security_mode: SecurityMode::default(),
capabilities: Capabilities::default(),
client_guid: Guid::ZERO,
dialects: vec![Dialect::Smb3_0_2],
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateRequest::unpack(&mut r).unwrap();
assert_eq!(decoded.dialects, vec![Dialect::Smb3_0_2]);
}
#[test]
fn negotiate_request_smb311_only() {
let original = NegotiateRequest {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::default(),
client_guid: sample_guid(),
dialects: vec![Dialect::Smb3_1_1],
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
}],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateRequest::unpack(&mut r).unwrap();
assert_eq!(decoded.dialects, vec![Dialect::Smb3_1_1]);
assert_eq!(decoded.negotiate_contexts.len(), 1);
assert_eq!(
decoded.negotiate_contexts[0],
original.negotiate_contexts[0]
);
}
#[test]
fn negotiate_response_roundtrip_no_contexts() {
let original = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: Dialect::Smb3_0,
server_guid: sample_guid(),
capabilities: Capabilities::new(
Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU,
),
max_transact_size: 8_388_608,
max_read_size: 8_388_608,
max_write_size: 8_388_608,
system_time: 133_485_408_000_000_000,
server_start_time: 0,
security_buffer: vec![0x60, 0x28, 0x06, 0x06],
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.security_mode.bits(), original.security_mode.bits());
assert_eq!(decoded.dialect_revision, Dialect::Smb3_0);
assert_eq!(decoded.server_guid, original.server_guid);
assert_eq!(decoded.capabilities.bits(), original.capabilities.bits());
assert_eq!(decoded.max_transact_size, 8_388_608);
assert_eq!(decoded.max_read_size, 8_388_608);
assert_eq!(decoded.max_write_size, 8_388_608);
assert_eq!(decoded.system_time, original.system_time);
assert_eq!(decoded.server_start_time, 0);
assert_eq!(decoded.security_buffer, original.security_buffer);
assert!(decoded.negotiate_contexts.is_empty());
}
#[test]
fn negotiate_response_roundtrip_with_contexts() {
let original = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: Dialect::Smb3_1_1,
server_guid: sample_guid(),
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::ENCRYPTION),
max_transact_size: 1_048_576,
max_read_size: 1_048_576,
max_write_size: 1_048_576,
system_time: 133_485_408_000_000_000,
server_start_time: 133_000_000_000_000_000,
security_buffer: vec![0x60, 0x28],
negotiate_contexts: vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xAA, 0xBB, 0xCC, 0xDD],
},
NegotiateContext::Encryption {
ciphers: vec![CIPHER_AES_128_GCM],
},
NegotiateContext::Signing {
algorithms: vec![SIGNING_AES_GMAC],
},
],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.dialect_revision, Dialect::Smb3_1_1);
assert_eq!(decoded.negotiate_contexts.len(), 3);
assert_eq!(decoded.negotiate_contexts, original.negotiate_contexts);
assert_eq!(decoded.security_buffer, original.security_buffer);
}
#[test]
fn negotiate_response_structure_size_field() {
let resp = NegotiateResponse {
security_mode: SecurityMode::default(),
dialect_revision: Dialect::Smb2_0_2,
server_guid: Guid::ZERO,
capabilities: Capabilities::default(),
max_transact_size: 0,
max_read_size: 0,
max_write_size: 0,
system_time: 0,
server_start_time: 0,
security_buffer: Vec::new(),
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
resp.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 65);
}
#[test]
fn negotiate_response_wrong_structure_size() {
let mut buf = [0u8; 70];
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = NegotiateResponse::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("structure size"), "error was: {err}");
}
#[test]
fn negotiate_response_empty_security_buffer() {
let original = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: Dialect::Smb2_1,
server_guid: Guid::ZERO,
capabilities: Capabilities::default(),
max_transact_size: 65536,
max_read_size: 65536,
max_write_size: 65536,
system_time: 0,
server_start_time: 0,
security_buffer: Vec::new(),
negotiate_contexts: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateResponse::unpack(&mut r).unwrap();
assert!(decoded.security_buffer.is_empty());
assert_eq!(decoded.dialect_revision, Dialect::Smb2_1);
}
#[test]
fn context_preauth_integrity_roundtrip() {
let ctx = NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
};
let mut w = WriteCursor::new();
pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap();
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0], ctx);
}
#[test]
fn context_encryption_roundtrip() {
let ctx = NegotiateContext::Encryption {
ciphers: vec![
CIPHER_AES_128_GCM,
CIPHER_AES_128_CCM,
CIPHER_AES_256_GCM,
CIPHER_AES_256_CCM,
],
};
let mut w = WriteCursor::new();
pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap();
assert_eq!(decoded[0], ctx);
}
#[test]
fn context_signing_roundtrip() {
let ctx = NegotiateContext::Signing {
algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC, SIGNING_HMAC_SHA256],
};
let mut w = WriteCursor::new();
pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap();
assert_eq!(decoded[0], ctx);
}
#[test]
fn context_compression_roundtrip() {
let ctx = NegotiateContext::Compression {
flags: COMPRESSION_FLAG_CHAINED,
algorithms: vec![COMPRESSION_LZ77, COMPRESSION_LZNT1, COMPRESSION_LZ4],
};
let mut w = WriteCursor::new();
pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap();
assert_eq!(decoded[0], ctx);
}
#[test]
fn context_unknown_roundtrip() {
let ctx = NegotiateContext::Unknown {
context_type: 0x00FF,
data: vec![0x01, 0x02, 0x03, 0x04],
};
let mut w = WriteCursor::new();
pack_negotiate_contexts(std::slice::from_ref(&ctx), &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 1).unwrap();
assert_eq!(decoded[0], ctx);
}
#[test]
fn multiple_contexts_roundtrip() {
let contexts = vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xAA; 32],
},
NegotiateContext::Encryption {
ciphers: vec![CIPHER_AES_128_GCM],
},
NegotiateContext::Compression {
flags: COMPRESSION_FLAG_NONE,
algorithms: vec![COMPRESSION_NONE],
},
NegotiateContext::Signing {
algorithms: vec![SIGNING_HMAC_SHA256],
},
];
let mut w = WriteCursor::new();
pack_negotiate_contexts(&contexts, &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 4).unwrap();
assert_eq!(decoded, contexts);
}
#[test]
fn context_alignment_is_8_bytes() {
let contexts = vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0x01, 0x02, 0x03], },
NegotiateContext::Encryption {
ciphers: vec![CIPHER_AES_128_GCM],
},
];
let mut w = WriteCursor::new();
pack_negotiate_contexts(&contexts, &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, 2).unwrap();
assert_eq!(decoded, contexts);
}
}
#[cfg(test)]
mod roundtrip_props {
use super::*;
use crate::msg::roundtrip_strategies::{
arb_capabilities, arb_dialect, arb_guid, arb_security_mode, arb_small_bytes,
};
use proptest::prelude::*;
fn arb_u16_vec(max: usize) -> impl Strategy<Value = Vec<u16>> {
prop::collection::vec(any::<u16>(), 0..=max)
}
fn arb_negotiate_context() -> impl Strategy<Value = NegotiateContext> {
let known_types = [
NEGOTIATE_CONTEXT_PREAUTH_INTEGRITY,
NEGOTIATE_CONTEXT_ENCRYPTION,
NEGOTIATE_CONTEXT_COMPRESSION,
NEGOTIATE_CONTEXT_SIGNING,
];
let preauth = (arb_u16_vec(8), prop::collection::vec(any::<u8>(), 0..=64)).prop_map(
|(hash_algorithms, salt)| NegotiateContext::PreauthIntegrity {
hash_algorithms,
salt,
},
);
let encryption =
arb_u16_vec(8).prop_map(|ciphers| NegotiateContext::Encryption { ciphers });
let compression = (any::<u32>(), arb_u16_vec(8))
.prop_map(|(flags, algorithms)| NegotiateContext::Compression { flags, algorithms });
let signing =
arb_u16_vec(8).prop_map(|algorithms| NegotiateContext::Signing { algorithms });
let unknown = (any::<u16>(), prop::collection::vec(any::<u8>(), 0..=64))
.prop_filter(
"type must not collide with a known variant",
move |(t, _)| !known_types.contains(t),
)
.prop_map(|(context_type, data)| NegotiateContext::Unknown { context_type, data });
prop_oneof![preauth, encryption, compression, signing, unknown]
}
fn arb_dialects_and_contexts() -> impl Strategy<Value = (Vec<Dialect>, Vec<NegotiateContext>)> {
prop::collection::vec(arb_dialect(), 1..=5).prop_flat_map(|dialects| {
let has_311 = dialects.contains(&Dialect::Smb3_1_1);
let ctx_strat: BoxedStrategy<Vec<NegotiateContext>> = if has_311 {
prop::collection::vec(arb_negotiate_context(), 0..=4).boxed()
} else {
Just(Vec::new()).boxed()
};
(Just(dialects), ctx_strat)
})
}
proptest! {
#[test]
fn negotiate_context_list_roundtrip(
contexts in prop::collection::vec(arb_negotiate_context(), 0..=6),
) {
let mut w = WriteCursor::new();
pack_negotiate_contexts(&contexts, &mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = unpack_negotiate_contexts(&mut r, contexts.len()).unwrap();
prop_assert_eq!(decoded, contexts);
}
#[test]
fn negotiate_request_pack_unpack(
security_mode in arb_security_mode(),
capabilities in arb_capabilities(),
client_guid in arb_guid(),
(dialects, negotiate_contexts) in arb_dialects_and_contexts(),
) {
let original = NegotiateRequest {
security_mode,
capabilities,
client_guid,
dialects,
negotiate_contexts,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateRequest::unpack(&mut r).unwrap();
prop_assert_eq!(decoded, original);
}
#[test]
fn negotiate_response_pack_unpack(
security_mode in arb_security_mode(),
server_guid in arb_guid(),
capabilities in arb_capabilities(),
max_transact_size in any::<u32>(),
max_read_size in any::<u32>(),
max_write_size in any::<u32>(),
system_time in any::<u64>(),
server_start_time in any::<u64>(),
security_buffer in arb_small_bytes(),
dialect_revision in arb_dialect(),
contexts_if_311 in prop::collection::vec(arb_negotiate_context(), 0..=4),
) {
let negotiate_contexts = if dialect_revision == Dialect::Smb3_1_1 {
contexts_if_311
} else {
Vec::new()
};
let original = NegotiateResponse {
security_mode,
dialect_revision,
server_guid,
capabilities,
max_transact_size,
max_read_size,
max_write_size,
system_time,
server_start_time,
security_buffer,
negotiate_contexts,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = NegotiateResponse::unpack(&mut r).unwrap();
prop_assert_eq!(decoded, original);
}
}
}