pub mod config;
pub mod connection_info;
#[cfg(not(feature = "single_threaded"))]
pub mod notification_handler;
pub mod preauth_hash;
pub mod transformer;
pub mod transport;
pub mod worker;
use crate::dialects::DialectImpl;
use crate::packets::guid::Guid;
use crate::packets::smb2::{Command, Response};
use crate::Error;
use crate::{compression, sync_helpers::*};
use crate::{
crypto,
msg_handler::*,
packets::{
smb1::SMB1NegotiateMessage,
smb2::{negotiate::*, plain::*},
},
session::Session,
};
use binrw::prelude::*;
pub use config::*;
use connection_info::{ConnectionInfo, NegotiatedProperties};
use maybe_async::*;
#[cfg(not(feature = "single_threaded"))]
use notification_handler::NotificationHandler;
use rand::rngs::OsRng;
use rand::Rng;
use std::cmp::max;
use std::sync::atomic::{AtomicU16, AtomicU64};
use std::sync::Arc;
use std::time::Duration;
pub use transformer::TransformError;
use transport::{make_transport, SmbTransport};
use worker::{Worker, WorkerImpl};
pub struct Connection {
handler: HandlerReference<ConnectionMessageHandler>,
config: ConnectionConfig,
server: String,
}
impl Connection {
pub fn build(server: String, config: ConnectionConfig) -> crate::Result<Connection> {
config.validate()?;
let client_guid = config.client_guid.unwrap_or_else(Guid::gen);
Ok(Connection {
handler: HandlerReference::new(ConnectionMessageHandler::new(client_guid)),
config,
server,
})
}
#[maybe_async]
pub async fn set_timeout(&mut self, timeout: Duration) -> crate::Result<()> {
self.config.timeout = Some(timeout);
if let Some(worker) = self.handler.worker.get() {
worker.set_timeout(timeout).await?;
}
Ok(())
}
#[maybe_async]
pub async fn connect(&mut self) -> crate::Result<()> {
if self.handler.worker().is_some() {
return Err(Error::InvalidState("Already connected".into()));
}
let mut transport = make_transport(&self.config.transport, self.config.timeout())?;
let port = self.config.port.unwrap_or_else(|| transport.default_port());
let endpoint = format!("{}:{}", self.server, port);
log::debug!("Connecting to {}...", &endpoint);
transport.connect(endpoint.as_str()).await?;
log::info!("Connected to {}. Negotiating.", &endpoint);
self.negotiate(transport, self.config.smb2_only_negotiate)
.await?;
Ok(())
}
#[maybe_async]
pub async fn close(&self) -> crate::Result<()> {
match self.handler.worker() {
Some(c) => c.stop().await,
None => Ok(()),
}
}
#[maybe_async]
async fn negotiate_switch_to_smb2(
&mut self,
mut transport: Box<dyn SmbTransport>,
smb2_only_neg: bool,
) -> crate::Result<Arc<WorkerImpl>> {
if !smb2_only_neg {
log::debug!("Negotiating multi-protocol: Sending SMB1");
let msg_bytes: Vec<u8> = SMB1NegotiateMessage::default().try_into()?;
transport.send(&msg_bytes).await?;
log::debug!("Sent SMB1 negotiate request, Receieving SMB2 response");
let recieved_bytes = transport.receive().await?;
let response = Response::try_from(recieved_bytes.as_ref())?;
let message = match response {
Response::Plain(m) => m,
_ => {
return Err(Error::InvalidMessage(
"Expected SMB2 negotiate response, got SMB1".to_string(),
))
}
};
let smb2_negotiate_response = message.content.to_negotiate()?;
if smb2_negotiate_response.dialect_revision != NegotiateDialect::Smb02Wildcard {
return Err(Error::InvalidMessage(
"Expected SMB2 wildcard dialect".to_string(),
));
}
if message.header.message_id != 0 {
return Err(Error::InvalidMessage("Expected message ID 0".to_string()));
}
if message.header.credit_charge != 0 || message.header.credit_request != 1 {
return Err(Error::InvalidMessage(
"Expected credit charge 0 and request 1 for initial message.".to_string(),
));
}
self.handler
.curr_msg_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
WorkerImpl::start(transport, self.config.timeout()).await
}
#[maybe_async]
async fn negotiate_smb2(&mut self) -> crate::Result<ConnectionInfo> {
if self.handler.conn_info.get().is_some() {
return Err(Error::InvalidState("Already negotiated".into()));
}
log::debug!("Negotiating SMB2");
let min_dialect = self.config.min_dialect.unwrap_or(Dialect::MIN);
let max_dialect = self.config.max_dialect.unwrap_or(Dialect::MAX);
let dialects: Vec<Dialect> = Dialect::ALL
.iter()
.filter(|dialect| **dialect >= min_dialect && **dialect <= max_dialect)
.copied()
.collect();
if dialects.is_empty() {
return Err(Error::InvalidConfiguration(
"No dialects to negotiate".to_string(),
));
}
let encryption_algos = if !self.config.encryption_mode.is_disabled() {
crypto::ENCRYPTING_ALGOS.into()
} else {
vec![]
};
let response = self
.handler
.send_recv(
self.make_smb2_neg_request(
dialects,
crypto::SIGNING_ALGOS.to_vec(),
encryption_algos,
compression::SUPPORTED_ALGORITHMS.to_vec(),
)
.into(),
)
.await?;
let smb2_negotiate_response = response.message.content.to_negotiate()?;
let dialect_rev = smb2_negotiate_response.dialect_revision.try_into()?;
if dialect_rev > max_dialect || dialect_rev < min_dialect {
return Err(Error::NegotiationError(
"Server selected an unsupported dialect.".into(),
));
}
let dialect_impl = DialectImpl::new(dialect_rev);
let mut negotiation = NegotiatedProperties {
server_guid: smb2_negotiate_response.server_guid,
caps: smb2_negotiate_response.capabilities,
max_transact_size: smb2_negotiate_response.max_transact_size,
max_read_size: smb2_negotiate_response.max_read_size,
max_write_size: smb2_negotiate_response.max_write_size,
auth_buffer: smb2_negotiate_response.buffer.clone(),
signing_algo: None,
encryption_cipher: None,
compression: None,
dialect_rev,
};
dialect_impl.process_negotiate_request(
&smb2_negotiate_response,
&mut negotiation,
&self.config,
)?;
if ((!u32::from_le_bytes(dialect_impl.get_negotiate_caps_mask().into_bytes()))
& u32::from_le_bytes(negotiation.caps.into_bytes()))
!= 0
{
return Err(Error::NegotiationError(
"Server capabilities are invalid for the selected dialect.".into(),
));
}
log::trace!(
"Negotiated SMB results: dialect={:?}, state={:?}",
dialect_rev,
&negotiation
);
Ok(ConnectionInfo {
negotiation,
dialect: dialect_impl,
config: self.config.clone(),
server: self.server.clone(),
})
}
fn make_smb2_neg_request(
&self,
supported_dialects: Vec<Dialect>,
signing_algorithms: Vec<SigningAlgorithmId>,
encrypting_algorithms: Vec<EncryptionCipher>,
compression_algorithms: Vec<CompressionAlgorithm>,
) -> NegotiateRequest {
let client_guid = self.handler.client_guid;
let client_netname = self
.config
.client_name
.clone()
.unwrap_or_else(|| "smb-client".to_string());
let has_signing = !signing_algorithms.is_empty();
let has_encryption = !encrypting_algorithms.is_empty();
let ctx_list = if supported_dialects.contains(&Dialect::Smb0311) {
let mut ctx_list = vec![
NegotiateContext {
context_type: NegotiateContextType::PreauthIntegrityCapabilities,
data: NegotiateContextValue::PreauthIntegrityCapabilities(
PreauthIntegrityCapabilities {
hash_algorithms: vec![HashAlgorithm::Sha512],
salt: (0..32).map(|_| OsRng.gen()).collect(),
},
),
},
NegotiateContext {
context_type: NegotiateContextType::NetnameNegotiateContextId,
data: NegotiateContextValue::NetnameNegotiateContextId(
NetnameNegotiateContextId {
netname: client_netname.into(),
},
),
},
NegotiateContext {
context_type: NegotiateContextType::EncryptionCapabilities,
data: NegotiateContextValue::EncryptionCapabilities(EncryptionCapabilities {
ciphers: encrypting_algorithms,
}),
},
NegotiateContext {
context_type: NegotiateContextType::CompressionCapabilities,
data: NegotiateContextValue::CompressionCapabilities(CompressionCapabilities {
flags: CompressionCapsFlags::new()
.with_chained(!compression_algorithms.is_empty()),
compression_algorithms,
}),
},
NegotiateContext {
context_type: NegotiateContextType::SigningCapabilities,
data: NegotiateContextValue::SigningCapabilities(SigningCapabilities {
signing_algorithms,
}),
},
];
if matches!(self.config.transport, TransportConfig::Quic(_)) {
ctx_list.push(NegotiateContext {
context_type: NegotiateContextType::TransportCapabilities,
data: NegotiateContextValue::TransportCapabilities(
TransportCapabilities::new().with_accept_transport_layer_security(true),
),
});
}
Some(ctx_list)
} else {
None
};
let capabilities = if supported_dialects.iter().all(|d| !d.is_smb3()) {
GlobalCapabilities::new()
} else {
let capabilities = GlobalCapabilities::new()
.with_dfs(true)
.with_leasing(true)
.with_large_mtu(true)
.with_multi_channel(true)
.with_persistent_handles(true)
.with_directory_leasing(true);
if has_encryption {
capabilities.with_encryption(true);
}
if !self.config.disable_notifications
&& cfg!(not(feature = "single_threaded"))
&& supported_dialects.contains(&Dialect::Smb0311)
{
capabilities.with_notifications(true);
}
capabilities
};
let security_mode = NegotiateSecurityMode::new().with_signing_enabled(has_signing);
NegotiateRequest {
security_mode,
capabilities,
client_guid,
dialects: supported_dialects,
negotiate_context_list: ctx_list,
}
}
#[maybe_async]
async fn negotiate(
&mut self,
transport: Box<dyn SmbTransport>,
smb2_only_neg: bool,
) -> crate::Result<()> {
if self.handler.conn_info.get().is_some() {
return Err(Error::InvalidState("Already negotiated".into()));
}
let worker = self
.negotiate_switch_to_smb2(transport, smb2_only_neg)
.await?;
self.handler.worker.set(worker).unwrap();
let info = self.negotiate_smb2().await?;
self.handler
.worker
.get()
.ok_or("Worker is uninitialized")
.unwrap()
.negotaite_complete(&info)
.await;
#[cfg(not(feature = "single_threaded"))]
if !self.config.disable_notifications && info.negotiation.caps.notifications() {
self.handler.start_notification_handler().await?;
}
self.handler.conn_info.set(Arc::new(info)).unwrap();
log::info!("Negotiation successful");
Ok(())
}
#[maybe_async]
pub async fn authenticate(&self, user_name: &str, password: String) -> crate::Result<Session> {
Session::setup(
user_name,
password,
&self.handler,
self.handler.conn_info.get().unwrap(),
)
.await
}
}
pub struct ConnectionMessageHandler {
client_guid: Guid,
extra_credits_to_request: u16,
worker: OnceCell<Arc<WorkerImpl>>,
#[cfg(not(feature = "single_threaded"))]
notification_handler: OnceCell<NotificationHandler>,
conn_info: OnceCell<Arc<ConnectionInfo>>,
curr_credits: Semaphore,
curr_msg_id: AtomicU64,
credit_pool: AtomicU16,
}
impl ConnectionMessageHandler {
fn new(client_guid: Guid) -> ConnectionMessageHandler {
ConnectionMessageHandler {
client_guid,
worker: OnceCell::new(),
conn_info: OnceCell::new(),
extra_credits_to_request: 4,
curr_credits: Semaphore::new(1),
curr_msg_id: AtomicU64::new(0),
credit_pool: AtomicU16::new(1),
#[cfg(not(feature = "single_threaded"))]
notification_handler: OnceCell::new(),
}
}
pub fn worker(&self) -> Option<&Arc<WorkerImpl>> {
self.worker.get()
}
const SET_CREDIT_CHARGE_CMDS: &'static [Command] = &[
Command::Read,
Command::Write,
Command::Ioctl,
Command::QueryDirectory,
];
const CREDIT_CALC_RATIO: u32 = 65536;
#[maybe_async]
async fn process_sequence_outgoing(&self, msg: &mut OutgoingMessage) -> crate::Result<()> {
if let Some(neg) = self.conn_info.get() {
if neg.negotiation.caps.large_mtu() {
let cost = if Self::SET_CREDIT_CHARGE_CMDS.contains(&msg.message.header.command) {
let send_payload_size = msg.message.content.req_payload_size();
let expected_response_payload_size = msg.message.content.expected_resp_size();
(1 + (max(send_payload_size, expected_response_payload_size) - 1)
/ Self::CREDIT_CALC_RATIO)
.try_into()
.unwrap()
} else {
1
};
self.curr_credits.acquire_many(cost as u32).await?.forget();
let mut request = cost;
let current_pool_size = self.credit_pool.load(std::sync::atomic::Ordering::SeqCst);
if current_pool_size < self.extra_credits_to_request {
request += self.extra_credits_to_request - current_pool_size;
}
msg.message.header.credit_charge = cost;
msg.message.header.credit_request = request;
msg.message.header.message_id = self
.curr_msg_id
.fetch_add(cost as u64, std::sync::atomic::Ordering::SeqCst);
return Ok(());
} else {
debug_assert_eq!(msg.message.header.credit_request, 0);
debug_assert_eq!(msg.message.header.credit_charge, 0);
}
}
{
msg.message.header.message_id = self
.curr_msg_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
#[maybe_async]
async fn process_sequence_incoming(&self, msg: &IncomingMessage) -> crate::Result<()> {
if let Some(neg) = self.conn_info.get() {
if neg.negotiation.caps.large_mtu() {
let granted_credits = msg.message.header.credit_request;
let charged_credits = msg.message.header.credit_charge;
if charged_credits > granted_credits {
self.credit_pool.fetch_sub(
charged_credits - granted_credits,
std::sync::atomic::Ordering::SeqCst,
);
} else {
self.credit_pool.fetch_add(
granted_credits - charged_credits,
std::sync::atomic::Ordering::SeqCst,
);
}
self.curr_credits.add_permits(granted_credits as usize);
}
}
Ok(())
}
#[cfg(not(feature = "single_threaded"))]
#[maybe_async]
async fn start_notification_handler(&self) -> crate::Result<()> {
let worker = self.worker.get().unwrap();
let handler = NotificationHandler::start(worker)?;
self.notification_handler
.set(handler)
.map_err(|_| Error::InvalidState("Notification handler already started".into()))?;
Ok(())
}
}
impl MessageHandler for ConnectionMessageHandler {
#[maybe_async]
async fn sendo(&self, mut msg: OutgoingMessage) -> crate::Result<SendMessageResult> {
let priority_value = match self.conn_info.get() {
Some(neg_info) => match neg_info.negotiation.dialect_rev {
Dialect::Smb0311 => 1,
_ => 0,
},
None => 0,
};
msg.message.header.flags = msg.message.header.flags.with_priority_mask(priority_value);
self.process_sequence_outgoing(&mut msg).await?;
self.worker
.get()
.ok_or(Error::InvalidState("Worker is uninitialized".into()))?
.send(msg)
.await
}
#[maybe_async]
async fn recvo(&self, options: ReceiveOptions<'_>) -> crate::Result<IncomingMessage> {
let msg = self.worker.get().unwrap().receive(&options).await?;
if let Some(cmd) = options.cmd {
if msg.message.header.command != cmd {
return Err(Error::UnexpectedMessageCommand(msg.message.header.command));
}
}
if !msg.message.header.flags.server_to_redir() {
return Err(Error::InvalidMessage(
"Expected server-to-redir message".into(),
));
}
self.process_sequence_incoming(&msg).await?;
if !options
.status
.iter()
.any(|s| msg.message.header.status == *s as u32)
{
if let ResponseContent::Error(error_res) = msg.message.content {
return Err(Error::ReceivedErrorMessage(
msg.message.header.status,
error_res,
));
}
return Err(Error::UnexpectedMessageStatus(msg.message.header.status));
}
Ok(msg)
}
}