use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::Mutex as AsyncMutex;
use crate::{Frame, INLINE_PAYLOAD_SIZE, INLINE_PAYLOAD_SLOT, MsgDescHot, Payload, TransportError};
use super::TransportBackend;
const DESC_SIZE: usize = 64;
const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
#[derive(Clone)]
pub struct StreamTransport {
inner: Arc<StreamInner>,
}
impl std::fmt::Debug for StreamTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamTransport").finish_non_exhaustive()
}
}
struct StreamInner {
reader: AsyncMutex<Box<dyn AsyncRead + Unpin + Send + Sync>>,
writer: AsyncMutex<Box<dyn AsyncWrite + Unpin + Send + Sync>>,
closed: AtomicBool,
}
impl StreamTransport {
pub fn new<S>(stream: S) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
let (reader, writer) = tokio::io::split(stream);
Self {
inner: Arc::new(StreamInner {
reader: AsyncMutex::new(Box::new(reader)),
writer: AsyncMutex::new(Box::new(writer)),
closed: AtomicBool::new(false),
}),
}
}
pub fn pair() -> (Self, Self) {
let (a, b) = tokio::io::duplex(65536);
(Self::new(a), Self::new(b))
}
fn is_closed_inner(&self) -> bool {
self.inner.closed.load(Ordering::Acquire)
}
}
fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
unsafe { std::mem::transmute_copy(desc) }
}
fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
unsafe { std::mem::transmute_copy(bytes) }
}
impl TransportBackend for StreamTransport {
async fn send_frame(&self, frame: Frame) -> Result<(), TransportError> {
if self.is_closed_inner() {
return Err(TransportError::Closed);
}
let payload = frame.payload_bytes();
let frame_len = DESC_SIZE + payload.len();
let desc_bytes = desc_to_bytes(&frame.desc);
let mut writer = self.inner.writer.lock().await;
writer
.write_all(&(frame_len as u32).to_le_bytes())
.await
.map_err(TransportError::Io)?;
writer
.write_all(&desc_bytes)
.await
.map_err(TransportError::Io)?;
if !payload.is_empty() {
writer
.write_all(payload)
.await
.map_err(TransportError::Io)?;
}
writer.flush().await.map_err(TransportError::Io)?;
Ok(())
}
async fn recv_frame(&self) -> Result<Frame, TransportError> {
if self.is_closed_inner() {
return Err(TransportError::Closed);
}
let mut reader = self.inner.reader.lock().await;
let mut len_buf = [0u8; 4];
reader.read_exact(&mut len_buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
TransportError::Closed
} else {
TransportError::Io(e)
}
})?;
let frame_len = u32::from_le_bytes(len_buf) as usize;
if frame_len < DESC_SIZE {
return Err(TransportError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("frame too small: {} < {}", frame_len, DESC_SIZE),
)));
}
let mut desc_buf = [0u8; DESC_SIZE];
reader
.read_exact(&mut desc_buf)
.await
.map_err(TransportError::Io)?;
let mut desc = bytes_to_desc(&desc_buf);
let payload_len = frame_len - DESC_SIZE;
let payload = if payload_len > 0 {
let mut buf = vec![0u8; payload_len];
reader
.read_exact(&mut buf)
.await
.map_err(TransportError::Io)?;
buf
} else {
Vec::new()
};
desc.payload_len = payload_len as u32;
if payload_len <= INLINE_PAYLOAD_SIZE {
desc.payload_slot = INLINE_PAYLOAD_SLOT;
desc.payload_generation = 0;
desc.payload_offset = 0;
desc.inline_payload[..payload_len].copy_from_slice(&payload);
Ok(Frame {
desc,
payload: Payload::Inline,
})
} else {
desc.payload_slot = 0;
desc.payload_generation = 0;
desc.payload_offset = 0;
Ok(Frame {
desc,
payload: Payload::Owned(payload),
})
}
}
fn close(&self) {
self.inner.closed.store(true, Ordering::Release);
}
fn is_closed(&self) -> bool {
self.is_closed_inner()
}
}