pub mod mock;
pub mod tcp;
pub use mock::MockTransport;
pub use tcp::TcpTransport;
use crate::error::Result;
use async_trait::async_trait;
#[async_trait]
pub trait TransportSend: Send + Sync {
async fn send(&self, data: &[u8]) -> Result<()>;
}
#[async_trait]
pub trait TransportReceive: Send + Sync {
async fn receive(&self) -> Result<Vec<u8>>;
}
pub trait Transport: TransportSend + TransportReceive {}
impl<T: TransportSend + TransportReceive> Transport for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::msg::header::{Header, PROTOCOL_ID};
use crate::msg::negotiate::{
NegotiateContext, NegotiateRequest, NegotiateResponse, HASH_ALGORITHM_SHA512,
};
use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor};
use crate::types::flags::{Capabilities, SecurityMode};
use crate::types::{Command, Dialect};
fn pack_message(header: &Header, body: &dyn Pack) -> Vec<u8> {
let mut cursor = WriteCursor::new();
header.pack(&mut cursor);
body.pack(&mut cursor);
cursor.into_inner()
}
#[tokio::test]
async fn cross_module_negotiate_via_mock_transport() {
let mock = MockTransport::new();
let req_header = Header::new_request(Command::Negotiate);
let req_body = NegotiateRequest {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::default(),
client_guid: Guid {
data1: 0xDEAD_BEEF,
data2: 0xCAFE,
data3: 0xF00D,
data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
},
dialects: vec![Dialect::Smb2_0_2, Dialect::Smb2_1, Dialect::Smb3_1_1],
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xAA; 32],
}],
};
let req_msg = pack_message(&req_header, &req_body);
let resp_header = {
let mut h = Header::new_request(Command::Negotiate);
h.flags.set_response();
h.credits = 1;
h
};
let resp_body = NegotiateResponse {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
dialect_revision: Dialect::Smb3_1_1,
server_guid: Guid {
data1: 0x1111_2222,
data2: 0x3333,
data3: 0x4444,
data4: [0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC],
},
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING),
max_transact_size: 65536,
max_read_size: 65536,
max_write_size: 65536,
system_time: 132_000_000_000_000_000,
server_start_time: 131_000_000_000_000_000,
security_buffer: vec![0x60, 0x00], negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xBB; 32],
}],
};
let resp_msg = pack_message(&resp_header, &resp_body);
mock.queue_response(resp_msg);
mock.send(&req_msg).await.unwrap();
let received = mock.receive().await.unwrap();
let mut cursor = ReadCursor::new(&received);
let hdr = Header::unpack(&mut cursor).unwrap();
assert!(hdr.is_response());
assert_eq!(hdr.command, Command::Negotiate);
let body = NegotiateResponse::unpack(&mut cursor).unwrap();
assert_eq!(body.dialect_revision, Dialect::Smb3_1_1);
assert_eq!(body.max_read_size, 65536);
assert!(body.security_mode.signing_enabled());
assert_eq!(mock.sent_count(), 1);
let sent = mock.sent_message(0).unwrap();
let mut cursor = ReadCursor::new(&sent);
let sent_hdr = Header::unpack(&mut cursor).unwrap();
assert_eq!(sent_hdr.command, Command::Negotiate);
assert!(!sent_hdr.is_response());
let sent_body = NegotiateRequest::unpack(&mut cursor).unwrap();
assert_eq!(sent_body.dialects.len(), 3);
assert!(sent_body.dialects.contains(&Dialect::Smb3_1_1));
}
#[tokio::test]
#[ignore] async fn negotiate_via_tcp_transport() {
use std::time::Duration;
let transport = TcpTransport::connect("192.168.1.111:445", Duration::from_secs(5))
.await
.expect("failed to connect to NAS");
let header = Header::new_request(Command::Negotiate);
let request = NegotiateRequest {
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
capabilities: Capabilities::new(
Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU,
),
client_guid: Guid {
data1: 0xDEAD_BEEF,
data2: 0xCAFE,
data3: 0xF00D,
data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
},
dialects: vec![
Dialect::Smb2_0_2,
Dialect::Smb2_1,
Dialect::Smb3_0,
Dialect::Smb3_0_2,
Dialect::Smb3_1_1,
],
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
salt: vec![0xAA; 32],
}],
};
let msg = pack_message(&header, &request);
transport.send(&msg).await.unwrap();
let resp_bytes = transport.receive().await.unwrap();
assert!(resp_bytes[0..4] == PROTOCOL_ID);
let mut cursor = ReadCursor::new(&resp_bytes);
let resp_header = Header::unpack(&mut cursor).unwrap();
assert!(resp_header.is_response());
assert_eq!(resp_header.command, Command::Negotiate);
let resp_body = NegotiateResponse::unpack(&mut cursor).unwrap();
assert!(Dialect::ALL.contains(&resp_body.dialect_revision));
assert!(resp_body.max_read_size >= 65536);
}
}