use std::fmt;
use std::time::Duration;
use crate::crypto::encryption::Cipher;
use crate::crypto::signing::SigningAlgorithm;
use crate::pack::Guid;
use crate::types::flags::Capabilities;
use crate::types::{Dialect, SessionId, TreeId};
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Diagnostics {
pub client: ClientInfo,
pub primary: ConnectionDiagnostics,
pub extra_connections: Vec<ConnectionDiagnostics>,
pub dfs_cache: Vec<DfsCacheEntry>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ClientInfo {
pub primary_server: String,
pub timeout: Duration,
pub auto_reconnect: bool,
pub dfs_enabled: bool,
pub metrics: ClientMetricsSnapshot,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ConnectionDiagnostics {
pub server: String,
pub negotiated: Option<NegotiatedSummary>,
pub credits: CreditInfo,
pub signing: SigningInfo,
pub encryption: EncryptionInfo,
pub compression: CompressionInfo,
pub rtt_estimate: Option<Duration>,
pub disconnected: bool,
pub dfs_trees: Vec<TreeId>,
pub session: Option<SessionDiagnostics>,
pub metrics: MetricsSnapshot,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct NegotiatedSummary {
pub dialect: Dialect,
pub max_read_size: u32,
pub max_write_size: u32,
pub max_transact_size: u32,
pub server_guid: Guid,
pub signing_required: bool,
pub capabilities: Capabilities,
pub gmac_negotiated: bool,
pub cipher: Option<Cipher>,
pub compression_supported: bool,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct CreditInfo {
pub available: u16,
pub in_flight: usize,
pub next_message_id: u64,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct SigningInfo {
pub active: bool,
pub algorithm: Option<SigningAlgorithm>,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct EncryptionInfo {
pub active: bool,
pub cipher: Option<Cipher>,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct CompressionInfo {
pub requested: bool,
pub negotiated: bool,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct SessionDiagnostics {
pub session_id: SessionId,
pub should_sign: bool,
pub should_encrypt: bool,
pub signing_algorithm: SigningAlgorithm,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct DfsCacheEntry {
pub path_prefix: String,
pub target_count: usize,
pub expires_in: Option<Duration>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct MetricsSnapshot {
pub requests_sent: u64,
pub compound_requests_sent: u64,
pub wire_bytes_sent: u64,
pub explicit_cancels_sent: u64,
pub responses_routed_ok: u64,
pub responses_routed_err: u64,
pub responses_late_after_drop: u64,
pub responses_stray: u64,
pub wire_bytes_received: u64,
pub status_pending_loops: u64,
pub unsolicited_notifications_received: u64,
pub signature_failures: u64,
pub decrypt_failures: u64,
pub decompress_failures: u64,
pub malformed_frames: u64,
pub session_expired_events: u64,
pub requests_returned_err: u64,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ClientMetricsSnapshot {
pub reconnects: u64,
pub dfs_referrals_resolved: u64,
pub dfs_cache_hits: u64,
}
impl fmt::Display for Diagnostics {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let c = &self.client;
writeln!(f, "SMB client → {}", c.primary_server)?;
writeln!(
f,
" reconnects: {} dfs: {} (hits: {}, referrals resolved: {}, cache entries: {})",
c.metrics.reconnects,
if c.dfs_enabled { "enabled" } else { "disabled" },
c.metrics.dfs_cache_hits,
c.metrics.dfs_referrals_resolved,
self.dfs_cache.len(),
)?;
writeln!(f)?;
writeln!(f, "Primary connection ({})", self.primary.server)?;
fmt_connection_body(&self.primary, f)?;
if !self.extra_connections.is_empty() {
writeln!(f)?;
writeln!(
f,
"DFS extra connections: ({})",
self.extra_connections.len()
)?;
for c in &self.extra_connections {
writeln!(f)?;
writeln!(f, " ↳ {}", c.server)?;
fmt_connection_body(c, f)?;
}
} else {
writeln!(f)?;
writeln!(f, "DFS extra connections: (0)")?;
}
Ok(())
}
}
fn fmt_connection_body(c: &ConnectionDiagnostics, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let m = &c.metrics;
match &c.negotiated {
Some(n) => {
let rtt = c
.rtt_estimate
.map(|d| format!("{:.1} ms", d.as_secs_f64() * 1000.0))
.unwrap_or_else(|| "—".to_string());
writeln!(f, " dialect: {:?} rtt: {}", n.dialect, rtt)?;
writeln!(
f,
" signing: {} encryption: {} compression: {}",
fmt_signing(&c.signing),
fmt_encryption(&c.encryption),
fmt_compression(&c.compression),
)?;
}
None => {
writeln!(
f,
" (pre-negotiate — no dialect / signing / encryption yet)"
)?;
}
}
writeln!(
f,
" credits: {} available · {} in flight · next msg_id {}",
c.credits.available, c.credits.in_flight, c.credits.next_message_id
)?;
writeln!(
f,
" wire bytes: {} sent · {} received",
m.wire_bytes_sent, m.wire_bytes_received
)?;
writeln!(
f,
" responses: {} ok · {} wire-err · {} late · {} stray (sent: {}, caller-err: {})",
m.responses_routed_ok,
m.responses_routed_err,
m.responses_late_after_drop,
m.responses_stray,
m.requests_sent,
m.requests_returned_err,
)?;
writeln!(
f,
" protocol events: {} status-pending · {} unsolicited · {} compound chains · {} cancels",
m.status_pending_loops,
m.unsolicited_notifications_received,
m.compound_requests_sent,
m.explicit_cancels_sent,
)?;
writeln!(
f,
" errors: {} signature · {} decrypt · {} decompress · {} malformed · {} session-expired",
m.signature_failures,
m.decrypt_failures,
m.decompress_failures,
m.malformed_frames,
m.session_expired_events,
)?;
if c.disconnected {
writeln!(f, " status: DISCONNECTED")?;
}
Ok(())
}
fn fmt_signing(s: &SigningInfo) -> String {
match (s.active, s.algorithm) {
(true, Some(algo)) => format!("active ({:?})", algo),
(true, None) => "active".to_string(),
(false, _) => "inactive".to_string(),
}
}
fn fmt_encryption(e: &EncryptionInfo) -> String {
match (e.active, e.cipher) {
(true, Some(c)) => format!("active ({:?})", c),
(true, None) => "active".to_string(),
(false, _) => "inactive".to_string(),
}
}
fn fmt_compression(c: &CompressionInfo) -> String {
match (c.requested, c.negotiated) {
(true, true) => "active".to_string(),
(true, false) => "requested, not negotiated".to_string(),
(false, true) => "active (not requested)".to_string(),
(false, false) => "off".to_string(),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use crate::client::connection::Connection;
use crate::msg::echo::{EchoRequest, EchoResponse};
use crate::msg::header::Header;
use crate::pack::Pack;
use crate::transport::mock::MockTransport;
use crate::types::status::NtStatus;
use crate::types::{Command, MessageId};
fn pack(header: &Header, body: &dyn Pack) -> Vec<u8> {
let mut cursor = crate::pack::WriteCursor::with_capacity(64 + 16);
header.pack(&mut cursor);
body.pack(&mut cursor);
cursor.into_inner()
}
fn echo_response(msg_id: MessageId, status: NtStatus) -> Vec<u8> {
let mut h = Header::new_request(Command::Echo);
h.flags.set_response();
h.credits = 10;
h.message_id = msg_id;
h.status = status;
pack(&h, &EchoResponse)
}
fn echo_ok(msg_id: MessageId) -> Vec<u8> {
echo_response(msg_id, NtStatus::SUCCESS)
}
async fn wait_for_sent(mock: &MockTransport, n: usize) {
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < n {
if std::time::Instant::now() > deadline {
panic!("expected {n} sent messages, got {}", mock.sent_count());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
fn fresh_conn() -> (Connection, Arc<MockTransport>) {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
(conn, mock)
}
#[tokio::test(flavor = "multi_thread")]
async fn requests_sent_and_wire_bytes_sent_tick_for_one_execute() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_ok(MessageId(0)));
handle.await.unwrap().unwrap();
let m = conn.metrics();
assert_eq!(m.requests_sent, 1, "one msg_id allocated → one request");
assert!(m.wire_bytes_sent > 0, "send wrote some bytes to the wire");
assert!(
m.wire_bytes_received > 0,
"receive read some bytes from the wire"
);
assert_eq!(m.responses_routed_ok, 1);
assert_eq!(m.responses_routed_err, 0);
assert_eq!(m.responses_late_after_drop, 0);
assert_eq!(m.responses_stray, 0);
assert_eq!(m.requests_returned_err, 0);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn requests_sent_ticks_per_sub_op_in_compound_and_compound_chain_counted() {
use crate::client::connection::CompoundOp;
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle = tokio::spawn(async move {
let ops = vec![
CompoundOp::new(Command::Echo, &EchoRequest, None),
CompoundOp::new(Command::Echo, &EchoRequest, None),
CompoundOp::new(Command::Echo, &EchoRequest, None),
];
c.execute_compound(&ops).await
});
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_ok(MessageId(0)));
mock.queue_response(echo_ok(MessageId(0)));
mock.queue_response(echo_ok(MessageId(0)));
handle.await.unwrap().unwrap();
let m = conn.metrics();
assert_eq!(m.requests_sent, 3, "three sub-ops → requests_sent += 3");
assert_eq!(m.compound_requests_sent, 1, "one compound chain");
assert_eq!(m.responses_routed_ok, 3);
assert_eq!(m.requests_returned_err, 0);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn requests_returned_err_ticks_on_outer_err_to_completed_caller() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.close();
let result = handle.await.unwrap();
assert!(result.is_err(), "execute should error after close");
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while conn.metrics().requests_returned_err == 0 && std::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(conn.metrics().requests_returned_err, 1);
}
#[tokio::test(flavor = "multi_thread")]
async fn responses_late_after_drop_ticks_when_caller_dropped() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
handle.abort();
let _ = handle.await;
mock.queue_response(echo_ok(MessageId(0)));
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while conn.metrics().responses_late_after_drop == 0 && std::time::Instant::now() < deadline
{
tokio::time::sleep(Duration::from_millis(10)).await;
}
let m = conn.metrics();
assert_eq!(m.responses_late_after_drop, 1, "caller-drop should tick");
assert_eq!(m.responses_stray, 0, "stray is for unregistered ids only");
assert_eq!(m.responses_routed_ok, 0);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn responses_stray_ticks_for_unregistered_msg_id() {
let (conn, mock) = fresh_conn();
let _ = (conn, mock); let plain_mock = Arc::new(MockTransport::new());
let conn = Connection::from_transport(
Box::new(plain_mock.clone()),
Box::new(plain_mock.clone()),
"test-server",
);
plain_mock.queue_response(echo_ok(MessageId(999_999)));
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while plain_mock.pending_responses() > 0 && std::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
let m = conn.metrics();
assert_eq!(m.responses_stray, 1);
assert_eq!(m.responses_late_after_drop, 0);
assert_eq!(m.responses_routed_ok, 0);
plain_mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn unsolicited_notifications_received_ticks_for_unsolicited_msg_id() {
let mock = Arc::new(MockTransport::new());
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let mut h = Header::new_request(Command::OplockBreak);
h.flags.set_response();
h.credits = 0;
h.message_id = MessageId::UNSOLICITED;
let frame = pack(&h, &EchoResponse); mock.queue_response(frame);
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while conn.metrics().unsolicited_notifications_received == 0
&& std::time::Instant::now() < deadline
{
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(conn.metrics().unsolicited_notifications_received, 1);
assert_eq!(conn.metrics().responses_routed_ok, 0);
assert_eq!(conn.metrics().responses_stray, 0);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn status_pending_loops_ticks_for_interim_pending_then_final() {
let mock = Arc::new(MockTransport::new());
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_response(MessageId(0), NtStatus::PENDING));
mock.queue_response(echo_response(MessageId(0), NtStatus::SUCCESS));
handle.await.unwrap().unwrap();
let m = conn.metrics();
assert_eq!(m.status_pending_loops, 1, "one interim PENDING observed");
assert_eq!(m.responses_routed_ok, 1, "one final response routed");
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn session_expired_events_ticks_and_also_routes_err() {
let mock = Arc::new(MockTransport::new());
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_response(
MessageId(0),
NtStatus::NETWORK_SESSION_EXPIRED,
));
let result = handle.await.unwrap();
assert!(result.is_err(), "session-expired should surface as Err");
let m = conn.metrics();
assert_eq!(m.session_expired_events, 1);
assert_eq!(
m.responses_routed_err, 1,
"session_expired_events is a subset of responses_routed_err"
);
assert_eq!(m.responses_routed_ok, 0);
assert_eq!(
m.requests_returned_err, 1,
"caller polled to completion and got Err"
);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn explicit_cancels_sent_ticks_on_send_cancel() {
let mock = Arc::new(MockTransport::new());
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.send_cancel(MessageId(42), None).await.unwrap();
assert_eq!(conn.metrics().explicit_cancels_sent, 1);
assert_eq!(conn.metrics().requests_sent, 0);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn dispatch_path_is_counted() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.dispatch(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
let rx = handle.await.unwrap().unwrap();
mock.queue_response(echo_ok(MessageId(0)));
let _ = rx.await.unwrap().unwrap();
let m = conn.metrics();
assert_eq!(m.requests_sent, 1, "dispatch funnel-counts via allocate");
assert!(m.wire_bytes_sent > 0);
assert_eq!(m.responses_routed_ok, 1);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn counters_survive_teardown() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_ok(MessageId(0)));
handle.await.unwrap().unwrap();
let before = conn.metrics();
assert_eq!(before.responses_routed_ok, 1);
mock.close();
tokio::time::sleep(Duration::from_millis(50)).await;
let after = conn.metrics();
assert_eq!(after.responses_routed_ok, before.responses_routed_ok);
assert_eq!(after.requests_sent, before.requests_sent);
}
fn fake_client(conn: Connection, session: crate::client::Session) -> crate::SmbClient {
let cfg = crate::ClientConfig {
addr: conn.server_name().to_string(),
timeout: Duration::from_secs(30),
username: String::new(),
password: String::new(),
domain: String::new(),
auto_reconnect: false,
compression: true,
dfs_enabled: true,
dfs_target_overrides: std::collections::HashMap::new(),
};
crate::SmbClient::from_parts(cfg, conn, session)
}
fn fake_session() -> crate::client::Session {
crate::client::Session {
session_id: crate::types::SessionId(0x1234_5678_9ABC_DEF0),
signing_key: vec![],
encryption_key: None,
decryption_key: None,
signing_algorithm: crate::crypto::signing::SigningAlgorithm::HmacSha256,
should_sign: false,
should_encrypt: false,
}
}
#[tokio::test(flavor = "multi_thread")]
async fn display_contains_key_labels() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_ok(MessageId(0)));
handle.await.unwrap().unwrap();
let client = fake_client(conn, fake_session());
let d = client.diagnostics();
let text = format!("{}", d);
for label in [
"SMB client",
"test-server",
"credits:",
"wire bytes:",
"responses:",
"protocol events:",
"errors:",
"DFS extra connections",
] {
assert!(
text.contains(label),
"Display missing {label:?} in:\n{text}"
);
}
mock.close();
}
#[cfg(feature = "serde")]
#[tokio::test(flavor = "multi_thread")]
async fn serde_round_trip_into_json_value() {
let (conn, mock) = fresh_conn();
let c = conn.clone();
let handle =
tokio::spawn(async move { c.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
mock.queue_response(echo_ok(MessageId(0)));
handle.await.unwrap().unwrap();
let client = fake_client(conn, fake_session());
let d = client.diagnostics();
let json = serde_json::to_string(&d).expect("serialize");
let v: serde_json::Value = serde_json::from_str(&json).expect("re-parse");
assert_eq!(v["client"]["primary_server"], "test-server", "json: {json}");
assert_eq!(v["primary"]["server"], "test-server");
assert_eq!(v["primary"]["metrics"]["requests_sent"], 1);
assert_eq!(v["primary"]["metrics"]["responses_routed_ok"], 1);
assert!(v["primary"]["disconnected"].is_boolean());
assert!(v["primary"]["credits"]["available"].is_number());
assert_eq!(
v["primary"]["session"]["session_id"], 0x1234_5678_9ABC_DEF0_u64,
"json: {json}"
);
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn snapshot_releases_all_locks_before_returning() {
let (conn, mock) = fresh_conn();
let _d = conn.diagnostics();
for _ in 0..100 {
let _ = conn.diagnostics();
}
mock.close();
}
#[tokio::test(flavor = "multi_thread")]
async fn routing_partition_is_disjoint_and_complete() {
let mock = Arc::new(MockTransport::new());
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let c1 = conn.clone();
let h1 = tokio::spawn(async move { c1.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 1).await;
let c2 = conn.clone();
let h2 = tokio::spawn(async move { c2.execute(Command::Echo, &EchoRequest, None).await });
wait_for_sent(&mock, 2).await;
mock.queue_response(echo_ok(MessageId(0)));
h1.await.unwrap().unwrap();
h2.abort();
let _ = h2.await;
mock.queue_response(echo_ok(MessageId(1)));
mock.queue_response(echo_ok(MessageId(999_999)));
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while mock.pending_responses() > 0 && std::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
tokio::time::sleep(Duration::from_millis(50)).await;
let m = conn.metrics();
assert_eq!(m.responses_routed_ok, 1);
assert_eq!(m.responses_routed_err, 0);
assert_eq!(m.responses_late_after_drop, 1);
assert_eq!(m.responses_stray, 1);
assert_eq!(
m.responses_routed_ok
+ m.responses_routed_err
+ m.responses_late_after_drop
+ m.responses_stray,
3
);
mock.close();
}
}