use crate::error::Result;
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::FileId;
use crate::Error;
pub const SMB2_WRITEFLAG_WRITE_THROUGH: u32 = 0x0000_0001;
pub const SMB2_WRITEFLAG_WRITE_UNBUFFERED: u32 = 0x0000_0002;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WriteRequest {
pub data_offset: u16,
pub offset: u64,
pub file_id: FileId,
pub channel: u32,
pub remaining_bytes: u32,
pub write_channel_info_offset: u16,
pub write_channel_info_length: u16,
pub flags: u32,
pub data: Vec<u8>,
}
impl WriteRequest {
pub const STRUCTURE_SIZE: u16 = 49;
}
impl Pack for WriteRequest {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u16_le(self.data_offset);
cursor.write_u32_le(self.data.len() as u32); cursor.write_u64_le(self.offset);
cursor.write_u64_le(self.file_id.persistent);
cursor.write_u64_le(self.file_id.volatile);
cursor.write_u32_le(self.channel);
cursor.write_u32_le(self.remaining_bytes);
cursor.write_u16_le(self.write_channel_info_offset);
cursor.write_u16_le(self.write_channel_info_length);
cursor.write_u32_le(self.flags);
if self.data.is_empty() {
cursor.write_u8(0);
} else {
cursor.write_bytes(&self.data);
}
}
}
impl Unpack for WriteRequest {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let structure_size = cursor.read_u16_le()?;
if structure_size != Self::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid WriteRequest structure size: expected {}, got {}",
Self::STRUCTURE_SIZE,
structure_size
)));
}
let data_offset = cursor.read_u16_le()?;
let length = cursor.read_u32_le()?;
let offset = cursor.read_u64_le()?;
let persistent = cursor.read_u64_le()?;
let volatile = cursor.read_u64_le()?;
let channel = cursor.read_u32_le()?;
let remaining_bytes = cursor.read_u32_le()?;
let write_channel_info_offset = cursor.read_u16_le()?;
let write_channel_info_length = cursor.read_u16_le()?;
let flags = cursor.read_u32_le()?;
let data = if length > 0 {
cursor.read_bytes_bounded(length as usize)?.to_vec()
} else {
cursor.skip(1)?;
Vec::new()
};
Ok(WriteRequest {
data_offset,
offset,
file_id: FileId {
persistent,
volatile,
},
channel,
remaining_bytes,
write_channel_info_offset,
write_channel_info_length,
flags,
data,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WriteResponse {
pub count: u32,
pub remaining: u32,
pub write_channel_info_offset: u16,
pub write_channel_info_length: u16,
}
impl WriteResponse {
pub const STRUCTURE_SIZE: u16 = 17;
}
impl Pack for WriteResponse {
fn pack(&self, cursor: &mut WriteCursor) {
cursor.write_u16_le(Self::STRUCTURE_SIZE);
cursor.write_u16_le(0); cursor.write_u32_le(self.count);
cursor.write_u32_le(self.remaining);
cursor.write_u16_le(self.write_channel_info_offset);
cursor.write_u16_le(self.write_channel_info_length);
}
}
impl Unpack for WriteResponse {
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
let structure_size = cursor.read_u16_le()?;
if structure_size != Self::STRUCTURE_SIZE {
return Err(Error::invalid_data(format!(
"invalid WriteResponse structure size: expected {}, got {}",
Self::STRUCTURE_SIZE,
structure_size
)));
}
let _reserved = cursor.read_u16_le()?;
let count = cursor.read_u32_le()?;
let remaining = cursor.read_u32_le()?;
let write_channel_info_offset = cursor.read_u16_le()?;
let write_channel_info_length = cursor.read_u16_le()?;
Ok(WriteResponse {
count,
remaining,
write_channel_info_offset,
write_channel_info_length,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_request_roundtrip() {
let original = WriteRequest {
data_offset: 0x70, offset: 0x2000,
file_id: FileId {
persistent: 0xAAAA_BBBB_CCCC_DDDD,
volatile: 0x1111_2222_3333_4444,
},
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: SMB2_WRITEFLAG_WRITE_THROUGH,
data: vec![0x48, 0x65, 0x6C, 0x6C, 0x6F], };
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), 53);
let mut r = ReadCursor::new(&bytes);
let decoded = WriteRequest::unpack(&mut r).unwrap();
assert_eq!(decoded.data_offset, original.data_offset);
assert_eq!(decoded.offset, original.offset);
assert_eq!(decoded.file_id, original.file_id);
assert_eq!(decoded.channel, original.channel);
assert_eq!(decoded.remaining_bytes, original.remaining_bytes);
assert_eq!(decoded.flags, original.flags);
assert_eq!(decoded.data, original.data);
}
#[test]
fn write_request_empty_data_roundtrip() {
let original = WriteRequest {
data_offset: 0x70,
offset: 0,
file_id: FileId {
persistent: 1,
volatile: 2,
},
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: Vec::new(),
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), 49);
let mut r = ReadCursor::new(&bytes);
let decoded = WriteRequest::unpack(&mut r).unwrap();
assert!(decoded.data.is_empty());
assert_eq!(decoded.file_id, original.file_id);
}
#[test]
fn write_request_wrong_structure_size() {
let mut buf = [0u8; 49];
buf[0..2].copy_from_slice(&48u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = WriteRequest::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("structure size"), "error was: {err}");
}
#[test]
fn write_request_known_bytes() {
let mut buf = Vec::new();
buf.extend_from_slice(&49u16.to_le_bytes());
buf.extend_from_slice(&0x70u16.to_le_bytes());
buf.extend_from_slice(&2u32.to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
buf.extend_from_slice(&0x10u64.to_le_bytes());
buf.extend_from_slice(&0x20u64.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&[0xAA, 0xBB]);
let mut cursor = ReadCursor::new(&buf);
let req = WriteRequest::unpack(&mut cursor).unwrap();
assert_eq!(req.data_offset, 0x70);
assert_eq!(req.file_id.persistent, 0x10);
assert_eq!(req.file_id.volatile, 0x20);
assert_eq!(req.flags, SMB2_WRITEFLAG_WRITE_THROUGH);
assert_eq!(req.data, vec![0xAA, 0xBB]);
}
#[test]
fn write_response_roundtrip() {
let original = WriteResponse {
count: 65536,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
assert_eq!(bytes.len(), 16);
let mut r = ReadCursor::new(&bytes);
let decoded = WriteResponse::unpack(&mut r).unwrap();
assert_eq!(decoded.count, original.count);
assert_eq!(decoded.remaining, original.remaining);
assert_eq!(
decoded.write_channel_info_offset,
original.write_channel_info_offset
);
assert_eq!(
decoded.write_channel_info_length,
original.write_channel_info_length
);
}
#[test]
fn write_response_known_bytes() {
let mut buf = Vec::new();
buf.extend_from_slice(&17u16.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
buf.extend_from_slice(&1024u32.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
buf.extend_from_slice(&0u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let resp = WriteResponse::unpack(&mut cursor).unwrap();
assert_eq!(resp.count, 1024);
assert_eq!(resp.remaining, 0);
}
#[test]
fn write_response_wrong_structure_size() {
let mut buf = [0u8; 16];
buf[0..2].copy_from_slice(&16u16.to_le_bytes());
let mut cursor = ReadCursor::new(&buf);
let result = WriteResponse::unpack(&mut cursor);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("structure size"), "error was: {err}");
}
}
#[cfg(test)]
mod roundtrip_props {
use super::*;
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id};
use proptest::prelude::*;
proptest! {
#[test]
fn write_request_pack_unpack(
data_offset in any::<u16>(),
offset in any::<u64>(),
file_id in arb_file_id(),
channel in any::<u32>(),
remaining_bytes in any::<u32>(),
write_channel_info_offset in any::<u16>(),
write_channel_info_length in any::<u16>(),
flags in any::<u32>(),
data in arb_bytes(),
) {
let original = WriteRequest {
data_offset,
offset,
file_id,
channel,
remaining_bytes,
write_channel_info_offset,
write_channel_info_length,
flags,
data,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = WriteRequest::unpack(&mut r).unwrap();
prop_assert_eq!(decoded, original);
prop_assert!(r.is_empty());
}
#[test]
fn write_response_pack_unpack(
count in any::<u32>(),
remaining in any::<u32>(),
write_channel_info_offset in any::<u16>(),
write_channel_info_length in any::<u16>(),
) {
let original = WriteResponse {
count,
remaining,
write_channel_info_offset,
write_channel_info_length,
};
let mut w = WriteCursor::new();
original.pack(&mut w);
let bytes = w.into_inner();
let mut r = ReadCursor::new(&bytes);
let decoded = WriteResponse::unpack(&mut r).unwrap();
prop_assert_eq!(decoded, original);
prop_assert!(r.is_empty());
}
}
}