use std::{sync::Arc, time::Duration};
use crate::{
connection::connection_info::ConnectionInfo, msg_handler::ReceiveOptions, sync_helpers::*,
Error,
};
use maybe_async::*;
use crate::{
connection::{netbios_client::NetBiosClient, transformer::Transformer},
msg_handler::{IncomingMessage, OutgoingMessage, SendMessageResult},
session::SessionInfo,
};
#[maybe_async(AFIT)]
#[allow(async_fn_in_trait)]
pub trait Worker: Sized + std::fmt::Debug {
async fn start(netbios_client: NetBiosClient, timeout: Duration) -> crate::Result<Arc<Self>>;
async fn stop(&self) -> crate::Result<()>;
async fn set_timeout(&self, timeout: Duration) -> crate::Result<()>;
async fn send(self: &Self, msg: OutgoingMessage) -> crate::Result<SendMessageResult>;
async fn receive_next(
self: &Self,
options: &ReceiveOptions<'_>,
) -> crate::Result<IncomingMessage>;
async fn receive(self: &Self, options: &ReceiveOptions<'_>) -> crate::Result<IncomingMessage> {
if options.msg_id == u64::MAX {
return Err(Error::InvalidArgument(
"Message ID -1 is not valid for receive()".to_string(),
));
}
let curr = self.receive_next(options).await?;
if !curr.message.header.flags.async_command() {
return Ok(curr);
}
if !options.allow_async {
return Err(Error::InvalidArgument(
"Async command is not allowed in this context.".to_string(),
));
}
if curr.message.header.status != crate::packets::smb2::Status::Pending as u32 {
return Ok(curr);
}
log::debug!(
"Received async pending message with ID {} and status {}.",
curr.message.header.message_id,
curr.message.header.status
);
let async_id = match curr.message.header.async_id {
Some(async_id) => async_id,
None => panic!("Async ID is None, but async command is set. This should not happen."),
};
if async_id == 0 {
return Ok(curr);
}
loop {
let msg = self.receive_next(options).await?;
if !msg.message.header.flags.async_command()
|| msg.message.header.async_id != Some(async_id)
{
return Err(Error::InvalidArgument(format!(
"Received message for msgid {} with async ID {} but expected async ID {}",
msg.message.header.message_id,
msg.message
.header
.async_id
.map(|x| x.to_string())
.unwrap_or("None".to_string()),
async_id
)));
}
if msg.message.header.status != crate::packets::smb2::Status::Pending as u32 {
return Ok(msg);
}
log::debug!(
"Received another async pending message with ID {} and status {}.",
msg.message.header.message_id,
msg.message.header.status
);
}
}
fn transformer(&self) -> &Transformer;
#[maybe_async]
async fn negotaite_complete(&self, neg: &ConnectionInfo) {
self.transformer().negotiated(neg).await.unwrap();
}
#[maybe_async]
async fn session_started(&self, session: Arc<Mutex<SessionInfo>>) -> crate::Result<()> {
self.transformer().session_started(session).await
}
#[maybe_async]
async fn session_ended(&self, session_id: u64) -> crate::Result<()> {
self.transformer().session_ended(session_id).await
}
}