use compio::buf::{BufResult, IntoInner, IoBuf, IoBufMut};
use compio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use compio::net::TcpStream;
#[cfg(unix)]
use compio::net::UnixStream;
#[cfg(feature = "compio-tls")]
use compio::tls::TlsStream;
const READ_BUF_CAPACITY: usize = 8192;
enum StreamInner {
Tcp(TcpStream),
#[cfg(feature = "compio-tls")]
Tls(TlsStream<TcpStream>),
#[cfg(unix)]
Unix(UnixStream),
}
pub struct Stream {
inner: StreamInner,
read_buf: Vec<u8>,
read_pos: usize,
}
impl Stream {
pub fn tcp(stream: TcpStream) -> Self {
Self {
inner: StreamInner::Tcp(stream),
read_buf: Vec::with_capacity(READ_BUF_CAPACITY),
read_pos: 0,
}
}
#[cfg(unix)]
pub fn unix(stream: UnixStream) -> Self {
Self {
inner: StreamInner::Unix(stream),
read_buf: Vec::with_capacity(READ_BUF_CAPACITY),
read_pos: 0,
}
}
#[cfg(feature = "compio-tls")]
pub async fn upgrade_to_tls(self, host: &str) -> Result<Self, crate::error::Error> {
match self.inner {
StreamInner::Tcp(tcp_stream) => {
let native_connector =
compio::native_tls::TlsConnector::new().map_err(crate::error::Error::Tls)?;
let connector = compio::tls::TlsConnector::from(native_connector);
let tls_stream = connector.connect(host, tcp_stream).await?;
Ok(Self {
inner: StreamInner::Tls(tls_stream),
read_buf: Vec::with_capacity(READ_BUF_CAPACITY),
read_pos: 0,
})
}
StreamInner::Tls(_) => Err(crate::error::Error::InvalidUsage(
"Stream is already TLS".into(),
)),
#[cfg(unix)]
StreamInner::Unix(_) => Err(crate::error::Error::InvalidUsage(
"Cannot upgrade Unix socket to TLS".into(),
)),
}
}
fn available(&self) -> usize {
self.read_buf.len() - self.read_pos
}
async fn fill_buf(&mut self) -> std::io::Result<()> {
if self.read_pos > 0 {
let valid = self.available();
self.read_buf
.copy_within(self.read_pos..self.read_pos + valid, 0);
self.read_buf.truncate(valid);
self.read_pos = 0;
}
let buf = std::mem::take(&mut self.read_buf);
let BufResult(result, buf) = self.read_raw(buf).await;
self.read_buf = buf;
let n = result?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
));
}
Ok(())
}
async fn ensure(&mut self, n: usize) -> std::io::Result<()> {
while self.available() < n {
self.fill_buf().await?;
}
Ok(())
}
pub async fn read_u8(&mut self) -> std::io::Result<u8> {
self.ensure(1).await?;
let byte = self.read_buf[self.read_pos];
self.read_pos += 1;
Ok(byte)
}
pub async fn read_message(
&mut self,
buffer_set: &mut crate::buffer_set::BufferSet,
) -> std::io::Result<()> {
self.ensure(5).await?;
buffer_set.type_byte = self.read_buf[self.read_pos];
let (len_bytes, _) = self.read_buf[self.read_pos + 1..]
.split_first_chunk::<4>()
.ok_or_else(|| std::io::Error::other("protocol: header shorter than 5 bytes"))?;
let length = u32::from_be_bytes(*len_bytes) as usize;
self.read_pos += 5;
let payload_len = length.saturating_sub(4);
if payload_len == 0 {
buffer_set.read_buffer.clear();
return Ok(());
}
buffer_set.read_buffer.clear();
buffer_set.read_buffer.reserve(payload_len);
let from_buf = self.available().min(payload_len);
buffer_set
.read_buffer
.extend_from_slice(&self.read_buf[self.read_pos..self.read_pos + from_buf]);
self.read_pos += from_buf;
let remaining = payload_len - from_buf;
if remaining > 0 {
let buf = std::mem::take(&mut buffer_set.read_buffer);
let BufResult(res, slice) = self.read_exact_raw(buf.slice(from_buf..payload_len)).await;
buffer_set.read_buffer = slice.into_inner();
res?;
}
Ok(())
}
async fn read_raw(&mut self, buf: Vec<u8>) -> BufResult<usize, Vec<u8>> {
match &mut self.inner {
StreamInner::Tcp(r) => r.read(buf).await,
#[cfg(feature = "compio-tls")]
StreamInner::Tls(r) => r.read(buf).await,
#[cfg(unix)]
StreamInner::Unix(r) => r.read(buf).await,
}
}
async fn read_exact_raw<B: IoBufMut>(&mut self, buf: B) -> BufResult<(), B> {
match &mut self.inner {
StreamInner::Tcp(r) => r.read_exact(buf).await,
#[cfg(feature = "compio-tls")]
StreamInner::Tls(r) => r.read_exact(buf).await,
#[cfg(unix)]
StreamInner::Unix(r) => r.read_exact(buf).await,
}
}
pub async fn write_all_owned(&mut self, buf: Vec<u8>) -> BufResult<(), Vec<u8>> {
match &mut self.inner {
StreamInner::Tcp(r) => r.write_all(buf).await,
#[cfg(feature = "compio-tls")]
StreamInner::Tls(r) => r.write_all(buf).await,
#[cfg(unix)]
StreamInner::Unix(r) => r.write_all(buf).await,
}
}
pub async fn flush(&mut self) -> std::io::Result<()> {
match &mut self.inner {
StreamInner::Tcp(r) => r.flush().await,
#[cfg(feature = "compio-tls")]
StreamInner::Tls(r) => r.flush().await,
#[cfg(unix)]
StreamInner::Unix(r) => r.flush().await,
}
}
pub fn is_tcp_loopback(&self) -> bool {
match &self.inner {
StreamInner::Tcp(r) => r
.peer_addr()
.map(|addr| addr.ip().is_loopback())
.unwrap_or(false),
#[cfg(feature = "compio-tls")]
StreamInner::Tls(_) => false,
#[cfg(unix)]
StreamInner::Unix(_) => false,
}
}
}