use maybe_async::*;
use std::{io::Cursor, time::Duration};
#[cfg(feature = "sync")]
use std::{
io::{self, Read, Write},
net::{TcpStream, ToSocketAddrs},
};
#[cfg(feature = "async")]
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt},
net::{tcp, TcpStream},
select,
};
use binrw::prelude::*;
use crate::packets::netbios::{NetBiosMessageContent, NetBiosTcpMessage, NetBiosTcpMessageHeader};
#[cfg(feature = "async")]
type TcpRead = tcp::OwnedReadHalf;
#[cfg(feature = "async")]
type TcpWrite = tcp::OwnedWriteHalf;
#[cfg(feature = "sync")]
type TcpRead = TcpStream;
#[cfg(feature = "sync")]
type TcpWrite = TcpStream;
#[derive(Debug)]
pub struct NetBiosClient {
reader: Option<TcpRead>,
writer: Option<TcpWrite>,
timeout: Option<Duration>,
}
impl NetBiosClient {
pub fn new(timeout: Option<Duration>) -> NetBiosClient {
NetBiosClient {
reader: None,
writer: None,
timeout,
}
}
#[maybe_async]
pub async fn connect(&mut self, address: &str) -> crate::Result<()> {
let socket = self.connect_timeout(address).await?;
let (r, w) = Self::split_socket(socket);
self.reader = Some(r);
self.writer = Some(w);
Ok(())
}
#[cfg(feature = "sync")]
fn connect_timeout(&mut self, address: &str) -> crate::Result<TcpStream> {
if let Some(t) = self.timeout {
log::debug!("Connecting to {} with timeout {:?}.", address, t);
let address = address
.to_socket_addrs()?
.next()
.ok_or(crate::Error::InvalidAddress(address.to_string()))?;
TcpStream::connect_timeout(&address, t).map_err(Into::into)
} else {
log::debug!("Connecting to {}.", address);
TcpStream::connect(&address).map_err(Into::into)
}
}
#[cfg(feature = "async")]
async fn connect_timeout(&mut self, address: &str) -> crate::Result<TcpStream> {
if let None = self.timeout {
log::debug!("Connecting to {}.", address);
return TcpStream::connect(&address).await.map_err(Into::into);
}
select! {
res = TcpStream::connect(&address) => res.map_err(Into::into),
_ = tokio::time::sleep(self.timeout.unwrap()) => Err(crate::Error::OperationTimeout("Tcp connect".to_string(), self.timeout.unwrap())),
}
}
pub fn disconnect(&mut self) {
self.reader.take();
self.writer.take();
}
#[maybe_async]
pub async fn send(&mut self, data: NetBiosMessageContent) -> crate::Result<()> {
let raw_message = NetBiosTcpMessage::from_content(&data)?;
Ok(self.send_raw(raw_message).await?)
}
#[maybe_async]
pub async fn send_raw(&mut self, data: NetBiosTcpMessage) -> crate::Result<()> {
log::trace!("Sending message of size {}.", data.content.len());
Self::write_all(
self.writer.as_mut().ok_or(crate::Error::NotConnected)?,
&data.to_bytes()?,
)
.await?;
Ok(())
}
#[maybe_async]
pub async fn received_bytes(&mut self) -> crate::Result<NetBiosTcpMessage> {
let tcp = self.reader.as_mut().ok_or(crate::Error::NotConnected)?;
let mut header_data = vec![0; NetBiosTcpMessageHeader::SIZE];
Self::read_exact(tcp, &mut header_data).await?;
let header = NetBiosTcpMessageHeader::read(&mut Cursor::new(header_data))?;
if header.stream_protocol_length.value > 2u32.pow(3 * 8) - 1 {
return Err(crate::Error::InvalidMessage("Message too large.".into()));
}
let mut data = vec![0; header.stream_protocol_length.value as usize];
Self::read_exact(tcp, &mut data).await?;
Ok(NetBiosTcpMessage { content: data })
}
#[cfg(feature = "sync")]
pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> crate::Result<()> {
if !self.can_read() {
return Err(crate::Error::NotConnected);
}
self.reader
.as_ref()
.ok_or(crate::Error::NotConnected)?
.set_read_timeout(timeout)
.map_err(|e| e.into())
}
#[cfg(feature = "sync")]
pub fn read_timeout(&self) -> crate::Result<Option<std::time::Duration>> {
if !self.can_read() {
return Err(crate::Error::NotConnected);
}
self.reader
.as_ref()
.ok_or(crate::Error::NotConnected)?
.read_timeout()
.map_err(|e| e.into())
}
#[maybe_async]
async fn read_exact(tcp: &mut TcpRead, buf: &mut [u8]) -> crate::Result<()> {
log::trace!("Reading {} bytes.", buf.len());
tcp.read_exact(buf).await.map_err(Self::map_tcp_error)?;
log::trace!("Read {} bytes OK.", buf.len());
Ok(())
}
#[maybe_async]
async fn write_all(tcp: &mut TcpWrite, buf: &[u8]) -> crate::Result<()> {
tcp.write_all(buf).await.map_err(Self::map_tcp_error)?;
Ok(())
}
#[inline]
fn map_tcp_error(e: io::Error) -> crate::Error {
if e.kind() == io::ErrorKind::ConnectionAborted || e.kind() == io::ErrorKind::UnexpectedEof
{
log::error!(
"Got IO error: {} -- Connection Error, notify NotConnected!",
e
);
return crate::Error::NotConnected;
}
if e.kind() == io::ErrorKind::WouldBlock {
log::debug!("Got IO error: {} -- with ErrorKind::WouldBlock.", e);
} else {
log::error!("Got IO error: {} -- Mapping to IO error.", e);
}
e.into()
}
#[cfg(feature = "async")]
fn split_socket(socket: TcpStream) -> (TcpRead, TcpWrite) {
let (r, w) = socket.into_split();
(r, w)
}
#[cfg(feature = "sync")]
fn split_socket(socket: TcpStream) -> (TcpRead, TcpWrite) {
let rsocket = socket.try_clone().unwrap();
let wsocket = socket;
(rsocket, wsocket)
}
pub fn split(self) -> crate::Result<(NetBiosClient, NetBiosClient)> {
if !self.can_read() || !self.can_write() {
return Err(crate::Error::InvalidState(
"Cannot split a non-connected client.".into(),
));
}
Ok((
NetBiosClient {
reader: self.reader,
writer: None,
timeout: self.timeout,
},
NetBiosClient {
reader: None,
writer: self.writer,
timeout: self.timeout,
},
))
}
pub fn can_read(&self) -> bool {
self.reader.is_some()
}
pub fn can_write(&self) -> bool {
self.writer.is_some()
}
}