use crate::error::Result;
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::flags::HeaderFlags;
use crate::types::status::NtStatus;
use crate::types::{Command, CreditCharge, MessageId, SessionId, TreeId};
use crate::Error;
pub const PROTOCOL_ID: [u8; 4] = [0xFE, b'S', b'M', b'B'];
#[derive(Debug, Clone)]
pub struct Header {
pub credit_charge: CreditCharge,
pub status: NtStatus,
pub command: Command,
pub credits: u16,
pub flags: HeaderFlags,
pub next_command: u32,
pub message_id: MessageId,
pub tree_id: Option<TreeId>,
pub async_id: Option<u64>,
pub session_id: SessionId,
pub signature: [u8; 16],
}
impl Header {
pub const STRUCTURE_SIZE: u16 = 64;
pub const SIZE: usize = 64;
pub fn new_request(command: Command) -> Self {
Self {
credit_charge: CreditCharge(0),
status: NtStatus::SUCCESS,
command,
credits: 1,
flags: HeaderFlags::default(),
next_command: 0,
message_id: MessageId::default(),
tree_id: Some(TreeId::default()),
async_id: None,
session_id: SessionId::default(),
signature: [0u8; 16],
}
}
pub fn is_response(&self) -> bool {
self.flags.is_response()
}
}
impl Pack for Header {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_bytes(&PROTOCOL_ID);
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u16_le(self.credit_charge.0);
cursor.write_u32_le(self.status.0);
cursor.write_u16_le(self.command.into());
cursor.write_u16_le(self.credits);
cursor.write_u32_le(self.flags.bits());
cursor.write_u32_le(self.next_command);
cursor.write_u64_le(self.message_id.0);
if self.flags.is_async() {
cursor.write_u64_le(self.async_id.unwrap_or(0));
} else {
cursor.write_u32_le(0);
cursor.write_u32_le(self.tree_id.map_or(0, |t| t.0));
}
cursor.write_u64_le(self.session_id.0);
cursor.write_bytes(&self.signature);
}
}
impl Unpack for Header {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let proto = cursor.read_bytes(4)?;
if proto != PROTOCOL_ID {
return Err(Error::invalid_data(format!(
"invalid SMB2 protocol ID: expected {:02X?}, got {:02X?}",
PROTOCOL_ID, proto
)));
}
let structure_size = cursor.read_u16_le()?;
if structure_size != Header::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid SMB2 header structure size: expected {}, got {}",
Header::STRUCTURE_SIZE,
structure_size
)));
}
let credit_charge = CreditCharge(cursor.read_u16_le()?);
let status = NtStatus(cursor.read_u32_le()?);
let command_raw = cursor.read_u16_le()?;
let command = Command::try_from(command_raw).map_err(|_| {
Error::invalid_data(format!("invalid SMB2 command code: 0x{:04X}", command_raw))
})?;
let credits = cursor.read_u16_le()?;
let flags = HeaderFlags::new(cursor.read_u32_le()?);
let next_command = cursor.read_u32_le()?;
let message_id = MessageId(cursor.read_u64_le()?);
let (tree_id, async_id) = if flags.is_async() {
let async_id = cursor.read_u64_le()?;
(None, Some(async_id))
} else {
let _reserved = cursor.read_u32_le()?;
let tree_id = TreeId(cursor.read_u32_le()?);
(Some(tree_id), None)
};
let session_id = SessionId(cursor.read_u64_le()?);
let sig_bytes = cursor.read_bytes(16)?;
let mut signature = [0u8; 16];
signature.copy_from_slice(sig_bytes);
Ok(Header {
credit_charge,
status,
command,
credits,
flags,
next_command,
message_id,
tree_id,
async_id,
session_id,
signature,
})
}
}
#[derive(Debug, Clone)]
pub struct ErrorResponse {
pub error_context_count: u8,
pub error_data: Vec<u8>,
}
impl ErrorResponse {
pub const STRUCTURE_SIZE: u16 = 9;
}
impl Pack for ErrorResponse {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u8(self.error_context_count);
cursor.write_u8(0);
cursor.write_u32_le(self.error_data.len() as u32);
cursor.write_bytes(&self.error_data);
}
}
impl Unpack for ErrorResponse {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let structure_size = cursor.read_u16_le()?;
if structure_size != Self::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid ErrorResponse structure size: expected {}, got {}",
Self::STRUCTURE_SIZE,
structure_size
)));
}
let error_context_count = cursor.read_u8()?;
let _reserved = cursor.read_u8()?;
let byte_count = cursor.read_u32_le()? as usize;
let error_data = if byte_count > 0 {
cursor.read_bytes_bounded(byte_count)?.to_vec()
} else {
Vec::new()
};
Ok(ErrorResponse {
error_context_count,
error_data,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_request_header_produces_64_bytes_with_correct_magic() {
let header = Header::new_request(Command::Negotiate);
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
let bytes = cursor.into_inner();
assert_eq!(bytes.len(), Header::SIZE);
assert_eq!(&bytes[0..4], &PROTOCOL_ID);
}
#[test]
fn unpack_known_64_byte_buffer() {
let mut buf = [0u8; 64];
buf[0..4].copy_from_slice(&PROTOCOL_ID);
buf[4..6].copy_from_slice(&64u16.to_le_bytes());
buf[6..8].copy_from_slice(&1u16.to_le_bytes());
buf[8..12].copy_from_slice(&0u32.to_le_bytes());
buf[12..14].copy_from_slice(&0u16.to_le_bytes());
buf[14..16].copy_from_slice(&31u16.to_le_bytes());
buf[16..20].copy_from_slice(&0u32.to_le_bytes());
buf[20..24].copy_from_slice(&0u32.to_le_bytes());
buf[24..32].copy_from_slice(&42u64.to_le_bytes());
buf[32..36].copy_from_slice(&0u32.to_le_bytes());
buf[36..40].copy_from_slice(&7u32.to_le_bytes());
buf[40..48].copy_from_slice(&0x1234u64.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let header = Header::unpack(&mut cursor).unwrap();
assert_eq!(header.credit_charge, CreditCharge(1));
assert_eq!(header.status, NtStatus::SUCCESS);
assert_eq!(header.command, Command::Negotiate);
assert_eq!(header.credits, 31);
assert!(!header.flags.is_async());
assert!(!header.flags.is_response());
assert_eq!(header.next_command, 0);
assert_eq!(header.message_id, MessageId(42));
assert_eq!(header.tree_id, Some(TreeId(7)));
assert_eq!(header.async_id, None);
assert_eq!(header.session_id, SessionId(0x1234));
assert_eq!(header.signature, [0u8; 16]);
}
#[test]
fn roundtrip_sync_header() {
let original = Header {
credit_charge: CreditCharge(3),
status: NtStatus::ACCESS_DENIED,
command: Command::Read,
credits: 10,
flags: {
let mut f = HeaderFlags::default();
f.set_response();
f
},
next_command: 0,
message_id: MessageId(99),
tree_id: Some(TreeId(42)),
async_id: None,
session_id: SessionId(0xDEAD_BEEF),
signature: [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
0x0F, 0x10,
],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), Header::SIZE);
let mut r = ReadCursor::new(&bytes);
let decoded = Header::unpack(&mut r).unwrap();
assert_eq!(decoded.credit_charge, original.credit_charge);
assert_eq!(decoded.status, original.status);
assert_eq!(decoded.command, original.command);
assert_eq!(decoded.credits, original.credits);
assert_eq!(decoded.flags.bits(), original.flags.bits());
assert_eq!(decoded.next_command, original.next_command);
assert_eq!(decoded.message_id, original.message_id);
assert_eq!(decoded.tree_id, original.tree_id);
assert_eq!(decoded.async_id, original.async_id);
assert_eq!(decoded.session_id, original.session_id);
assert_eq!(decoded.signature, original.signature);
}
#[test]
fn wrong_magic_bytes_returns_error() {
let mut buf = [0u8; 64];
buf[0..4].copy_from_slice(&[0xFF, b'X', b'Y', b'Z']);
buf[4..6].copy_from_slice(&64u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = Header::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("protocol ID"), "error was: {err}");
}
#[test]
fn wrong_structure_size_returns_error() {
let mut buf = [0u8; 64];
buf[0..4].copy_from_slice(&PROTOCOL_ID);
buf[4..6].copy_from_slice(&32u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = Header::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("structure size"), "error was: {err}");
}
#[test]
fn async_header_pack_unpack() {
let mut flags = HeaderFlags::default();
flags.set_async();
flags.set_response();
let original = Header {
credit_charge: CreditCharge(0),
status: NtStatus::PENDING,
command: Command::ChangeNotify,
credits: 1,
flags,
next_command: 0,
message_id: MessageId(8),
tree_id: None,
async_id: Some(0x0000_0000_0000_0008),
session_id: SessionId(0x0000_0000_0853_27D7),
signature: [0u8; 16],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), Header::SIZE);
let mut r = ReadCursor::new(&bytes);
let decoded = Header::unpack(&mut r).unwrap();
assert!(decoded.flags.is_async());
assert_eq!(decoded.async_id, Some(8));
assert_eq!(decoded.tree_id, None);
assert_eq!(decoded.command, Command::ChangeNotify);
assert_eq!(decoded.status, NtStatus::PENDING);
assert_eq!(decoded.session_id, SessionId(0x0000_0000_0853_27D7));
}
#[test]
fn sync_header_has_tree_id_and_no_async_id() {
let header = Header::new_request(Command::Create);
let mut w = WriteCursor::new();
header.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = Header::unpack(&mut r).unwrap();
assert!(!decoded.flags.is_async());
assert!(decoded.tree_id.is_some());
assert_eq!(decoded.async_id, None);
}
#[test]
fn signature_field_preserved() {
let sig = [
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
0x99, 0x00,
];
let mut header = Header::new_request(Command::Echo);
header.signature = sig;
let mut w = WriteCursor::new();
header.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = Header::unpack(&mut r).unwrap();
assert_eq!(decoded.signature, sig);
}
#[test]
fn new_request_produces_correct_defaults() {
let header = Header::new_request(Command::Write);
assert_eq!(header.command, Command::Write);
assert_eq!(header.credit_charge, CreditCharge(0));
assert_eq!(header.status, NtStatus::SUCCESS);
assert_eq!(header.credits, 1);
assert!(!header.flags.is_response());
assert!(!header.flags.is_async());
assert_eq!(header.next_command, 0);
assert_eq!(header.message_id, MessageId(0));
assert_eq!(header.tree_id, Some(TreeId(0)));
assert_eq!(header.async_id, None);
assert_eq!(header.session_id, SessionId(0));
assert_eq!(header.signature, [0u8; 16]);
assert!(!header.is_response());
}
#[test]
fn error_response_pack_unpack_empty() {
let original = ErrorResponse {
error_context_count: 0,
error_data: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), 8);
let mut r = ReadCursor::new(&bytes);
let decoded = ErrorResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.error_context_count, 0);
assert!(decoded.error_data.is_empty());
}
#[test]
fn error_response_pack_unpack_with_data() {
let data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE];
let original = ErrorResponse {
error_context_count: 1,
error_data: data.clone(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), 14);
let mut r = ReadCursor::new(&bytes);
let decoded = ErrorResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.error_context_count, 1);
assert_eq!(decoded.error_data, data);
}
#[test]
fn error_response_roundtrip() {
let original = ErrorResponse {
error_context_count: 2,
error_data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = ErrorResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.error_context_count, original.error_context_count);
assert_eq!(decoded.error_data, original.error_data);
}
}