use std::sync::Arc;
use std::time::{Duration, Instant};
use smb2::{ClientConfig, SmbClient};
use tokio::time::timeout;
const MAXREAD_ADDR: &str = "127.0.0.1:10454";
const SLOW_ADDR: &str = "127.0.0.1:10451";
const SHARE: &str = "public";
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const HANG_BUDGET: Duration = Duration::from_secs(30);
const SLOW_HANG_BUDGET: Duration = Duration::from_secs(60);
fn guest_config(addr: &str) -> ClientConfig {
ClientConfig {
addr: addr.to_string(),
timeout: CONNECT_TIMEOUT,
username: String::new(),
password: String::new(),
domain: String::new(),
auto_reconnect: false,
compression: true,
dfs_enabled: false,
dfs_target_overrides: std::collections::HashMap::new(),
}
}
fn make_payload(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 199) as u8).collect()
}
async fn fresh_client(addr: &str) -> (SmbClient, smb2::Tree) {
let mut client = SmbClient::connect(guest_config(addr))
.await
.expect("SmbClient::connect failed (is the Docker container up?)");
let tree = client
.connect_share(SHARE)
.await
.expect("connect_share('public') failed");
(client, tree)
}
async fn drive_writer(
client: &mut SmbClient,
tree: &smb2::Tree,
path: &str,
payload: &[u8],
chunk_size: usize,
) -> u64 {
let mut writer = client
.create_file_writer(tree, path)
.await
.unwrap_or_else(|e| panic!("create_file_writer({}) failed: {:?}", path, e));
for chunk in payload.chunks(chunk_size) {
writer
.write_chunk(chunk)
.await
.unwrap_or_else(|e| panic!("write_chunk on {} failed: {:?}", path, e));
}
writer
.finish()
.await
.unwrap_or_else(|e| panic!("finish() on {} failed: {:?}", path, e))
}
async fn run_concurrent_writers(
addr: &str,
n: usize,
prefix: &str,
payload: Arc<Vec<u8>>,
chunk_size: usize,
budget: Duration,
) -> Result<Duration, ()> {
let start = Instant::now();
let prefix = prefix.to_string();
let run = async move {
let mut joins = Vec::with_capacity(n);
for i in 0..n {
let payload = Arc::clone(&payload);
let path = format!("{}_{}.bin", prefix, i);
let addr = addr.to_string();
joins.push(tokio::spawn(async move {
let (mut client, tree) = fresh_client(&addr).await;
let bytes = drive_writer(&mut client, &tree, &path, &payload, chunk_size).await;
assert_eq!(
bytes,
payload.len() as u64,
"writer {} wrote wrong byte count",
i
);
i
}));
}
for j in joins {
j.await.expect("writer task panicked");
}
};
match timeout(budget, run).await {
Ok(()) => Ok(start.elapsed()),
Err(_) => Err(()),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore]
async fn concurrent_writers_via_smbclient_finish_all_n8() {
let _ = env_logger::try_init();
let n: usize = 8;
let payload = Arc::new(make_payload(256 * 1024));
let result = run_concurrent_writers(
MAXREAD_ADDR,
n,
"concurrent_test_n8",
payload,
256 * 1024, HANG_BUDGET,
)
.await;
match result {
Ok(elapsed) => {
eprintln!(
"concurrent_writers_via_smbclient_finish_all_n8: PASSED in {:?} (n={})",
elapsed, n
);
}
Err(()) => panic!(
"concurrent writers (N={}) hung past {:?} — FileWriter::finish deadlock reproduced",
n, HANG_BUDGET
),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore]
async fn concurrent_writers_via_smbclient_high_concurrency_finish_all() {
let _ = env_logger::try_init();
let n: usize = 32;
let payload = Arc::new(make_payload(256 * 1024));
let result = run_concurrent_writers(
MAXREAD_ADDR,
n,
"concurrent_test_n32",
payload,
256 * 1024,
HANG_BUDGET,
)
.await;
match result {
Ok(elapsed) => {
eprintln!(
"concurrent_writers_via_smbclient_high_concurrency_finish_all: PASSED in {:?} (n={})",
elapsed, n
);
}
Err(()) => panic!(
"concurrent writers (N={}) hung past {:?} — high-concurrency deadlock reproduced",
n, HANG_BUDGET
),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore]
async fn concurrent_writers_interleaved_with_stats() {
let _ = env_logger::try_init();
let n: usize = 8;
let payload = Arc::new(make_payload(256 * 1024));
let stop = Arc::new(tokio::sync::Notify::new());
let stop_for_stats = Arc::clone(&stop);
let stat_handle = tokio::spawn(async move {
let (mut client, mut tree) = fresh_client(MAXREAD_ADDR).await;
let stat_paths = [
"stat_target_a.bin",
"stat_target_b.bin",
"stat_target_c.bin",
"concurrent_test_mixed_0.bin",
"concurrent_test_mixed_4.bin",
"concurrent_test_mixed_7.bin",
"this_file_does_not_exist.bin",
];
let mut i = 0usize;
let mut stats = 0u64;
loop {
tokio::select! {
biased;
_ = stop_for_stats.notified() => break,
_ = async {
let path = stat_paths[i % stat_paths.len()];
let _ = client.stat(&mut tree, path).await;
i = i.wrapping_add(1);
stats += 1;
tokio::task::yield_now().await;
} => {}
}
}
stats
});
let writer_result = run_concurrent_writers(
MAXREAD_ADDR,
n,
"concurrent_test_mixed",
payload,
256 * 1024,
HANG_BUDGET,
)
.await;
stop.notify_waiters();
let stats_done = stat_handle.await.unwrap_or(0);
match writer_result {
Ok(elapsed) => {
eprintln!(
"concurrent_writers_interleaved_with_stats: PASSED in {:?} (n={}, stats={})",
elapsed, n, stats_done
);
}
Err(()) => panic!(
"interleaved writers+stats (N={}) hung past {:?} — mixed-traffic deadlock reproduced",
n, HANG_BUDGET
),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore]
async fn concurrent_writers_slow_fixture_finish_all() {
let _ = env_logger::try_init();
let n: usize = 8;
let payload = Arc::new(make_payload(256 * 1024));
let result = run_concurrent_writers(
SLOW_ADDR,
n,
"concurrent_test_slow",
payload,
256 * 1024,
SLOW_HANG_BUDGET,
)
.await;
match result {
Ok(elapsed) => {
eprintln!(
"concurrent_writers_slow_fixture_finish_all: PASSED in {:?} (n={})",
elapsed, n
);
}
Err(()) => panic!(
"slow-fixture concurrent writers (N={}) hung past {:?} — \
deadlock reproduced with 200ms RTT",
n, SLOW_HANG_BUDGET
),
}
}