use crate::error::Result;
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::SessionId;
use crate::Error;
pub const TRANSFORM_PROTOCOL_ID: [u8; 4] = [0xFD, b'S', b'M', b'B'];
pub const COMPRESSION_PROTOCOL_ID: [u8; 4] = [0xFC, b'S', b'M', b'B'];
pub const SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED: u16 = 0x0001;
pub const COMPRESSION_ALGORITHM_NONE: u16 = 0x0000;
pub const COMPRESSION_ALGORITHM_LZNT1: u16 = 0x0001;
pub const COMPRESSION_ALGORITHM_LZ77: u16 = 0x0002;
pub const COMPRESSION_ALGORITHM_LZ77_HUFFMAN: u16 = 0x0003;
pub const COMPRESSION_ALGORITHM_PATTERN_V1: u16 = 0x0004;
pub const COMPRESSION_ALGORITHM_LZ4: u16 = 0x0005;
pub const SMB2_COMPRESSION_FLAG_NONE: u16 = 0x0000;
pub const SMB2_COMPRESSION_FLAG_CHAINED: u16 = 0x0001;
#[derive(Debug, Clone)]
pub struct TransformHeader {
pub signature: [u8; 16],
pub nonce: [u8; 16],
pub original_message_size: u32,
pub flags: u16,
pub session_id: SessionId,
}
impl TransformHeader {
pub const SIZE: usize = 52;
}
impl Pack for TransformHeader {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_bytes(&TRANSFORM_PROTOCOL_ID);
cursor.write_bytes(&self.signature);
cursor.write_bytes(&self.nonce);
cursor.write_u32_le(self.original_message_size);
cursor.write_u16_le(0);
cursor.write_u16_le(self.flags);
cursor.write_u64_le(self.session_id.0);
}
}
impl Unpack for TransformHeader {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let proto = cursor.read_bytes(4)?;
if proto != TRANSFORM_PROTOCOL_ID {
return Err(Error::invalid_data(format!(
"invalid transform header protocol ID: expected {:02X?}, got {:02X?}",
TRANSFORM_PROTOCOL_ID, proto
)));
}
let sig_bytes = cursor.read_bytes(16)?;
let mut signature = [0u8; 16];
signature.copy_from_slice(sig_bytes);
let nonce_bytes = cursor.read_bytes(16)?;
let mut nonce = [0u8; 16];
nonce.copy_from_slice(nonce_bytes);
let original_message_size = cursor.read_u32_le()?;
let _reserved = cursor.read_u16_le()?;
let flags = cursor.read_u16_le()?;
let session_id = SessionId(cursor.read_u64_le()?);
Ok(TransformHeader {
signature,
nonce,
original_message_size,
flags,
session_id,
})
}
}
#[derive(Debug, Clone)]
pub struct CompressionTransformHeader {
pub original_compressed_segment_size: u32,
pub compression_algorithm: u16,
pub flags: u16,
pub offset_or_length: u32,
}
impl CompressionTransformHeader {
pub const SIZE: usize = 16;
}
impl Pack for CompressionTransformHeader {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_bytes(&COMPRESSION_PROTOCOL_ID);
cursor.write_u32_le(self.original_compressed_segment_size);
cursor.write_u16_le(self.compression_algorithm);
cursor.write_u16_le(self.flags);
cursor.write_u32_le(self.offset_or_length);
}
}
impl Unpack for CompressionTransformHeader {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let proto = cursor.read_bytes(4)?;
if proto != COMPRESSION_PROTOCOL_ID {
return Err(Error::invalid_data(format!(
"invalid compression transform header protocol ID: expected {:02X?}, got {:02X?}",
COMPRESSION_PROTOCOL_ID, proto
)));
}
let original_compressed_segment_size = cursor.read_u32_le()?;
let compression_algorithm = cursor.read_u16_le()?;
let flags = cursor.read_u16_le()?;
let offset_or_length = cursor.read_u32_le()?;
Ok(CompressionTransformHeader {
original_compressed_segment_size,
compression_algorithm,
flags,
offset_or_length,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transform_header_roundtrip() {
let mut nonce = [0u8; 16];
nonce[0..12].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
let original = TransformHeader {
signature: [
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
0x99, 0x00,
],
nonce,
original_message_size: 1024,
flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED,
session_id: SessionId(0xDEAD_BEEF_CAFE_FACE),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), TransformHeader::SIZE);
let mut r = ReadCursor::new(&bytes);
let decoded = TransformHeader::unpack(&mut r).unwrap();
assert_eq!(decoded.signature, original.signature);
assert_eq!(decoded.nonce, original.nonce);
assert_eq!(decoded.original_message_size, 1024);
assert_eq!(decoded.flags, SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED);
assert_eq!(decoded.session_id, SessionId(0xDEAD_BEEF_CAFE_FACE));
}
#[test]
fn transform_header_protocol_id_is_0xfd() {
let original = TransformHeader {
signature: [0u8; 16],
nonce: [0u8; 16],
original_message_size: 0,
flags: 0,
session_id: SessionId(0),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes[0], 0xFD);
assert_eq!(bytes[1], b'S');
assert_eq!(bytes[2], b'M');
assert_eq!(bytes[3], b'B');
assert_ne!(bytes[0], 0xFE, "transform header must use 0xFD, not 0xFE");
}
#[test]
fn transform_header_wrong_protocol_id() {
let mut buf = [0u8; TransformHeader::SIZE];
buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']);
let mut cursor = ReadCursor::new(&buf);
let result = TransformHeader::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("protocol ID"), "error was: {err}");
}
#[test]
fn compression_transform_header_roundtrip_unchained() {
let original = CompressionTransformHeader {
original_compressed_segment_size: 4096,
compression_algorithm: COMPRESSION_ALGORITHM_LZ77,
flags: SMB2_COMPRESSION_FLAG_NONE,
offset_or_length: 64,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), CompressionTransformHeader::SIZE);
let mut r = ReadCursor::new(&bytes);
let decoded = CompressionTransformHeader::unpack(&mut r).unwrap();
assert_eq!(decoded.original_compressed_segment_size, 4096);
assert_eq!(decoded.compression_algorithm, COMPRESSION_ALGORITHM_LZ77);
assert_eq!(decoded.flags, SMB2_COMPRESSION_FLAG_NONE);
assert_eq!(decoded.offset_or_length, 64);
}
#[test]
fn compression_transform_header_protocol_id_is_0xfc() {
let original = CompressionTransformHeader {
original_compressed_segment_size: 0,
compression_algorithm: COMPRESSION_ALGORITHM_NONE,
flags: SMB2_COMPRESSION_FLAG_NONE,
offset_or_length: 0,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes[0], 0xFC);
assert_eq!(bytes[1], b'S');
assert_eq!(bytes[2], b'M');
assert_eq!(bytes[3], b'B');
assert_ne!(
bytes[0], 0xFE,
"compression transform header must use 0xFC, not 0xFE"
);
}
#[test]
fn compression_transform_header_wrong_protocol_id() {
let mut buf = [0u8; CompressionTransformHeader::SIZE];
buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']);
let mut cursor = ReadCursor::new(&buf);
let result = CompressionTransformHeader::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("protocol ID"), "error was: {err}");
}
#[test]
fn compression_transform_header_lz77_huffman() {
let original = CompressionTransformHeader {
original_compressed_segment_size: 8192,
compression_algorithm: COMPRESSION_ALGORITHM_LZ77_HUFFMAN,
flags: SMB2_COMPRESSION_FLAG_NONE,
offset_or_length: 128,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = CompressionTransformHeader::unpack(&mut r).unwrap();
assert_eq!(
decoded.compression_algorithm,
COMPRESSION_ALGORITHM_LZ77_HUFFMAN
);
assert_eq!(decoded.original_compressed_segment_size, 8192);
}
}