use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::DriverError;
pub(crate) enum AsyncStream {
Tcp(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl AsyncStream {
#[inline]
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
match self {
Self::Tcp(s) => s.read(buf).await,
#[cfg(feature = "tls")]
Self::Tls(s) => s.read(buf).await,
}
}
#[inline]
pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
match self {
Self::Tcp(s) => s.write_all(buf).await,
#[cfg(feature = "tls")]
Self::Tls(s) => s.write_all(buf).await,
}
}
#[allow(dead_code)] #[inline]
pub async fn flush(&mut self) -> Result<(), std::io::Error> {
match self {
Self::Tcp(s) => s.flush().await,
#[cfg(feature = "tls")]
Self::Tls(s) => s.flush().await,
}
}
#[allow(dead_code)] pub fn set_nodelay(&self, nodelay: bool) -> Result<(), DriverError> {
match self {
Self::Tcp(s) => s.set_nodelay(nodelay).map_err(DriverError::Io),
#[cfg(feature = "tls")]
Self::Tls(s) => {
let (tcp, _) = s.get_ref();
tcp.set_nodelay(nodelay).map_err(DriverError::Io)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn async_stream_enum_variants_exist() {
fn _assert_send<T: Send>() {}
_assert_send::<AsyncStream>();
}
}