use std::ops::ControlFlow;
use std::sync::Arc;
use log::debug;
use crate::client::connection::Connection;
use crate::client::tree::Tree;
use crate::error::Result;
use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE};
use crate::msg::write::{WriteRequest, WriteResponse};
use crate::pack::{ReadCursor, Unpack};
use crate::types::status::NtStatus;
use crate::types::{Command, FileId};
use crate::Error;
const MAX_PIPELINE_WINDOW: usize = 32;
#[derive(Debug, Clone, Copy)]
pub struct Progress {
pub bytes_transferred: u64,
pub total_bytes: Option<u64>,
}
impl Progress {
#[must_use]
pub fn percent(&self) -> f64 {
self.fraction() * 100.0
}
#[must_use]
pub fn fraction(&self) -> f64 {
match self.total_bytes {
Some(total) if total > 0 => self.bytes_transferred as f64 / total as f64,
Some(_) => 1.0, None => 0.0,
}
}
}
pub struct FileDownload<'a> {
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
file_size: u64,
bytes_received: u64,
chunk_size: u32,
done: bool,
}
impl<'a> FileDownload<'a> {
pub fn new(
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
file_size: u64,
chunk_size: u32,
) -> Self {
Self {
tree,
conn,
file_id,
file_size,
bytes_received: 0,
chunk_size,
done: false,
}
}
#[must_use]
pub fn size(&self) -> u64 {
self.file_size
}
#[must_use]
pub fn bytes_received(&self) -> u64 {
self.bytes_received
}
#[must_use]
pub fn progress(&self) -> Progress {
Progress {
bytes_transferred: self.bytes_received,
total_bytes: Some(self.file_size),
}
}
pub async fn next_chunk(&mut self) -> Option<Result<Vec<u8>>> {
if self.done {
return None;
}
let remaining = self.file_size.saturating_sub(self.bytes_received);
if remaining == 0 {
let close_result = self.close().await;
if let Err(e) = close_result {
return Some(Err(e));
}
return None;
}
let this_chunk = remaining.min(self.chunk_size as u64) as u32;
let req = ReadRequest {
padding: 0x50,
flags: 0,
length: this_chunk,
offset: self.bytes_received,
file_id: self.file_id,
minimum_count: 0,
channel: SMB2_CHANNEL_NONE,
remaining_bytes: 0,
read_channel_info: vec![],
};
let credit_charge = (this_chunk as u64).div_ceil(65536).max(1) as u16;
let exec_result = self
.conn
.execute_with_credits(
Command::Read,
&req,
Some(self.tree.tree_id),
crate::types::CreditCharge(credit_charge),
)
.await;
match exec_result {
Err(e) => {
self.done = true;
Some(Err(e))
}
Ok(frame) => {
if frame.header.status == NtStatus::END_OF_FILE {
let _ = self.close().await;
return None;
}
if frame.header.status != NtStatus::SUCCESS {
self.done = true;
return Some(Err(Error::Protocol {
status: frame.header.status,
command: Command::Read,
}));
}
let mut cursor = ReadCursor::new(&frame.body);
match ReadResponse::unpack(&mut cursor) {
Err(e) => {
self.done = true;
Some(Err(e))
}
Ok(resp) => {
if resp.data.is_empty() {
let _ = self.close().await;
return None;
}
self.bytes_received += resp.data.len() as u64;
if self.bytes_received >= self.file_size {
if let Err(e) = self.close().await {
return Some(Err(e));
}
}
Some(Ok(resp.data))
}
}
}
}
}
pub async fn collect_with_progress<F>(mut self, mut on_progress: F) -> Result<Vec<u8>>
where
F: FnMut(Progress) -> ControlFlow<()>,
{
let mut data = Vec::with_capacity(self.file_size as usize);
while let Some(result) = self.next_chunk().await {
let chunk = result?;
data.extend_from_slice(&chunk);
if let ControlFlow::Break(()) = on_progress(self.progress()) {
let _ = self.close().await;
return Err(Error::Cancelled);
}
}
Ok(data)
}
pub async fn collect(mut self) -> Result<Vec<u8>> {
let mut data = Vec::with_capacity(self.file_size as usize);
while let Some(result) = self.next_chunk().await {
let chunk = result?;
data.extend_from_slice(&chunk);
}
Ok(data)
}
async fn close(&mut self) -> Result<()> {
if self.done {
return Ok(());
}
self.done = true;
self.tree.close_handle(self.conn, self.file_id).await
}
}
impl Drop for FileDownload<'_> {
fn drop(&mut self) {
if !self.done {
debug!(
"stream: FileDownload dropped before completion, file handle may leak \
(bytes_received={}/{})",
self.bytes_received, self.file_size
);
}
}
}
pub struct FileUpload<'a> {
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
data: &'a [u8],
total_bytes: u64,
bytes_written: u64,
chunk_size: u32,
done: bool,
}
impl<'a> FileUpload<'a> {
pub(crate) fn new(
tree: &'a Tree,
conn: &'a mut Connection,
file_id: FileId,
data: &'a [u8],
chunk_size: u32,
) -> Self {
Self {
tree,
conn,
file_id,
data,
total_bytes: data.len() as u64,
bytes_written: 0,
chunk_size,
done: false,
}
}
pub(crate) fn new_done(tree: &'a Tree, conn: &'a mut Connection, total_bytes: u64) -> Self {
Self {
tree,
conn,
file_id: FileId::SENTINEL,
data: &[],
total_bytes,
bytes_written: total_bytes,
chunk_size: 0,
done: true,
}
}
#[must_use]
pub fn total_bytes(&self) -> u64 {
self.total_bytes
}
#[must_use]
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
#[must_use]
pub fn progress(&self) -> Progress {
Progress {
bytes_transferred: self.bytes_written,
total_bytes: Some(self.total_bytes),
}
}
pub async fn write_next_chunk(&mut self) -> Result<bool> {
if self.done {
return Ok(false);
}
let offset = self.bytes_written as usize;
if offset >= self.data.len() {
self.flush_and_close().await?;
return Ok(false);
}
let remaining = self.data.len() - offset;
let this_chunk = remaining.min(self.chunk_size as usize);
let chunk = &self.data[offset..offset + this_chunk];
let write_req = WriteRequest {
data_offset: 0x70,
offset: offset as u64,
file_id: self.file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data: chunk.to_vec(),
};
let credit_charge = (this_chunk as u64).div_ceil(65536).max(1) as u16;
let exec_result = self
.conn
.execute_with_credits(
Command::Write,
&write_req,
Some(self.tree.tree_id),
crate::types::CreditCharge(credit_charge),
)
.await;
match exec_result {
Err(e) => {
self.done = true;
Err(e)
}
Ok(frame) => {
if frame.header.status != NtStatus::SUCCESS {
self.done = true;
let _ = self.tree.close_handle(self.conn, self.file_id).await;
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
self.bytes_written += resp.count as u64;
if self.bytes_written >= self.total_bytes {
self.flush_and_close().await?;
return Ok(false);
}
Ok(true)
}
}
}
async fn flush_and_close(&mut self) -> Result<()> {
if self.done {
return Ok(());
}
self.done = true;
self.tree.flush_handle(self.conn, self.file_id).await?;
self.tree.close_handle(self.conn, self.file_id).await
}
}
impl Drop for FileUpload<'_> {
fn drop(&mut self) {
if !self.done {
debug!(
"stream: FileUpload dropped before completion, file handle may leak \
(bytes_written={}/{})",
self.bytes_written, self.total_bytes
);
}
}
}
type BoxedWriteFut = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<crate::client::connection::Frame>> + Send>,
>;
pub struct FileWriter {
tree: Arc<Tree>,
conn: Connection,
file_id: FileId,
max_write_size: u32,
offset: u64,
in_flight: futures_util::stream::FuturesUnordered<BoxedWriteFut>,
total_written: u64,
pending_data: Vec<u8>,
pending_offset: usize,
stashed_chunk: Option<Vec<u8>>,
done: bool,
}
pub async fn open_file_writer(
tree: Arc<Tree>,
mut conn: Connection,
path: &str,
) -> Result<FileWriter> {
let normalized = tree.format_path(path);
debug!("stream: open_file_writer path={}", normalized);
let file_id = tree.open_file_for_write(&mut conn, &normalized).await?;
let max_write = conn.params().map(|p| p.max_write_size).unwrap_or(65536);
Ok(FileWriter::new(tree, conn, file_id, max_write))
}
impl FileWriter {
pub(crate) fn new(
tree: Arc<Tree>,
conn: Connection,
file_id: FileId,
max_write_size: u32,
) -> Self {
Self {
tree,
conn,
file_id,
max_write_size,
offset: 0,
in_flight: futures_util::stream::FuturesUnordered::new(),
total_written: 0,
pending_data: Vec::new(),
pending_offset: 0,
stashed_chunk: None,
done: false,
}
}
pub async fn write_chunk(&mut self, data: &[u8]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
if self.pending_offset < self.pending_data.len() {
let leftover = self.pending_data[self.pending_offset..].to_vec();
self.pending_data = leftover;
self.pending_offset = 0;
self.pending_data.extend_from_slice(data);
} else {
self.pending_data = data.to_vec();
self.pending_offset = 0;
}
self.flush_stash().await?;
while let Some(wire_chunk) = self.next_pending_chunk() {
if !self.send_or_stash(wire_chunk).await? {
return Ok(()); }
}
Ok(())
}
pub async fn finish(mut self) -> Result<u64> {
self.flush_stash().await?;
while let Some(wire_chunk) = self.next_pending_chunk() {
if !self.send_or_stash(wire_chunk).await? {
self.flush_stash().await?;
}
}
self.drain_all().await?;
self.tree.flush_handle(&mut self.conn, self.file_id).await?;
self.tree.close_handle(&mut self.conn, self.file_id).await?;
self.done = true;
Ok(self.total_written)
}
pub async fn abort(mut self) -> Result<u64> {
use futures_util::stream::StreamExt;
self.pending_data.clear();
self.pending_offset = 0;
self.stashed_chunk = None;
while let Some(result) = self.in_flight.next().await {
match result {
Ok(frame) => {
if frame.header.status == NtStatus::SUCCESS {
let mut cursor = ReadCursor::new(&frame.body);
if let Ok(resp) = WriteResponse::unpack(&mut cursor) {
self.total_written += resp.count as u64;
}
} else {
debug!(
"stream: FileWriter::abort() ignoring WRITE error status {:?}",
frame.header.status
);
}
}
Err(e) => {
debug!(
"stream: FileWriter::abort() giving up on remaining in-flight \
response(s) after transport error: {}",
e
);
break;
}
}
}
if let Err(e) = self.tree.close_handle(&mut self.conn, self.file_id).await {
debug!(
"stream: FileWriter::abort() best-effort CLOSE failed, handle may leak \
server-side until session teardown: {}",
e
);
}
self.done = true;
Ok(self.total_written)
}
#[must_use]
pub fn bytes_written(&self) -> u64 {
self.total_written
}
#[must_use]
pub fn progress(&self) -> Progress {
Progress {
bytes_transferred: self.total_written,
total_bytes: None,
}
}
fn next_pending_chunk(&mut self) -> Option<Vec<u8>> {
if self.pending_offset >= self.pending_data.len() {
return None;
}
let end = (self.pending_offset + self.max_write_size as usize).min(self.pending_data.len());
let slice = self.pending_data[self.pending_offset..end].to_vec();
self.pending_offset = end;
if self.pending_offset >= self.pending_data.len() {
self.pending_data.clear();
self.pending_offset = 0;
}
Some(slice)
}
fn launch_wire_chunk(&mut self, data: Vec<u8>) {
let data_len = data.len() as u64;
let credit_charge = data_len.div_ceil(65536).max(1) as u16;
let req = WriteRequest {
data_offset: 0x70,
offset: self.offset,
file_id: self.file_id,
channel: 0,
remaining_bytes: 0,
write_channel_info_offset: 0,
write_channel_info_length: 0,
flags: 0,
data,
};
let c = self.conn.clone();
let tree_id = self.tree.tree_id;
self.in_flight.push(Box::pin(async move {
c.execute_with_credits(
Command::Write,
&req,
Some(tree_id),
crate::types::CreditCharge(credit_charge),
)
.await
}));
self.offset += data_len;
}
async fn drain_one(&mut self) -> Result<()> {
use futures_util::stream::StreamExt;
let Some(result) = self.in_flight.next().await else {
return Ok(());
};
let frame = result?;
if frame.header.status != NtStatus::SUCCESS {
while self.in_flight.next().await.is_some() {}
let _ = self.tree.close_handle(&mut self.conn, self.file_id).await;
self.done = true;
return Err(Error::Protocol {
status: frame.header.status,
command: Command::Write,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = WriteResponse::unpack(&mut cursor)?;
self.total_written += resp.count as u64;
Ok(())
}
async fn drain_all(&mut self) -> Result<()> {
while !self.in_flight.is_empty() {
self.drain_one().await?;
}
Ok(())
}
fn can_send(&self, data: &[u8]) -> bool {
let credit_charge = (data.len() as u64).div_ceil(65536).max(1) as u16;
let credits_available = self.conn.credits() as usize / credit_charge.max(1) as usize;
credits_available > 0 && self.in_flight.len() < MAX_PIPELINE_WINDOW
}
async fn send_or_stash(&mut self, data: Vec<u8>) -> Result<bool> {
if self.in_flight.len() >= MAX_PIPELINE_WINDOW {
self.drain_one().await?;
}
if self.can_send(&data) {
self.launch_wire_chunk(data);
return Ok(true);
}
if !self.in_flight.is_empty() {
self.drain_one().await?;
if self.can_send(&data) {
self.launch_wire_chunk(data);
return Ok(true);
}
}
self.stashed_chunk = Some(data);
Ok(false)
}
async fn flush_stash(&mut self) -> Result<()> {
if let Some(stashed) = self.stashed_chunk.take() {
if !self.in_flight.is_empty() && !self.can_send(&stashed) {
self.drain_one().await?;
}
if self.can_send(&stashed) {
self.launch_wire_chunk(stashed);
} else {
self.stashed_chunk = Some(stashed);
}
}
Ok(())
}
}
impl Drop for FileWriter {
fn drop(&mut self) {
if !self.done {
debug!(
"stream: FileWriter dropped without finish(), file handle may leak \
(bytes_written={})",
self.total_written
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::test_helpers::{
build_close_error_response, build_close_response, build_create_response,
build_flush_response, build_write_error_response, build_write_response, setup_connection,
};
use crate::transport::MockTransport;
use crate::types::status::NtStatus;
use crate::types::{FileId, TreeId};
use std::sync::Arc;
fn test_tree() -> Arc<Tree> {
Arc::new(Tree {
tree_id: TreeId(10),
share_name: "test".to_string(),
server: "test-server".to_string(),
is_dfs: false,
encrypt_data: false,
})
}
fn test_file_id() -> FileId {
FileId {
persistent: 0xAA,
volatile: 0xBB,
}
}
#[tokio::test]
async fn file_writer_single_chunk() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
assert_eq!(writer.bytes_written(), 0); let total = writer.finish().await.unwrap();
assert_eq!(total, 100);
}
#[tokio::test]
async fn file_writer_multiple_chunks() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_response(100));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[1u8; 100]).await.unwrap();
writer.write_chunk(&[2u8; 100]).await.unwrap();
writer.write_chunk(&[3u8; 100]).await.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 300);
}
#[tokio::test]
async fn file_writer_empty_finish() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let writer = tree.create_file_writer(conn, "empty.bin").await.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 0);
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn file_writer_empty_chunk_noop() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(50));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[]).await.unwrap(); writer.write_chunk(&[0u8; 50]).await.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 50);
assert_eq!(mock.sent_count(), 4);
}
#[tokio::test]
async fn file_writer_chunk_splitting() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
let chunk_size = 200 * 1024;
let wire_1 = 65536u32;
let wire_2 = 65536u32;
let wire_3 = 65536u32;
let wire_4 = (chunk_size - 3 * 65536) as u32;
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(wire_1));
mock.queue_response(build_write_response(wire_2));
mock.queue_response(build_write_response(wire_3));
mock.queue_response(build_write_response(wire_4));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "big.bin").await.unwrap();
writer.write_chunk(&vec![0u8; chunk_size]).await.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, (wire_1 + wire_2 + wire_3 + wire_4) as u64);
assert_eq!(mock.sent_count(), 7);
}
#[tokio::test]
async fn file_writer_progress_none_total() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
let progress = writer.progress();
assert!(progress.total_bytes.is_none());
assert_eq!(progress.bytes_transferred, 0);
writer.finish().await.unwrap();
}
#[tokio::test]
async fn file_writer_bytes_written_tracks_confirmed() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_response(200));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
assert_eq!(writer.bytes_written(), 0);
writer.write_chunk(&[0u8; 200]).await.unwrap();
assert_eq!(writer.bytes_written(), 0);
let total = writer.finish().await.unwrap();
assert_eq!(total, 300);
}
#[tokio::test]
async fn file_writer_backpressure() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
for _ in 0..MAX_PIPELINE_WINDOW + 1 {
mock.queue_response(build_write_response(64));
}
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
for _ in 0..MAX_PIPELINE_WINDOW {
writer.write_chunk(&[0u8; 64]).await.unwrap();
}
writer.write_chunk(&[0u8; 64]).await.unwrap();
assert!(writer.bytes_written() >= 64);
let total = writer.finish().await.unwrap();
assert_eq!(total, (MAX_PIPELINE_WINDOW as u64 + 1) * 64);
}
#[tokio::test]
async fn file_writer_server_error() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_error_response(NtStatus::DISK_FULL));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
format!("{err:?}").contains("DISK_FULL"),
"expected DISK_FULL, got: {err:?}"
);
}
#[tokio::test]
async fn file_writer_finish_drains_all() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(50));
mock.queue_response(build_write_response(75));
mock.queue_response(build_write_response(25));
mock.queue_response(build_flush_response());
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 50]).await.unwrap();
writer.write_chunk(&[0u8; 75]).await.unwrap();
writer.write_chunk(&[0u8; 25]).await.unwrap();
assert_eq!(writer.bytes_written(), 0);
let total = writer.finish().await.unwrap();
assert_eq!(total, 150);
}
#[tokio::test]
async fn file_writer_abort_no_in_flight() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
let total = writer.abort().await.unwrap();
assert_eq!(total, 0);
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn file_writer_abort_drains_in_flight() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(50));
mock.queue_response(build_write_response(75));
mock.queue_response(build_write_response(25));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 50]).await.unwrap();
writer.write_chunk(&[0u8; 75]).await.unwrap();
writer.write_chunk(&[0u8; 25]).await.unwrap();
assert_eq!(writer.bytes_written(), 0);
let total = writer.abort().await.unwrap();
assert_eq!(total, 150);
assert_eq!(mock.sent_count(), 5);
}
#[tokio::test]
async fn file_writer_abort_swallows_write_errors() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_write_error_response(NtStatus::DISK_FULL));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
let total = writer.abort().await.unwrap();
assert_eq!(total, 100);
assert_eq!(mock.sent_count(), 4);
}
#[tokio::test]
async fn file_writer_abort_discards_stashed_chunk() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.stashed_chunk = Some(vec![0u8; 500]);
writer.pending_data = vec![0u8; 1000];
writer.pending_offset = 0;
let total = writer.abort().await.unwrap();
assert_eq!(total, 0);
assert_eq!(mock.sent_count(), 2);
}
#[tokio::test]
async fn file_writer_abort_close_error_is_swallowed() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_write_response(100));
mock.queue_response(build_close_error_response(NtStatus::FILE_CLOSED));
let conn = setup_connection(&mock);
let tree = test_tree();
let mut writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
writer.write_chunk(&[0u8; 100]).await.unwrap();
let result = writer.abort().await;
assert!(
result.is_ok(),
"abort() should swallow CLOSE errors, got: {result:?}"
);
assert_eq!(result.unwrap(), 100);
assert_eq!(mock.sent_count(), 3);
}
#[tokio::test]
async fn file_writer_abort_sets_done_so_drop_is_silent() {
let mock = Arc::new(MockTransport::new());
let file_id = test_file_id();
mock.queue_response(build_create_response(file_id, 0));
mock.queue_response(build_close_response());
let conn = setup_connection(&mock);
let tree = test_tree();
let writer = tree.create_file_writer(conn, "out.bin").await.unwrap();
let result = writer.abort().await;
assert!(result.is_ok());
}
#[test]
fn progress_calculations() {
let cases = [
(50, Some(100), 50.0, 0.5),
(100, Some(100), 100.0, 1.0),
(25, Some(100), 25.0, 0.25),
(0, Some(0), 100.0, 1.0), (50, None, 0.0, 0.0), ];
for (transferred, total, expected_pct, expected_frac) in cases {
let p = Progress {
bytes_transferred: transferred,
total_bytes: total,
};
assert_eq!(
p.percent(),
expected_pct,
"percent failed for {transferred}/{total:?}"
);
assert_eq!(
p.fraction(),
expected_frac,
"fraction failed for {transferred}/{total:?}"
);
}
let large = Progress {
bytes_transferred: u64::MAX / 2,
total_bytes: Some(u64::MAX),
};
let frac = large.fraction();
assert!(frac > 0.49 && frac < 0.51);
}
}