use crate::{
CbSmbFile, CbSmbFileStart, CbSmbFileStop, DirConfirmFn, Parser, ParserFactory, ParserFuture,
PktStrm, Prolens, SMB_PORT, SharedStateManager, packet::*,
};
use byteorder::{BigEndian, ByteOrder, LittleEndian};
use std::ffi::c_void;
use std::marker::PhantomData;
pub struct SmbParser<T>
where
T: Packet,
{
version: ValidVer,
encrypt: EncryptMng,
cb_file_start: Option<CbSmbFileStart>,
cb_file: Option<CbSmbFile>,
cb_file_stop: Option<CbSmbFileStop>,
_phantom_t: PhantomData<T>,
}
impl<T> SmbParser<T>
where
T: Packet,
{
pub(crate) fn new() -> Self {
Self {
version: ValidVer::new(),
encrypt: EncryptMng::new(),
cb_file_start: None,
cb_file: None,
cb_file_stop: None,
_phantom_t: PhantomData,
}
}
async fn c2s_parser_inner(
strm: *mut PktStrm<T>,
cb_smb: SmbCallbacks,
cb_ctx: *mut c_void,
valid_ver: ValidVer,
encrypt: EncryptMng,
) -> Result<(), ()> {
let stm = unsafe { &mut *strm };
if !valid_ver.wait_for().await {
return Err(());
}
if encrypt.wait_for().await {
return Err(());
}
loop {
let (byte, _seq) = stm.readn(4).await?;
let netbios_len = BigEndian::read_u24(&byte[1..4]) as usize;
let start_size = stm.get_read_size();
loop {
let header = Self::header(stm).await?;
match header.command {
command::READ => Self::read_req(stm, header, &cb_smb, cb_ctx).await?,
command::WRITE => Self::write_req(stm, header, &cb_smb, cb_ctx).await?,
_ => Self::consume(stm, netbios_len - HEADER_SIZE as usize).await?,
}
let cmd_size = stm.get_read_size().saturating_sub(start_size);
if cmd_size == netbios_len {
break;
}
}
}
}
async fn s2c_parser_inner(
strm: *mut PktStrm<T>,
cb_smb: SmbCallbacks,
cb_ctx: *mut c_void,
valid_ver: ValidVer,
encrypt: EncryptMng,
) -> Result<(), ()> {
let stm = unsafe { &mut *strm };
loop {
let (byte, _seq) = stm.readn(4).await?;
let netbios_len = BigEndian::read_u24(&byte[1..4]) as usize;
let start_size = stm.get_read_size();
loop {
let header = Self::header(stm).await?;
match header.command {
command::NEGOTIATE => {
let ver_valiad = Self::peek_negotiate_rsp(stm).await?;
if !ver_valiad {
valid_ver.set(false).map_err(|_| ())?;
return Err(());
}
valid_ver.set(true).map_err(|_| ())?;
Self::consume(stm, netbios_len - HEADER_SIZE as usize).await?
}
command::SESSION_SETUP => {
let ecypt = Self::session_setup_rsp(stm).await?;
if ecypt {
encrypt.set(true).map_err(|_| ())?;
return Err(());
}
encrypt.set(false).map_err(|_| ())?;
}
command::READ => Self::read_rsp(stm, header, &cb_smb, cb_ctx).await?,
_ => Self::consume(stm, netbios_len - HEADER_SIZE as usize).await?,
}
let cmd_size = stm.get_read_size().saturating_sub(start_size);
if cmd_size == netbios_len {
break;
}
}
}
}
async fn header(stm: &mut PktStrm<T>) -> Result<SmbHeader, ()> {
let (byte, _seq) = stm.readn(HEADER_SIZE as usize).await?;
let protocol_id = LittleEndian::read_u32(&byte[0..4]);
let structure_size = LittleEndian::read_u16(&byte[4..6]);
let credit_charge = LittleEndian::read_u16(&byte[6..8]);
let status = LittleEndian::read_u32(&byte[8..12]);
let command = LittleEndian::read_u16(&byte[12..14]);
let credit = LittleEndian::read_u16(&byte[14..16]);
let flags = LittleEndian::read_u32(&byte[16..20]);
let next_command = LittleEndian::read_u32(&byte[20..24]);
let message_id = LittleEndian::read_u64(&byte[24..32]);
let reserved = LittleEndian::read_u32(&byte[32..36]);
let tree_id = LittleEndian::read_u32(&byte[36..40]);
let session_id = LittleEndian::read_u64(&byte[40..48]);
let signature = LittleEndian::read_u128(&byte[48..64]);
Ok(SmbHeader {
protocol_id,
structure_size,
credit_charge,
status,
command,
credit,
flags,
next_command,
message_id,
reserved,
tree_id,
session_id,
signature,
})
}
async fn read_req(
stm: &mut PktStrm<T>,
header: SmbHeader,
cb_smb: &SmbCallbacks,
cb_ctx: *mut c_void,
) -> Result<(), ()> {
let (bytes, _seq) = stm.readn(48).await?;
let len: u32 = LittleEndian::read_u32(&bytes[4..8]);
let off: u64 = LittleEndian::read_u64(&bytes[8..16]);
let fid: u128 = LittleEndian::read_u128(&bytes[16..32]);
if let Some(ref cb) = cb_smb.file_start {
cb.borrow_mut()(&header, len, off, fid, cb_ctx);
}
let channel_len: u16 = LittleEndian::read_u16(&bytes[46..48]);
if channel_len == 0 {
let (_bytes, _seq) = stm.readn(1).await?;
} else {
Self::consume(stm, channel_len.into()).await?;
}
Ok(())
}
async fn write_req(
stm: &mut PktStrm<T>,
header: SmbHeader,
cb_smb: &SmbCallbacks,
cb_ctx: *mut c_void,
) -> Result<(), ()> {
let (bytes, _seq) = stm.readn(48).await?;
let len: u32 = LittleEndian::read_u32(&bytes[4..8]);
let off: u64 = LittleEndian::read_u64(&bytes[8..16]);
let fid: u128 = LittleEndian::read_u128(&bytes[16..32]);
if let Some(ref cb) = cb_smb.file_start {
cb.borrow_mut()(&header, len, off, fid, cb_ctx);
}
if len == 0 {
let (_bytes, _seq) = stm.readn(1).await?;
} else {
let mut remain_size = len as usize;
while remain_size > 0 {
let (bytes, seq) = stm.read(remain_size).await?;
if let Some(ref cb) = cb_smb.file {
cb.borrow_mut()(bytes, seq, cb_ctx);
}
remain_size -= bytes.len();
}
if let Some(ref cb) = cb_smb.file_stop {
cb.borrow_mut()(&header, cb_ctx);
}
}
Ok(())
}
async fn peek_negotiate_rsp(stm: &mut PktStrm<T>) -> Result<bool, ()> {
let bytes = stm.peekn(6).await?;
let ver: u16 = LittleEndian::read_u16(&bytes[4..6]);
let is_valid = matches!(ver, 0x0202 | 0x0210 | 0x0300 | 0x0302 | 0x0311);
Ok(is_valid)
}
async fn session_setup_rsp(stm: &mut PktStrm<T>) -> Result<bool, ()> {
let (bytes, _seq) = stm.readn(8).await?;
let flags: u16 = LittleEndian::read_u16(&bytes[2..4]);
let buff_len: u16 = LittleEndian::read_u16(&bytes[6..8]);
Self::consume(stm, buff_len.into()).await?;
let ret = (flags & SESSION_FLAG_ENCRYPT_DATA) == SESSION_FLAG_ENCRYPT_DATA;
Ok(ret)
}
async fn read_rsp(
stm: &mut PktStrm<T>,
header: SmbHeader,
cb_smb: &SmbCallbacks,
cb_ctx: *mut c_void,
) -> Result<(), ()> {
let (bytes, _seq) = stm.readn(16).await?;
let data_len: u32 = LittleEndian::read_u32(&bytes[4..8]);
if data_len == 0 {
let (_bytes, _seq) = stm.readn(1).await?;
} else {
let mut remain_size = data_len as usize;
while remain_size > 0 {
let (bytes, seq) = stm.read(remain_size).await?;
if let Some(ref cb) = cb_smb.file {
cb.borrow_mut()(bytes, seq, cb_ctx);
}
remain_size -= bytes.len();
}
if let Some(ref cb) = cb_smb.file_stop {
cb.borrow_mut()(&header, cb_ctx);
}
}
Ok(())
}
async fn consume(stm: &mut PktStrm<T>, size: usize) -> Result<(), ()> {
let mut remain_size = size;
while remain_size > 0 {
let (bytes, _seq) = stm.read(remain_size).await?;
remain_size -= bytes.len();
}
Ok(())
}
}
impl<T> Parser for SmbParser<T>
where
T: Packet + 'static,
{
type T = T;
fn dir_confirm(&self) -> DirConfirmFn<Self::T> {
|c2s_strm, s2c_strm, c2s_port, s2c_port| {
let stm_c2s = unsafe { &mut *c2s_strm };
let stm_s2c = unsafe { &mut *s2c_strm };
if s2c_port == SMB_PORT {
return Some(true);
} else if c2s_port == SMB_PORT {
return Some(false);
}
let payload_c2s = stm_c2s.peek_payload();
let payload_s2c = stm_s2c.peek_payload();
if payload_c2s.is_err() && payload_s2c.is_err() {
return None;
}
if let Ok(payload) = payload_c2s {
if payload.len() >= 32 {
let flags = LittleEndian::read_u32(&payload[28..32]);
if (flags & HEADER_FLAGS_SERVER_TO_REDIR) == 0 {
return Some(true);
} else {
return Some(false);
}
}
}
if let Ok(payload) = payload_s2c {
if payload.len() >= 32 {
let flags = LittleEndian::read_u32(&payload[28..32]);
if (flags & HEADER_FLAGS_SERVER_TO_REDIR) == 1 {
return Some(true);
} else {
return Some(false);
}
}
}
Some(true)
}
}
fn c2s_parser(&self, strm: *mut PktStrm<T>, cb_ctx: *mut c_void) -> Option<ParserFuture> {
let cb = SmbCallbacks {
file_start: self.cb_file_start.clone(),
file: self.cb_file.clone(),
file_stop: self.cb_file_stop.clone(),
};
let valid_ver = self.version.clone();
let encrypt = self.encrypt.clone();
Some(Box::pin(Self::c2s_parser_inner(
strm, cb, cb_ctx, valid_ver, encrypt,
)))
}
fn s2c_parser(&self, strm: *mut PktStrm<T>, cb_ctx: *mut c_void) -> Option<ParserFuture> {
let cb = SmbCallbacks {
file_start: self.cb_file_start.clone(),
file: self.cb_file.clone(),
file_stop: self.cb_file_stop.clone(),
};
let valid_ver = self.version.clone();
let encrypt = self.encrypt.clone();
Some(Box::pin(Self::s2c_parser_inner(
strm, cb, cb_ctx, valid_ver, encrypt,
)))
}
}
pub(crate) struct SmbFactory<T> {
_phantom_t: PhantomData<T>,
}
impl<T> ParserFactory<T> for SmbFactory<T>
where
T: Packet + 'static,
{
fn new() -> Self {
Self {
_phantom_t: PhantomData,
}
}
fn create(&self, prolens: &Prolens<T>) -> Box<dyn Parser<T = T>> {
let mut parser = Box::new(SmbParser::new());
parser.cb_file_start = prolens.cb_smb_file_start.clone();
parser.cb_file = prolens.cb_smb_file.clone();
parser.cb_file_stop = prolens.cb_smb_file_stop.clone();
parser
}
}
#[derive(Clone)]
pub(crate) struct SmbCallbacks {
file_start: Option<CbSmbFileStart>,
file: Option<CbSmbFile>,
file_stop: Option<CbSmbFileStop>,
}
pub(crate) type ValidVer = SharedStateManager<bool>;
pub(crate) type EncryptMng = SharedStateManager<bool>;
const HEADER_FLAGS_SERVER_TO_REDIR: u32 = 0x00000001;
const HEADER_SIZE: u16 = 64;
const SESSION_FLAG_ENCRYPT_DATA: u16 = 0x0004;
#[rustfmt::skip]
mod command {
pub(crate) const NEGOTIATE: u16 = 0x0000;
pub(crate) const SESSION_SETUP: u16 = 0x0001;
pub(crate) const READ: u16 = 0x0008;
pub(crate) const WRITE: u16 = 0x0009;
}
#[derive(Debug, Clone)]
pub struct SmbHeader {
pub protocol_id: u32,
pub structure_size: u16,
pub credit_charge: u16,
pub status: u32,
pub command: u16,
pub credit: u16,
pub flags: u32,
pub next_command: u32,
pub message_id: u64,
pub reserved: u32,
pub tree_id: u32,
pub session_id: u64,
pub signature: u128,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::env;
use std::ffi::c_void;
use std::rc::Rc;
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn test_smb_parser() {
let project_root = env::current_dir().unwrap();
let file_path = project_root.join("tests/pcap/smb3.pcap");
let mut cap = Capture::init(file_path).unwrap();
let expected_content_path = project_root.join("tests/pcap/smb_read.txt");
let expected_content =
std::fs::read_to_string(expected_content_path).expect("Failed to read smb_read.txt");
let captured_file_starts = Rc::new(RefCell::new(Vec::<(SmbHeader, u32, u64, u128)>::new()));
let captured_file_data = Rc::new(RefCell::new(HashMap::<u128, Vec<u8>>::new()));
let captured_file_stops = Rc::new(RefCell::new(Vec::<SmbHeader>::new()));
let current_fid = Rc::new(RefCell::new(None::<u128>));
let file_start_callback = {
let starts_clone = captured_file_starts.clone();
let fid_clone = current_fid.clone();
let data_clone = captured_file_data.clone();
move |header: &SmbHeader, len: u32, offset: u64, fid: u128, _cb_ctx: *mut c_void| {
let mut starts_guard = starts_clone.borrow_mut();
starts_guard.push((header.clone(), len, offset, fid));
let mut fid_guard = fid_clone.borrow_mut();
*fid_guard = Some(fid);
let mut data_guard = data_clone.borrow_mut();
data_guard.entry(fid).or_default();
println!(
"SMB File Start: command=0x{:04x}, len={}, offset={}, fid=0x{:032x}",
header.command, len, offset, fid
);
}
};
let file_data_callback = {
let data_clone = captured_file_data.clone();
let fid_clone = current_fid.clone();
move |data: &[u8], _seq: u32, _cb_ctx: *mut c_void| {
let fid_guard = fid_clone.borrow();
if let Some(fid) = *fid_guard {
let mut data_guard = data_clone.borrow_mut();
if let Some(file_data) = data_guard.get_mut(&fid) {
file_data.extend_from_slice(data);
}
}
println!("SMB File Data: {} bytes received", data.len());
}
};
let file_stop_callback = {
let stops_clone = captured_file_stops.clone();
let fid_clone = current_fid.clone();
move |header: &SmbHeader, _cb_ctx: *mut c_void| {
let mut stops_guard = stops_clone.borrow_mut();
stops_guard.push(header.clone());
let mut fid_guard = fid_clone.borrow_mut();
*fid_guard = None;
println!("SMB File Stop: command=0x{:04x}\n", header.command);
}
};
let mut protolens = Prolens::<CapPacket>::default();
protolens.set_cb_smb_file_start(file_start_callback);
protolens.set_cb_smb_file(file_data_callback);
protolens.set_cb_smb_file_stop(file_stop_callback);
let mut task = protolens.new_task(TransProto::Tcp);
protolens.set_task_parser(&mut task, L7Proto::Smb);
loop {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
let pkt = cap.next_packet(now);
if pkt.is_none() {
break;
}
let pkt = pkt.unwrap();
if pkt.decode().is_err() {
continue;
}
protolens.run_task(&mut task, pkt);
}
let starts_guard = captured_file_starts.borrow();
let data_guard = captured_file_data.borrow();
let stops_guard = captured_file_stops.borrow();
println!("Captured {} file start events", starts_guard.len());
let total_data_bytes: usize = data_guard.values().map(|v| v.len()).sum();
println!(
"Captured {} bytes of file data across {} files",
total_data_bytes,
data_guard.len()
);
println!("Captured {} file stop events", stops_guard.len());
assert_eq!(5, starts_guard.len());
assert_eq!(3975, total_data_bytes);
assert_eq!(5, stops_guard.len());
let read_operations: Vec<_> = starts_guard
.iter()
.filter(|(header, _, _, _)| header.command == command::READ)
.collect();
assert!(
!read_operations.is_empty(),
"Should have at least one READ operation"
);
println!("Found {} READ operations", read_operations.len());
let target_fid = 0x000000006843c46e0000000076045775u128;
let large_read_op = starts_guard.iter().find(|(header, len, offset, fid)| {
header.command == command::READ && *len == 3519 && *offset == 0 && *fid == target_fid
});
assert!(
large_read_op.is_some(),
"Should find the large READ operation with len=3519 and target FID"
);
println!(
"Found target large READ operation: len=3519, fid=0x{:032x}",
target_fid
);
if let Some(large_file_data) = data_guard.get(&target_fid) {
let received_content = String::from_utf8_lossy(large_file_data);
println!("Large file content length: {}", received_content.len());
println!("Expected file content length: {}", expected_content.len());
assert_eq!(
received_content.trim(),
expected_content.trim(),
"Large file read content should match expected content from smb_read.txt"
);
println!("✓ Large file content verification passed");
} else {
panic!("No data found for the target large file read operation");
}
assert_eq!(
starts_guard.len(),
stops_guard.len(),
"File start and stop events should be paired"
);
println!("\nFile operations summary:");
for (i, (header, len, offset, fid)) in starts_guard.iter().enumerate() {
let data_len = data_guard.get(fid).map(|v| v.len()).unwrap_or(0);
println!(
" {}: command=0x{:04x}, len={}, offset={}, fid=0x{:032x}, actual_data={} bytes",
i + 1,
header.command,
len,
offset,
fid,
data_len
);
}
}
}