use crate::compression::{Compressor, Decompressor};
use super::negotiation_state::SmbNegotiateState;
use super::netbios_client::NetBiosClient;
use super::preauth_hash::*;
use crate::packets::guid::Guid;
use crate::{
crypto,
msg_handler::*,
packets::{
netbios::{NetBiosMessageContent, NetBiosTcpMessage},
smb1::SMB1NegotiateMessage,
smb2::{header::*, message::*, negotiate::*, plain::*},
},
session::Session,
};
use binrw::prelude::*;
use core::panic;
use std::{cell::OnceCell, error::Error, fmt::Display, io::Cursor};
pub struct Connection {
handler: HandlerReference<ClientMessageHandler>,
}
#[derive(Debug, Clone)]
pub struct SmbClientNotConnectedError;
impl Display for SmbClientNotConnectedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SMB client is not connected")
}
}
impl Error for SmbClientNotConnectedError {}
impl Connection {
pub fn new() -> Connection {
Connection {
handler: HandlerReference::new(ClientMessageHandler::new()),
}
}
pub fn connect(&mut self, address: &str) -> Result<(), Box<dyn Error>> {
self.handler.borrow_mut().netbios_client.connect(address)?;
log::info!("Connected to {}", address);
Ok(())
}
fn negotiate_smb1(&mut self) -> Result<(), Box<dyn Error>> {
log::debug!("Negotiating SMB1");
self.handler
.borrow_mut()
.netbios_client
.send(NetBiosMessageContent::SMB1Message(
SMB1NegotiateMessage::new(),
))?;
let smb2_response = self.handler.recv(Command::Negotiate)?;
let smb2_negotiate_response = match smb2_response.message.content {
Content::NegotiateResponse(response) => Some(response),
_ => None,
}
.unwrap();
if smb2_negotiate_response.dialect_revision != NegotiateDialect::Smb02Wildcard {
return Err("Unexpected SMB2 dialect revision".into());
}
Ok(())
}
fn negotiate_smb2(&mut self) -> Result<(), Box<dyn Error>> {
log::debug!("Negotiating SMB2");
self.handler.borrow_mut().preauth_hash = Some(PreauthHashState::default());
let client_guid = self.handler.borrow().client_guid;
let response = self
.handler
.send_recv(Content::NegotiateRequest(NegotiateRequest::new(
"AVIV-MBP".to_string(),
client_guid,
crypto::SIGNING_ALGOS.into(),
crypto::ENCRYPTING_ALGOS.to_vec(),
)))?;
let smb2_negotiate_response = match response.message.content {
Content::NegotiateResponse(response) => Some(response),
_ => None,
}
.unwrap();
if smb2_negotiate_response.dialect_revision != NegotiateDialect::Smb0311 {
return Err("Unexpected SMB2 dialect revision".into());
}
if let None = smb2_negotiate_response.negotiate_context_list {
return Err("Negotiate context list is missing".into());
}
let selected_signing_algo: SigningAlgorithmId =
smb2_negotiate_response.get_signing_algo().unwrap();
if !crypto::SIGNING_ALGOS.contains(&selected_signing_algo) {
return Err(
format!("Unsupported signing algorithm {:?}", selected_signing_algo).into(),
);
}
if let Some(algos) = smb2_negotiate_response.get_preauth_integrity_algos() {
if !algos.contains(&HashAlgorithm::Sha512) {
return Err("SHA-512 preauth integrity not supported".into());
}
}
let compression = smb2_negotiate_response.get_compression().unwrap();
let compressor = Compressor::new(
compression.compression_algorithms.clone(),
compression.flags.chained(),
);
let decompressor = Decompressor::new();
let negotiate_state = SmbNegotiateState {
server_guid: smb2_negotiate_response.server_guid,
max_transact_size: smb2_negotiate_response.max_transact_size,
max_read_size: smb2_negotiate_response.max_read_size,
max_write_size: smb2_negotiate_response.max_write_size,
gss_negotiate_token: smb2_negotiate_response.buffer,
selected_dialect: smb2_negotiate_response.dialect_revision.try_into()?,
signing_algo: selected_signing_algo,
compressor: Some(compressor),
decompressor: Some(decompressor),
};
log::trace!(
"Negotiated SMB results: dialect={:?}, state={:?}",
negotiate_state.selected_dialect,
&negotiate_state
);
self.handler
.borrow_mut()
.negotiate_state
.set(negotiate_state)
.map_err(|_| "Negotiate state already set")?;
Ok(())
}
pub fn negotiate(&mut self) -> Result<(), Box<dyn Error>> {
self.negotiate_smb1()?;
self.negotiate_smb2()?;
log::info!("Negotiation successful");
Ok(())
}
pub fn authenticate(
self: &mut Connection,
user_name: String,
password: String,
) -> Result<Session, Box<dyn Error>> {
let mut session = Session::new(self.handler.clone());
session.setup(user_name, password)?;
Ok(session)
}
}
pub struct ClientMessageHandler {
client_guid: Guid,
netbios_client: NetBiosClient,
current_message_id: u64,
credits_balance: u16,
preauth_hash: Option<PreauthHashState>,
negotiate_state: OnceCell<SmbNegotiateState>,
}
impl ClientMessageHandler {
fn new() -> ClientMessageHandler {
ClientMessageHandler {
client_guid: Guid::gen(),
netbios_client: NetBiosClient::new(),
negotiate_state: OnceCell::new(),
current_message_id: 0,
credits_balance: 1,
preauth_hash: None,
}
}
fn step_preauth_hash(&mut self, raw: &Vec<u8>) {
if let Some(preauth_hash) = self.preauth_hash.take() {
if let PreauthHashState::Finished(_) = preauth_hash {
return;
}
self.preauth_hash = Some(preauth_hash.next(&raw));
}
}
pub fn finalize_preauth_hash(&mut self) -> PreauthHashValue {
self.preauth_hash = Some(self.preauth_hash.take().unwrap().finish());
match self.preauth_hash.take().unwrap() {
PreauthHashState::Finished(hash) => hash,
_ => panic!("Preauth hash not finished"),
}
}
pub fn negotiate_state(&self) -> Option<&SmbNegotiateState> {
self.negotiate_state.get()
}
fn tranform_outgoing(
&mut self,
mut msg: OutgoingMessage,
) -> Result<NetBiosTcpMessage, Box<dyn Error>> {
let should_encrypt = msg.encryptor.is_some();
let should_sign = msg.signer.is_some() && !should_encrypt;
let is_signed_set = msg.message.header.flags.signed();
let set_session_id = msg.message.header.session_id;
let mut data = {
let mut data = Vec::new();
msg.message.write(&mut Cursor::new(&mut data))?;
self.step_preauth_hash(&data);
if should_sign {
debug_assert!(!should_encrypt && is_signed_set);
let mut header_copy = msg.message.header.clone();
if let Some(mut signer) = msg.signer.take() {
signer.sign_message(&mut header_copy, &mut data)?;
};
};
data
};
data = {
if msg.compress && data.len() > 1024 {
if let Some(compressor) = self.negotiate_state().unwrap().compressor.as_ref() {
let compressed = compressor.compress(&data)?;
data.clear();
let mut cursor = Cursor::new(&mut data);
Message::Compressed(compressed).write(&mut cursor)?;
};
}
data
};
let data = {
if let Some(mut encryptor) = msg.encryptor.take() {
debug_assert!(should_encrypt && !should_sign);
let encrypted = encryptor.encrypt_message(data, set_session_id)?;
let mut cursor = Cursor::new(Vec::new());
Message::Encrypted(encrypted).write(&mut cursor)?;
cursor.into_inner()
} else {
data
}
};
Ok(NetBiosTcpMessage::from_content_bytes(data)?)
}
fn transform_incoming(
&mut self,
netbios: NetBiosTcpMessage,
options: &mut ReceiveOptions,
) -> Result<(PlainMessage, Vec<u8>, MessageForm), Box<dyn Error>> {
let message = match netbios.parse_content()? {
NetBiosMessageContent::SMB2Message(message) => Some(message),
_ => None,
}
.ok_or("Expected SMB2 message")?;
let mut form = MessageForm::default();
let (message, raw) = if let Message::Encrypted(encrypted_message) = &message {
form.encrypted = true;
match options.decryptor.take() {
Some(mut decryptor) => decryptor.decrypt_message(&encrypted_message)?,
None => return Err("Encrypted message received without decryptor".into()),
}
} else {
(message, netbios.content)
};
debug_assert!(!matches!(message, Message::Encrypted(_)));
let (message, raw) = if let Message::Compressed(compressed_message) = &message {
form.compressed = true;
match self.negotiate_state().unwrap().decompressor.as_ref() {
Some(decompressor) => decompressor.decompress(compressed_message)?,
None => return Err("Compressed message received without decompressor!".into()),
}
} else {
(message, raw)
};
let message = match message {
Message::Plain(message) => message,
_ => panic!("Unexpected message type"),
};
Ok((message, raw, form))
}
}
impl MessageHandler for ClientMessageHandler {
fn hsendo(
&mut self,
mut msg: OutgoingMessage,
) -> Result<SendMessageResult, Box<(dyn std::error::Error + 'static)>> {
self.current_message_id += 1;
let priority_value = match self.negotiate_state.get() {
Some(negotiate_state) => match negotiate_state.selected_dialect {
Dialect::Smb0311 => 1,
_ => 0,
},
None => 0,
};
msg.message.header.message_id = self.current_message_id;
msg.message.header.flags = msg.message.header.flags.with_priority_mask(priority_value);
msg.message.header.credit_charge = 1;
msg.message.header.credit_request = 1;
let finalize_hash_required = msg.finalize_preauth_hash;
let final_message = self.tranform_outgoing(msg)?;
self.netbios_client.send_raw(final_message)?;
let hash = match finalize_hash_required {
true => Some(self.finalize_preauth_hash()),
false => None,
};
Ok(SendMessageResult::new(hash.clone()))
}
fn hrecvo(
&mut self,
mut options: ReceiveOptions,
) -> Result<IncomingMessage, Box<dyn std::error::Error>> {
let netbios = self.netbios_client.recieve_bytes()?;
self.step_preauth_hash(&netbios.content);
let (message, raw, form) = self.transform_incoming(netbios, &mut options)?;
if let Some(cmd) = options.cmd {
if message.header.command != cmd {
return Err("Unexpected SMB2 command".into());
}
}
if !message.header.flags.server_to_redir() {
return Err("Unexpected SMB2 message direction (Not a response)".into());
}
if message.header.status != options.status {
if let Content::ErrorResponse(msg) = &message.content {
return Err(
format!("SMB2 error response {:?}: {:?}", message.header.status, msg).into(),
);
}
return Err(format!("Unexpected SMB2 status: {:?}", message.header.status).into());
}
self.credits_balance -= message.header.credit_charge;
self.credits_balance += message.header.credit_request;
Ok(IncomingMessage { message, raw, form })
}
}