use crate::session::{SessionAndChannel, SessionInfo};
use crate::sync_helpers::*;
use crate::{compression::*, msg_handler::*};
use binrw::prelude::*;
use maybe_async::*;
use smb_msg::*;
use smb_transport::IoVec;
use std::{collections::HashMap, io::Cursor, sync::Arc};
use super::connection_info::ConnectionInfo;
#[derive(Default)]
pub struct Transformer {
sessions: RwLock<HashMap<u64, Arc<RwLock<SessionAndChannel>>>>,
config: RwLock<TransformerConfig>,
}
#[derive(Default, Debug)]
struct TransformerConfig {
compress: Option<(Compressor, Decompressor)>,
negotiated: bool,
}
#[maybe_async(AFIT)]
impl Transformer {
pub async fn negotiated(&self, neg_info: &ConnectionInfo) -> crate::Result<()> {
{
let config = self.config.read().await?;
if config.negotiated {
return Err(crate::Error::InvalidState(
"Connection is already negotiated!".into(),
));
}
}
let mut config = self.config.write().await?;
if neg_info.dialect.supports_compression() && neg_info.config.compression_enabled {
let compress = neg_info
.negotiation
.compression
.as_ref()
.map(|c| (Compressor::new(c), Decompressor::new(c)));
config.compress = compress;
}
config.negotiated = true;
Ok(())
}
pub async fn session_started(
&self,
session: &Arc<RwLock<SessionAndChannel>>,
) -> crate::Result<()> {
let rconfig = self.config.read().await?;
if !rconfig.negotiated {
return Err(crate::Error::InvalidState(
"Connection is not negotiated yet!".to_string(),
));
}
let session_id = { session.read().await?.session_id };
self.sessions
.write()
.await?
.insert(session_id, session.clone());
log::trace!(
"Session {} started and inserted to worker {:p}.",
session_id,
self
);
Ok(())
}
pub async fn session_ended(
&self,
session: &Arc<RwLock<SessionAndChannel>>,
) -> crate::Result<()> {
let session_id = { session.read().await?.session_id };
self.sessions
.write()
.await?
.remove(&session_id)
.ok_or(crate::Error::InvalidState(format!(
"Session {session_id} not found!",
)))?;
log::trace!(
"Session {} ended and removed from worker {:p}.",
session_id,
self
);
Ok(())
}
#[maybe_async]
#[inline]
async fn _with_channel<F, R>(&self, session_id: u64, f: F) -> crate::Result<R>
where
F: FnOnce(&SessionAndChannel) -> crate::Result<R>,
{
let sessions = self.sessions.read().await?;
let session = sessions
.get(&session_id)
.ok_or(crate::Error::InvalidState(format!(
"Session {session_id} not found!",
)))?;
let session = session.read().await?;
f(&session)
}
#[maybe_async]
#[inline]
async fn _with_session<F, R>(&self, session_id: u64, f: F) -> crate::Result<R>
where
F: FnOnce(&SessionInfo) -> crate::Result<R>,
{
let sessions = self.sessions.read().await?;
let session = sessions
.get(&session_id)
.ok_or(crate::Error::InvalidState(format!(
"Session {session_id} not found!",
)))?;
let session = session.read().await?;
let session_info = session.session.read().await?;
f(&session_info)
}
pub async fn transform_outgoing(&self, mut msg: OutgoingMessage) -> crate::Result<IoVec> {
let should_encrypt = msg.encrypt;
let should_sign = msg.message.header.flags.signed();
let session_id = msg.message.header.session_id;
let mut outgoing_data = IoVec::default();
{
let buffer = outgoing_data.add_owned(Vec::with_capacity(Header::STRUCT_SIZE));
msg.message.write(&mut Cursor::new(buffer))?;
}
if msg.additional_data.as_ref().is_some_and(|d| !d.is_empty()) {
outgoing_data.add_shared(msg.additional_data.unwrap().clone());
}
if should_sign {
debug_assert!(
!should_encrypt,
"Should not sign and encrypt at the same time!"
);
let mut signer = self
._with_channel(session_id, |session| {
let channel_info =
session
.channel
.as_ref()
.ok_or(crate::Error::TranformFailed(TransformError {
outgoing: true,
phase: TransformPhase::SignVerify,
session_id: Some(session_id),
why: "Message is required to be signed, but no channel is set up!",
msg_id: Some(msg.message.header.message_id),
}))?;
Ok(channel_info.signer()?.clone())
})
.await?;
signer.sign_message(&mut msg.message.header, &mut outgoing_data)?;
log::debug!(
"Message #{} signed (signature={}).",
msg.message.header.message_id,
msg.message.header.signature
);
};
const COMPRESSION_THRESHOLD: usize = 1024;
outgoing_data = {
if msg.compress && outgoing_data.total_size() > COMPRESSION_THRESHOLD {
let rconfig = self.config.read().await?;
if let Some(compress) = &rconfig.compress {
outgoing_data.consolidate();
let compressed = compress.0.compress(outgoing_data.first().unwrap())?;
let mut compressed_result = IoVec::default();
let write_compressed =
compressed_result.add_owned(Vec::with_capacity(compressed.total_size()));
compressed.write(&mut Cursor::new(write_compressed))?;
compressed_result
} else {
outgoing_data
}
} else {
outgoing_data
}
};
if should_encrypt {
let mut encryptor = self
._with_session(session_id, |session| {
let encryptor = session.encryptor()?.ok_or(crate::Error::TranformFailed(
TransformError {
outgoing: true,
phase: TransformPhase::EncryptDecrypt,
session_id: Some(session_id),
why: "Message is required to be encrypted, but no encryptor is set up!",
msg_id: Some(msg.message.header.message_id),
},
))?;
Ok(encryptor.clone())
})
.await?;
debug_assert!(should_encrypt && !should_sign);
let encrypted_header = encryptor.encrypt_message(&mut outgoing_data, session_id)?;
let write_encryption_header =
outgoing_data.insert_owned(0, Vec::with_capacity(EncryptedHeader::STRUCTURE_SIZE));
encrypted_header.write(&mut Cursor::new(write_encryption_header))?;
}
Ok(outgoing_data)
}
pub async fn transform_incoming(&self, data: Vec<u8>) -> crate::Result<IncomingMessage> {
let message = Response::try_from(data.as_ref())?;
let mut form = MessageForm::default();
let (message, raw) = if let Response::Encrypted(encrypted_message) = message {
let session_id = encrypted_message.header.session_id;
let mut decryptor = self
._with_session(session_id, |session| {
let decryptor = session.decryptor()?.ok_or(crate::Error::TranformFailed(
TransformError {
outgoing: false,
phase: TransformPhase::EncryptDecrypt,
session_id: Some(session_id),
why: "Message is required to be encrypted, but no decryptor is set up!",
msg_id: None,
},
))?;
Ok(decryptor.clone())
})
.await?;
form.encrypted = true;
decryptor.decrypt_message(encrypted_message)?
} else {
(message, data)
};
debug_assert!(!matches!(message, Response::Encrypted(_)));
let (message, raw) = if let Response::Compressed(compressed_message) = message {
let rconfig = self.config.read().await?;
form.compressed = true;
match &rconfig.compress {
Some(compress) => compress.1.decompress(&compressed_message)?,
None => {
return Err(crate::Error::TranformFailed(TransformError {
outgoing: false,
phase: TransformPhase::CompressDecompress,
session_id: None,
why: "Compression is requested, but no decompressor is set up!",
msg_id: None,
}));
}
}
} else {
(message, raw)
};
let mut message = match message {
Response::Plain(message) => message,
_ => panic!("Unexpected message type"),
};
let iovec = IoVec::from(raw);
match self
.verify_plain_incoming(&mut message, &iovec, &mut form)
.await
{
Ok(_) => {}
Err(e) => {
log::error!("Failed to verify incoming message: {e:?}",);
return Err(crate::Error::TranformFailed(TransformError {
outgoing: false,
phase: TransformPhase::SignVerify,
session_id: Some(message.header.session_id),
why: "Failed to verify incoming message!",
msg_id: Some(message.header.message_id),
}));
}
};
Ok(IncomingMessage::new(message, iovec, form))
}
#[maybe_async]
async fn verify_plain_incoming(
&self,
message: &mut PlainResponse,
raw: &IoVec,
form: &mut MessageForm,
) -> crate::Result<()> {
if form.encrypted
|| message.header.message_id == u64::MAX
|| message.header.status == Status::Pending as u32
|| !(message.header.flags.signed() || self.is_message_signed_ksmbd(message).await)
{
return Ok(());
}
let session_id = message.header.session_id;
let mut signer = self
._with_channel(session_id, |session| {
let channel_info = session
.channel
.as_ref()
.ok_or(crate::Error::TranformFailed(TransformError {
outgoing: false,
phase: TransformPhase::SignVerify,
session_id: Some(session_id),
why: "Message is required to be signed, but no channel is set up!",
msg_id: Some(message.header.message_id),
}))?;
Ok(channel_info.signer()?.clone())
})
.await?;
signer.verify_signature(&mut message.header, raw)?;
log::debug!(
"Message #{} verified (signature={}).",
message.header.message_id,
message.header.signature
);
form.signed = true;
Ok(())
}
#[maybe_async]
async fn is_message_signed_ksmbd(&self, _message: &PlainResponse) -> bool {
#[cfg(feature = "ksmbd-multichannel-compat")]
{
if _message.header.command != Command::SessionSetup || _message.header.signature == 0 {
return false;
}
let session_id = _message.header.session_id;
let is_binding = self
._with_channel(session_id, |session| {
let channel_info = session.channel.as_ref().ok_or(crate::Error::Other(
"Get channel info for ksmbd sign test failed",
))?;
Ok(channel_info.is_binding())
})
.await;
return matches!(is_binding, Ok(true));
}
#[cfg(not(feature = "ksmbd-multichannel-compat"))]
return false;
}
}
#[derive(Debug)]
pub struct TransformError {
pub outgoing: bool,
pub phase: TransformPhase,
pub session_id: Option<u64>,
pub why: &'static str,
pub msg_id: Option<u64>,
}
impl std::fmt::Display for TransformError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.outgoing {
write!(
f,
"Failed to transform outgoing message: {:?} (session_id: {:?}) - {}",
self.phase, self.session_id, self.why
)
} else {
write!(
f,
"Failed to transform incoming message: {:?} (session_id: {:?}) - {}",
self.phase, self.session_id, self.why
)
}
}
}
#[derive(Debug)]
pub enum TransformPhase {
EncodeDecode,
SignVerify,
CompressDecompress,
EncryptDecrypt,
}