pub mod config;
pub mod connection_info;
pub mod netbios_client;
pub mod preauth_hash;
pub mod transformer;
pub mod worker;
use crate::dialects::DialectImpl;
use crate::packets::guid::Guid;
use crate::packets::smb2::{Command, Message};
use crate::Error;
use crate::{compression, sync_helpers::*};
use crate::{
crypto,
msg_handler::*,
packets::{
netbios::NetBiosMessageContent,
smb1::SMB1NegotiateMessage,
smb2::{negotiate::*, plain::*},
},
session::Session,
};
use binrw::prelude::*;
pub use config::*;
use connection_info::{ConnectionInfo, NegotiatedProperties};
use maybe_async::*;
use netbios_client::NetBiosClient;
use std::cmp::max;
use std::sync::atomic::{AtomicU16, AtomicU64};
use std::sync::Arc;
use std::time::Duration;
pub use transformer::TransformError;
use worker::{Worker, WorkerImpl};
pub struct Connection {
handler: HandlerReference<ConnectionMessageHandler>,
config: ConnectionConfig,
}
impl Connection {
pub fn build(config: ConnectionConfig) -> crate::Result<Connection> {
config.validate()?;
let client_guid = config.client_guid.unwrap_or_else(Guid::gen);
Ok(Connection {
handler: HandlerReference::new(ConnectionMessageHandler::new(client_guid)),
config,
})
}
#[maybe_async]
pub async fn set_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> {
self.config.timeout = timeout;
if let Some(worker) = self.handler.worker.get() {
worker.set_timeout(timeout).await?;
}
Ok(())
}
#[maybe_async]
pub async fn connect(&mut self, address: &str) -> crate::Result<()> {
if self.handler.worker().is_some() {
return Err(Error::InvalidState("Already connected".into()));
}
let mut netbios_client = NetBiosClient::new(self.config.timeout);
log::debug!("Connecting to {}...", address);
netbios_client.connect(address).await?;
log::info!("Connected to {}. Negotiating.", address);
self.negotiate(netbios_client, true).await?;
Ok(())
}
#[maybe_async]
pub async fn close(&self) -> crate::Result<()> {
match self.handler.worker().take() {
Some(c) => c.stop().await,
None => Ok(()),
}
}
#[maybe_async]
async fn negotiate_switch_to_smb2(
&mut self,
mut netbios_client: NetBiosClient,
negotiate_smb1: bool,
) -> crate::Result<Arc<WorkerImpl>> {
if negotiate_smb1 {
log::debug!("Negotiating multi-protocol");
netbios_client
.send(NetBiosMessageContent::SMB1Message(
SMB1NegotiateMessage::new(),
))
.await?;
let response = netbios_client.received_bytes().await?.parse_content()?;
let message = match response {
NetBiosMessageContent::SMB2Message(Message::Plain(m)) => m,
_ => {
return Err(Error::InvalidMessage(
"Expected SMB2 negotiate response, got SMB1".to_string(),
))
}
};
let smb2_negotiate_response = message.content.to_negotiateresponse()?;
if smb2_negotiate_response.dialect_revision != NegotiateDialect::Smb02Wildcard {
return Err(Error::InvalidMessage(
"Expected SMB2 wildcard dialect".to_string(),
));
}
if message.header.message_id != 0 {
return Err(Error::InvalidMessage("Expected message ID 0".to_string()));
}
}
Ok(WorkerImpl::start(netbios_client, self.config.timeout).await?)
}
#[maybe_async]
async fn negotiate_smb2(&mut self) -> crate::Result<ConnectionInfo> {
if self.handler.conn_info.get().is_some() {
return Err(Error::InvalidState("Already negotiated".into()));
}
log::debug!("Negotiating SMB2");
let min_dialect = self.config.min_dialect.unwrap_or(Dialect::MIN);
let max_dialect = self.config.max_dialect.unwrap_or(Dialect::MAX);
let dialects: Vec<Dialect> = Dialect::ALL
.iter()
.filter(|dialect| **dialect >= min_dialect && **dialect <= max_dialect)
.copied()
.collect();
if dialects.is_empty() {
return Err(Error::InvalidConfiguration(
"No dialects to negotiate".to_string(),
));
}
let encryption_algos = if !self.config.encryption_mode.is_disabled() {
crypto::SIGNING_ALGOS.into()
} else {
vec![]
};
let client_guid = self.handler.client_guid;
let hostname = self
.config
.client_name
.clone()
.unwrap_or_else(|| "smb-client".to_string());
let response = self
.handler
.send_recv(Content::NegotiateRequest(NegotiateRequest::new(
hostname,
client_guid,
dialects,
encryption_algos,
crypto::ENCRYPTING_ALGOS.to_vec(),
compression::SUPPORTED_ALGORITHMS.to_vec(),
)))
.await?;
let smb2_negotiate_response = response.message.content.to_negotiateresponse()?;
let dialect_rev = smb2_negotiate_response.dialect_revision.try_into()?;
if dialect_rev > max_dialect || dialect_rev < min_dialect {
return Err(Error::NegotiationError(
"Server selected an unsupported dialect.".into(),
));
}
let dialect_impl = DialectImpl::new(dialect_rev);
let mut negotiation = NegotiatedProperties {
server_guid: smb2_negotiate_response.server_guid,
caps: smb2_negotiate_response.capabilities.clone(),
max_transact_size: smb2_negotiate_response.max_transact_size,
max_read_size: smb2_negotiate_response.max_read_size,
max_write_size: smb2_negotiate_response.max_write_size,
auth_buffer: smb2_negotiate_response.buffer.clone(),
signing_algo: None,
encryption_cipher: None,
compression: None,
dialect_rev,
};
dialect_impl.process_negotiate_request(
&smb2_negotiate_response,
&mut negotiation,
&self.config,
)?;
if ((!u32::from_le_bytes(dialect_impl.get_negotiate_caps_mask().into_bytes()))
& u32::from_le_bytes(negotiation.caps.into_bytes()))
!= 0
{
return Err(Error::NegotiationError(
"Server capabilities are invalid for the selected dialect.".into(),
));
}
log::trace!(
"Negotiated SMB results: dialect={:?}, state={:?}",
dialect_rev,
&negotiation
);
Ok(ConnectionInfo {
negotiation,
dialect: dialect_impl,
config: self.config.clone(),
})
}
#[maybe_async]
async fn negotiate(
&mut self,
netbios_client: NetBiosClient,
multi_protocol: bool,
) -> crate::Result<()> {
if self.handler.conn_info.get().is_some() {
return Err(Error::InvalidState("Already negotiated".into()));
}
let worker = self
.negotiate_switch_to_smb2(netbios_client, multi_protocol)
.await?;
self.handler.worker.set(worker).unwrap();
let info = self.negotiate_smb2().await?;
self.handler
.worker
.get()
.ok_or("Worker is uninitialized")
.unwrap()
.negotaite_complete(&info)
.await;
self.handler.conn_info.set(Arc::new(info)).unwrap();
log::info!("Negotiation successful");
Ok(())
}
#[maybe_async]
pub async fn authenticate(
self: &mut Connection,
user_name: &str,
password: String,
) -> crate::Result<Session> {
Session::setup(
user_name,
password,
&self.handler,
self.handler.conn_info.get().unwrap(),
)
.await
}
}
pub struct ConnectionMessageHandler {
client_guid: Guid,
extra_credits_to_request: u16,
worker: OnceCell<Arc<WorkerImpl>>,
conn_info: OnceCell<Arc<ConnectionInfo>>,
curr_credits: Semaphore,
curr_msg_id: AtomicU64,
credit_pool: AtomicU16,
}
impl ConnectionMessageHandler {
fn new(client_guid: Guid) -> ConnectionMessageHandler {
ConnectionMessageHandler {
client_guid,
worker: OnceCell::new(),
conn_info: OnceCell::new(),
extra_credits_to_request: 4,
curr_credits: Semaphore::new(1),
curr_msg_id: AtomicU64::new(1),
credit_pool: AtomicU16::new(1),
}
}
pub fn worker(&self) -> Option<&Arc<WorkerImpl>> {
self.worker.get()
}
const SET_CREDIT_CHARGE_CMDS: &[Command] = &[
Command::Read,
Command::Write,
Command::Ioctl,
Command::QueryDirectory,
];
const CREDIT_CALC_RATIO: u32 = 65536;
#[maybe_async]
async fn process_sequence_outgoing(&self, msg: &mut OutgoingMessage) -> crate::Result<()> {
if let Some(neg) = self.conn_info.get() {
if neg.negotiation.caps.large_mtu() {
let cost = if Self::SET_CREDIT_CHARGE_CMDS
.iter()
.any(|&cmd| cmd == msg.message.header.command)
{
let send_payload_size = msg.message.content.req_payload_size();
let expected_response_payload_size = msg.message.content.expected_resp_size();
(1 + (max(send_payload_size, expected_response_payload_size) - 1)
/ Self::CREDIT_CALC_RATIO)
.try_into()
.unwrap()
} else {
1
};
self.curr_credits.acquire_many(cost as u32).await?.forget();
let mut request = cost;
let current_pool_size = self.credit_pool.load(std::sync::atomic::Ordering::SeqCst);
if current_pool_size < self.extra_credits_to_request {
request += self.extra_credits_to_request - current_pool_size;
}
msg.message.header.credit_charge = cost;
msg.message.header.credit_request = request;
msg.message.header.message_id = self
.curr_msg_id
.fetch_add(cost as u64, std::sync::atomic::Ordering::SeqCst);
return Ok(());
} else {
debug_assert_eq!(msg.message.header.credit_request, 0);
debug_assert_eq!(msg.message.header.credit_charge, 0);
}
}
{
msg.message.header.message_id = self
.curr_msg_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
#[maybe_async]
async fn process_sequence_incoming(&self, msg: &IncomingMessage) -> crate::Result<()> {
if let Some(neg) = self.conn_info.get() {
if neg.negotiation.caps.large_mtu() {
let granted_credits = msg.message.header.credit_request;
let charged_credits = msg.message.header.credit_charge;
if charged_credits > granted_credits {
self.credit_pool.fetch_sub(
charged_credits - granted_credits,
std::sync::atomic::Ordering::SeqCst,
);
} else {
self.credit_pool.fetch_add(
granted_credits - charged_credits,
std::sync::atomic::Ordering::SeqCst,
);
}
self.curr_credits.add_permits(granted_credits as usize);
}
}
Ok(())
}
}
impl MessageHandler for ConnectionMessageHandler {
#[maybe_async]
async fn sendo(&self, mut msg: OutgoingMessage) -> crate::Result<SendMessageResult> {
let priority_value = match self.conn_info.get() {
Some(neg_info) => match neg_info.negotiation.dialect_rev {
Dialect::Smb0311 => 1,
_ => 0,
},
None => 0,
};
msg.message.header.flags = msg.message.header.flags.with_priority_mask(priority_value);
self.process_sequence_outgoing(&mut msg).await?;
Ok(self
.worker
.get()
.ok_or(Error::InvalidState("Worker is uninitialized".into()))?
.send(msg)
.await?)
}
#[maybe_async]
async fn recvo(&self, options: ReceiveOptions) -> crate::Result<IncomingMessage> {
let msg = self
.worker
.get()
.unwrap()
.receive(options.msg_id_filter)
.await?;
if let Some(cmd) = options.cmd {
if msg.message.header.command != cmd {
return Err(Error::UnexpectedCommand(msg.message.header.command));
}
}
if !msg.message.header.flags.server_to_redir() {
return Err(Error::InvalidMessage(
"Expected server-to-redir message".into(),
));
}
if msg.message.header.status != options.status {
if let Content::ErrorResponse(error_res) = msg.message.content {
return Err(Error::ReceivedErrorMessage(
msg.message.header.status,
error_res,
));
}
return Err(Error::UnexpectedMessageStatus(msg.message.header.status));
}
self.process_sequence_incoming(&msg).await?;
Ok(msg)
}
}