use crate::connection::transformer::Transformer;
use crate::connection::worker::Worker;
use crate::msg_handler::ReceiveOptions;
use crate::sync_helpers::*;
use maybe_async::*;
use smb_msg::ResponseContent;
use smb_transport::{IoVec, SmbTransport, SmbTransportWrite, TransportError};
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use crate::{
Error,
msg_handler::{IncomingMessage, OutgoingMessage, SendMessageResult},
};
use super::backend_trait::MultiWorkerBackend;
pub struct ParallelWorker<BackendImplT>
where
BackendImplT: MultiWorkerBackend + std::fmt::Debug,
BackendImplT::AwaitingNotifier: std::fmt::Debug,
{
pub(crate) state: Mutex<WorkerAwaitState<BackendImplT>>,
backend_impl: Mutex<Option<Arc<BackendImplT>>>,
transformer: Transformer,
notify_messages_channel: OnceCell<mpsc::Sender<IncomingMessage>>,
pub(crate) sender: mpsc::Sender<BackendImplT::SendMessage>,
stopped: AtomicBool,
timeout: RwLock<Duration>,
}
#[derive(Debug)]
pub struct WorkerAwaitState<T>
where
T: MultiWorkerBackend,
T::AwaitingNotifier: std::fmt::Debug,
{
pub awaiting: HashMap<u64, T::AwaitingNotifier>,
pub pending: HashMap<u64, crate::Result<IncomingMessage>>,
}
impl<T> WorkerAwaitState<T>
where
T: MultiWorkerBackend,
T::AwaitingNotifier: std::fmt::Debug,
{
fn new() -> Self {
Self {
awaiting: HashMap::new(),
pending: HashMap::new(),
}
}
}
#[maybe_async(AFIT)]
impl<T> ParallelWorker<T>
where
T: MultiWorkerBackend + std::fmt::Debug,
T::AwaitingNotifier: std::fmt::Debug,
{
pub fn stopped(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::SeqCst)
}
pub(crate) async fn incoming_data_callback(
self: &Arc<Self>,
message: Result<Vec<u8>, TransportError>,
) -> crate::Result<()> {
log::trace!("Received message from server.");
let message = message?;
let msg = self.transformer.transform_incoming(message).await;
let (msg, msg_id) = match msg {
Ok(msg) => {
let msg_id = msg.message.header.message_id;
(Ok(msg), msg_id)
}
Err(crate::Error::TranformFailed(e)) => match e.msg_id {
Some(msg_id) => (Err(crate::Error::TranformFailed(e)), msg_id),
None => return Err(Error::TranformFailed(e)),
},
Err(e) => {
log::error!("Failed to transform message: {e:?}");
return Err(e);
}
};
if msg_id == u64::MAX {
let msg = msg?;
if !matches!(
msg.message.content,
ResponseContent::OplockBreakNotify(_)
| ResponseContent::ServerToClientNotification(_)
) {
return Err(Error::MessageProcessingError(
"Received notification message, but not an OPLOCK_BREAK or SERVER_TO_CLIENT_NOTIFICATION.".to_string(),
));
}
if let Some(s2c_channel) = self.notify_messages_channel.get() {
log::trace!("Sending notification message to notify channel.");
s2c_channel.send(msg).await.map_err(|_| {
Error::MessageProcessingError(
"Failed to send notification message to notify channel.".to_string(),
)
})?;
} else {
log::warn!("Received notification message, but no notify channel is set.");
}
return Ok(());
}
let mut state = self.state.lock().await?;
let message_waiter = state.awaiting.remove(&msg_id);
match message_waiter {
Some(tx) => {
log::trace!("Waking up awaiting task for key {msg_id}.");
T::send_notify(tx, msg)?;
}
None => {
log::trace!("Storing message until awaited: {msg_id}.",);
state.pending.insert(msg_id, msg);
}
}
Ok(())
}
pub fn start_notify_channel(
self: &Arc<Self>,
notify_channel: mpsc::Sender<IncomingMessage>,
) -> crate::Result<()> {
self.notify_messages_channel
.set(notify_channel)
.map_err(|_| Error::InvalidState("Notify channel is already set.".to_string()))?;
Ok(())
}
pub async fn outgoing_data_callback(
self: &Arc<Self>,
message: Option<IoVec>,
wtransport: &mut dyn SmbTransportWrite,
) -> crate::Result<()> {
let message = match message {
Some(m) => m,
None => {
if self.stopped() {
return Err(Error::ConnectionStopped);
} else {
return Err(Error::MessageProcessingError(
"Empty message cannot be sent to the server.".to_string(),
));
}
}
};
wtransport.send(&message).await?;
Ok(())
}
}
#[maybe_async(AFIT)]
impl<T> Worker for ParallelWorker<T>
where
T: MultiWorkerBackend + std::fmt::Debug,
T::AwaitingNotifier: std::fmt::Debug,
{
async fn start(
transport: Box<dyn SmbTransport>,
timeout: Duration,
) -> crate::Result<Arc<Self>> {
let (tx, rx) = T::make_send_channel_pair();
let worker = Arc::new(ParallelWorker::<T> {
state: Mutex::new(WorkerAwaitState::new()),
backend_impl: Default::default(),
transformer: Transformer::default(),
notify_messages_channel: Default::default(),
sender: tx,
stopped: AtomicBool::new(false),
timeout: RwLock::new(timeout),
});
worker
.backend_impl
.lock()
.await?
.replace(T::start(transport, worker.clone(), rx).await?);
Ok(worker)
}
async fn stop(&self) -> crate::Result<()> {
self.stopped
.store(true, std::sync::atomic::Ordering::SeqCst);
{
self.backend_impl
.lock()
.await?
.take()
.ok_or(Error::InvalidState(
"No backend present for worker.".to_string(),
))?
}
.stop()
.await
}
async fn send(&self, msg: OutgoingMessage) -> crate::Result<SendMessageResult> {
log::trace!("ParallelWorker::send({msg:?}) called");
let return_raw_data = msg.return_raw_data;
let id = msg.message.header.message_id;
let message = { self.transformer.transform_outgoing(msg).await? };
log::trace!("Message with ID {id} is passed to the worker for sending",);
let raw_message_copy = if return_raw_data {
Some(message.clone())
} else {
None
};
let message = T::wrap_msg_to_send(message);
self.sender.send(message).await.map_err(|_| {
Error::MessageProcessingError("Failed to send message to worker!".to_string())
})?;
Ok(SendMessageResult::new(id, raw_message_copy))
}
async fn receive_next(&self, options: &ReceiveOptions<'_>) -> crate::Result<IncomingMessage> {
let wait_for_receive = {
let mut state = self.state.lock().await?;
if self.stopped() {
log::trace!("Connection is closed, avoid receiving.");
return Err(Error::ConnectionStopped);
}
if state.pending.contains_key(&options.msg_id) {
log::trace!(
"Message with ID {} is already received, remove from pending.",
&options.msg_id
);
let data = state.pending.remove(&options.msg_id).ok_or_else(|| {
Error::InvalidState("Message ID not found in pending messages.".to_string())
})?;
return data;
}
log::trace!(
"Message with ID {} is not received yet, insert channel and await.",
options.msg_id
);
let (tx, rx) = T::make_notifier_awaiter_pair();
state.awaiting.insert(options.msg_id, tx);
rx
};
let timeout = { *self.timeout.read().await? };
let result = T::wait_on_waiter(wait_for_receive, timeout).await?;
log::trace!("Received message {result:?}");
Ok(result)
}
fn transformer(&self) -> &Transformer {
&self.transformer
}
}
impl<T> std::fmt::Debug for ParallelWorker<T>
where
T: MultiWorkerBackend + std::fmt::Debug,
T::AwaitingNotifier: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelWorker")
.field("state", &self.state)
.field("backend", &self.backend_impl)
.field("sender", &self.sender)
.field("stopped", &self.stopped)
.finish()
}
}