use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Mutex as StdMutex, OnceLock};
use std::time::Duration;
use log::{debug, info, trace, warn};
use tokio::sync::oneshot;
use crate::crypto::compression::{compress_message, decompress_message, CompressedMessage};
use crate::crypto::encryption::{self, Cipher, NonceGenerator};
use crate::crypto::kdf::PreauthHasher;
use crate::crypto::signing::{self, SigningAlgorithm};
use crate::error::{Error, Result};
use crate::msg::header::Header;
use crate::msg::negotiate::{
NegotiateContext, NegotiateRequest, NegotiateResponse, CIPHER_AES_128_CCM, CIPHER_AES_128_GCM,
CIPHER_AES_256_CCM, CIPHER_AES_256_GCM, COMPRESSION_LZ4, HASH_ALGORITHM_SHA512,
SIGNING_AES_CMAC, SIGNING_AES_GMAC, SIGNING_HMAC_SHA256,
};
use crate::msg::transform::{
CompressionTransformHeader, TransformHeader, COMPRESSION_ALGORITHM_LZ4,
COMPRESSION_PROTOCOL_ID, SMB2_COMPRESSION_FLAG_NONE, TRANSFORM_PROTOCOL_ID,
};
use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor};
use crate::transport::{TcpTransport, TransportReceive, TransportSend};
use crate::types::flags::{Capabilities, HeaderFlags, SecurityMode};
use crate::types::status::NtStatus;
use crate::types::{Command, CreditCharge, Dialect, MessageId, SessionId, TreeId};
#[derive(Debug, Clone)]
pub struct NegotiatedParams {
pub dialect: Dialect,
pub max_read_size: u32,
pub max_write_size: u32,
pub max_transact_size: u32,
pub server_guid: Guid,
pub signing_required: bool,
pub capabilities: Capabilities,
pub gmac_negotiated: bool,
pub cipher: Option<Cipher>,
pub compression_supported: bool,
}
#[derive(Debug)]
pub struct Frame {
pub header: Header,
pub body: Vec<u8>,
pub raw: Vec<u8>,
}
pub struct CompoundOp<'a> {
pub command: Command,
pub body: &'a dyn Pack,
pub tree_id: Option<TreeId>,
pub credit_charge: CreditCharge,
}
impl<'a> CompoundOp<'a> {
pub fn new(command: Command, body: &'a dyn Pack, tree_id: Option<TreeId>) -> Self {
Self {
command,
body,
tree_id,
credit_charge: CreditCharge(1),
}
}
}
struct CryptoState {
signing_key: Option<Vec<u8>>,
signing_algorithm: Option<SigningAlgorithm>,
should_sign: bool,
encryption_key: Option<Vec<u8>>,
decryption_key: Option<Vec<u8>>,
encryption_cipher: Option<Cipher>,
should_encrypt: bool,
nonce_gen: Option<NonceGenerator>,
session_id: SessionId,
}
impl CryptoState {
fn new() -> Self {
Self {
signing_key: None,
signing_algorithm: None,
should_sign: false,
encryption_key: None,
decryption_key: None,
encryption_cipher: None,
should_encrypt: false,
nonce_gen: None,
session_id: SessionId::NONE,
}
}
}
struct Inner {
waiters: StdMutex<HashMap<MessageId, oneshot::Sender<Result<Frame>>>>,
credits: AtomicU32,
next_message_id: AtomicU64,
crypto: StdMutex<CryptoState>,
disconnected: AtomicBool,
sender: Arc<dyn TransportSend>,
receiver_task: StdMutex<Option<tokio::task::JoinHandle<()>>>,
server_name: String,
params: OnceLock<NegotiatedParams>,
estimated_rtt: StdMutex<Option<Duration>>,
compression_enabled: AtomicBool,
compression_requested: AtomicBool,
preauth_hasher: StdMutex<PreauthHasher>,
dfs_trees: StdMutex<HashSet<TreeId>>,
}
impl Inner {
fn new(sender: Arc<dyn TransportSend>, server_name: String) -> Self {
Self {
waiters: StdMutex::new(HashMap::new()),
credits: AtomicU32::new(1),
next_message_id: AtomicU64::new(0),
crypto: StdMutex::new(CryptoState::new()),
disconnected: AtomicBool::new(false),
sender,
receiver_task: StdMutex::new(None),
server_name,
params: OnceLock::new(),
estimated_rtt: StdMutex::new(None),
compression_enabled: AtomicBool::new(false),
compression_requested: AtomicBool::new(true),
preauth_hasher: StdMutex::new(PreauthHasher::new()),
dfs_trees: StdMutex::new(HashSet::new()),
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
if let Some(handle) = self.receiver_task.lock().unwrap().take() {
handle.abort();
}
}
}
#[derive(Clone)]
pub struct Connection {
inner: Arc<Inner>,
}
impl Connection {
pub fn from_transport(
sender: Box<dyn TransportSend>,
receiver: Box<dyn TransportReceive>,
server_name: impl Into<String>,
) -> Self {
let sender: Arc<dyn TransportSend> = Arc::from(sender);
let inner = Arc::new(Inner::new(sender, server_name.into()));
let inner_for_task = Arc::clone(&inner);
let handle = tokio::spawn(async move {
receiver_loop(receiver, inner_for_task).await;
});
*inner.receiver_task.lock().unwrap() = Some(handle);
Self { inner }
}
pub async fn connect(addr: &str, timeout: Duration) -> Result<Self> {
let server_name = addr.split(':').next().unwrap_or(addr).to_string();
let transport = TcpTransport::connect(addr, timeout).await?;
info!("connection: connected to {}", addr);
let transport = Arc::new(transport);
Ok(Self::from_transport(
Box::new(Arc::clone(&transport)),
Box::new(transport),
server_name,
))
}
pub async fn negotiate(&mut self) -> Result<()> {
debug!("negotiate: sending request, dialects={:?}", Dialect::ALL);
let client_guid = generate_guid();
let mut negotiate_contexts = vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: generate_salt(),
},
NegotiateContext::Encryption {
ciphers: vec![
CIPHER_AES_128_GCM,
CIPHER_AES_128_CCM,
CIPHER_AES_256_GCM,
CIPHER_AES_256_CCM,
],
},
NegotiateContext::Signing {
algorithms: vec![SIGNING_AES_GMAC, SIGNING_AES_CMAC, SIGNING_HMAC_SHA256],
},
];
if self.inner.compression_requested.load(Ordering::Acquire) {
negotiate_contexts.push(NegotiateContext::Compression {
flags: 0,
algorithms: vec![COMPRESSION_LZ4],
});
}
let request = NegotiateRequest {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::new(
Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU,
),
client_guid,
dialects: Dialect::ALL.to_vec(),
negotiate_contexts,
};
let mut header = Header::new_request(Command::Negotiate);
let msg_id = self.allocate_msg_id(1);
header.message_id = msg_id;
header.credits = 1;
let req_bytes = pack_message(&header, &request);
self.inner.preauth_hasher.lock().unwrap().update(&req_bytes);
let rx = self.register_waiter(msg_id)?;
let rtt_start = std::time::Instant::now();
if let Err(e) = self.inner.sender.send(&req_bytes).await {
self.remove_waiter(msg_id);
return Err(e);
}
let frame = await_frame(rx).await?;
*self.inner.estimated_rtt.lock().unwrap() = Some(rtt_start.elapsed());
self.inner.preauth_hasher.lock().unwrap().update(&frame.raw);
let resp_header = &frame.header;
if !resp_header.is_response() {
return Err(Error::invalid_data("expected a response but got a request"));
}
if resp_header.command != Command::Negotiate {
return Err(Error::invalid_data(format!(
"expected Negotiate response, got {:?}",
resp_header.command
)));
}
if resp_header.status != NtStatus::SUCCESS {
return Err(Error::Protocol {
status: resp_header.status,
command: Command::Negotiate,
});
}
let mut cursor = ReadCursor::new(&frame.body);
let resp = NegotiateResponse::unpack(&mut cursor)?;
if !Dialect::ALL.contains(&resp.dialect_revision) {
return Err(Error::invalid_data(format!(
"server selected dialect 0x{:04X} which we did not offer",
u16::from(resp.dialect_revision)
)));
}
if resp.max_read_size < 65536 {
return Err(Error::invalid_data(format!(
"MaxReadSize {} is below minimum 65536",
resp.max_read_size
)));
}
if resp.max_write_size < 65536 {
return Err(Error::invalid_data(format!(
"MaxWriteSize {} is below minimum 65536",
resp.max_write_size
)));
}
let mut gmac_negotiated = false;
let mut cipher = None;
let mut compression_supported = false;
for ctx in &resp.negotiate_contexts {
match ctx {
NegotiateContext::Signing { algorithms }
if algorithms.contains(&SIGNING_AES_GMAC) =>
{
gmac_negotiated = true;
}
NegotiateContext::Encryption { ciphers } => {
if let Some(&c) = ciphers.first() {
cipher = match c {
CIPHER_AES_128_CCM => Some(Cipher::Aes128Ccm),
CIPHER_AES_128_GCM => Some(Cipher::Aes128Gcm),
CIPHER_AES_256_CCM => Some(Cipher::Aes256Ccm),
CIPHER_AES_256_GCM => Some(Cipher::Aes256Gcm),
_ => None,
};
}
}
NegotiateContext::Compression { algorithms, .. }
if algorithms.contains(&COMPRESSION_LZ4) =>
{
compression_supported = true;
}
_ => {}
}
}
let signing_required = resp.security_mode.signing_required();
let compression_enabled =
self.inner.compression_requested.load(Ordering::Acquire) && compression_supported;
self.inner
.compression_enabled
.store(compression_enabled, Ordering::Release);
let _ = self.inner.params.set(NegotiatedParams {
dialect: resp.dialect_revision,
max_read_size: resp.max_read_size,
max_write_size: resp.max_write_size,
max_transact_size: resp.max_transact_size,
server_guid: resp.server_guid,
signing_required,
capabilities: resp.capabilities,
gmac_negotiated,
cipher,
compression_supported,
});
info!(
"negotiate: dialect={}, signing_required={}, capabilities={:?}",
resp.dialect_revision, signing_required, resp.capabilities
);
debug!(
"negotiate: max_read={}, max_write={}, max_transact={}, server_guid={:?}, cipher={:?}, gmac={}, compression={}",
resp.max_read_size, resp.max_write_size, resp.max_transact_size,
resp.server_guid, cipher, gmac_negotiated, compression_enabled
);
Ok(())
}
pub fn estimated_rtt(&self) -> Option<Duration> {
*self.inner.estimated_rtt.lock().unwrap()
}
pub fn params(&self) -> Option<&NegotiatedParams> {
self.inner.params.get()
}
pub fn preauth_hasher(&self) -> PreauthHasher {
self.inner.preauth_hasher.lock().unwrap().clone()
}
#[doc(hidden)] pub fn with_preauth_hasher_mut<R>(&self, f: impl FnOnce(&mut PreauthHasher) -> R) -> R {
let mut h = self.inner.preauth_hasher.lock().unwrap();
f(&mut h)
}
pub fn set_session_id(&mut self, id: SessionId) {
self.inner.crypto.lock().unwrap().session_id = id;
}
pub fn session_id(&self) -> SessionId {
self.inner.crypto.lock().unwrap().session_id
}
pub fn activate_signing(&mut self, key: Vec<u8>, algorithm: SigningAlgorithm) {
debug!(
"signing: activated, algo={:?}, key_len={}",
algorithm,
key.len()
);
let mut c = self.inner.crypto.lock().unwrap();
c.signing_key = Some(key);
c.signing_algorithm = Some(algorithm);
c.should_sign = true;
}
pub fn activate_encryption(&mut self, enc_key: Vec<u8>, dec_key: Vec<u8>, cipher: Cipher) {
debug!(
"encryption: activated, cipher={:?}, enc_key_len={}, dec_key_len={}",
cipher,
enc_key.len(),
dec_key.len()
);
let mut c = self.inner.crypto.lock().unwrap();
c.encryption_key = Some(enc_key);
c.decryption_key = Some(dec_key);
c.encryption_cipher = Some(cipher);
c.nonce_gen = Some(NonceGenerator::new());
c.should_encrypt = true;
}
pub fn should_encrypt(&self) -> bool {
self.inner.crypto.lock().unwrap().should_encrypt
}
pub fn credits(&self) -> u16 {
self.inner.credits.load(Ordering::Acquire) as u16
}
pub fn next_message_id(&self) -> u64 {
self.inner.next_message_id.load(Ordering::Acquire)
}
pub fn server_name(&self) -> &str {
&self.inner.server_name
}
pub fn set_compression_requested(&mut self, requested: bool) {
self.inner
.compression_requested
.store(requested, Ordering::Release);
}
pub fn compression_enabled(&self) -> bool {
self.inner.compression_enabled.load(Ordering::Acquire)
}
pub async fn execute(
&self,
command: Command,
body: &dyn Pack,
tree_id: Option<TreeId>,
) -> Result<Frame> {
self.execute_with_credits(command, body, tree_id, CreditCharge(1))
.await
}
pub(crate) async fn execute_capturing_request(
&self,
command: Command,
body: &dyn Pack,
tree_id: Option<TreeId>,
) -> Result<(Frame, Vec<u8>)> {
self.execute_with_credits_capturing_request(command, body, tree_id, CreditCharge(1))
.await
}
pub(crate) async fn execute_with_credits_capturing_request(
&self,
command: Command,
body: &dyn Pack,
tree_id: Option<TreeId>,
credit_charge: CreditCharge,
) -> Result<(Frame, Vec<u8>)> {
if self.inner.disconnected.load(Ordering::Acquire) {
return Err(Error::Disconnected);
}
let charge = credit_charge.0.max(1);
let msg_id = self.allocate_msg_id(charge as u64);
let mut header = Header::new_request(command);
header.message_id = msg_id;
header.credits = 256;
header.credit_charge = CreditCharge(charge);
header.session_id = self.session_id();
if let Some(tid) = tree_id {
header.tree_id = Some(tid);
}
let (should_sign, should_encrypt) = {
let c = self.inner.crypto.lock().unwrap();
(c.should_sign, c.should_encrypt)
};
if should_sign && !should_encrypt {
header.flags.set_signed();
}
if self.should_set_dfs_flag(tree_id) {
header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS);
}
let mut msg_bytes = pack_message(&header, body);
let captured = msg_bytes.clone();
let rx = self.register_waiter(msg_id)?;
let wire_bytes = if should_encrypt {
match self.encrypt_bytes(&msg_bytes) {
Ok(enc) => enc,
Err(e) => {
self.remove_waiter(msg_id);
return Err(e);
}
}
} else {
if should_sign {
let c = self.inner.crypto.lock().unwrap();
if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) {
if let Err(e) =
signing::sign_message(&mut msg_bytes, key, *algo, msg_id.0, false)
{
drop(c);
self.remove_waiter(msg_id);
return Err(e);
}
}
}
msg_bytes
};
if let Err(e) = self.inner.sender.send(&wire_bytes).await {
self.remove_waiter(msg_id);
return Err(e);
}
debug!(
"execute_cap: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, encrypted={}",
command, msg_id.0, charge, tree_id, should_sign, should_encrypt
);
let frame = await_frame(rx).await?;
Ok((frame, captured))
}
pub async fn execute_with_credits(
&self,
command: Command,
body: &dyn Pack,
tree_id: Option<TreeId>,
credit_charge: CreditCharge,
) -> Result<Frame> {
if self.inner.disconnected.load(Ordering::Acquire) {
return Err(Error::Disconnected);
}
let charge = credit_charge.0.max(1);
let msg_id = self.allocate_msg_id(charge as u64);
let mut header = Header::new_request(command);
header.message_id = msg_id;
header.credits = 256;
header.credit_charge = CreditCharge(charge);
header.session_id = self.session_id();
if let Some(tid) = tree_id {
header.tree_id = Some(tid);
}
let (should_sign, should_encrypt) = {
let c = self.inner.crypto.lock().unwrap();
(c.should_sign, c.should_encrypt)
};
if should_sign && !should_encrypt {
header.flags.set_signed();
}
if self.should_set_dfs_flag(tree_id) {
header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS);
}
let mut msg_bytes = pack_message(&header, body);
let rx = self.register_waiter(msg_id)?;
let wire_bytes = if should_encrypt {
match self.encrypt_bytes(&msg_bytes) {
Ok(enc) => enc,
Err(e) => {
self.remove_waiter(msg_id);
return Err(e);
}
}
} else {
if should_sign {
let c = self.inner.crypto.lock().unwrap();
if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) {
if let Err(e) =
signing::sign_message(&mut msg_bytes, key, *algo, msg_id.0, false)
{
drop(c);
self.remove_waiter(msg_id);
return Err(e);
}
}
}
if self.compression_enabled() && msg_bytes.len() > Header::SIZE {
if let Some(compressed) = compress_message(&msg_bytes, Header::SIZE) {
let framed = build_compressed_frame(&compressed);
match self.inner.sender.send(&framed).await {
Ok(()) => {
debug!(
"execute: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, compressed {}->{} bytes",
command, msg_id.0, charge, tree_id, should_sign,
msg_bytes.len(), framed.len()
);
return await_frame(rx).await;
}
Err(e) => {
self.remove_waiter(msg_id);
return Err(e);
}
}
}
}
msg_bytes
};
if let Err(e) = self.inner.sender.send(&wire_bytes).await {
self.remove_waiter(msg_id);
return Err(e);
}
debug!(
"execute: cmd={:?}, msg_id={}, credit_charge={}, tree_id={:?}, signed={}, encrypted={}, len={}",
command, msg_id.0, charge, tree_id, should_sign, should_encrypt, wire_bytes.len()
);
await_frame(rx).await
}
pub async fn execute_compound(&self, ops: &[CompoundOp<'_>]) -> Result<Vec<Result<Frame>>> {
if ops.is_empty() {
return Err(Error::invalid_data(
"compound request must have at least one operation",
));
}
if self.inner.disconnected.load(Ordering::Acquire) {
return Err(Error::Disconnected);
}
let (should_sign, should_encrypt) = {
let c = self.inner.crypto.lock().unwrap();
(c.should_sign, c.should_encrypt)
};
let session_id = self.session_id();
let mut message_ids: Vec<MessageId> = Vec::with_capacity(ops.len());
let mut sub_requests: Vec<Vec<u8>> = Vec::with_capacity(ops.len());
for (i, op) in ops.iter().enumerate() {
let charge = op.credit_charge.0.max(1);
let msg_id = self.allocate_msg_id(charge as u64);
let mut header = Header::new_request(op.command);
header.message_id = msg_id;
header.credits = 256;
header.credit_charge = CreditCharge(charge);
header.session_id = session_id;
header.tree_id = op.tree_id;
if i > 0 {
header.flags.set_related();
}
if should_sign && !should_encrypt {
header.flags.set_signed();
}
if self.should_set_dfs_flag(op.tree_id) {
header.flags |= HeaderFlags::new(HeaderFlags::DFS_OPERATIONS);
}
message_ids.push(msg_id);
sub_requests.push(pack_message(&header, op.body));
}
let last_idx = sub_requests.len() - 1;
for sub_req in sub_requests.iter_mut().take(last_idx) {
let rem = sub_req.len() % 8;
if rem != 0 {
let pad = 8 - rem;
let new_len = sub_req.len() + pad;
sub_req.resize(new_len, 0);
}
}
for sub_req in sub_requests.iter_mut().take(last_idx) {
let next_cmd = sub_req.len() as u32;
sub_req[20..24].copy_from_slice(&next_cmd.to_le_bytes());
}
if should_sign && !should_encrypt {
let c = self.inner.crypto.lock().unwrap();
if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) {
for (i, sub_req) in sub_requests.iter_mut().enumerate() {
signing::sign_message(sub_req, key, *algo, message_ids[i].0, false)?;
}
}
}
let mut receivers: Vec<oneshot::Receiver<Result<Frame>>> =
Vec::with_capacity(message_ids.len());
let mut registered: Vec<MessageId> = Vec::with_capacity(message_ids.len());
for id in &message_ids {
match self.register_waiter(*id) {
Ok(rx) => {
receivers.push(rx);
registered.push(*id);
}
Err(e) => {
for done in ®istered {
self.remove_waiter(*done);
}
return Err(e);
}
}
}
let total_len: usize = sub_requests.iter().map(|r| r.len()).sum();
let mut compound_buf = Vec::with_capacity(total_len);
for sub_req in &sub_requests {
compound_buf.extend_from_slice(sub_req);
}
let send_result = if should_encrypt {
match self.encrypt_bytes(&compound_buf) {
Ok(enc) => self.inner.sender.send(&enc).await,
Err(e) => {
for id in ®istered {
self.remove_waiter(*id);
}
return Err(e);
}
}
} else {
self.inner.sender.send(&compound_buf).await
};
if let Err(e) = send_result {
for id in ®istered {
self.remove_waiter(*id);
}
return Err(e);
}
debug!(
"execute_compound: {} operations, total_len={}, msg_ids={:?}, signed={}, encrypted={}",
ops.len(),
compound_buf.len(),
message_ids.iter().map(|m| m.0).collect::<Vec<_>>(),
should_sign,
should_encrypt,
);
let mut results: Vec<Result<Frame>> = Vec::with_capacity(receivers.len());
for rx in receivers {
results.push(await_frame(rx).await);
}
Ok(results)
}
pub async fn send_cancel(
&mut self,
original_msg_id: MessageId,
async_id: Option<u64>,
) -> Result<()> {
use crate::msg::cancel::CancelRequest;
let (should_sign, should_encrypt) = {
let c = self.inner.crypto.lock().unwrap();
(c.should_sign, c.should_encrypt)
};
let session_id = self.session_id();
let mut header = Header::new_request(Command::Cancel);
header.message_id = original_msg_id;
header.credit_charge = CreditCharge(0);
header.credits = 0;
header.session_id = session_id;
if let Some(aid) = async_id {
header.flags.set_async();
header.async_id = Some(aid);
header.tree_id = None;
}
if should_sign && !should_encrypt {
header.flags.set_signed();
}
let body = CancelRequest;
let mut msg_bytes = pack_message(&header, &body);
if should_encrypt {
let encrypted = self.encrypt_bytes(&msg_bytes)?;
self.inner.sender.send(&encrypted).await?;
debug!(
"send_cancel: msg_id={}, async_id={:?}, encrypted",
original_msg_id.0, async_id
);
} else {
if should_sign {
let c = self.inner.crypto.lock().unwrap();
if let (Some(key), Some(algo)) = (&c.signing_key, &c.signing_algorithm) {
signing::sign_message(&mut msg_bytes, key, *algo, original_msg_id.0, false)?;
}
}
self.inner.sender.send(&msg_bytes).await?;
debug!(
"send_cancel: msg_id={}, async_id={:?}, signed={}",
original_msg_id.0, async_id, should_sign
);
}
Ok(())
}
fn encrypt_bytes(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let mut c = self.inner.crypto.lock().unwrap();
let enc_key = c
.encryption_key
.as_ref()
.ok_or_else(|| Error::invalid_data("encryption active but no encryption key"))?
.clone();
let cipher = c
.encryption_cipher
.ok_or_else(|| Error::invalid_data("encryption active but no cipher"))?;
let session_id = c.session_id.0;
let nonce = c
.nonce_gen
.as_mut()
.ok_or_else(|| Error::invalid_data("encryption active but no nonce generator"))?
.next(cipher);
drop(c);
let (transform_header, ciphertext) =
encryption::encrypt_message(plaintext, &enc_key, cipher, &nonce, session_id)?;
let mut encrypted = transform_header;
encrypted.extend_from_slice(&ciphertext);
trace!(
"encrypt: plaintext={} bytes, encrypted={} bytes, nonce={:02X?}",
plaintext.len(),
encrypted.len(),
&nonce[..cipher.nonce_len()]
);
Ok(encrypted)
}
pub fn register_dfs_tree(&mut self, tree_id: TreeId) {
self.inner.dfs_trees.lock().unwrap().insert(tree_id);
}
pub fn deregister_dfs_tree(&mut self, tree_id: TreeId) {
self.inner.dfs_trees.lock().unwrap().remove(&tree_id);
}
fn should_set_dfs_flag(&self, tree_id: Option<TreeId>) -> bool {
tree_id.is_some_and(|id| self.inner.dfs_trees.lock().unwrap().contains(&id))
}
fn allocate_msg_id(&self, charge: u64) -> MessageId {
let first = self
.inner
.next_message_id
.fetch_add(charge, Ordering::SeqCst);
MessageId(first)
}
fn register_waiter(&self, msg_id: MessageId) -> Result<oneshot::Receiver<Result<Frame>>> {
let mut waiters = self.inner.waiters.lock().unwrap();
if self.inner.disconnected.load(Ordering::Acquire) {
return Err(Error::Disconnected);
}
let (tx, rx) = oneshot::channel();
waiters.insert(msg_id, tx);
Ok(rx)
}
fn remove_waiter(&self, msg_id: MessageId) {
self.inner.waiters.lock().unwrap().remove(&msg_id);
}
#[cfg(test)]
pub(crate) fn set_test_params(&mut self, params: NegotiatedParams) {
let _ = self.inner.params.set(params);
}
#[cfg(test)]
pub(crate) fn set_credits(&mut self, credits: u16) {
self.inner.credits.store(credits as u32, Ordering::Release);
}
#[cfg(test)]
pub(crate) fn set_next_message_id(&mut self, id: u64) {
self.inner.next_message_id.store(id, Ordering::Release);
}
}
async fn receiver_loop(transport_recv: Box<dyn TransportReceive>, inner: Arc<Inner>) {
loop {
let raw = match transport_recv.receive().await {
Ok(bytes) => bytes,
Err(e) => {
debug!("receiver_loop: transport error: {}, shutting down", e);
fan_error_to_waiters(&inner, &e);
return;
}
};
trace!("receiver_loop: received {} bytes", raw.len());
let (decoded, was_encrypted) = if raw.len() >= 4 && raw[0..4] == TRANSFORM_PROTOCOL_ID {
match decrypt_frame(&raw, &inner) {
Ok(plain) => (plain, true),
Err(e) => {
warn!(
"receiver_loop: decrypt failed: {}; tearing down connection",
e
);
fan_error_to_waiters(&inner, &e);
return;
}
}
} else {
(raw, false)
};
let decoded = if decoded.len() >= 4 && decoded[0..4] == COMPRESSION_PROTOCOL_ID {
match decompress_response(&decoded) {
Ok(plain) => plain,
Err(e) => {
warn!(
"receiver_loop: decompress failed: {}; tearing down connection",
e
);
fan_error_to_waiters(&inner, &e);
return;
}
}
} else {
decoded
};
let sub_frames = match split_compound(&decoded) {
Ok(subs) => subs,
Err(e) => {
warn!(
"receiver_loop: malformed frame: {}; tearing down connection",
e
);
fan_error_to_waiters(&inner, &e);
return;
}
};
let mut routable: Vec<(MessageId, Result<Frame>)> = Vec::new();
for sub in sub_frames {
match prepare_sub_frame(&sub, was_encrypted, &inner) {
Ok(SubFrameAction::Route(msg_id, result)) => routable.push((msg_id, result)),
Ok(SubFrameAction::Skip) => { }
Err(e) => {
warn!(
"receiver_loop: sub-frame parse failed: {}; tearing down connection",
e
);
fan_error_to_waiters(&inner, &e);
return;
}
}
}
if routable.is_empty() {
continue;
}
for (msg_id, result) in routable {
let maybe_tx = inner.waiters.lock().unwrap().remove(&msg_id);
match maybe_tx {
Some(tx) => {
if tx.send(result).is_err() {
trace!("recv: late arrival for dropped waiter, msg_id={}", msg_id.0);
}
}
None => {
debug!("recv: orphan dropped, msg_id={}", msg_id.0);
}
}
}
}
}
pub(crate) enum SubFrameAction {
Route(MessageId, std::result::Result<Frame, Error>),
Skip,
}
fn prepare_sub_frame(sub: &[u8], was_encrypted: bool, inner: &Inner) -> Result<SubFrameAction> {
let mut cursor = ReadCursor::new(sub);
let header = match Header::unpack(&mut cursor) {
Ok(h) => h,
Err(e) => {
return Err(Error::invalid_data(format!(
"sub-frame header parse failed: {}",
e
)));
}
};
if header.credits > 0 {
let prev = inner.credits.load(Ordering::Relaxed) as u16;
let next = prev.saturating_add(header.credits);
inner.credits.store(next as u32, Ordering::Release);
}
if header.message_id == MessageId::UNSOLICITED {
debug!(
"recv: skipping unsolicited oplock break notification, cmd={:?}",
header.command
);
return Ok(SubFrameAction::Skip);
}
if header.status.is_pending() {
debug!(
"recv: STATUS_PENDING (interim), cmd={:?}, msg_id={}",
header.command, header.message_id.0
);
return Ok(SubFrameAction::Skip);
}
let consume = header.credit_charge.0.max(1);
let prev = inner.credits.load(Ordering::Relaxed) as u16;
inner
.credits
.store(prev.saturating_sub(consume) as u32, Ordering::Release);
let (should_sign, signing_key, signing_algorithm) = {
let c = inner.crypto.lock().unwrap();
(c.should_sign, c.signing_key.clone(), c.signing_algorithm)
};
if should_sign && !was_encrypted && sub.len() >= Header::SIZE {
let flags = u32::from_le_bytes(sub[16..20].try_into().unwrap());
let is_signed = (flags & HeaderFlags::SIGNED) != 0;
let status = u32::from_le_bytes(sub[8..12].try_into().unwrap());
let is_pending = status == NtStatus::PENDING.0;
if is_signed && !is_pending {
if let (Some(key), Some(algo)) = (signing_key, signing_algorithm) {
if let Err(e) =
signing::verify_signature(sub, &key, algo, header.message_id.0, false)
{
return Ok(SubFrameAction::Route(header.message_id, Err(e)));
}
}
}
}
if header.status == NtStatus::NETWORK_SESSION_EXPIRED {
warn!(
"recv: session expired (STATUS_NETWORK_SESSION_EXPIRED), cmd={:?}, msg_id={}",
header.command, header.message_id.0
);
return Ok(SubFrameAction::Route(
header.message_id,
Err(Error::SessionExpired),
));
}
let body = if sub.len() > Header::SIZE {
sub[Header::SIZE..].to_vec()
} else {
Vec::new()
};
let raw = sub.to_vec();
let msg_id = header.message_id;
Ok(SubFrameAction::Route(
msg_id,
Ok(Frame { header, body, raw }),
))
}
fn fan_error_to_waiters(inner: &Inner, e: &Error) {
let drained: Vec<(MessageId, oneshot::Sender<Result<Frame>>)> = {
let mut waiters = inner.waiters.lock().unwrap();
inner.disconnected.store(true, Ordering::Release);
waiters.drain().collect()
};
for (_id, tx) in drained {
let _ = tx.send(Err(clone_err_as_disconnected(e)));
}
}
fn clone_err_as_disconnected(_e: &Error) -> Error {
Error::Disconnected
}
fn decrypt_frame(data: &[u8], inner: &Inner) -> Result<Vec<u8>> {
let c = inner.crypto.lock().unwrap();
let dec_key = c
.decryption_key
.as_ref()
.ok_or_else(|| Error::invalid_data("received encrypted message but no decryption key"))?
.clone();
let cipher = c
.encryption_cipher
.ok_or_else(|| Error::invalid_data("received encrypted message but no cipher"))?;
drop(c);
if data.len() < TransformHeader::SIZE {
return Err(Error::invalid_data(
"encrypted message too short for TransformHeader",
));
}
let transform_header = &data[..TransformHeader::SIZE];
let ciphertext = &data[TransformHeader::SIZE..];
let plaintext = encryption::decrypt_message(transform_header, ciphertext, &dec_key, cipher)?;
Ok(plaintext)
}
fn split_compound(data: &[u8]) -> Result<Vec<Vec<u8>>> {
let mut results = Vec::new();
let mut offset = 0usize;
loop {
if offset + Header::SIZE > data.len() {
return Err(Error::invalid_data(format!(
"compound response truncated at offset {}: need {} bytes for header, but only {} remain",
offset,
Header::SIZE,
data.len() - offset,
)));
}
if !results.is_empty() && offset % 8 != 0 {
return Err(Error::invalid_data(format!(
"compound response at offset {} is not 8-byte aligned -- must disconnect",
offset,
)));
}
let next_cmd = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap());
let sub_end = if next_cmd > 0 {
offset + next_cmd as usize
} else {
data.len()
};
if sub_end > data.len() {
return Err(Error::invalid_data(format!(
"compound NextCommand offset {} at position {} exceeds response length {}",
next_cmd,
offset,
data.len(),
)));
}
results.push(data[offset..sub_end].to_vec());
if next_cmd == 0 {
break;
}
offset += next_cmd as usize;
}
Ok(results)
}
async fn await_frame(rx: oneshot::Receiver<Result<Frame>>) -> Result<Frame> {
match rx.await {
Ok(Ok(frame)) => Ok(frame),
Ok(Err(e)) => Err(e),
Err(_canceled) => Err(Error::Disconnected),
}
}
pub(crate) fn pack_message(header: &Header, body: &dyn Pack) -> Vec<u8> {
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
body.pack(&mut cursor);
cursor.into_inner()
}
fn generate_guid() -> Guid {
let mut bytes = [0u8; 16];
getrandom::fill(&mut bytes).expect("failed to generate random GUID");
Guid {
data1: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
data2: u16::from_le_bytes([bytes[4], bytes[5]]),
data3: u16::from_le_bytes([bytes[6], bytes[7]]),
data4: [
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
],
}
}
fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; 32];
getrandom::fill(&mut salt).expect("failed to generate random salt");
salt
}
fn build_compressed_frame(compressed: &CompressedMessage) -> Vec<u8> {
let header = CompressionTransformHeader {
original_compressed_segment_size: compressed.original_size,
compression_algorithm: COMPRESSION_ALGORITHM_LZ4,
flags: SMB2_COMPRESSION_FLAG_NONE,
offset_or_length: compressed.offset,
};
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
let mut frame = cursor.into_inner();
frame.extend_from_slice(&compressed.uncompressed_prefix);
frame.extend_from_slice(&compressed.compressed_data);
frame
}
fn decompress_response(data: &[u8]) -> Result<Vec<u8>> {
if data.len() < CompressionTransformHeader::SIZE {
return Err(Error::invalid_data(
"compressed response too short for CompressionTransformHeader",
));
}
let mut cursor = ReadCursor::new(data);
let header = CompressionTransformHeader::unpack(&mut cursor)?;
if header.compression_algorithm != COMPRESSION_ALGORITHM_LZ4 {
return Err(Error::invalid_data(format!(
"unsupported compression algorithm 0x{:04X}, only LZ4 (0x{:04X}) is supported",
header.compression_algorithm, COMPRESSION_ALGORITHM_LZ4
)));
}
if header.flags != SMB2_COMPRESSION_FLAG_NONE {
return Err(Error::invalid_data(format!(
"unsupported compression flags 0x{:04X}, only unchained (0x0000) is supported",
header.flags
)));
}
let offset = header.offset_or_length as usize;
let remaining = &data[CompressionTransformHeader::SIZE..];
if offset > remaining.len() {
return Err(Error::invalid_data(format!(
"compression offset {} exceeds remaining data length {}",
offset,
remaining.len()
)));
}
let uncompressed_prefix = &remaining[..offset];
let compressed_data = &remaining[offset..];
decompress_message(
uncompressed_prefix,
compressed_data,
header.original_compressed_segment_size,
)
}
#[async_trait::async_trait]
impl<T: TransportSend> TransportSend for Arc<T> {
async fn send(&self, data: &[u8]) -> Result<()> {
(**self).send(data).await
}
}
#[async_trait::async_trait]
impl<T: TransportReceive> TransportReceive for Arc<T> {
async fn receive(&self) -> Result<Vec<u8>> {
(**self).receive().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::msg::negotiate::{NegotiateContext, HASH_ALGORITHM_SHA512};
use crate::transport::MockTransport;
use crate::types::flags::HeaderFlags;
fn build_compound_response_frame(responses: &[Vec<u8>]) -> Vec<u8> {
let mut padded: Vec<Vec<u8>> = Vec::new();
for (i, resp) in responses.iter().enumerate() {
let mut r = resp.clone();
let is_last = i == responses.len() - 1;
if !is_last {
let remainder = r.len() % 8;
if remainder != 0 {
r.resize(r.len() + (8 - remainder), 0);
}
let next_cmd = r.len() as u32;
r[20..24].copy_from_slice(&next_cmd.to_le_bytes());
}
padded.push(r);
}
let mut frame = Vec::new();
for r in &padded {
frame.extend_from_slice(r);
}
frame
}
fn build_negotiate_response(dialect: Dialect) -> Vec<u8> {
let resp_header = {
let mut h = Header::new_request(Command::Negotiate);
h.flags.set_response();
h.credits = 32;
h
};
let resp_body = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: dialect,
server_guid: Guid::ZERO,
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING),
max_transact_size: 65536,
max_read_size: 65536,
max_write_size: 65536,
system_time: 132_000_000_000_000_000,
server_start_time: 131_000_000_000_000_000,
security_buffer: vec![0x60, 0x00],
negotiate_contexts: if dialect == Dialect::Smb3_1_1 {
vec![NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xBB; 32],
}]
} else {
vec![]
},
};
pack_message(&resp_header, &resp_body)
}
#[tokio::test]
async fn negotiate_stores_params_correctly() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let params = conn.params().unwrap();
assert_eq!(params.dialect, Dialect::Smb3_1_1);
assert_eq!(params.max_read_size, 65536);
assert_eq!(params.max_write_size, 65536);
assert_eq!(params.max_transact_size, 65536);
assert!(!params.signing_required);
}
#[tokio::test]
async fn negotiate_updates_credits() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_0));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
assert_eq!(conn.credits(), 32);
}
#[tokio::test]
async fn negotiate_increments_message_id() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb2_0_2));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
assert_eq!(conn.next_message_id(), 0);
conn.negotiate().await.unwrap();
assert_eq!(conn.next_message_id(), 1);
}
#[tokio::test]
async fn negotiate_updates_preauth_hash() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let initial_hash = *conn.preauth_hasher().value();
conn.negotiate().await.unwrap();
assert_ne!(conn.preauth_hasher().value(), &initial_hash);
}
#[tokio::test]
async fn negotiate_rejects_invalid_max_read_size() {
let resp_header = {
let mut h = Header::new_request(Command::Negotiate);
h.flags.set_response();
h.credits = 1;
h
};
let resp_body = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: Dialect::Smb2_0_2,
server_guid: Guid::ZERO,
capabilities: Capabilities::default(),
max_transact_size: 65536,
max_read_size: 1024, max_write_size: 65536,
system_time: 0,
server_start_time: 0,
security_buffer: vec![],
negotiate_contexts: vec![],
};
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(pack_message(&resp_header, &resp_body));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let result = conn.negotiate().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("MaxReadSize"));
}
#[tokio::test]
async fn message_id_increments_on_send_request() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_next_message_id(5);
use crate::msg::tree_disconnect::TreeDisconnectRequest;
let body = TreeDisconnectRequest;
assert_eq!(conn.next_message_id(), 5);
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute(Command::TreeDisconnect, &body, None),
)
.await;
assert_eq!(conn.next_message_id(), 6);
}
#[tokio::test]
async fn signing_applied_to_outgoing_messages() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let key = vec![0xAA; 16];
conn.activate_signing(key, SigningAlgorithm::HmacSha256);
conn.set_session_id(SessionId(0x1234));
use crate::msg::tree_disconnect::TreeDisconnectRequest;
let body = TreeDisconnectRequest;
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute(Command::TreeDisconnect, &body, None),
)
.await;
let msg_bytes = mock.sent_message(0).expect("one send recorded");
let flags = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap());
assert!(flags & HeaderFlags::SIGNED != 0, "message should be signed");
let sig = &msg_bytes[48..64];
assert_ne!(sig, &[0u8; 16], "signature should not be all zeros");
}
#[tokio::test]
async fn negotiate_with_smb2_dialect() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb2_0_2));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let params = conn.params().unwrap();
assert_eq!(params.dialect, Dialect::Smb2_0_2);
assert!(!params.gmac_negotiated);
assert!(params.cipher.is_none());
}
#[tokio::test]
async fn negotiate_sends_all_five_dialects() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = NegotiateRequest::unpack(&mut cursor).unwrap();
assert_eq!(req.dialects.len(), 5);
assert!(req.dialects.contains(&Dialect::Smb2_0_2));
assert!(req.dialects.contains(&Dialect::Smb2_1));
assert!(req.dialects.contains(&Dialect::Smb3_0));
assert!(req.dialects.contains(&Dialect::Smb3_0_2));
assert!(req.dialects.contains(&Dialect::Smb3_1_1));
}
use crate::msg::negotiate::COMPRESSION_LZ4;
use crate::msg::transform::{
CompressionTransformHeader, COMPRESSION_ALGORITHM_LZ4, COMPRESSION_PROTOCOL_ID,
SMB2_COMPRESSION_FLAG_NONE,
};
fn build_negotiate_response_with_compression(dialect: Dialect) -> Vec<u8> {
let resp_header = {
let mut h = Header::new_request(Command::Negotiate);
h.flags.set_response();
h.credits = 32;
h
};
let resp_body = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: dialect,
server_guid: Guid::ZERO,
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING),
max_transact_size: 65536,
max_read_size: 65536,
max_write_size: 65536,
system_time: 132_000_000_000_000_000,
server_start_time: 131_000_000_000_000_000,
security_buffer: vec![0x60, 0x00],
negotiate_contexts: vec![
NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xBB; 32],
},
NegotiateContext::Compression {
flags: 0,
algorithms: vec![COMPRESSION_LZ4],
},
],
};
pack_message(&resp_header, &resp_body)
}
#[tokio::test]
async fn negotiate_detects_compression_support() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let params = conn.params().unwrap();
assert!(params.compression_supported);
assert!(conn.compression_enabled());
}
#[tokio::test]
async fn negotiate_without_compression_context_disables_compression() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let params = conn.params().unwrap();
assert!(!params.compression_supported);
assert!(!conn.compression_enabled());
}
#[tokio::test]
async fn compression_disabled_when_client_config_says_no() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_compression_requested(false);
conn.negotiate().await.unwrap();
let params = conn.params().unwrap();
assert!(params.compression_supported);
assert!(!conn.compression_enabled());
}
#[tokio::test]
async fn negotiate_offers_compression_context_when_requested() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response_with_compression(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.negotiate().await.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = NegotiateRequest::unpack(&mut cursor).unwrap();
let has_compression = req.negotiate_contexts.iter().any(|ctx| {
matches!(ctx, NegotiateContext::Compression { algorithms, .. }
if algorithms.contains(&COMPRESSION_LZ4))
});
assert!(
has_compression,
"negotiate request should include compression context with LZ4"
);
}
#[tokio::test]
async fn negotiate_does_not_offer_compression_when_disabled() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
mock.queue_response(build_negotiate_response(Dialect::Smb3_1_1));
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_compression_requested(false);
conn.negotiate().await.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let _header = Header::unpack(&mut cursor).unwrap();
let req = NegotiateRequest::unpack(&mut cursor).unwrap();
let has_compression = req
.negotiate_contexts
.iter()
.any(|ctx| matches!(ctx, NegotiateContext::Compression { .. }));
assert!(
!has_compression,
"negotiate request should not include compression context"
);
}
#[test]
fn build_compressed_frame_roundtrip() {
let mut message = vec![0xFE; Header::SIZE]; let payload: Vec<u8> = b"COMPRESS_ME_".iter().copied().cycle().take(2048).collect();
message.extend_from_slice(&payload);
let compressed = compress_message(&message, Header::SIZE).expect("should compress");
let framed = build_compressed_frame(&compressed);
assert_eq!(&framed[0..4], &COMPRESSION_PROTOCOL_ID);
let decompressed = decompress_response(&framed).expect("should decompress");
assert_eq!(decompressed, message);
}
#[test]
fn decompress_response_rejects_unsupported_algorithm() {
let header = CompressionTransformHeader {
original_compressed_segment_size: 100,
compression_algorithm: 0x0001, flags: SMB2_COMPRESSION_FLAG_NONE,
offset_or_length: 0,
};
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
let mut frame = cursor.into_inner();
frame.extend_from_slice(&[0u8; 10]);
let result = decompress_response(&frame);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("unsupported compression algorithm"));
}
#[test]
fn decompress_response_rejects_chained_compression() {
let header = CompressionTransformHeader {
original_compressed_segment_size: 100,
compression_algorithm: COMPRESSION_ALGORITHM_LZ4,
flags: 0x0001, offset_or_length: 0,
};
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
let mut frame = cursor.into_inner();
frame.extend_from_slice(&[0u8; 10]);
let result = decompress_response(&frame);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("unchained"));
}
#[test]
fn decompress_response_rejects_too_short_data() {
let result = decompress_response(&[0xFC, b'S', b'M']);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too short"));
}
#[tokio::test]
async fn phase3_decrypt_failure_errors_waiter_not_hangs() {
use crate::crypto::encryption::Cipher;
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_credits(10);
let enc_key = vec![0x42; 16];
let dec_key = vec![0x99; 16]; conn.activate_encryption(enc_key, dec_key, Cipher::Aes128Gcm);
let rx = conn.register_waiter(MessageId(4)).unwrap();
let mut frame = Vec::new();
frame.extend_from_slice(&TRANSFORM_PROTOCOL_ID); frame.extend_from_slice(&[0u8; 16]); frame.extend_from_slice(&[0u8; 16]); frame.extend_from_slice(&64u32.to_le_bytes()); frame.extend_from_slice(&0u16.to_le_bytes()); frame.extend_from_slice(&1u16.to_le_bytes()); frame.extend_from_slice(&0xDEADu64.to_le_bytes()); frame.extend_from_slice(&[0xAAu8; 64]);
mock.queue_response(frame);
let result = tokio::time::timeout(Duration::from_secs(2), await_frame(rx)).await;
assert!(
result.is_ok(),
"waiter hung forever on a decrypt-failed frame — Phase 3's silent-discard \
fix must tear down the connection on unrecoverable frame errors and propagate \
Err(Disconnected) to pending waiters. Instead the receiver task silently discards \
the frame and the waiter never resolves. (P3.4 fixes this.)"
);
let waiter_result = result.unwrap();
assert!(
waiter_result.is_err(),
"waiter should return an error on decrypt failure, not Ok"
);
}
#[tokio::test]
async fn send_cancel_does_not_consume_credit_or_advance_message_id() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_next_message_id(10);
conn.set_credits(5);
conn.send_cancel(MessageId(7), None).await.unwrap();
assert_eq!(conn.next_message_id(), 10);
assert_eq!(conn.credits(), 5);
}
#[tokio::test]
async fn send_cancel_sync_uses_original_message_id() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_session_id(SessionId(0xAAAA));
conn.send_cancel(MessageId(42), None).await.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let header = Header::unpack(&mut cursor).unwrap();
assert_eq!(header.command, Command::Cancel);
assert_eq!(header.message_id, MessageId(42));
assert_eq!(header.credit_charge, CreditCharge(0));
assert_eq!(header.credits, 0);
assert_eq!(header.session_id, SessionId(0xAAAA));
assert!(!header.flags.is_async());
assert_eq!(sent.len(), Header::SIZE + 4);
let body_structure_size = u16::from_le_bytes(sent[64..66].try_into().unwrap());
assert_eq!(body_structure_size, 4);
}
#[tokio::test]
async fn send_cancel_async_sets_async_flag_and_async_id() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_session_id(SessionId(0xBBBB));
let async_id = 0x1234_5678_9ABC_DEF0u64;
conn.send_cancel(MessageId(99), Some(async_id))
.await
.unwrap();
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let header = Header::unpack(&mut cursor).unwrap();
assert_eq!(header.command, Command::Cancel);
assert_eq!(header.message_id, MessageId(99));
assert!(header.flags.is_async());
assert_eq!(header.async_id, Some(async_id));
assert_eq!(header.tree_id, None);
assert_eq!(header.credit_charge, CreditCharge(0));
assert_eq!(header.credits, 0);
}
#[tokio::test]
async fn send_cancel_signs_message_when_signing_active() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let key = vec![0xCC; 16];
conn.activate_signing(key, SigningAlgorithm::HmacSha256);
conn.set_session_id(SessionId(0xDDDD));
conn.send_cancel(MessageId(50), None).await.unwrap();
let sent = mock.sent_message(0).unwrap();
let flags = u32::from_le_bytes(sent[16..20].try_into().unwrap());
assert!(flags & HeaderFlags::SIGNED != 0, "CANCEL should be signed");
let sig = &sent[48..64];
assert_ne!(sig, &[0u8; 16], "signature should not be all zeros");
}
#[tokio::test]
async fn no_encryption_when_not_activated() {
use crate::msg::echo::EchoRequest;
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_test_params(NegotiatedParams {
dialect: Dialect::Smb3_1_1,
max_read_size: 65536,
max_write_size: 65536,
max_transact_size: 65536,
server_guid: Guid::ZERO,
signing_required: false,
capabilities: Capabilities::default(),
gmac_negotiated: false,
cipher: Some(Cipher::Aes128Gcm),
compression_supported: false,
});
conn.set_session_id(SessionId(1));
conn.set_credits(5);
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute(Command::Echo, &EchoRequest, None),
)
.await;
let sent = mock.sent_message(0).unwrap();
assert_eq!(
sent[0], 0xFE,
"without encryption, message must start with 0xFE"
);
}
#[tokio::test]
async fn activate_encryption_sets_state() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
assert!(!conn.should_encrypt());
conn.activate_encryption(vec![0x42; 16], vec![0x42; 16], Cipher::Aes128Gcm);
assert!(conn.should_encrypt());
}
#[tokio::test]
async fn dfs_flag_set_for_registered_tree() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_credits(256);
let tree_id = TreeId(7);
conn.register_dfs_tree(tree_id);
use crate::msg::echo::EchoRequest;
let body = EchoRequest;
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute_with_credits(Command::Echo, &body, Some(tree_id), CreditCharge(1)),
)
.await;
let msg_bytes = mock.sent_message(0).expect("one send recorded");
let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap());
assert_ne!(
flags_raw & HeaderFlags::DFS_OPERATIONS,
0,
"DFS_OPERATIONS flag must be set for registered tree"
);
}
#[tokio::test]
async fn dfs_flag_not_set_for_unregistered_tree() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_credits(256);
use crate::msg::echo::EchoRequest;
let body = EchoRequest;
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute_with_credits(Command::Echo, &body, Some(TreeId(7)), CreditCharge(1)),
)
.await;
let msg_bytes = mock.sent_message(0).expect("one send recorded");
let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap());
assert_eq!(
flags_raw & HeaderFlags::DFS_OPERATIONS,
0,
"DFS_OPERATIONS flag must NOT be set for unregistered tree"
);
}
#[tokio::test]
async fn dfs_flag_cleared_after_deregister() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
conn.set_credits(256);
let tree_id = TreeId(7);
conn.register_dfs_tree(tree_id);
conn.deregister_dfs_tree(tree_id);
use crate::msg::echo::EchoRequest;
let body = EchoRequest;
let _ = tokio::time::timeout(
std::time::Duration::from_millis(50),
conn.execute_with_credits(Command::Echo, &body, Some(tree_id), CreditCharge(1)),
)
.await;
let msg_bytes = mock.sent_message(0).expect("one send recorded");
let flags_raw = u32::from_le_bytes(msg_bytes[16..20].try_into().unwrap());
assert_eq!(
flags_raw & HeaderFlags::DFS_OPERATIONS,
0,
"DFS_OPERATIONS flag must NOT be set after deregister"
);
}
#[tokio::test]
async fn connection_is_cloneable_and_clones_share_state() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut original = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
original.set_credits(42);
original.set_session_id(SessionId(0x1234_5678_9ABC_DEF0));
original.set_next_message_id(100);
let cloned = original.clone();
assert_eq!(cloned.credits(), 42);
assert_eq!(cloned.session_id(), SessionId(0x1234_5678_9ABC_DEF0));
assert_eq!(cloned.next_message_id(), 100);
assert_eq!(cloned.server_name(), "test-server");
cloned.inner.credits.store(7, Ordering::Release);
assert_eq!(original.credits(), 7);
}
fn build_echo_response_with_msg_id(msg_id: MessageId) -> Vec<u8> {
let mut h = Header::new_request(Command::Echo);
h.flags.set_response();
h.credits = 10;
h.message_id = msg_id;
pack_message(&h, &crate::msg::echo::EchoResponse)
}
#[tokio::test(flavor = "multi_thread")]
async fn execute_returns_correct_frame_for_sent_request() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let c = conn.clone();
let handle = tokio::spawn(async move {
c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None)
.await
});
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 1 {
if std::time::Instant::now() > deadline {
panic!("execute task did not send its request in 5s");
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
mock.queue_response(build_echo_response_with_msg_id(MessageId(0)));
let frame = handle.await.unwrap().unwrap();
assert_eq!(frame.header.command, Command::Echo);
assert_eq!(frame.header.message_id, MessageId(0));
assert!(frame.header.is_response());
let mut cursor = ReadCursor::new(&frame.body);
crate::msg::echo::EchoResponse::unpack(&mut cursor).unwrap();
mock.assert_fully_consumed();
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_execute_on_one_connection_all_succeed() {
const N: u64 = 20;
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let mut handles = Vec::with_capacity(N as usize);
for _ in 0..N {
let c = conn.clone();
handles.push(tokio::spawn(async move {
c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None)
.await
}));
}
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < N as usize {
if std::time::Instant::now() > deadline {
panic!(
"tasks did not send all {} requests in 5s (got {})",
N,
mock.sent_count()
);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
for i in 0..N {
mock.queue_response(build_echo_response_with_msg_id(MessageId(i)));
}
let mut got_ids: Vec<u64> = Vec::with_capacity(N as usize);
for h in handles {
let frame = h.await.unwrap().unwrap();
assert_eq!(frame.header.command, Command::Echo);
got_ids.push(frame.header.message_id.0);
}
got_ids.sort_unstable();
assert_eq!(got_ids, (0..N).collect::<Vec<_>>());
mock.assert_fully_consumed();
}
#[tokio::test(flavor = "multi_thread")]
async fn dropped_execute_future_does_not_affect_others() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let mut handles = Vec::new();
for idx in 0..5 {
let c = conn.clone();
let h = tokio::spawn(async move {
c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None)
.await
});
handles.push(h);
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < idx + 1 {
if std::time::Instant::now() > deadline {
panic!(
"task {} did not send its request in 5s (sent_count={})",
idx,
mock.sent_count()
);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
assert_eq!(mock.sent_count(), 5);
handles[1].abort();
handles[3].abort();
for i in 0..5u64 {
mock.queue_response(build_echo_response_with_msg_id(MessageId(i)));
}
for (idx, h) in handles.into_iter().enumerate() {
let res = h.await;
if idx == 1 || idx == 3 {
assert!(res.is_err(), "task {} should have been aborted", idx);
} else {
let frame = res.unwrap().unwrap();
assert_eq!(frame.header.command, Command::Echo);
assert_eq!(frame.header.message_id, MessageId(idx as u64));
}
}
mock.assert_fully_consumed();
}
#[tokio::test(flavor = "multi_thread")]
async fn execute_compound_partial_failure_routes_correctly() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let echo_ok_0 = build_echo_response_with_msg_id(MessageId(0));
let mut err_hdr = Header::new_request(Command::Echo);
err_hdr.flags.set_response();
err_hdr.credits = 10;
err_hdr.message_id = MessageId(1);
err_hdr.status = NtStatus::OBJECT_NAME_NOT_FOUND;
let err_body = pack_message(
&err_hdr,
&crate::msg::header::ErrorResponse {
error_context_count: 0,
error_data: vec![],
},
);
let echo_ok_2 = build_echo_response_with_msg_id(MessageId(2));
let compound_response = build_compound_response_frame(&[echo_ok_0, err_body, echo_ok_2]);
let conn = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let c = conn.clone();
let handle = tokio::spawn(async move {
let ops = [
CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None),
CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None),
CompoundOp::new(Command::Echo, &crate::msg::echo::EchoRequest, None),
];
c.execute_compound(&ops).await
});
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 1 {
if std::time::Instant::now() > deadline {
panic!("execute_compound did not send in 5s");
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
mock.queue_response(compound_response);
let results = handle.await.unwrap().unwrap();
assert_eq!(results.len(), 3);
let f0 = results[0].as_ref().expect("op 0 should be Ok");
assert_eq!(f0.header.status, NtStatus::SUCCESS);
assert_eq!(f0.header.message_id, MessageId(0));
let f1 = results[1]
.as_ref()
.expect("op 1 still carries a Frame — error status in header");
assert_eq!(f1.header.status, NtStatus::OBJECT_NAME_NOT_FOUND);
assert_eq!(f1.header.message_id, MessageId(1));
let f2 = results[2].as_ref().expect("op 2 should be Ok");
assert_eq!(f2.header.status, NtStatus::SUCCESS);
assert_eq!(f2.header.message_id, MessageId(2));
mock.assert_fully_consumed();
}
#[tokio::test(flavor = "multi_thread")]
async fn execute_on_clone_works_after_original_dropped() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let original = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
let cloned = original.clone();
drop(original);
let c = cloned.clone();
let handle = tokio::spawn(async move {
c.execute(Command::Echo, &crate::msg::echo::EchoRequest, None)
.await
});
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while mock.sent_count() < 1 {
if std::time::Instant::now() > deadline {
panic!("execute on clone did not send in 5s");
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
mock.queue_response(build_echo_response_with_msg_id(MessageId(0)));
let frame = handle.await.unwrap().unwrap();
assert_eq!(frame.header.command, Command::Echo);
assert_eq!(frame.header.message_id, MessageId(0));
mock.assert_fully_consumed();
}
#[tokio::test]
async fn connection_is_cloneable_clone_outlives_original() {
let mock = Arc::new(MockTransport::new());
mock.enable_auto_rewrite_msg_id();
let mut original = Connection::from_transport(
Box::new(mock.clone()),
Box::new(mock.clone()),
"test-server",
);
original.set_credits(9);
let cloned = original.clone();
drop(original);
assert_eq!(cloned.credits(), 9);
assert_eq!(cloned.server_name(), "test-server");
cloned
.inner
.sender
.send(b"\x00\x00\x00\x10ignore-me")
.await
.unwrap();
assert_eq!(mock.sent_count(), 1);
}
}