use std::ops::ControlFlow;
use std::sync::Arc;
use log::{debug, info, trace, warn};
use crate::client::connection::{CompoundOp, Connection};
use crate::client::stream::{FileDownload, Progress};
use crate::error::Result;
use crate::msg::close::CloseRequest;
use crate::msg::create::{
CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess,
};
use crate::msg::flush::FlushRequest;
use crate::msg::query_directory::{
FileInformationClass, QueryDirectoryFlags, QueryDirectoryRequest, QueryDirectoryResponse,
};
use crate::msg::query_info::{InfoType, QueryInfoRequest, QueryInfoResponse};
use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE};
use crate::msg::set_info::SetInfoRequest;
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
use crate::msg::tree_disconnect::TreeDisconnectRequest;
use crate::msg::write::{WriteRequest, WriteResponse};
use crate::pack::{FileTime, ReadCursor, Unpack};
use crate::types::flags::FileAccessMask;
use crate::types::status::NtStatus;
#[cfg(test)]
use crate::types::MessageId;
use crate::types::{Command, CreditCharge, FileId, OplockLevel, TreeId};
use crate::Error;
const MAX_PIPELINE_WINDOW: usize = 32;
fn all_or_first_err(
frames: Vec<Result<crate::client::connection::Frame>>,
) -> Result<Vec<crate::client::connection::Frame>> {
let mut out = Vec::with_capacity(frames.len());
for r in frames {
out.push(r?);
}
Ok(out)
}
const FILE_ATTRIBUTE_DIRECTORY: u32 = 0x0000_0010;
const FILE_DIRECTORY_FILE: u32 = 0x0000_0001;
const FILE_NON_DIRECTORY_FILE: u32 = 0x0000_0040;
const FILE_DELETE_ON_CLOSE: u32 = 0x0000_1000;
const FILE_BASIC_INFORMATION: u8 = 4;
const FILE_STANDARD_INFORMATION: u8 = 5;
const FILE_RENAME_INFORMATION: u8 = 10;
const FILE_FS_FULL_SIZE_INFORMATION: u8 = 7;
#[derive(Debug, Clone)]
pub struct DirectoryEntry {
pub name: String,
pub size: u64,
pub is_directory: bool,
pub created: FileTime,
pub modified: FileTime,
}
#[derive(Debug, Clone)]
pub struct FileInfo {
pub size: u64,
pub is_directory: bool,
pub created: FileTime,
pub modified: FileTime,
pub accessed: FileTime,
}
#[derive(Debug, Clone)]
pub struct FsInfo {
pub total_bytes: u64,
pub free_bytes: u64,
pub total_free_bytes: u64,
pub bytes_per_sector: u32,
pub sectors_per_unit: u32,
}
#[derive(Clone)]
pub struct Tree {
pub tree_id: TreeId,
pub share_name: String,
pub server: String,
pub is_dfs: bool,
pub encrypt_data: bool,
}
impl Tree {
pub async fn connect(conn: &mut Connection, share_name: &str) -> Result<Tree> {
let server = conn.server_name().to_string();
let unc_path = format!(r"\\{}\{}", server, share_name);
let req = TreeConnectRequest {
flags: TreeConnectRequestFlags::default(),
path: unc_path,
};
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
if frame.header.command != Command::TreeConnect {
return Err(Error::invalid_data(format!(
"expected TreeConnect response, got {:?}",
frame.header.command
)));
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeConnect,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = TreeConnectResponse::unpack(&mut cursor)?;
let tree_id = frame
.header
.tree_id
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
let is_dfs = resp
.capabilities
.contains(crate::types::flags::ShareCapabilities::DFS);
let encrypt_data = resp
.share_flags
.contains(crate::types::flags::ShareFlags::ENCRYPT_DATA);
info!("tree: connected share={}, tree_id={}", share_name, tree_id);
debug!("tree: is_dfs={}, encrypt_data={}", is_dfs, encrypt_data);
if is_dfs {
conn.register_dfs_tree(tree_id);
}
Ok(Tree {
tree_id,
share_name: share_name.to_string(),
server: server.clone(),
is_dfs,
encrypt_data,
})
}
pub(crate) fn format_path(&self, path: &str) -> String {
let normalized = normalize_path(path);
if self.is_dfs {
let hostname = self.server.split(':').next().unwrap_or(&self.server);
if normalized.is_empty() {
format!("{}\\{}", hostname, self.share_name)
} else {
format!("{}\\{}\\{}", hostname, self.share_name, normalized)
}
} else {
normalized
}
}
pub async fn list_directory(
&self,
conn: &mut Connection,
path: &str,
) -> Result<Vec<DirectoryEntry>> {
let normalized = self.format_path(path);
debug!("tree: list_directory path={}", normalized);
let file_id = self.open_directory(conn, &normalized).await?;
let result = self.query_directory_loop(conn, file_id).await;
let close_result = self.close_handle(conn, file_id).await;
let entries = result?;
close_result?;
debug!("tree: list_directory done, entries={}", entries.len());
Ok(entries)
}
pub async fn read_file_compound(&self, conn: &mut Connection, path: &str) -> Result<Vec<u8>> {
let normalized = self.format_path(path);
let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536);
debug!(
"tree: read_file_compound path={}, max_read={}",
normalized, max_read
);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_DATA
| FileAccessMask::FILE_READ_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: normalized.clone(),
create_contexts: vec![],
};
let read_credit_charge = (max_read as u64).div_ceil(65536) as u16;
let read_req = ReadRequest {
padding: 0x50,
flags: 0,
length: max_read,
offset: 0,
file_id: FileId::SENTINEL,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Read,
body: &read_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(read_credit_charge),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let read_header = &responses[1].header;
let read_body = &responses[1].body;
let close_header = &responses[2].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
let file_id = create_resp.file_id;
if read_header.status != NtStatus::SUCCESS && read_header.status != NtStatus::END_OF_FILE {
debug!(
"tree: compound READ failed ({:?}), issuing standalone CLOSE",
read_header.status
);
let _ = self.close_handle(conn, file_id).await;
return Err(Error::Protocol {
status: read_header.status,
command: Command::Read,
});
}
let data = if read_header.status == NtStatus::END_OF_FILE {
Vec::new()
} else {
let mut cursor = ReadCursor::new(read_body);
let read_resp = ReadResponse::unpack(&mut cursor)?;
read_resp.data
};
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound CLOSE returned {:?} (non-fatal, data already read)",
close_header.status,
);
}
debug!("tree: read_file_compound done, read {} bytes", data.len());
Ok(data)
}
pub async fn read_file(&self, conn: &mut Connection, path: &str) -> Result<Vec<u8>> {
self.read_file_compound(conn, path).await
}
pub async fn disconnect(&self, conn: &mut Connection) -> Result<()> {
debug!(
"tree: disconnecting share={}, tree_id={}",
self.share_name, self.tree_id
);
let body = TreeDisconnectRequest;
let frame = conn
.execute(Command::TreeDisconnect, &body, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::TreeDisconnect,
});
}
conn.deregister_dfs_tree(self.tree_id);
info!(
"tree: disconnected share={}, tree_id={}",
self.share_name, self.tree_id
);
Ok(())
}
pub async fn watch(
&self,
conn: &mut Connection,
path: &str,
recursive: bool,
) -> Result<crate::client::watcher::Watcher> {
let normalized = self.format_path(path);
debug!(
"tree: watch path={}, recursive={}, tree_id={}",
normalized, recursive, self.tree_id
);
let file_id = self.open_directory(conn, &normalized).await?;
Ok(crate::client::watcher::Watcher::new(
self.clone(),
conn.clone(),
file_id,
recursive,
))
}
pub async fn delete_file(&self, conn: &mut Connection, path: &str) -> Result<()> {
self.delete_compound(conn, path, FILE_NON_DIRECTORY_FILE, "file")
.await
}
pub async fn delete_files(&self, conn: &mut Connection, paths: &[&str]) -> Vec<Result<()>> {
if paths.is_empty() {
return vec![];
}
debug!("tree: delete_files batch, count={}", paths.len());
let mut results: Vec<Result<()>> = Vec::with_capacity(paths.len());
let mut cleanup_handles: Vec<FileId> = Vec::new();
for path in paths {
let normalized = self.format_path(path);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: FILE_DELETE_ON_CLOSE | FILE_NON_DIRECTORY_FILE,
name: normalized,
create_contexts: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let frames = match conn.execute_compound(&ops).await {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let responses = match all_or_first_err(frames) {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let close_header = &responses[1].header;
if create_header.status != NtStatus::SUCCESS {
results.push(Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
}));
} else if close_header.status != NtStatus::SUCCESS {
if let Ok(create_resp) = CreateResponse::unpack(&mut ReadCursor::new(create_body)) {
cleanup_handles.push(create_resp.file_id);
}
results.push(Err(Error::Protocol {
status: close_header.status,
command: Command::Close,
}));
} else {
info!("tree: batch deleted file={}", path);
results.push(Ok(()));
}
}
for file_id in &cleanup_handles {
warn!(
"tree: batch delete cleanup, issuing standalone CLOSE for {:?}",
file_id
);
let _ = self.close_handle(conn, *file_id).await;
}
debug!(
"tree: delete_files batch done, {}/{} succeeded",
results.iter().filter(|r| r.is_ok()).count(),
paths.len()
);
results
}
pub async fn stat(&self, conn: &mut Connection, path: &str) -> Result<FileInfo> {
let normalized = self.format_path(path);
debug!("tree: stat (compound) path={}", normalized);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: normalized.clone(),
create_contexts: vec![],
};
let basic_req = QueryInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_BASIC_INFORMATION,
output_buffer_length: 40,
additional_information: 0,
flags: 0,
file_id: FileId::SENTINEL,
input_buffer: vec![],
};
let std_req = QueryInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_STANDARD_INFORMATION,
output_buffer_length: 24,
additional_information: 0,
flags: 0,
file_id: FileId::SENTINEL,
input_buffer: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::QueryInfo,
body: &basic_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::QueryInfo,
body: &std_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let basic_header = &responses[1].header;
let basic_body = &responses[1].body;
let std_header = &responses[2].header;
let std_body = &responses[2].body;
let close_header = &responses[3].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
if !basic_header.status.is_success_or_partial() {
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
warn!(
"tree: compound QUERY_INFO (basic) failed ({:?}), issuing standalone CLOSE",
basic_header.status
);
let _ = self.close_handle(conn, create_resp.file_id).await;
return Err(Error::Protocol {
status: basic_header.status,
command: Command::QueryInfo,
});
}
if basic_header.status == NtStatus::BUFFER_OVERFLOW {
warn!("recv: STATUS_BUFFER_OVERFLOW on FileBasicInformation, response data may be truncated");
}
let mut cursor = ReadCursor::new(basic_body);
let basic_resp = QueryInfoResponse::unpack(&mut cursor)?;
let basic_buf = &basic_resp.output_buffer;
if basic_buf.len() < 36 {
return Err(Error::invalid_data(format!(
"FileBasicInformation too short: {} bytes",
basic_buf.len()
)));
}
let created = FileTime(u64::from_le_bytes(basic_buf[0..8].try_into().unwrap()));
let accessed = FileTime(u64::from_le_bytes(basic_buf[8..16].try_into().unwrap()));
let modified = FileTime(u64::from_le_bytes(basic_buf[16..24].try_into().unwrap()));
let _change_time = u64::from_le_bytes(basic_buf[24..32].try_into().unwrap());
let file_attributes = u32::from_le_bytes(basic_buf[32..36].try_into().unwrap());
if !std_header.status.is_success_or_partial() {
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
warn!(
"tree: compound QUERY_INFO (standard) failed ({:?}), issuing standalone CLOSE",
std_header.status
);
let _ = self.close_handle(conn, create_resp.file_id).await;
return Err(Error::Protocol {
status: std_header.status,
command: Command::QueryInfo,
});
}
if std_header.status == NtStatus::BUFFER_OVERFLOW {
warn!("recv: STATUS_BUFFER_OVERFLOW on FileStandardInformation, response data may be truncated");
}
let mut cursor = ReadCursor::new(std_body);
let std_resp = QueryInfoResponse::unpack(&mut cursor)?;
let std_buf = &std_resp.output_buffer;
if std_buf.len() < 22 {
return Err(Error::invalid_data(format!(
"FileStandardInformation too short: {} bytes",
std_buf.len()
)));
}
let _allocation_size = u64::from_le_bytes(std_buf[0..8].try_into().unwrap());
let end_of_file = u64::from_le_bytes(std_buf[8..16].try_into().unwrap());
let _number_of_links = u32::from_le_bytes(std_buf[16..20].try_into().unwrap());
let _delete_pending = std_buf[20];
let is_directory_byte = std_buf[21];
let is_directory =
is_directory_byte != 0 || (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0;
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound CLOSE returned {:?} (non-fatal, stat data already read)",
close_header.status,
);
}
debug!(
"tree: stat done, size={}, is_dir={}",
end_of_file, is_directory
);
Ok(FileInfo {
size: end_of_file,
is_directory,
created,
modified,
accessed,
})
}
pub async fn stat_files(&self, conn: &mut Connection, paths: &[&str]) -> Vec<Result<FileInfo>> {
if paths.is_empty() {
return vec![];
}
debug!("tree: stat_files batch, count={}", paths.len());
let mut results: Vec<Result<FileInfo>> = Vec::with_capacity(paths.len());
let mut cleanup_handles: Vec<FileId> = Vec::new();
for path in paths {
let normalized = self.format_path(path);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: normalized,
create_contexts: vec![],
};
let basic_req = QueryInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_BASIC_INFORMATION,
output_buffer_length: 40,
additional_information: 0,
flags: 0,
file_id: FileId::SENTINEL,
input_buffer: vec![],
};
let std_req = QueryInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_STANDARD_INFORMATION,
output_buffer_length: 24,
additional_information: 0,
flags: 0,
file_id: FileId::SENTINEL,
input_buffer: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::QueryInfo,
body: &basic_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::QueryInfo,
body: &std_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let frames = match conn.execute_compound(&ops).await {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let responses = match all_or_first_err(frames) {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let parsed = self.parse_stat_batch_response(&responses, &mut cleanup_handles);
if parsed.is_ok() {
debug!("tree: batch stat done for file={}", path);
}
results.push(parsed);
}
for file_id in &cleanup_handles {
warn!(
"tree: batch stat cleanup, issuing standalone CLOSE for {:?}",
file_id
);
let _ = self.close_handle(conn, *file_id).await;
}
debug!(
"tree: stat_files batch done, {}/{} succeeded",
results.iter().filter(|r| r.is_ok()).count(),
paths.len()
);
results
}
fn parse_stat_batch_response(
&self,
responses: &[crate::client::connection::Frame],
cleanup_handles: &mut Vec<FileId>,
) -> Result<FileInfo> {
debug_assert_eq!(
responses.len(),
4,
"stat compound must have 4 sub-responses"
);
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let basic_header = &responses[1].header;
let basic_body = &responses[1].body;
let std_header = &responses[2].header;
let std_body = &responses[2].body;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
let file_id = CreateResponse::unpack(&mut ReadCursor::new(create_body))
.map(|r| r.file_id)
.ok();
if !basic_header.status.is_success_or_partial() {
if let Some(fid) = file_id {
cleanup_handles.push(fid);
}
return Err(Error::Protocol {
status: basic_header.status,
command: Command::QueryInfo,
});
}
let mut cursor = ReadCursor::new(basic_body);
let basic_resp = QueryInfoResponse::unpack(&mut cursor)?;
let basic_buf = &basic_resp.output_buffer;
if basic_buf.len() < 36 {
if let Some(fid) = file_id {
cleanup_handles.push(fid);
}
return Err(Error::invalid_data(format!(
"FileBasicInformation too short: {} bytes",
basic_buf.len()
)));
}
let created = FileTime(u64::from_le_bytes(basic_buf[0..8].try_into().unwrap()));
let accessed = FileTime(u64::from_le_bytes(basic_buf[8..16].try_into().unwrap()));
let modified = FileTime(u64::from_le_bytes(basic_buf[16..24].try_into().unwrap()));
let file_attributes = u32::from_le_bytes(basic_buf[32..36].try_into().unwrap());
if !std_header.status.is_success_or_partial() {
if let Some(fid) = file_id {
cleanup_handles.push(fid);
}
return Err(Error::Protocol {
status: std_header.status,
command: Command::QueryInfo,
});
}
let mut cursor = ReadCursor::new(std_body);
let std_resp = QueryInfoResponse::unpack(&mut cursor)?;
let std_buf = &std_resp.output_buffer;
if std_buf.len() < 22 {
if let Some(fid) = file_id {
cleanup_handles.push(fid);
}
return Err(Error::invalid_data(format!(
"FileStandardInformation too short: {} bytes",
std_buf.len()
)));
}
let end_of_file = u64::from_le_bytes(std_buf[8..16].try_into().unwrap());
let is_directory_byte = std_buf[21];
let is_directory =
is_directory_byte != 0 || (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0;
Ok(FileInfo {
size: end_of_file,
is_directory,
created,
modified,
accessed,
})
}
pub async fn fs_info(&self, conn: &mut Connection) -> Result<FsInfo> {
debug!("tree: fs_info on share={}", self.share_name);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: FILE_DIRECTORY_FILE,
name: String::new(), create_contexts: vec![],
};
let query_req = QueryInfoRequest {
info_type: InfoType::Filesystem,
file_info_class: FILE_FS_FULL_SIZE_INFORMATION,
output_buffer_length: 32, additional_information: 0,
flags: 0,
file_id: FileId::SENTINEL,
input_buffer: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::QueryInfo,
body: &query_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let query_header = &responses[1].header;
let query_body = &responses[1].body;
let close_header = &responses[2].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
if !query_header.status.is_success_or_partial() {
let mut cursor = ReadCursor::new(&responses[0].body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
debug!(
"tree: compound QUERY_INFO failed ({:?}), issuing standalone CLOSE",
query_header.status
);
let _ = self.close_handle(conn, create_resp.file_id).await;
return Err(Error::Protocol {
status: query_header.status,
command: Command::QueryInfo,
});
}
if query_header.status == NtStatus::BUFFER_OVERFLOW {
warn!("recv: STATUS_BUFFER_OVERFLOW on FileFsFullSizeInformation, response data may be truncated");
}
let mut cursor = ReadCursor::new(query_body);
let query_resp = QueryInfoResponse::unpack(&mut cursor)?;
let buf = &query_resp.output_buffer;
if buf.len() < 32 {
return Err(Error::invalid_data(format!(
"FileFsFullSizeInformation too short: {} bytes",
buf.len()
)));
}
let total_allocation_units = i64::from_le_bytes(buf[0..8].try_into().unwrap()) as u64;
let caller_available_units = i64::from_le_bytes(buf[8..16].try_into().unwrap()) as u64;
let actual_available_units = i64::from_le_bytes(buf[16..24].try_into().unwrap()) as u64;
let sectors_per_unit = u32::from_le_bytes(buf[24..28].try_into().unwrap());
let bytes_per_sector = u32::from_le_bytes(buf[28..32].try_into().unwrap());
let bytes_per_unit = sectors_per_unit as u64 * bytes_per_sector as u64;
let total_bytes = total_allocation_units * bytes_per_unit;
let free_bytes = caller_available_units * bytes_per_unit;
let total_free_bytes = actual_available_units * bytes_per_unit;
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound CLOSE returned {:?} (non-fatal, fs_info already read)",
close_header.status,
);
}
debug!(
"tree: fs_info done, total={}, free={}, total_free={}",
total_bytes, free_bytes, total_free_bytes
);
Ok(FsInfo {
total_bytes,
free_bytes,
total_free_bytes,
bytes_per_sector,
sectors_per_unit,
})
}
pub async fn rename(&self, conn: &mut Connection, from: &str, to: &str) -> Result<()> {
let from_normalized = self.format_path(from);
let to_normalized = normalize_path(to);
debug!(
"tree: rename (compound) from={} to={}",
from_normalized, to_normalized
);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: from_normalized.clone(),
create_contexts: vec![],
};
let setinfo_req = SetInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_RENAME_INFORMATION,
additional_information: 0,
file_id: FileId::SENTINEL,
buffer: build_rename_info_buffer(&to_normalized),
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::SetInfo,
body: &setinfo_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let setinfo_header = &responses[1].header;
let close_header = &responses[2].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
if setinfo_header.status != NtStatus::SUCCESS {
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
warn!(
"tree: compound SET_INFO failed ({:?}), issuing standalone CLOSE",
setinfo_header.status
);
let _ = self.close_handle(conn, create_resp.file_id).await;
return Err(Error::Protocol {
status: setinfo_header.status,
command: Command::SetInfo,
});
}
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound CLOSE returned {:?} (non-fatal, rename already done)",
close_header.status,
);
}
info!(
"tree: renamed from={} to={}",
from_normalized, to_normalized
);
Ok(())
}
pub async fn rename_files(
&self,
conn: &mut Connection,
renames: &[(&str, &str)],
) -> Vec<Result<()>> {
if renames.is_empty() {
return vec![];
}
debug!("tree: rename_files batch, count={}", renames.len());
let mut results: Vec<Result<()>> = Vec::with_capacity(renames.len());
let mut cleanup_handles: Vec<FileId> = Vec::new();
for (from, to) in renames {
let from_normalized = self.format_path(from);
let to_normalized = normalize_path(to);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: from_normalized,
create_contexts: vec![],
};
let setinfo_req = SetInfoRequest {
info_type: InfoType::File,
file_info_class: FILE_RENAME_INFORMATION,
additional_information: 0,
file_id: FileId::SENTINEL,
buffer: build_rename_info_buffer(&to_normalized),
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::SetInfo,
body: &setinfo_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let frames = match conn.execute_compound(&ops).await {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let responses = match all_or_first_err(frames) {
Ok(v) => v,
Err(e) => {
results.push(Err(e));
continue;
}
};
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let setinfo_header = &responses[1].header;
let close_header = &responses[2].header;
if create_header.status != NtStatus::SUCCESS {
results.push(Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
}));
} else if setinfo_header.status != NtStatus::SUCCESS {
if let Ok(create_resp) = CreateResponse::unpack(&mut ReadCursor::new(create_body)) {
cleanup_handles.push(create_resp.file_id);
}
results.push(Err(Error::Protocol {
status: setinfo_header.status,
command: Command::SetInfo,
}));
} else {
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: batch rename CLOSE returned {:?} (non-fatal)",
close_header.status,
);
}
info!("tree: batch renamed from={} to={}", from, to);
results.push(Ok(()));
}
}
for file_id in &cleanup_handles {
warn!(
"tree: batch rename cleanup, issuing standalone CLOSE for {:?}",
file_id
);
let _ = self.close_handle(conn, *file_id).await;
}
debug!(
"tree: rename_files batch done, {}/{} succeeded",
results.iter().filter(|r| r.is_ok()).count(),
renames.len()
);
results
}
pub async fn write_file_compound(
&self,
conn: &mut Connection,
path: &str,
data: &[u8],
) -> Result<u64> {
let normalized = self.format_path(path);
debug!(
"tree: write_file_compound path={}, len={}",
normalized,
data.len()
);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_WRITE_DATA
| FileAccessMask::FILE_WRITE_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0x80, share_access: ShareAccess(0),
create_disposition: CreateDisposition::FileOverwriteIf,
create_options: FILE_NON_DIRECTORY_FILE,
name: normalized.clone(),
create_contexts: vec![],
};
let write_credit_charge = (data.len() as u64).div_ceil(65536).max(1) as u16;
let write_req = WriteRequest {
data_offset: 0x70,
offset: 0,
file_id: FileId::SENTINEL,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: data.to_vec(),
};
let flush_req = FlushRequest {
file_id: FileId::SENTINEL,
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Write,
body: &write_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(write_credit_charge),
},
CompoundOp {
command: Command::Flush,
body: &flush_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let write_header = &responses[1].header;
let write_body = &responses[1].body;
let flush_header = &responses[2].header;
let close_header = &responses[3].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
let file_id = create_resp.file_id;
if write_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound WRITE failed ({:?}), issuing standalone CLOSE",
write_header.status
);
let _ = self.close_handle(conn, file_id).await;
return Err(Error::Protocol {
status: write_header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(write_body);
let write_resp = WriteResponse::unpack(&mut cursor)?;
let bytes_written = write_resp.count as u64;
if flush_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound FLUSH returned {:?} (data written but may not be persisted)",
flush_header.status,
);
}
if close_header.status != NtStatus::SUCCESS {
debug!(
"tree: compound CLOSE returned {:?} (non-fatal, data already written)",
close_header.status,
);
}
debug!(
"tree: write_file_compound done, wrote {} bytes",
bytes_written
);
Ok(bytes_written)
}
pub async fn write_file(&self, conn: &mut Connection, path: &str, data: &[u8]) -> Result<u64> {
let max_write = conn
.params()
.map(|p| p.max_write_size as usize)
.unwrap_or(65536);
if data.len() <= max_write {
self.write_file_compound(conn, path, data).await
} else {
self.write_file_pipelined(conn, path, data).await
}
}
pub async fn read_file_pipelined(&self, conn: &mut Connection, path: &str) -> Result<Vec<u8>> {
let normalized = self.format_path(path);
let (file_id, file_size) = self.open_file(conn, &normalized).await?;
if file_size == 0 {
debug!(
"tree: read_file_pipelined path={}, size=0 (empty file)",
normalized
);
self.close_handle(conn, file_id).await?;
return Ok(Vec::new());
}
let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536);
let pipeline_chunk = 512 * 1024_u32; let chunk_size = if file_size <= max_read as u64 {
(file_size as u32).min(max_read)
} else {
pipeline_chunk.min(max_read)
};
let credit_charge = chunk_size.div_ceil(65536) as u16;
let total_chunks = file_size.div_ceil(chunk_size as u64) as usize;
debug!(
"tree: read_file_pipelined path={}, size={}, chunk_size={}, credit_charge={}, total_chunks={}, credits={}",
normalized, file_size, chunk_size, credit_charge, total_chunks, conn.credits()
);
let start = std::time::Instant::now();
let result = self
.read_pipelined_loop(
conn,
file_id,
file_size,
chunk_size,
credit_charge,
total_chunks,
)
.await;
let close_result = self.close_handle(conn, file_id).await;
let data = result?;
close_result?;
let elapsed = start.elapsed();
let mb = data.len() as f64 / (1024.0 * 1024.0);
let mbps = if elapsed.as_secs_f64() > 0.0 {
mb / elapsed.as_secs_f64()
} else {
0.0
};
debug!(
"tree: read_file_pipelined done, read {} bytes in {:.2?} ({:.1} MB/s)",
data.len(),
elapsed,
mbps
);
Ok(data)
}
pub async fn read_file_pipelined_with_progress<F>(
&self,
conn: &mut Connection,
path: &str,
mut on_progress: F,
) -> Result<Vec<u8>>
where
F: FnMut(Progress) -> ControlFlow<()>,
{
let normalized = self.format_path(path);
let (file_id, file_size) = self.open_file(conn, &normalized).await?;
if file_size == 0 {
debug!(
"tree: read_file_pipelined_with_progress path={}, size=0 (empty file)",
normalized
);
self.close_handle(conn, file_id).await?;
let _ = on_progress(Progress {
bytes_transferred: 0,
total_bytes: Some(0),
});
return Ok(Vec::new());
}
let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536);
let pipeline_chunk = 512 * 1024_u32;
let chunk_size = if file_size <= max_read as u64 {
(file_size as u32).min(max_read)
} else {
pipeline_chunk.min(max_read)
};
let credit_charge = chunk_size.div_ceil(65536) as u16;
let total_chunks = file_size.div_ceil(chunk_size as u64) as usize;
debug!(
"tree: read_file_pipelined_with_progress path={}, size={}, chunk_size={}, total_chunks={}",
normalized, file_size, chunk_size, total_chunks
);
let result = self
.read_pipelined_loop_with_progress(
conn,
file_id,
file_size,
chunk_size,
credit_charge,
total_chunks,
&mut on_progress,
)
.await;
let close_result = self.close_handle(conn, file_id).await;
let data = result?;
close_result?;
debug!(
"tree: read_file_pipelined_with_progress done, read {} bytes",
data.len()
);
Ok(data)
}
pub async fn download<'a>(
&'a self,
conn: &'a mut Connection,
path: &str,
) -> Result<FileDownload<'a>> {
let normalized = path.replace('/', "\\");
let normalized = normalized.trim_start_matches('\\');
let (file_id, file_size) = self.open_file(conn, normalized).await?;
let chunk_size = conn.params().map(|p| p.max_read_size).unwrap_or(65536);
Ok(FileDownload::new(
self, conn, file_id, file_size, chunk_size,
))
}
pub async fn write_file_pipelined(
&self,
conn: &mut Connection,
path: &str,
data: &[u8],
) -> Result<u64> {
let normalized = self.format_path(path);
if data.is_empty() {
debug!(
"tree: write_file_pipelined path={}, len=0 (empty write)",
normalized
);
return self.write_file_compound(conn, path, data).await;
}
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_WRITE_DATA
| FileAccessMask::FILE_WRITE_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0x80, share_access: ShareAccess(0),
create_disposition: CreateDisposition::FileOverwriteIf,
create_options: FILE_NON_DIRECTORY_FILE,
name: normalized.clone(),
create_contexts: vec![],
};
let frame = conn
.execute(Command::Create, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
let file_id = create_resp.file_id;
let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536);
let chunk_size = max_write;
let credit_charge = chunk_size.div_ceil(65536) as u16;
let total_chunks = data.len().div_ceil(chunk_size as usize);
debug!(
"tree: write_file_pipelined path={}, len={}, chunk_size={}, credit_charge={}, total_chunks={}, credits={}",
normalized, data.len(), chunk_size, credit_charge, total_chunks, conn.credits()
);
let start = std::time::Instant::now();
let result = self
.write_pipelined_loop(conn, file_id, data, chunk_size, credit_charge, total_chunks)
.await;
if result.is_ok() {
self.flush_handle(conn, file_id).await?;
}
let close_result = self.close_handle(conn, file_id).await;
let bytes_written = result?;
close_result?;
let elapsed = start.elapsed();
let mb = bytes_written as f64 / (1024.0 * 1024.0);
let mbps = if elapsed.as_secs_f64() > 0.0 {
mb / elapsed.as_secs_f64()
} else {
0.0
};
debug!(
"tree: write_file_pipelined done, wrote {} bytes in {:.2?} ({:.1} MB/s)",
bytes_written, elapsed, mbps
);
Ok(bytes_written)
}
pub async fn write_file_streamed<F>(
&self,
conn: &mut Connection,
path: &str,
next_chunk: &mut F,
) -> Result<u64>
where
F: FnMut() -> Option<std::result::Result<Vec<u8>, std::io::Error>>,
{
let normalized = self.format_path(path);
debug!("tree: write_file_streamed path={}", normalized);
let file_id = self.open_file_for_write(conn, &normalized).await?;
let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536);
let start = std::time::Instant::now();
let result = self
.write_streamed_loop(conn, file_id, next_chunk, max_write)
.await;
if result.is_ok() {
self.flush_handle(conn, file_id).await?;
}
let close_result = self.close_handle(conn, file_id).await;
let bytes_written = result?;
close_result?;
let elapsed = start.elapsed();
let mb = bytes_written as f64 / (1024.0 * 1024.0);
let mbps = if elapsed.as_secs_f64() > 0.0 {
mb / elapsed.as_secs_f64()
} else {
0.0
};
debug!(
"tree: write_file_streamed done, wrote {} bytes in {:.2?} ({:.1} MB/s)",
bytes_written, elapsed, mbps
);
Ok(bytes_written)
}
pub async fn create_file_writer(
self: &Arc<Self>,
conn: Connection,
path: &str,
) -> Result<super::stream::FileWriter> {
super::stream::open_file_writer(Arc::clone(self), conn, path).await
}
pub async fn create_directory(&self, conn: &mut Connection, path: &str) -> Result<()> {
let normalized = self.format_path(path);
debug!("tree: create_directory path={}", normalized);
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_ATTRIBUTES | FileAccessMask::SYNCHRONIZE,
),
file_attributes: FILE_ATTRIBUTE_DIRECTORY,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileCreate,
create_options: FILE_DIRECTORY_FILE,
name: normalized.clone(),
create_contexts: vec![],
};
let frame = conn
.execute(Command::Create, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
let file_id = create_resp.file_id;
self.close_handle(conn, file_id).await?;
info!("tree: created directory={}", normalized);
Ok(())
}
pub async fn delete_directory(&self, conn: &mut Connection, path: &str) -> Result<()> {
self.delete_compound(conn, path, FILE_DIRECTORY_FILE, "directory")
.await
}
async fn delete_compound(
&self,
conn: &mut Connection,
path: &str,
type_option: u32,
kind: &str,
) -> Result<()> {
let normalized = self.format_path(path);
debug!("tree: delete_{} (compound) path={}", kind, normalized);
let create_req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::DELETE | FileAccessMask::FILE_READ_ATTRIBUTES,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: FILE_DELETE_ON_CLOSE | type_option,
name: normalized.clone(),
create_contexts: vec![],
};
let close_req = CloseRequest {
flags: 0,
file_id: FileId::SENTINEL,
};
let ops = [
CompoundOp {
command: Command::Create,
body: &create_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
CompoundOp {
command: Command::Close,
body: &close_req,
tree_id: Some(self.tree_id),
credit_charge: CreditCharge(1),
},
];
let responses = all_or_first_err(conn.execute_compound(&ops).await?)?;
let create_header = &responses[0].header;
let create_body = &responses[0].body;
let close_header = &responses[1].header;
if create_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: create_header.status,
command: Command::Create,
});
}
if close_header.status != NtStatus::SUCCESS {
let mut cursor = ReadCursor::new(create_body);
let create_resp = CreateResponse::unpack(&mut cursor)?;
warn!(
"tree: compound CLOSE failed ({:?}), issuing standalone CLOSE",
close_header.status
);
let _ = self.close_handle(conn, create_resp.file_id).await;
return Err(Error::Protocol {
status: close_header.status,
command: Command::Close,
});
}
info!("tree: deleted {}={}", kind, normalized);
Ok(())
}
async fn open_directory(&self, conn: &mut Connection, path: &str) -> Result<FileId> {
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_DATA
| FileAccessMask::FILE_READ_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: FILE_DIRECTORY_FILE,
name: path.to_string(),
create_contexts: vec![],
};
let frame = conn
.execute(Command::Create, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = CreateResponse::unpack(&mut cursor)?;
Ok(resp.file_id)
}
pub async fn open_file(&self, conn: &mut Connection, path: &str) -> Result<(FileId, u64)> {
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_READ_DATA
| FileAccessMask::FILE_READ_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0,
share_access: ShareAccess(
ShareAccess::FILE_SHARE_READ
| ShareAccess::FILE_SHARE_WRITE
| ShareAccess::FILE_SHARE_DELETE,
),
create_disposition: CreateDisposition::FileOpen,
create_options: 0,
name: path.to_string(),
create_contexts: vec![],
};
let frame = conn
.execute(Command::Create, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = CreateResponse::unpack(&mut cursor)?;
Ok((resp.file_id, resp.end_of_file))
}
pub(crate) async fn open_file_for_write(
&self,
conn: &mut Connection,
path: &str,
) -> Result<FileId> {
let req = CreateRequest {
requested_oplock_level: OplockLevel::None,
impersonation_level: ImpersonationLevel::Impersonation,
desired_access: FileAccessMask::new(
FileAccessMask::FILE_WRITE_DATA
| FileAccessMask::FILE_WRITE_ATTRIBUTES
| FileAccessMask::SYNCHRONIZE,
),
file_attributes: 0x80, share_access: ShareAccess(0),
create_disposition: CreateDisposition::FileOverwriteIf,
create_options: FILE_NON_DIRECTORY_FILE,
name: path.to_string(),
create_contexts: vec![],
};
let frame = conn
.execute(Command::Create, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Create,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = CreateResponse::unpack(&mut cursor)?;
Ok(resp.file_id)
}
async fn query_directory_loop(
&self,
conn: &mut Connection,
file_id: FileId,
) -> Result<Vec<DirectoryEntry>> {
let max_output = conn
.params()
.map(|p| p.max_transact_size.min(65536))
.unwrap_or(65536);
let mut all_entries = Vec::new();
let mut first = true;
loop {
let req = QueryDirectoryRequest {
file_information_class: FileInformationClass::FileBothDirectoryInformation,
flags: QueryDirectoryFlags(if first {
QueryDirectoryFlags::RESTART_SCANS
} else {
0
}),
file_index: 0,
file_id,
output_buffer_length: max_output,
file_name: "*".to_string(),
};
first = false;
let frame = conn
.execute(Command::QueryDirectory, &req, Some(self.tree_id))
.await?;
if frame.header.status == NtStatus::NO_MORE_FILES {
break;
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::QueryDirectory,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = QueryDirectoryResponse::unpack(&mut cursor)?;
let entries = parse_file_both_directory_info(&resp.output_buffer)?;
for e in &entries {
trace!(
"tree: dir_entry name={}, size={}, is_dir={}",
e.name,
e.size,
e.is_directory
);
}
all_entries.extend(entries);
}
Ok(all_entries)
}
#[allow(dead_code)] async fn read_loop(
&self,
conn: &mut Connection,
file_id: FileId,
file_size: u64,
) -> Result<Vec<u8>> {
let max_read = conn.params().map(|p| p.max_read_size).unwrap_or(65536);
let mut data = Vec::with_capacity(file_size as usize);
let mut offset = 0u64;
loop {
let remaining = file_size.saturating_sub(offset);
if remaining == 0 {
break;
}
let chunk_size = remaining.min(max_read as u64) as u32;
let req = ReadRequest {
padding: 0x50,
flags: 0,
length: chunk_size,
offset,
file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
};
let frame = conn
.execute(Command::Read, &req, Some(self.tree_id))
.await?;
if frame.header.status == NtStatus::END_OF_FILE {
break;
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Read,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ReadResponse::unpack(&mut cursor)?;
if resp.data.is_empty() {
break;
}
offset += resp.data.len() as u64;
data.extend_from_slice(&resp.data);
}
Ok(data)
}
async fn read_pipelined_loop(
&self,
conn: &mut Connection,
file_id: FileId,
file_size: u64,
chunk_size: u32,
credit_charge: u16,
total_chunks: usize,
) -> Result<Vec<u8>> {
use futures_util::stream::{FuturesUnordered, StreamExt};
let mut data = vec![0u8; file_size as usize];
let mut chunks_sent = 0usize;
let mut chunks_received = 0usize;
let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize;
let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW);
if initial_window == 0 {
return Err(Error::invalid_data(
"no credits available for pipelined read",
));
}
debug!(
"tree: pipeline read sliding window: initial_window={}, total_chunks={}, credits={}",
initial_window,
total_chunks,
conn.credits()
);
let mut in_flight = FuturesUnordered::new();
let build_req = |chunk_index: usize| -> ReadRequest {
let offset = chunk_index as u64 * chunk_size as u64;
let this_chunk = if chunk_index == total_chunks - 1 {
(file_size - offset) as u32
} else {
chunk_size
};
ReadRequest {
padding: 0x50,
flags: 0,
length: this_chunk,
offset,
file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
}
};
let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| -> _ {
let c = conn.clone();
let req = build_req(chunk_index);
async move {
let frame = c
.execute_with_credits(
Command::Read,
&req,
Some(tree_id),
CreditCharge(credit_charge),
)
.await;
(chunk_index, frame)
}
};
for _ in 0..initial_window {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
while chunks_received < total_chunks {
let Some((chunk_index, frame_result)) = in_flight.next().await else {
break;
};
chunks_received += 1;
let frame = frame_result?;
if frame.header.status == NtStatus::END_OF_FILE {
continue;
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Read,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ReadResponse::unpack(&mut cursor)?;
if !resp.data.is_empty() {
let dest_offset = chunk_index as u64 * chunk_size as u64;
let dest_end = (dest_offset as usize + resp.data.len()).min(data.len());
let src_len = dest_end - dest_offset as usize;
data[dest_offset as usize..dest_end].copy_from_slice(&resp.data[..src_len]);
}
if chunks_sent < total_chunks {
let credits_available = conn.credits() as usize / credit_charge.max(1) as usize;
if credits_available > 0 {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
}
}
Ok(data)
}
async fn read_pipelined_loop_with_progress<F>(
&self,
conn: &mut Connection,
file_id: FileId,
file_size: u64,
chunk_size: u32,
credit_charge: u16,
total_chunks: usize,
on_progress: &mut F,
) -> Result<Vec<u8>>
where
F: FnMut(Progress) -> ControlFlow<()>,
{
use futures_util::stream::{FuturesUnordered, StreamExt};
let mut data = vec![0u8; file_size as usize];
let mut chunks_sent = 0usize;
let mut chunks_received = 0usize;
let mut bytes_received = 0u64;
let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize;
let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW);
if initial_window == 0 {
return Err(Error::invalid_data(
"no credits available for pipelined read",
));
}
let mut in_flight = FuturesUnordered::new();
let build_req = |chunk_index: usize| -> ReadRequest {
let offset = chunk_index as u64 * chunk_size as u64;
let this_chunk = if chunk_index == total_chunks - 1 {
(file_size - offset) as u32
} else {
chunk_size
};
ReadRequest {
padding: 0x50,
flags: 0,
length: this_chunk,
offset,
file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
}
};
let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| {
let c = conn.clone();
let req = build_req(chunk_index);
async move {
let frame = c
.execute_with_credits(
Command::Read,
&req,
Some(tree_id),
CreditCharge(credit_charge),
)
.await;
(chunk_index, frame)
}
};
for _ in 0..initial_window {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
while chunks_received < total_chunks {
let Some((chunk_index, frame_result)) = in_flight.next().await else {
break;
};
chunks_received += 1;
let frame = frame_result?;
if frame.header.status == NtStatus::END_OF_FILE {
continue;
}
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Read,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = ReadResponse::unpack(&mut cursor)?;
if !resp.data.is_empty() {
let dest_offset = chunk_index as u64 * chunk_size as u64;
let dest_end = (dest_offset as usize + resp.data.len()).min(data.len());
let src_len = dest_end - dest_offset as usize;
data[dest_offset as usize..dest_end].copy_from_slice(&resp.data[..src_len]);
bytes_received += src_len as u64;
}
let progress = Progress {
bytes_transferred: bytes_received,
total_bytes: Some(file_size),
};
if let ControlFlow::Break(()) = on_progress(progress) {
return Err(Error::Cancelled);
}
if chunks_sent < total_chunks {
let credits_available = conn.credits() as usize / credit_charge.max(1) as usize;
if credits_available > 0 {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
}
}
Ok(data)
}
async fn write_pipelined_loop(
&self,
conn: &mut Connection,
file_id: FileId,
data: &[u8],
chunk_size: u32,
credit_charge: u16,
total_chunks: usize,
) -> Result<u64> {
use futures_util::stream::{FuturesUnordered, StreamExt};
let mut chunks_sent = 0usize;
let mut chunks_received = 0usize;
let mut total_written = 0u64;
let max_from_credits = conn.credits() as usize / credit_charge.max(1) as usize;
let initial_window = total_chunks.min(max_from_credits).min(MAX_PIPELINE_WINDOW);
if initial_window == 0 {
return Err(Error::invalid_data(
"no credits available for pipelined write",
));
}
debug!(
"tree: pipeline write sliding window: initial_window={}, total_chunks={}, credits={}",
initial_window,
total_chunks,
conn.credits()
);
let mut in_flight = FuturesUnordered::new();
let build_req = |chunk_index: usize| -> WriteRequest {
let offset = chunk_index * chunk_size as usize;
let end = (offset + chunk_size as usize).min(data.len());
let chunk = &data[offset..end];
WriteRequest {
data_offset: 0x70,
offset: offset as u64,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk.to_vec(),
}
};
let launch_chunk = |conn: &Connection, chunk_index: usize, tree_id: TreeId| {
let c = conn.clone();
let req = build_req(chunk_index);
async move {
let frame = c
.execute_with_credits(
Command::Write,
&req,
Some(tree_id),
CreditCharge(credit_charge),
)
.await;
(chunk_index, frame)
}
};
for _ in 0..initial_window {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
while chunks_received < total_chunks {
let Some((_chunk_index, frame_result)) = in_flight.next().await else {
break;
};
chunks_received += 1;
let frame = frame_result?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
total_written += resp.count as u64;
if chunks_sent < total_chunks {
let credits_available = conn.credits() as usize / credit_charge.max(1) as usize;
if credits_available > 0 {
in_flight.push(launch_chunk(conn, chunks_sent, self.tree_id));
chunks_sent += 1;
}
}
}
Ok(total_written)
}
async fn write_streamed_loop<F>(
&self,
conn: &mut Connection,
file_id: FileId,
next_chunk: &mut F,
max_write: u32,
) -> Result<u64>
where
F: FnMut() -> Option<std::result::Result<Vec<u8>, std::io::Error>>,
{
use futures_util::stream::{FuturesUnordered, StreamExt};
type BoxedExecute = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<crate::client::connection::Frame>> + Send>,
>;
let mut offset = 0u64;
let mut in_flight = 0usize;
let mut total_written = 0u64;
let mut done = false; let mut callback_err: Option<std::io::Error> = None;
let mut in_flight_futs: FuturesUnordered<BoxedExecute> = FuturesUnordered::new();
let mut pending_data: Vec<u8> = Vec::new();
let mut pending_offset = 0usize;
let mut stashed_chunk: Option<Vec<u8>> = None;
let next_wire_chunk = |pending_data: &mut Vec<u8>,
pending_offset: &mut usize,
done: &mut bool,
callback_err: &mut Option<std::io::Error>,
next_chunk: &mut F|
-> Option<Vec<u8>> {
if *pending_offset < pending_data.len() {
let end = (*pending_offset + max_write as usize).min(pending_data.len());
let slice = pending_data[*pending_offset..end].to_vec();
*pending_offset = end;
if *pending_offset >= pending_data.len() {
pending_data.clear();
*pending_offset = 0;
}
return Some(slice);
}
if *done {
return None;
}
match next_chunk() {
None => {
*done = true;
None
}
Some(Err(e)) => {
*done = true;
*callback_err = Some(e);
None
}
Some(Ok(data)) => {
if data.is_empty() {
*done = true;
return None;
}
if data.len() <= max_write as usize {
Some(data)
} else {
let first = data[..max_write as usize].to_vec();
*pending_data = data;
*pending_offset = max_write as usize;
Some(first)
}
}
}
};
loop {
let credit_charge_per = max_write.div_ceil(65536).max(1) as u16;
let max_from_credits = conn.credits() as usize / credit_charge_per.max(1) as usize;
let can_send = max_from_credits.min(MAX_PIPELINE_WINDOW.saturating_sub(in_flight));
if can_send == 0 {
break;
}
let chunk = next_wire_chunk(
&mut pending_data,
&mut pending_offset,
&mut done,
&mut callback_err,
next_chunk,
);
match chunk {
None => break,
Some(chunk_data) => {
let data_len = chunk_data.len() as u64;
let cc = data_len.div_ceil(65536).max(1) as u16;
let c = conn.clone();
let tree_id = self.tree_id;
let req = WriteRequest {
data_offset: 0x70,
offset,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk_data,
};
in_flight_futs.push(Box::pin(async move {
c.execute_with_credits(
Command::Write,
&req,
Some(tree_id),
CreditCharge(cc),
)
.await
}));
offset += data_len;
in_flight += 1;
}
}
}
while in_flight > 0 {
let frame_result = match in_flight_futs.next().await {
Some(r) => r,
None => break,
};
in_flight -= 1;
let frame = frame_result?;
if frame.header.status != NtStatus::SUCCESS {
while in_flight_futs.next().await.is_some() {}
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
total_written += resp.count as u64;
if callback_err.is_none() && stashed_chunk.is_none() {
let chunk = next_wire_chunk(
&mut pending_data,
&mut pending_offset,
&mut done,
&mut callback_err,
next_chunk,
);
if let Some(chunk_data) = chunk {
let data_len = chunk_data.len() as u64;
let cc = data_len.div_ceil(65536).max(1) as u16;
let credits_available = conn.credits() as usize / cc.max(1) as usize;
if credits_available > 0 {
let c = conn.clone();
let tree_id = self.tree_id;
let req = WriteRequest {
data_offset: 0x70,
offset,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk_data,
};
in_flight_futs.push(Box::pin(async move {
c.execute_with_credits(
Command::Write,
&req,
Some(tree_id),
CreditCharge(cc),
)
.await
}));
offset += data_len;
in_flight += 1;
} else {
stashed_chunk = Some(chunk_data);
}
}
} else if let Some(chunk_data) = stashed_chunk.take() {
let data_len = chunk_data.len() as u64;
let cc = data_len.div_ceil(65536).max(1) as u16;
let credits_available = conn.credits() as usize / cc.max(1) as usize;
if credits_available > 0 {
let c = conn.clone();
let tree_id = self.tree_id;
let req = WriteRequest {
data_offset: 0x70,
offset,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk_data,
};
in_flight_futs.push(Box::pin(async move {
c.execute_with_credits(
Command::Write,
&req,
Some(tree_id),
CreditCharge(cc),
)
.await
}));
offset += data_len;
in_flight += 1;
} else {
stashed_chunk = Some(chunk_data);
}
}
}
if let Some(io_err) = callback_err {
return Err(Error::Io(io_err));
}
Ok(total_written)
}
pub(crate) async fn flush_handle(&self, conn: &mut Connection, file_id: FileId) -> Result<()> {
debug!("tree: flushing file handle");
let req = FlushRequest { file_id };
let frame = conn
.execute(Command::Flush, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Flush,
});
}
Ok(())
}
pub(crate) async fn close_handle(&self, conn: &mut Connection, file_id: FileId) -> Result<()> {
let req = CloseRequest { flags: 0, file_id };
let frame = conn
.execute(Command::Close, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Close,
});
}
Ok(())
}
#[allow(dead_code)]
async fn write_loop(&self, conn: &mut Connection, file_id: FileId, data: &[u8]) -> Result<u64> {
let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536);
let mut total_written = 0u64;
let mut offset = 0usize;
while offset < data.len() {
let remaining = data.len() - offset;
let chunk_size = remaining.min(max_write as usize);
let chunk = &data[offset..offset + chunk_size];
let req = WriteRequest {
data_offset: 0x70,
offset: offset as u64,
file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk.to_vec(),
};
let frame = conn
.execute(Command::Write, &req, Some(self.tree_id))
.await?;
if frame.header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
total_written += resp.count as u64;
offset += chunk_size;
}
Ok(total_written)
}
}
fn build_rename_info_buffer(new_name: &str) -> Vec<u8> {
let name_u16: Vec<u16> = new_name.encode_utf16().collect();
let name_byte_len = name_u16.len() * 2;
let mut buf = Vec::with_capacity(20 + name_byte_len);
buf.push(0); buf.extend_from_slice(&[0u8; 7]); buf.extend_from_slice(&0u64.to_le_bytes()); buf.extend_from_slice(&(name_byte_len as u32).to_le_bytes()); for &u in &name_u16 {
buf.extend_from_slice(&u.to_le_bytes());
}
buf
}
fn normalize_path(path: &str) -> String {
let p = path.replace('/', "\\");
p.trim_start_matches('\\').to_string()
}
fn parse_file_both_directory_info(data: &[u8]) -> Result<Vec<DirectoryEntry>> {
let mut entries = Vec::new();
let mut offset = 0usize;
loop {
if offset + 94 > data.len() {
break;
}
let entry_data = &data[offset..];
let mut cursor = ReadCursor::new(entry_data);
let next_entry_offset = cursor.read_u32_le()? as usize;
let _file_index = cursor.read_u32_le()?;
let creation_time = FileTime::unpack(&mut cursor)?;
let _last_access_time = FileTime::unpack(&mut cursor)?;
let last_write_time = FileTime::unpack(&mut cursor)?;
let _change_time = FileTime::unpack(&mut cursor)?;
let end_of_file = cursor.read_u64_le()?;
let _allocation_size = cursor.read_u64_le()?;
let file_attributes = cursor.read_u32_le()?;
let file_name_length = cursor.read_u32_le()? as usize;
let _ea_size = cursor.read_u32_le()?;
let _short_name_length = cursor.read_u8()?;
let _reserved = cursor.read_u8()?;
cursor.skip(24)?;
let name = if file_name_length > 0 {
cursor.read_utf16_le(file_name_length)?
} else {
String::new()
};
let is_directory = (file_attributes & FILE_ATTRIBUTE_DIRECTORY) != 0;
entries.push(DirectoryEntry {
name,
size: end_of_file,
is_directory,
created: creation_time,
modified: last_write_time,
});
if next_entry_offset == 0 {
break;
}
offset += next_entry_offset;
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::connection::pack_message;
use crate::client::test_helpers::{
build_close_response, build_create_response, build_tree_connect_response, setup_connection,
};
use crate::msg::create::{CreateAction, CreateResponse};
use crate::msg::header::Header;
use crate::msg::query_directory::QueryDirectoryResponse;
use crate::msg::tree_connect::ShareType;
use crate::transport::MockTransport;
use crate::types::status::NtStatus;
use crate::types::{Command, TreeId};
use std::sync::Arc;
fn build_flush_response() -> Vec<u8> {
let mut h = Header::new_request(Command::Flush);
h.flags.set_response();
h.credits = 32;
let body = crate::msg::flush::FlushResponse;
pack_message(&h, &body)
}
fn build_query_directory_response(status: NtStatus, entries_data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::QueryDirectory);
h.flags.set_response();
h.credits = 32;
h.status = status;
if status == NtStatus::NO_MORE_FILES {
use crate::msg::header::ErrorResponse;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
return pack_message(&h, &body);
}
let body = QueryDirectoryResponse {
output_buffer: entries_data,
};
pack_message(&h, &body)
}
fn build_read_response(status: NtStatus, data: Vec<u8>) -> Vec<u8> {
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 32;
h.status = status;
if status == NtStatus::END_OF_FILE {
use crate::msg::header::ErrorResponse;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
return pack_message(&h, &body);
}
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data,
};
pack_message(&h, &body)
}
fn build_file_both_dir_info(
name: &str,
size: u64,
is_directory: bool,
next_offset: u32,
) -> Vec<u8> {
let name_u16: Vec<u16> = name.encode_utf16().collect();
let name_bytes_len = name_u16.len() * 2;
let mut buf = Vec::new();
buf.extend_from_slice(&next_offset.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes());
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes());
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes());
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes());
buf.extend_from_slice(&size.to_le_bytes());
buf.extend_from_slice(&((size + 4095) & !4095).to_le_bytes());
let attrs = if is_directory {
FILE_ATTRIBUTE_DIRECTORY
} else {
0x00000020 };
buf.extend_from_slice(&attrs.to_le_bytes());
buf.extend_from_slice(&(name_bytes_len as u32).to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes());
buf.push(0);
buf.push(0);
buf.extend_from_slice(&[0u8; 24]);
for &u in &name_u16 {
buf.extend_from_slice(&u.to_le_bytes());
}
buf
}
#[tokio::test]
async fn tree_connect_stores_tree_id() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(42);
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Disk));
let mut conn = setup_connection(&mock);
let tree = Tree::connect(&mut conn, "naspi").await.unwrap();
assert_eq!(tree.tree_id, tree_id);
assert_eq!(tree.share_name, "naspi");
}
#[tokio::test]
async fn tree_connect_sends_unc_path() {
let mock = Arc::new(MockTransport::new());
mock.queue_response(build_tree_connect_response(TreeId(1), ShareType::Disk));
let mut conn = setup_connection(&mock);
let _tree = Tree::connect(&mut conn, "myshare").await.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = TreeConnectRequest::unpack(&mut cursor).unwrap();
assert_eq!(req.path, r"\\test-server\myshare");
}
#[tokio::test]
async fn list_directory_returns_entries() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(10);
let file_id = FileId {
persistent: 0x1111,
volatile: 0x2222,
};
let entry1 = build_file_both_dir_info("file1.txt", 1024, false, 0);
let total_entry_len = entry1.len();
let entry1_with_next =
build_file_both_dir_info("file1.txt", 1024, false, total_entry_len as u32);
let entry2 = build_file_both_dir_info("subdir", 0, true, 0);
let mut entries_data = entry1_with_next;
entries_data.extend_from_slice(&entry2);
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_query_directory_response(
NtStatus::SUCCESS,
entries_data,
));
mock.queue_response(build_query_directory_response(
NtStatus::NO_MORE_FILES,
vec![],
));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id,
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let entries = tree.list_directory(&mut conn, "somedir").await.unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].name, "file1.txt");
assert_eq!(entries[0].size, 1024);
assert!(!entries[0].is_directory);
assert_eq!(entries[1].name, "subdir");
assert!(entries[1].is_directory);
}
#[tokio::test]
async fn read_file_returns_data() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(20);
let file_id = FileId {
persistent: 0x3333,
volatile: 0x4444,
};
let file_data = b"Hello, SMB world!";
let create_resp = build_create_response(file_id, file_data.len() as u64);
let read_resp = build_read_response(NtStatus::SUCCESS, file_data.to_vec());
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id,
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree.read_file(&mut conn, "test.txt").await.unwrap();
assert_eq!(data, file_data);
}
#[tokio::test]
async fn normalize_path_converts_slashes() {
assert_eq!(normalize_path("foo/bar/baz"), "foo\\bar\\baz");
assert_eq!(normalize_path("/leading/slash"), "leading\\slash");
assert_eq!(normalize_path("\\leading\\backslash"), "leading\\backslash");
assert_eq!(normalize_path("no_change"), "no_change");
}
#[tokio::test]
async fn format_path_prepends_dfs_prefix() {
let tree = Tree {
tree_id: TreeId(1),
share_name: "dfs".to_string(),
server: "server1".to_string(),
is_dfs: true,
encrypt_data: false,
};
assert_eq!(
tree.format_path("data/hello.txt"),
"server1\\dfs\\data\\hello.txt"
);
assert_eq!(tree.format_path(""), "server1\\dfs");
assert_eq!(
tree.format_path("nested/path"),
"server1\\dfs\\nested\\path"
);
}
#[tokio::test]
async fn format_path_strips_port_from_dfs_prefix() {
let tree = Tree {
tree_id: TreeId(1),
share_name: "dfs".to_string(),
server: "server1:10456".to_string(),
is_dfs: true,
encrypt_data: false,
};
assert_eq!(
tree.format_path("data/hello.txt"),
"server1\\dfs\\data\\hello.txt"
);
}
#[tokio::test]
async fn format_path_no_prefix_when_not_dfs() {
let tree = Tree {
tree_id: TreeId(1),
share_name: "public".to_string(),
server: "server1".to_string(),
is_dfs: false,
encrypt_data: false,
};
assert_eq!(tree.format_path("data/hello.txt"), "data\\hello.txt");
assert_eq!(tree.format_path(""), "");
}
#[tokio::test]
async fn parse_file_both_dir_info_single_entry() {
let data = build_file_both_dir_info("test.txt", 42, false, 0);
let entries = parse_file_both_directory_info(&data).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "test.txt");
assert_eq!(entries[0].size, 42);
assert!(!entries[0].is_directory);
}
#[tokio::test]
async fn tree_disconnect_sends_request() {
let mock = Arc::new(MockTransport::new());
let mut h = Header::new_request(Command::TreeDisconnect);
h.flags.set_response();
h.credits = 32;
use crate::msg::tree_disconnect::TreeDisconnectResponse;
mock.queue_response(pack_message(&h, &TreeDisconnectResponse));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(99),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
tree.disconnect(&mut conn).await.unwrap();
assert_eq!(mock.sent_count(), 1);
}
fn build_write_response(count: u32) -> Vec<u8> {
use crate::msg::write::WriteResponse;
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
let body = WriteResponse {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
fn build_query_info_response(output_buffer: Vec<u8>) -> Vec<u8> {
build_query_info_response_with_status(NtStatus::SUCCESS, output_buffer)
}
fn build_query_info_response_with_status(status: NtStatus, output_buffer: Vec<u8>) -> Vec<u8> {
use crate::msg::query_info::QueryInfoResponse;
let mut h = Header::new_request(Command::QueryInfo);
h.flags.set_response();
h.credits = 32;
h.status = status;
let body = QueryInfoResponse { output_buffer };
pack_message(&h, &body)
}
fn build_set_info_response() -> Vec<u8> {
use crate::msg::set_info::SetInfoResponse;
let mut h = Header::new_request(Command::SetInfo);
h.flags.set_response();
h.credits = 32;
let body = SetInfoResponse;
pack_message(&h, &body)
}
fn build_file_basic_info(
creation_time: u64,
last_access_time: u64,
last_write_time: u64,
change_time: u64,
file_attributes: u32,
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&creation_time.to_le_bytes());
buf.extend_from_slice(&last_access_time.to_le_bytes());
buf.extend_from_slice(&last_write_time.to_le_bytes());
buf.extend_from_slice(&change_time.to_le_bytes());
buf.extend_from_slice(&file_attributes.to_le_bytes());
buf.extend_from_slice(&0u32.to_le_bytes()); buf
}
fn build_file_standard_info(
allocation_size: u64,
end_of_file: u64,
number_of_links: u32,
delete_pending: bool,
directory: bool,
) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&allocation_size.to_le_bytes());
buf.extend_from_slice(&end_of_file.to_le_bytes());
buf.extend_from_slice(&number_of_links.to_le_bytes());
buf.push(if delete_pending { 1 } else { 0 });
buf.push(if directory { 1 } else { 0 });
buf.extend_from_slice(&0u16.to_le_bytes()); buf
}
#[tokio::test]
async fn delete_file_sends_compound_create_and_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xAA,
volatile: 0xBB,
};
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
tree.delete_file(&mut conn, "remove.txt").await.unwrap();
assert_eq!(mock.sent_count(), 1);
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = CreateRequest::unpack(&mut cursor).unwrap();
assert!(req.desired_access.contains(FileAccessMask::DELETE));
assert_ne!(req.create_options & FILE_DELETE_ON_CLOSE, 0);
assert_ne!(req.create_options & FILE_NON_DIRECTORY_FILE, 0);
}
#[tokio::test]
async fn delete_file_create_failure_returns_error() {
let mock = Arc::new(MockTransport::new());
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_resp = pack_message(
&create_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_resp = pack_message(
&close_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.delete_file(&mut conn, "nonexistent.txt").await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn delete_file_close_failure_issues_standalone_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xAA,
volatile: 0xBB,
};
let create_resp = build_create_response(file_id, 0);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::UNSUCCESSFUL;
let close_resp = pack_message(
&close_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, close_resp]);
mock.queue_response(frame);
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.delete_file(&mut conn, "tricky.txt").await;
assert!(result.is_err());
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn write_file_sends_create_write_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xCC,
volatile: 0xDD,
};
let create_resp = build_create_response(file_id, 0);
let write_resp = build_write_response(5);
let flush_resp = build_flush_response();
let close_resp = build_close_response();
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let written = tree
.write_file(&mut conn, "out.txt", b"hello")
.await
.unwrap();
assert_eq!(written, 5);
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn stat_sends_compound_and_returns_file_info() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xEE,
volatile: 0xFF,
};
let create_resp = build_create_response(file_id, 0);
let basic = build_file_basic_info(
132_000_000_000_000_000,
132_100_000_000_000_000,
133_000_000_000_000_000,
133_000_000_000_000_000,
0x20, );
let basic_resp = build_query_info_response(basic);
let std_info = build_file_standard_info(4096, 2048, 1, false, false);
let std_resp = build_query_info_response(std_info);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let info = tree.stat(&mut conn, "doc.txt").await.unwrap();
assert_eq!(info.size, 2048);
assert!(!info.is_directory);
assert_eq!(info.created, FileTime(132_000_000_000_000_000));
assert_eq!(info.modified, FileTime(133_000_000_000_000_000));
assert_eq!(info.accessed, FileTime(132_100_000_000_000_000));
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn stat_create_failure_returns_error() {
let mock = Arc::new(MockTransport::new());
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let create_resp = pack_message(&create_hdr, &err_body);
let mut q1_hdr = Header::new_request(Command::QueryInfo);
q1_hdr.flags.set_response();
q1_hdr.credits = 32;
q1_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let q1_resp = pack_message(&q1_hdr, &err_body);
let mut q2_hdr = Header::new_request(Command::QueryInfo);
q2_hdr.flags.set_response();
q2_hdr.credits = 32;
q2_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let q2_resp = pack_message(&q2_hdr, &err_body);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_resp = pack_message(&close_hdr, &err_body);
let frame = build_compound_response_frame(&[create_resp, q1_resp, q2_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.stat(&mut conn, "nonexistent.txt").await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn stat_query_failure_issues_standalone_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xEE,
volatile: 0xFF,
};
let create_resp = build_create_response(file_id, 0);
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let mut q1_hdr = Header::new_request(Command::QueryInfo);
q1_hdr.flags.set_response();
q1_hdr.credits = 32;
q1_hdr.status = NtStatus::UNSUCCESSFUL;
let q1_resp = pack_message(&q1_hdr, &err_body);
let mut q2_hdr = Header::new_request(Command::QueryInfo);
q2_hdr.flags.set_response();
q2_hdr.credits = 32;
q2_hdr.status = NtStatus::UNSUCCESSFUL;
let q2_resp = pack_message(&q2_hdr, &err_body);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::UNSUCCESSFUL;
let close_resp = pack_message(&close_hdr, &err_body);
let frame = build_compound_response_frame(&[create_resp, q1_resp, q2_resp, close_resp]);
mock.queue_response(frame);
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.stat(&mut conn, "tricky.txt").await;
assert!(result.is_err());
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn stat_files_batch_happy_path() {
let mock = Arc::new(MockTransport::new());
for i in 0..3u64 {
let file_id = FileId {
persistent: i + 1,
volatile: i + 100,
};
let create_resp = build_create_response(file_id, 0);
let basic = build_file_basic_info(
132_000_000_000_000_000 + i,
132_100_000_000_000_000 + i,
133_000_000_000_000_000 + i,
133_000_000_000_000_000 + i,
0x20,
);
let basic_resp = build_query_info_response(basic);
let std_info = build_file_standard_info(4096, 1024 * (i + 1), 1, false, false);
let std_resp = build_query_info_response(std_info);
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[
create_resp,
basic_resp,
std_resp,
close_resp,
]));
}
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.stat_files(&mut conn, &["a.txt", "b.txt", "c.txt"])
.await;
assert_eq!(results.len(), 3);
assert_eq!(results[0].as_ref().unwrap().size, 1024);
assert_eq!(results[1].as_ref().unwrap().size, 2048);
assert_eq!(results[2].as_ref().unwrap().size, 3072);
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn stat_files_batch_partial_failure() {
let mock = Arc::new(MockTransport::new());
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let file_id = FileId {
persistent: 1,
volatile: 100,
};
let create_resp = build_create_response(file_id, 0);
let basic = build_file_basic_info(
132_000_000_000_000_000,
132_100_000_000_000_000,
133_000_000_000_000_000,
133_000_000_000_000_000,
0x20,
);
let basic_resp = build_query_info_response(basic);
let std_info = build_file_standard_info(4096, 512, 1, false, false);
let std_resp = build_query_info_response(std_info);
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[
create_resp,
basic_resp,
std_resp,
close_resp,
]));
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_err = pack_message(&create_hdr, &err_body);
let mut q1_hdr = Header::new_request(Command::QueryInfo);
q1_hdr.flags.set_response();
q1_hdr.credits = 32;
q1_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let q1_err = pack_message(&q1_hdr, &err_body);
let mut q2_hdr = Header::new_request(Command::QueryInfo);
q2_hdr.flags.set_response();
q2_hdr.credits = 32;
q2_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let q2_err = pack_message(&q2_hdr, &err_body);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_err = pack_message(&close_hdr, &err_body);
mock.queue_response(build_compound_response_frame(&[
create_err, q1_err, q2_err, close_err,
]));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.stat_files(&mut conn, &["exists.txt", "missing.txt"])
.await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].as_ref().unwrap().size, 512);
assert!(results[1].is_err());
assert_eq!(
results[1].as_ref().unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
}
#[tokio::test]
async fn stat_files_empty_returns_empty() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results: Vec<Result<FileInfo>> = tree.stat_files(&mut conn, &[]).await;
assert!(results.is_empty());
}
#[tokio::test]
async fn rename_sends_compound_create_setinfo_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x11,
volatile: 0x22,
};
let create_resp = build_create_response(file_id, 0);
let setinfo_resp = build_set_info_response();
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
tree.rename(&mut conn, "old.txt", "new.txt").await.unwrap();
assert_eq!(mock.sent_count(), 1);
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = CreateRequest::unpack(&mut cursor).unwrap();
assert!(req.desired_access.contains(FileAccessMask::DELETE));
}
#[tokio::test]
async fn rename_create_failure_returns_error() {
let mock = Arc::new(MockTransport::new());
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_resp = pack_message(
&create_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut setinfo_hdr = Header::new_request(Command::SetInfo);
setinfo_hdr.flags.set_response();
setinfo_hdr.credits = 32;
setinfo_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let setinfo_resp = pack_message(
&setinfo_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_resp = pack_message(
&close_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.rename(&mut conn, "old.txt", "new.txt").await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn rename_setinfo_failure_issues_standalone_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x11,
volatile: 0x22,
};
let create_resp = build_create_response(file_id, 0);
let mut setinfo_hdr = Header::new_request(Command::SetInfo);
setinfo_hdr.flags.set_response();
setinfo_hdr.credits = 32;
setinfo_hdr.status = NtStatus::UNSUCCESSFUL;
let setinfo_resp = pack_message(
&setinfo_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::UNSUCCESSFUL;
let close_resp = pack_message(
&close_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, setinfo_resp, close_resp]);
mock.queue_response(frame);
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.rename(&mut conn, "old.txt", "new.txt").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().status(), Some(NtStatus::UNSUCCESSFUL));
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn rename_files_batch_happy_path() {
let mock = Arc::new(MockTransport::new());
for i in 0..3u64 {
let file_id = FileId {
persistent: i + 1,
volatile: i + 100,
};
let create_resp = build_create_response(file_id, 0);
let setinfo_resp = build_set_info_response();
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[
create_resp,
setinfo_resp,
close_resp,
]));
}
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.rename_files(
&mut conn,
&[
("a.txt", "a2.txt"),
("b.txt", "b2.txt"),
("c.txt", "c2.txt"),
],
)
.await;
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_ok());
assert!(results[2].is_ok());
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn rename_files_batch_partial_failure() {
let mock = Arc::new(MockTransport::new());
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let file_id = FileId {
persistent: 1,
volatile: 100,
};
let create_resp = build_create_response(file_id, 0);
let setinfo_resp = build_set_info_response();
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[
create_resp,
setinfo_resp,
close_resp,
]));
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_err = pack_message(&create_hdr, &err_body);
let mut si_hdr = Header::new_request(Command::SetInfo);
si_hdr.flags.set_response();
si_hdr.credits = 32;
si_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let si_err = pack_message(&si_hdr, &err_body);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_err = pack_message(&close_hdr, &err_body);
mock.queue_response(build_compound_response_frame(&[
create_err, si_err, close_err,
]));
let file_id = FileId {
persistent: 3,
volatile: 102,
};
let create_resp = build_create_response(file_id, 0);
let setinfo_resp = build_set_info_response();
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[
create_resp,
setinfo_resp,
close_resp,
]));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.rename_files(
&mut conn,
&[
("a.txt", "a2.txt"),
("missing.txt", "m2.txt"),
("c.txt", "c2.txt"),
],
)
.await;
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert_eq!(
results[1].as_ref().unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
assert!(results[2].is_ok());
}
#[tokio::test]
async fn rename_files_empty_returns_empty() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results: Vec<Result<()>> = tree.rename_files(&mut conn, &[]).await;
assert!(results.is_empty());
assert_eq!(mock.sent_count(), 0);
}
#[tokio::test]
async fn create_directory_sends_create_and_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x33,
volatile: 0x44,
};
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
tree.create_directory(&mut conn, "new_dir").await.unwrap();
assert_eq!(mock.sent_count(), 2);
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = CreateRequest::unpack(&mut cursor).unwrap();
assert_eq!(req.create_disposition, CreateDisposition::FileCreate);
assert_ne!(req.create_options & FILE_DIRECTORY_FILE, 0);
}
#[tokio::test]
async fn delete_directory_sends_compound_create_and_close() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x55,
volatile: 0x66,
};
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
tree.delete_directory(&mut conn, "old_dir").await.unwrap();
assert_eq!(mock.sent_count(), 1);
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = CreateRequest::unpack(&mut cursor).unwrap();
assert_ne!(req.create_options & FILE_DELETE_ON_CLOSE, 0);
assert_ne!(req.create_options & FILE_DIRECTORY_FILE, 0);
}
#[tokio::test]
async fn delete_files_batch_happy_path() {
let mock = Arc::new(MockTransport::new());
for i in 0..3u64 {
let file_id = FileId {
persistent: i + 1,
volatile: i + 100,
};
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[create_resp, close_resp]));
}
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.delete_files(&mut conn, &["a.txt", "b.txt", "c.txt"])
.await;
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_ok());
assert!(results[2].is_ok());
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn delete_files_batch_partial_failure() {
let mock = Arc::new(MockTransport::new());
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let file_id = FileId {
persistent: 1,
volatile: 100,
};
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[create_resp, close_resp]));
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_err = pack_message(&create_hdr, &err_body);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_err = pack_message(&close_hdr, &err_body);
mock.queue_response(build_compound_response_frame(&[create_err, close_err]));
let file_id = FileId {
persistent: 3,
volatile: 102,
};
let create_resp = build_create_response(file_id, 0);
let close_resp = build_close_response();
mock.queue_response(build_compound_response_frame(&[create_resp, close_resp]));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree
.delete_files(&mut conn, &["a.txt", "missing.txt", "c.txt"])
.await;
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert_eq!(
results[1].as_ref().unwrap_err().status(),
Some(NtStatus::OBJECT_NAME_NOT_FOUND)
);
assert!(results[2].is_ok());
}
#[tokio::test]
async fn delete_files_batch_close_failure_issues_cleanup() {
let mock = Arc::new(MockTransport::new());
let err_body = crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
let file_id = FileId {
persistent: 0xAA,
volatile: 0xBB,
};
let create_resp = build_create_response(file_id, 0);
let mut close_hdr = Header::new_request(Command::Close);
close_hdr.flags.set_response();
close_hdr.credits = 32;
close_hdr.status = NtStatus::UNSUCCESSFUL;
let close_fail = pack_message(&close_hdr, &err_body);
mock.queue_response(build_compound_response_frame(&[create_resp, close_fail]));
let file_id2 = FileId {
persistent: 0xCC,
volatile: 0xDD,
};
let create_resp2 = build_create_response(file_id2, 0);
let close_resp2 = build_close_response();
mock.queue_response(build_compound_response_frame(&[create_resp2, close_resp2]));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree.delete_files(&mut conn, &["leaky.txt", "ok.txt"]).await;
assert_eq!(results.len(), 2);
assert!(results[0].is_err());
assert!(results[1].is_ok());
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn delete_files_empty_returns_empty() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let results = tree.delete_files(&mut conn, &[]).await;
assert!(results.is_empty());
assert_eq!(mock.sent_count(), 0);
}
fn build_read_response_with_msg_id(
status: NtStatus,
msg_id: MessageId,
data: Vec<u8>,
) -> Vec<u8> {
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 32;
h.status = status;
h.message_id = msg_id;
if status == NtStatus::END_OF_FILE {
use crate::msg::header::ErrorResponse;
let body = ErrorResponse {
error_context_count: 0,
error_data: vec![],
};
return pack_message(&h, &body);
}
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data,
};
pack_message(&h, &body)
}
fn build_write_response_with_msg_id(msg_id: MessageId, count: u32) -> Vec<u8> {
use crate::msg::write::WriteResponse;
let mut h = Header::new_request(Command::Write);
h.flags.set_response();
h.credits = 32;
h.message_id = msg_id;
let body = WriteResponse {
count,
remaining: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
};
pack_message(&h, &body)
}
#[tokio::test]
async fn pipelined_read_four_chunks() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(20);
let file_id = FileId {
persistent: 0x100,
volatile: 0x200,
};
let file_size = 256 * 1024u64;
let mut expected_data = vec![0u8; file_size as usize];
for (i, byte) in expected_data.iter_mut().enumerate() {
*byte = (i % 251) as u8; }
mock.queue_response(build_create_response(file_id, file_size));
for i in 0..4 {
let offset = i * 65536;
let chunk = expected_data[offset..offset + 65536].to_vec();
mock.queue_response(build_read_response_with_msg_id(
NtStatus::SUCCESS,
MessageId((i / 65536 + 1) as u64), chunk,
));
}
let mock = Arc::new(MockTransport::new());
mock.queue_response(build_create_response(file_id, file_size));
for i in 0u64..4 {
let offset = (i * 65536) as usize;
let chunk = expected_data[offset..offset + 65536].to_vec();
mock.queue_response(build_read_response_with_msg_id(
NtStatus::SUCCESS,
MessageId(i + 1), chunk,
));
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id,
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "big.bin")
.await
.unwrap();
assert_eq!(data.len(), expected_data.len());
assert_eq!(data, expected_data);
assert_eq!(mock.sent_count(), 6);
}
#[tokio::test]
async fn pipelined_read_responses_out_of_order() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(20);
let file_id = FileId {
persistent: 0x300,
volatile: 0x400,
};
let file_size = 192 * 1024u64;
let mut expected_data = vec![0u8; file_size as usize];
for (i, byte) in expected_data.iter_mut().enumerate() {
*byte = (i % 199) as u8;
}
mock.queue_response(build_create_response(file_id, file_size));
for i in (0u64..3).rev() {
let offset = (i * 65536) as usize;
let chunk = expected_data[offset..offset + 65536].to_vec();
mock.queue_response(build_read_response_with_msg_id(
NtStatus::SUCCESS,
MessageId(i + 1),
chunk,
));
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id,
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "reverse.bin")
.await
.unwrap();
assert_eq!(data.len(), expected_data.len());
assert_eq!(data, expected_data);
}
#[tokio::test]
async fn pipelined_read_zero_byte_file() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x500,
volatile: 0x600,
};
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "empty.bin")
.await
.unwrap();
assert!(data.is_empty());
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn pipelined_read_end_of_file_mid_window() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x700,
volatile: 0x800,
};
let file_size = 128 * 1024u64;
let first_chunk = vec![0xAA; 65536];
mock.queue_response(build_create_response(file_id, file_size));
mock.queue_response(build_read_response_with_msg_id(
NtStatus::SUCCESS,
MessageId(1),
first_chunk.clone(),
));
mock.queue_response(build_read_response_with_msg_id(
NtStatus::END_OF_FILE,
MessageId(2),
vec![],
));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "truncated.bin")
.await
.unwrap();
assert_eq!(data.len(), file_size as usize);
assert_eq!(&data[..65536], &first_chunk);
}
#[tokio::test]
async fn pipelined_read_window_sliding() {
let file_id = FileId {
persistent: 0x900,
volatile: 0xA00,
};
let file_size = 192 * 1024u64;
let mut expected_data = vec![0u8; file_size as usize];
for (i, byte) in expected_data.iter_mut().enumerate() {
*byte = (i % 173) as u8;
}
let mock = Arc::new(MockTransport::new());
let create_resp = {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 2; let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: file_size,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
};
mock.queue_response(create_resp);
for i in 0u64..2 {
let offset = (i * 65536) as usize;
let chunk_data = expected_data[offset..offset + 65536].to_vec();
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 2; h.message_id = MessageId(i + 1);
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data: chunk_data,
};
mock.queue_response(pack_message(&h, &body));
}
{
let offset = (2 * 65536) as usize;
let chunk_data = expected_data[offset..offset + 65536].to_vec();
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 2;
h.message_id = MessageId(3);
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data: chunk_data,
};
mock.queue_response(pack_message(&h, &body));
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "sliding.bin")
.await
.unwrap();
assert_eq!(data.len(), expected_data.len());
assert_eq!(data, expected_data);
assert_eq!(mock.sent_count(), 5);
}
#[tokio::test]
async fn sliding_window_sends_immediately_after_receive() {
let file_id = FileId {
persistent: 0xF00,
volatile: 0xF01,
};
let file_size = 8 * 65536u64;
let mut expected_data = vec![0u8; file_size as usize];
for (i, byte) in expected_data.iter_mut().enumerate() {
*byte = (i % 137) as u8;
}
let mock = Arc::new(MockTransport::new());
let create_resp = {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 4;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: file_size,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
};
mock.queue_response(create_resp);
for i in 0u64..8 {
let offset = (i * 65536) as usize;
let chunk_data = expected_data[offset..offset + 65536].to_vec();
let mut h = Header::new_request(Command::Read);
h.flags.set_response();
h.credits = 1; h.message_id = MessageId(i + 1);
let body = ReadResponse {
data_offset: 0x50,
data_remaining: 0,
flags: 0,
data: chunk_data,
};
mock.queue_response(pack_message(&h, &body));
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let data = tree
.read_file_pipelined(&mut conn, "sliding_test.bin")
.await
.unwrap();
assert_eq!(data.len(), expected_data.len());
assert_eq!(data, expected_data);
assert_eq!(mock.sent_count(), 10);
}
#[tokio::test]
async fn read_pipelined_with_progress_reports_progress() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xF1,
volatile: 0xF2,
};
let file_size = 65536u64 * 2;
let expected_data = vec![0xABu8; file_size as usize];
let create_resp = {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: file_size,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
};
mock.queue_response(create_resp);
for i in 0..2u64 {
let offset = (i * 65536) as usize;
let chunk = expected_data[offset..offset + 65536].to_vec();
let resp = build_read_response_with_msg_id(NtStatus::SUCCESS, MessageId(i + 1), chunk);
mock.queue_response(resp);
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut progress_reports = Vec::new();
let data = tree
.read_file_pipelined_with_progress(&mut conn, "progress_test.bin", |p| {
progress_reports.push(p.bytes_transferred);
ControlFlow::Continue(())
})
.await
.unwrap();
assert_eq!(data.len(), file_size as usize);
assert_eq!(progress_reports.len(), 2);
assert_eq!(progress_reports[0], 65536);
assert_eq!(progress_reports[1], file_size);
}
#[tokio::test]
async fn read_pipelined_with_progress_cancellation() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xF3,
volatile: 0xF4,
};
let file_size = 65536u64 * 4;
let create_resp = {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: file_size,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
};
mock.queue_response(create_resp);
for i in 0..4u64 {
let chunk = vec![0x42u8; 65536];
let resp = build_read_response_with_msg_id(NtStatus::SUCCESS, MessageId(i + 1), chunk);
mock.queue_response(resp);
}
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree
.read_file_pipelined_with_progress(&mut conn, "cancel_test.bin", |_p| {
ControlFlow::Break(())
})
.await;
assert!(result.is_err());
match result.unwrap_err() {
Error::Cancelled => {} other => panic!("expected Cancelled, got {:?}", other),
}
}
#[tokio::test]
async fn read_pipelined_with_progress_empty_file() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xF5,
volatile: 0xF6,
};
let create_resp = {
let mut h = Header::new_request(Command::Create);
h.flags.set_response();
h.credits = 32;
let body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0,
create_action: CreateAction::FileOpened,
creation_time: FileTime::ZERO,
last_access_time: FileTime::ZERO,
last_write_time: FileTime::ZERO,
change_time: FileTime::ZERO,
allocation_size: 0,
end_of_file: 0,
file_attributes: 0,
file_id,
create_contexts: vec![],
};
pack_message(&h, &body)
};
mock.queue_response(create_resp);
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut progress_called = false;
let data = tree
.read_file_pipelined_with_progress(&mut conn, "empty.bin", |p| {
progress_called = true;
assert_eq!(p.bytes_transferred, 0);
assert_eq!(p.total_bytes, Some(0));
ControlFlow::Continue(())
})
.await
.unwrap();
assert!(data.is_empty());
assert!(progress_called);
}
#[tokio::test]
async fn pipelined_write_four_chunks() {
let mock = Arc::new(MockTransport::new());
let tree_id = TreeId(20);
let file_id = FileId {
persistent: 0xB00,
volatile: 0xC00,
};
let data_to_write = vec![0x42u8; 256 * 1024];
mock.queue_response(build_create_response(file_id, 0));
for i in 0u64..4 {
mock.queue_response(build_write_response_with_msg_id(MessageId(i + 1), 65536));
}
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id,
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let written = tree
.write_file_pipelined(&mut conn, "big_write.bin", &data_to_write)
.await
.unwrap();
assert_eq!(written, 256 * 1024);
assert_eq!(mock.sent_count(), 7);
for i in 0..4 {
let sent = mock.sent_message(i + 1).unwrap(); let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = WriteRequest::unpack(&mut cursor).unwrap();
assert_eq!(req.data.len(), 65536);
assert_eq!(req.offset, i as u64 * 65536);
assert!(req.data.iter().all(|&b| b == 0x42));
}
}
#[tokio::test]
async fn pipelined_write_last_chunk_smaller() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xD00,
volatile: 0xE00,
};
let data_to_write = vec![0x55u8; 100 * 1024];
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response_with_msg_id(MessageId(1), 65536));
mock.queue_response(build_write_response_with_msg_id(MessageId(2), 36 * 1024));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(20),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let written = tree
.write_file_pipelined(&mut conn, "partial.bin", &data_to_write)
.await
.unwrap();
assert_eq!(written, 65536 + 36 * 1024);
assert_eq!(mock.sent_count(), 5); }
fn build_compound_response_frame(responses: &[Vec<u8>]) -> Vec<u8> {
let mut padded: Vec<Vec<u8>> = Vec::new();
for (i, resp) in responses.iter().enumerate() {
let mut r = resp.clone();
let is_last = i == responses.len() - 1;
if !is_last {
let remainder = r.len() % 8;
if remainder != 0 {
r.resize(r.len() + (8 - remainder), 0);
}
let next_cmd = r.len() as u32;
r[20..24].copy_from_slice(&next_cmd.to_le_bytes());
}
padded.push(r);
}
let mut frame = Vec::new();
for r in &padded {
frame.extend_from_slice(r);
}
frame
}
#[tokio::test]
async fn read_file_compound_returns_file_data() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 0x42,
volatile: 0x99,
};
let file_data = b"Hello, compound!".to_vec();
let create_resp = build_create_response(file_id, file_data.len() as u64);
let read_resp = build_read_response(NtStatus::SUCCESS, file_data.clone());
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
let data = tree
.read_file_compound(&mut conn, "test.txt")
.await
.unwrap();
assert_eq!(data, b"Hello, compound!");
assert_eq!(mock.sent_count(), 2); }
#[tokio::test]
async fn read_file_compound_handles_empty_file() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 1,
volatile: 2,
};
let create_resp = build_create_response(file_id, 0);
let read_resp = build_read_response(NtStatus::END_OF_FILE, vec![]);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
let data = tree
.read_file_compound(&mut conn, "empty.txt")
.await
.unwrap();
assert!(data.is_empty());
}
#[tokio::test]
async fn read_file_compound_create_failure_returns_error() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let mut create_resp_header = Header::new_request(Command::Create);
create_resp_header.flags.set_response();
create_resp_header.credits = 32;
create_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_resp = pack_message(
&create_resp_header,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut read_resp_header = Header::new_request(Command::Read);
read_resp_header.flags.set_response();
read_resp_header.credits = 32;
read_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let read_resp = pack_message(
&read_resp_header,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_resp_header = Header::new_request(Command::Close);
close_resp_header.flags.set_response();
close_resp_header.credits = 32;
close_resp_header.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_resp = pack_message(
&close_resp_header,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
let result = tree.read_file_compound(&mut conn, "nonexistent.txt").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND));
}
#[tokio::test]
async fn read_file_compound_read_failure_issues_standalone_close() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 0x42,
volatile: 0x99,
};
let create_resp = build_create_response(file_id, 1024);
let mut read_resp_header = Header::new_request(Command::Read);
read_resp_header.flags.set_response();
read_resp_header.credits = 32;
read_resp_header.status = NtStatus::INSUFFICIENT_RESOURCES;
let read_resp = pack_message(
&read_resp_header,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_resp_header = Header::new_request(Command::Close);
close_resp_header.flags.set_response();
close_resp_header.credits = 32;
close_resp_header.status = NtStatus::INSUFFICIENT_RESOURCES;
let close_resp = pack_message(
&close_resp_header,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
mock.queue_response(build_close_response());
let result = tree.read_file_compound(&mut conn, "problem.txt").await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(NtStatus::INSUFFICIENT_RESOURCES));
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn read_file_compound_sends_correct_request_structure() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 1,
volatile: 2,
};
let create_resp = build_create_response(file_id, 5);
let read_resp = build_read_response(NtStatus::SUCCESS, vec![1, 2, 3, 4, 5]);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, read_resp, close_resp]);
mock.queue_response(frame);
tree.read_file_compound(&mut conn, "verify.txt")
.await
.unwrap();
let compound = mock.sent_message(1).unwrap();
let mut cursor = ReadCursor::new(&compound);
let h1 = Header::unpack(&mut cursor).unwrap();
assert_eq!(h1.command, Command::Create);
assert!(!h1.flags.is_related());
assert!(h1.next_command > 0);
assert_eq!(h1.tree_id, Some(TreeId(7)));
let off2 = h1.next_command as usize;
let mut cursor2 = ReadCursor::new(&compound[off2..]);
let h2 = Header::unpack(&mut cursor2).unwrap();
assert_eq!(h2.command, Command::Read);
assert!(h2.flags.is_related());
assert!(h2.next_command > 0);
let read_parsed = ReadRequest::unpack(&mut cursor2).unwrap();
assert_eq!(read_parsed.file_id, FileId::SENTINEL);
let off3 = off2 + h2.next_command as usize;
let mut cursor3 = ReadCursor::new(&compound[off3..]);
let h3 = Header::unpack(&mut cursor3).unwrap();
assert_eq!(h3.command, Command::Close);
assert!(h3.flags.is_related());
assert_eq!(h3.next_command, 0);
let close_parsed = CloseRequest::unpack(&mut cursor3).unwrap();
assert_eq!(close_parsed.file_id, FileId::SENTINEL);
}
#[tokio::test]
async fn write_file_compound_returns_bytes_written() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 0x42,
volatile: 0x99,
};
let file_data = b"Hello, compound write!";
let create_resp = build_create_response(file_id, 0);
let write_resp = build_write_response(file_data.len() as u32);
let flush_resp = build_flush_response();
let close_resp = build_close_response();
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
let written = tree
.write_file_compound(&mut conn, "test.txt", file_data)
.await
.unwrap();
assert_eq!(written, file_data.len() as u64);
assert_eq!(mock.sent_count(), 2); }
#[tokio::test]
async fn write_file_compound_empty_file() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 1,
volatile: 2,
};
let create_resp = build_create_response(file_id, 0);
let write_resp = build_write_response(0);
let flush_resp = build_flush_response();
let close_resp = build_close_response();
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
let written = tree
.write_file_compound(&mut conn, "empty.txt", b"")
.await
.unwrap();
assert_eq!(written, 0);
}
#[tokio::test]
async fn write_file_compound_create_failure_returns_error() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let mut create_h = Header::new_request(Command::Create);
create_h.flags.set_response();
create_h.credits = 32;
create_h.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_resp = pack_message(
&create_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut write_h = Header::new_request(Command::Write);
write_h.flags.set_response();
write_h.credits = 32;
write_h.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let write_resp = pack_message(
&write_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut flush_h = Header::new_request(Command::Flush);
flush_h.flags.set_response();
flush_h.credits = 32;
flush_h.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let flush_resp = pack_message(
&flush_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_h = Header::new_request(Command::Close);
close_h.flags.set_response();
close_h.credits = 32;
close_h.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let close_resp = pack_message(
&close_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
let result = tree
.write_file_compound(&mut conn, "bad/path.txt", b"data")
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND));
}
#[tokio::test]
async fn write_file_compound_write_failure_issues_standalone_close() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 0x42,
volatile: 0x99,
};
let create_resp = build_create_response(file_id, 0);
let mut write_h = Header::new_request(Command::Write);
write_h.flags.set_response();
write_h.credits = 32;
write_h.status = NtStatus::INSUFFICIENT_RESOURCES;
let write_resp = pack_message(
&write_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut flush_h = Header::new_request(Command::Flush);
flush_h.flags.set_response();
flush_h.credits = 32;
flush_h.status = NtStatus::INSUFFICIENT_RESOURCES;
let flush_resp = pack_message(
&flush_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let mut close_h = Header::new_request(Command::Close);
close_h.flags.set_response();
close_h.credits = 32;
close_h.status = NtStatus::INSUFFICIENT_RESOURCES;
let close_resp = pack_message(
&close_h,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
mock.queue_response(build_close_response());
let result = tree
.write_file_compound(&mut conn, "problem.txt", b"data")
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), Some(NtStatus::INSUFFICIENT_RESOURCES));
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn write_file_compound_sends_correct_request_structure() {
let mock = Arc::new(MockTransport::new());
let mut conn = setup_connection(&mock);
mock.queue_response(build_tree_connect_response(TreeId(7), ShareType::Disk));
let tree = Tree::connect(&mut conn, "share").await.unwrap();
let file_id = FileId {
persistent: 1,
volatile: 2,
};
let create_resp = build_create_response(file_id, 0);
let write_resp = build_write_response(5);
let flush_resp = build_flush_response();
let close_resp = build_close_response();
let frame =
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
mock.queue_response(frame);
tree.write_file_compound(&mut conn, "verify.txt", &[1, 2, 3, 4, 5])
.await
.unwrap();
let compound = mock.sent_message(1).unwrap();
let mut cursor = ReadCursor::new(&compound);
let h1 = Header::unpack(&mut cursor).unwrap();
assert_eq!(h1.command, Command::Create);
assert!(!h1.flags.is_related());
assert!(h1.next_command > 0);
assert_eq!(h1.tree_id, Some(TreeId(7)));
let off2 = h1.next_command as usize;
let mut cursor2 = ReadCursor::new(&compound[off2..]);
let h2 = Header::unpack(&mut cursor2).unwrap();
assert_eq!(h2.command, Command::Write);
assert!(h2.flags.is_related());
assert!(h2.next_command > 0);
let write_parsed = WriteRequest::unpack(&mut cursor2).unwrap();
assert_eq!(write_parsed.file_id, FileId::SENTINEL);
assert_eq!(write_parsed.data, vec![1, 2, 3, 4, 5]);
let off3 = off2 + h2.next_command as usize;
let mut cursor3 = ReadCursor::new(&compound[off3..]);
let h3 = Header::unpack(&mut cursor3).unwrap();
assert_eq!(h3.command, Command::Flush);
assert!(h3.flags.is_related());
assert!(h3.next_command > 0);
let flush_parsed = FlushRequest::unpack(&mut cursor3).unwrap();
assert_eq!(flush_parsed.file_id, FileId::SENTINEL);
let off4 = off3 + h3.next_command as usize;
let mut cursor4 = ReadCursor::new(&compound[off4..]);
let h4 = Header::unpack(&mut cursor4).unwrap();
assert_eq!(h4.command, Command::Close);
assert!(h4.flags.is_related());
assert_eq!(h4.next_command, 0);
let close_parsed = CloseRequest::unpack(&mut cursor4).unwrap();
assert_eq!(close_parsed.file_id, FileId::SENTINEL);
}
#[tokio::test]
async fn stat_accepts_buffer_overflow_as_partial_data() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xCC,
volatile: 0xDD,
};
let create_resp = build_create_response(file_id, 0);
let basic = build_file_basic_info(
132_000_000_000_000_000,
132_100_000_000_000_000,
133_000_000_000_000_000,
133_000_000_000_000_000,
0x20, );
let basic_resp = build_query_info_response_with_status(NtStatus::BUFFER_OVERFLOW, basic);
let std_info = build_file_standard_info(4096, 1024, 1, false, false);
let std_resp = build_query_info_response(std_info);
let close_resp = build_close_response();
let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]);
mock.queue_response(frame);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let info = tree.stat(&mut conn, "partial.txt").await.unwrap();
assert_eq!(info.size, 1024);
assert!(!info.is_directory);
assert_eq!(info.created, FileTime(132_000_000_000_000_000));
assert_eq!(mock.sent_count(), 1);
}
#[tokio::test]
async fn write_file_streamed_basic() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xAA,
volatile: 0xBB,
};
let chunk1 = vec![0x01; 100];
let chunk2 = vec![0x02; 200];
let chunk3 = vec![0x03; 150];
let chunks = vec![Ok(chunk1.clone()), Ok(chunk2.clone()), Ok(chunk3.clone())];
let mut chunk_iter = chunks.into_iter();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_response(200));
mock.queue_response(build_write_response(150));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(30),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut next_chunk =
move || -> Option<std::result::Result<Vec<u8>, std::io::Error>> { chunk_iter.next() };
let written = tree
.write_file_streamed(&mut conn, "streamed.bin", &mut next_chunk)
.await
.unwrap();
assert_eq!(written, 450);
assert_eq!(mock.sent_count(), 6);
let sent1 = mock.sent_message(1).unwrap();
let mut cursor1 = ReadCursor::new(&sent1);
let _ = Header::unpack(&mut cursor1).unwrap();
let req1 = WriteRequest::unpack(&mut cursor1).unwrap();
assert_eq!(req1.offset, 0);
assert_eq!(req1.data, chunk1);
let sent2 = mock.sent_message(2).unwrap();
let mut cursor2 = ReadCursor::new(&sent2);
let _ = Header::unpack(&mut cursor2).unwrap();
let req2 = WriteRequest::unpack(&mut cursor2).unwrap();
assert_eq!(req2.offset, 100);
assert_eq!(req2.data, chunk2);
let sent3 = mock.sent_message(3).unwrap();
let mut cursor3 = ReadCursor::new(&sent3);
let _ = Header::unpack(&mut cursor3).unwrap();
let req3 = WriteRequest::unpack(&mut cursor3).unwrap();
assert_eq!(req3.offset, 300);
assert_eq!(req3.data, chunk3);
let sent5 = mock.sent_message(5).unwrap();
let mut cursor5 = ReadCursor::new(&sent5);
let h5 = Header::unpack(&mut cursor5).unwrap();
assert_eq!(h5.command, Command::Close);
}
#[tokio::test]
async fn write_file_streamed_empty() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xCC,
volatile: 0xDD,
};
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(31),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut next_chunk = || -> Option<std::result::Result<Vec<u8>, std::io::Error>> { None };
let written = tree
.write_file_streamed(&mut conn, "empty_stream.bin", &mut next_chunk)
.await
.unwrap();
assert_eq!(written, 0);
assert_eq!(mock.sent_count(), 3);
let sent0 = mock.sent_message(0).unwrap();
let mut c0 = ReadCursor::new(&sent0);
let h0 = Header::unpack(&mut c0).unwrap();
assert_eq!(h0.command, Command::Create);
let sent1 = mock.sent_message(1).unwrap();
let mut c1 = ReadCursor::new(&sent1);
let h1 = Header::unpack(&mut c1).unwrap();
assert_eq!(h1.command, Command::Flush);
let sent2 = mock.sent_message(2).unwrap();
let mut c2 = ReadCursor::new(&sent2);
let h2 = Header::unpack(&mut c2).unwrap();
assert_eq!(h2.command, Command::Close);
}
#[tokio::test]
async fn write_file_streamed_callback_error() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xEE,
volatile: 0xFF,
};
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(64));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(32),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut call_count = 0u32;
let mut next_chunk = move || -> Option<std::result::Result<Vec<u8>, std::io::Error>> {
call_count += 1;
match call_count {
1 => Some(Ok(vec![0x42; 64])),
2 => Some(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"source stream broke",
))),
_ => None,
}
};
let result = tree
.write_file_streamed(&mut conn, "error_stream.bin", &mut next_chunk)
.await;
assert!(result.is_err(), "expected error from callback to propagate");
assert_eq!(mock.sent_count(), 3);
let sent_last = mock.sent_message(2).unwrap();
let mut cl = ReadCursor::new(&sent_last);
let hl = Header::unpack(&mut cl).unwrap();
assert_eq!(hl.command, Command::Close);
}
#[tokio::test]
async fn write_file_streamed_callback_error_is_not_connection_lost() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0x11,
volatile: 0x22,
};
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(64));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(40),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut call_count = 0u32;
let mut next_chunk = move || -> Option<std::result::Result<Vec<u8>, std::io::Error>> {
call_count += 1;
match call_count {
1 => Some(Ok(vec![0x42; 64])),
2 => Some(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"user cancelled",
))),
_ => None,
}
};
let err = tree
.write_file_streamed(&mut conn, "cancel_test.bin", &mut next_chunk)
.await
.unwrap_err();
assert_ne!(
err.kind(),
crate::ErrorKind::ConnectionLost,
"callback error should not be classified as ConnectionLost; the connection is still healthy"
);
}
#[tokio::test]
async fn write_file_streamed_callback_error_connection_still_usable() {
let mock = Arc::new(MockTransport::new());
let write_file_id = FileId {
persistent: 0x33,
volatile: 0x44,
};
let read_file_id = FileId {
persistent: 0x55,
volatile: 0x66,
};
mock.queue_response(build_create_response(write_file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_response(200));
mock.queue_response(build_close_response());
let read_data = b"hello from the server";
mock.queue_response(build_compound_read_response(
read_file_id,
read_data.to_vec(),
));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(41),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut call_count = 0u32;
let mut next_chunk = move || -> Option<std::result::Result<Vec<u8>, std::io::Error>> {
call_count += 1;
match call_count {
1 => Some(Ok(vec![0xAA; 100])),
2 => Some(Ok(vec![0xBB; 200])),
3 => Some(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled by user",
))),
_ => None,
}
};
let write_result = tree
.write_file_streamed(&mut conn, "partial.bin", &mut next_chunk)
.await;
assert!(
write_result.is_err(),
"write should fail due to callback error"
);
let data = tree
.read_file_compound(&mut conn, "other.txt")
.await
.unwrap();
assert_eq!(data, read_data);
}
fn build_compound_read_response(file_id: FileId, data: Vec<u8>) -> Vec<u8> {
use crate::msg::read::ReadResponse;
let mut h1 = Header::new_request(Command::Create);
h1.flags.set_response();
h1.credits = 32;
let create_body = CreateResponse {
oplock_level: OplockLevel::None,
flags: 0u8,
create_action: CreateAction::FileOpened,
creation_time: crate::pack::FileTime(0),
last_access_time: crate::pack::FileTime(0),
last_write_time: crate::pack::FileTime(0),
change_time: crate::pack::FileTime(0),
allocation_size: 0,
end_of_file: data.len() as u64,
file_attributes: 0x80,
file_id,
create_contexts: vec![],
};
let create_bytes = pack_message(&h1, &create_body);
let mut h2 = Header::new_request(Command::Read);
h2.flags.set_response();
h2.credits = 32;
let read_body = ReadResponse {
data_offset: 0x50,
data: data.clone(),
data_remaining: 0,
flags: 0,
};
let read_bytes = pack_message(&h2, &read_body);
let close_bytes = build_close_response();
let mut frame = Vec::new();
let mut create_buf = create_bytes;
let create_len = create_buf.len();
let padded_create_len = (create_len + 7) & !7;
create_buf.resize(padded_create_len, 0);
let next_cmd = padded_create_len as u32;
create_buf[20..24].copy_from_slice(&next_cmd.to_le_bytes());
frame.extend_from_slice(&create_buf);
let mut read_buf = read_bytes;
let read_len = read_buf.len();
let padded_read_len = (read_len + 7) & !7;
read_buf.resize(padded_read_len, 0);
let next_cmd2 = padded_read_len as u32;
read_buf[20..24].copy_from_slice(&next_cmd2.to_le_bytes());
frame.extend_from_slice(&read_buf);
frame.extend_from_slice(&close_bytes);
frame
}
#[tokio::test]
async fn tree_download_streams_small_file() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xA1,
volatile: 0xB2,
};
let payload = b"streaming hello from Tree::download".to_vec();
mock.queue_response(build_create_response(file_id, payload.len() as u64));
mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone()));
mock.queue_response(build_close_response());
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(11),
share_name: "share".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut download = tree
.download(&mut conn, "hello.txt")
.await
.expect("download");
assert_eq!(download.size(), payload.len() as u64);
let mut received = Vec::new();
while let Some(chunk) = download.next_chunk().await {
let bytes = chunk.expect("chunk");
received.extend_from_slice(&bytes);
}
assert_eq!(received, payload);
assert_eq!(mock.sent_count(), 3);
mock.assert_fully_consumed();
}
#[tokio::test]
async fn tree_download_create_failure_returns_protocol_error() {
let mock = Arc::new(MockTransport::new());
let mut create_hdr = Header::new_request(Command::Create);
create_hdr.flags.set_response();
create_hdr.credits = 32;
create_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let create_err = pack_message(
&create_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
mock.queue_response(create_err);
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(12),
share_name: "share".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let result = tree.download(&mut conn, "missing.txt").await;
let err = result.err().expect("expected error");
assert_eq!(err.status(), Some(NtStatus::OBJECT_NAME_NOT_FOUND));
}
#[tokio::test]
async fn tree_download_drop_mid_stream_does_not_panic() {
let mock = Arc::new(MockTransport::new());
let file_id = FileId {
persistent: 0xC3,
volatile: 0xD4,
};
let total = 3 * 65536usize;
mock.queue_response(build_create_response(file_id, total as u64));
mock.queue_response(build_read_response(NtStatus::SUCCESS, vec![0xAB; 65536]));
let mut conn = setup_connection(&mock);
let tree = Tree {
tree_id: TreeId(13),
share_name: "share".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
};
let mut download = tree.download(&mut conn, "big.bin").await.expect("download");
let first = download
.next_chunk()
.await
.expect("first chunk exists")
.expect("first chunk ok");
assert_eq!(first.len(), 65536);
drop(download);
}
#[tokio::test(flavor = "multi_thread")]
async fn tree_download_concurrent_on_cloned_connections() {
use std::time::{Duration, Instant};
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let params = crate::client::connection::NegotiatedParams {
dialect: crate::types::Dialect::Smb2_0_2,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: crate::pack::Guid::ZERO,
signing_required: false,
capabilities: crate::types::flags::Capabilities::default(),
gmac_negotiated: false,
cipher: None,
compression_supported: false,
};
let mut conn_primary = crate::client::connection::Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn_primary.set_test_params(params);
conn_primary.set_session_id(crate::types::SessionId(0x1234));
let mut conn_secondary = conn_primary.clone();
let tree = Arc::new(Tree {
tree_id: TreeId(14),
share_name: "share".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
});
let payload = b"shared-body-for-both-readers".to_vec();
let file_id_1 = FileId {
persistent: 0x0A,
volatile: 0x1A,
};
let file_id_2 = FileId {
persistent: 0x0B,
volatile: 0x1B,
};
let tree_a = Arc::clone(&tree);
let payload_a = payload.clone();
let handle_a = tokio::spawn(async move {
let mut dl = tree_a
.download(&mut conn_primary, "same.txt")
.await
.expect("download a");
let mut buf = Vec::new();
while let Some(c) = dl.next_chunk().await {
buf.extend_from_slice(&c.expect("chunk a"));
}
assert_eq!(buf, payload_a);
});
let tree_b = Arc::clone(&tree);
let payload_b = payload.clone();
let handle_b = tokio::spawn(async move {
let mut dl = tree_b
.download(&mut conn_secondary, "same.txt")
.await
.expect("download b");
let mut buf = Vec::new();
while let Some(c) = dl.next_chunk().await {
buf.extend_from_slice(&c.expect("chunk b"));
}
assert_eq!(buf, payload_b);
});
let deadline = Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 2 {
if Instant::now() > deadline {
panic!("CREATE sends did not land: {}", mock.sent_count());
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
mock.queue_response(build_create_response(file_id_1, payload.len() as u64));
mock.queue_response(build_create_response(file_id_2, payload.len() as u64));
let deadline = Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 4 {
if Instant::now() > deadline {
panic!("READ sends did not land: {}", mock.sent_count());
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone()));
mock.queue_response(build_read_response(NtStatus::SUCCESS, payload.clone()));
let deadline = Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 6 {
if Instant::now() > deadline {
panic!("CLOSE sends did not land: {}", mock.sent_count());
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
mock.queue_response(build_close_response());
mock.queue_response(build_close_response());
handle_a.await.expect("task a panicked");
handle_b.await.expect("task b panicked");
assert_eq!(mock.sent_count(), 6); }
}