use maybe_async::*;
use smb_msg::{Command, PlainRequest, PlainResponse, RequestContent, Status};
use smb_transport::IoVec;
#[cfg(not(feature = "async"))]
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, atomic::AtomicU64};
#[cfg(feature = "async")]
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
pub struct OutgoingMessage {
pub message: PlainRequest,
pub return_raw_data: bool,
pub compress: bool,
pub encrypt: bool,
pub has_response: bool,
pub additional_data: Option<Arc<[u8]>>,
pub channel_id: Option<u32>,
}
impl OutgoingMessage {
pub fn new(content: RequestContent) -> OutgoingMessage {
OutgoingMessage {
message: PlainRequest::new(content),
return_raw_data: false,
compress: true,
encrypt: false,
has_response: true,
additional_data: None,
channel_id: None,
}
}
pub fn with_additional_data(mut self, data: Arc<[u8]>) -> Self {
self.additional_data = Some(data);
self
}
pub fn with_return_raw_data(mut self, return_raw_data: bool) -> Self {
self.return_raw_data = return_raw_data;
self
}
pub fn with_encrypt(mut self, encrypt: bool) -> Self {
self.encrypt = encrypt;
self
}
pub fn with_channel_id(mut self, channel_id: Option<u32>) -> Self {
self.channel_id = channel_id;
self
}
}
#[derive(Debug)]
pub struct SendMessageResult {
pub msg_id: u64,
pub raw: Option<IoVec>,
}
impl SendMessageResult {
pub fn new(msg_id: u64, raw: Option<IoVec>) -> SendMessageResult {
SendMessageResult { msg_id, raw }
}
}
#[derive(Debug)]
pub struct IncomingMessage {
pub message: PlainResponse,
pub raw: IoVec,
pub form: MessageForm,
pub source_channel_id: Option<u32>,
}
impl IncomingMessage {
pub fn new(message: PlainResponse, raw: IoVec, form: MessageForm) -> IncomingMessage {
IncomingMessage {
message,
raw,
form,
source_channel_id: None,
}
}
}
#[derive(Debug, Default)]
pub struct MessageForm {
pub compressed: bool,
pub encrypted: bool,
pub signed: bool,
}
impl MessageForm {
pub fn signed_or_encrypted(&self) -> bool {
self.signed || self.encrypted
}
}
#[derive(Debug)]
pub struct AsyncMessageIds {
pub msg_id: AtomicU64,
pub async_id: AtomicU64,
}
impl AsyncMessageIds {
pub fn reset(&self) {
self.set(u64::MAX, u64::MAX);
}
pub fn set(&self, msg_id: u64, async_id: u64) {
self.msg_id
.store(msg_id, std::sync::atomic::Ordering::SeqCst);
self.async_id
.store(async_id, std::sync::atomic::Ordering::SeqCst);
}
}
impl Default for AsyncMessageIds {
fn default() -> Self {
Self {
msg_id: AtomicU64::new(u64::MAX),
async_id: AtomicU64::new(u64::MAX),
}
}
}
#[derive(Debug, Clone)]
pub struct ReceiveOptions<'a> {
pub status: &'a [Status],
pub cmd: Option<Command>,
pub msg_id: u64,
pub channel_id: Option<u32>,
pub allow_async: bool,
#[cfg(feature = "async")]
pub async_cancel: Option<CancellationToken>,
#[cfg(not(feature = "async"))]
pub async_cancel: Option<Arc<AtomicBool>>,
pub async_msg_ids: Option<Arc<AsyncMessageIds>>,
pub timeout: Option<std::time::Duration>,
}
impl<'a> ReceiveOptions<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn with_status(mut self, status: &'a [Status]) -> Self {
self.status = status;
self
}
pub fn with_cmd(mut self, cmd: Option<Command>) -> Self {
self.cmd = cmd;
self
}
pub fn with_msg_id_filter(mut self, msg_id: u64) -> Self {
self.msg_id = msg_id;
self
}
pub fn with_allow_async(mut self, allow_async: bool) -> Self {
self.allow_async = allow_async;
self
}
#[cfg(feature = "async")]
pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
self.async_cancel = Some(token);
self
}
#[cfg(not(feature = "async"))]
pub fn with_cancellation_flag(mut self, flag: Arc<AtomicBool>) -> Self {
self.async_cancel = Some(flag);
self
}
pub fn with_async_msg_ids(mut self, async_msg_ids: Arc<AsyncMessageIds>) -> Self {
self.async_msg_ids = Some(async_msg_ids);
self
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
impl<'a> Default for ReceiveOptions<'a> {
fn default() -> Self {
ReceiveOptions {
status: &[Status::Success],
cmd: None,
msg_id: 0,
allow_async: false,
channel_id: None,
async_cancel: None,
async_msg_ids: None,
timeout: None,
}
}
}
#[maybe_async(AFIT)]
#[allow(async_fn_in_trait)] pub trait MessageHandler {
async fn sendo(&self, msg: OutgoingMessage) -> crate::Result<SendMessageResult>;
async fn recvo(&self, options: ReceiveOptions) -> crate::Result<IncomingMessage>;
async fn notify(&self, msg: IncomingMessage) -> crate::Result<()> {
log::debug!("Received notification message: {msg:?}");
Ok(())
}
#[maybe_async]
#[inline]
async fn send(&self, msg: RequestContent) -> crate::Result<SendMessageResult> {
self.sendo(OutgoingMessage::new(msg)).await
}
#[maybe_async]
#[inline]
async fn recv(&self, cmd: Command) -> crate::Result<IncomingMessage> {
self.recvo(ReceiveOptions::new().with_cmd(Some(cmd))).await
}
#[maybe_async]
#[inline]
async fn sendor_recvo(
&self,
msg: OutgoingMessage,
mut options: ReceiveOptions<'_>,
) -> crate::Result<(SendMessageResult, IncomingMessage)> {
let channel_id = msg.channel_id;
let send_result = self.sendo(msg).await?;
options.msg_id = send_result.msg_id;
options.channel_id = channel_id;
let in_result = self.recvo(options).await?;
Ok((send_result, in_result))
}
#[maybe_async]
#[inline]
async fn sendo_recvo(
&self,
msg: OutgoingMessage,
options: ReceiveOptions<'_>,
) -> crate::Result<IncomingMessage> {
self.sendor_recvo(msg, options).await.map(|(_, r)| r)
}
#[maybe_async]
#[inline]
async fn send_recvo(
&self,
msg: RequestContent,
options: ReceiveOptions<'_>,
) -> crate::Result<IncomingMessage> {
self.sendo_recvo(OutgoingMessage::new(msg), options).await
}
#[maybe_async]
#[inline]
async fn sendo_recv(&self, msg: OutgoingMessage) -> crate::Result<IncomingMessage> {
let cmd = msg.message.content.associated_cmd();
let options = ReceiveOptions::new().with_cmd(Some(cmd));
self.sendo_recvo(msg, options).await
}
#[maybe_async]
#[inline]
async fn send_recv(&self, msg: RequestContent) -> crate::Result<IncomingMessage> {
self.sendo_recv(OutgoingMessage::new(msg)).await
}
#[maybe_async]
#[inline]
async fn sendor_recv(
&self,
msg: OutgoingMessage,
) -> crate::Result<(SendMessageResult, IncomingMessage)> {
self.sendor_recvo(msg, ReceiveOptions::new()).await
}
}
pub(crate) struct HandlerReference<T: MessageHandler + ?Sized> {
pub handler: Arc<T>,
}
impl<T: MessageHandler> HandlerReference<T> {
pub(crate) fn new(handler: T) -> HandlerReference<T> {
HandlerReference {
handler: Arc::new(handler),
}
}
pub(crate) fn weak(&self) -> std::sync::Weak<T> {
Arc::downgrade(&self.handler)
}
}
impl<T: MessageHandler> std::ops::Deref for HandlerReference<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.handler
}
}
impl<T: MessageHandler> Clone for HandlerReference<T> {
fn clone(&self) -> Self {
HandlerReference {
handler: self.handler.clone(),
}
}
}