rapace_core/transport/
stream.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use tokio::sync::Mutex as AsyncMutex;
6
7use crate::{Frame, INLINE_PAYLOAD_SIZE, INLINE_PAYLOAD_SLOT, MsgDescHot, Payload, TransportError};
8
9use super::TransportBackend;
10
11/// Size of MsgDescHot in bytes (must be 64).
12const DESC_SIZE: usize = 64;
13
14const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
15
16#[derive(Clone)]
17pub struct StreamTransport {
18    inner: Arc<StreamInner>,
19}
20
21impl std::fmt::Debug for StreamTransport {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("StreamTransport").finish_non_exhaustive()
24    }
25}
26
27struct StreamInner {
28    reader: AsyncMutex<Box<dyn AsyncRead + Unpin + Send + Sync>>,
29    writer: AsyncMutex<Box<dyn AsyncWrite + Unpin + Send + Sync>>,
30    closed: AtomicBool,
31}
32
33impl StreamTransport {
34    pub fn new<S>(stream: S) -> Self
35    where
36        S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
37    {
38        let (reader, writer) = tokio::io::split(stream);
39        Self {
40            inner: Arc::new(StreamInner {
41                reader: AsyncMutex::new(Box::new(reader)),
42                writer: AsyncMutex::new(Box::new(writer)),
43                closed: AtomicBool::new(false),
44            }),
45        }
46    }
47
48    pub fn pair() -> (Self, Self) {
49        let (a, b) = tokio::io::duplex(65536);
50        (Self::new(a), Self::new(b))
51    }
52
53    fn is_closed_inner(&self) -> bool {
54        self.inner.closed.load(Ordering::Acquire)
55    }
56}
57
58fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
59    unsafe { std::mem::transmute_copy(desc) }
60}
61
62fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
63    unsafe { std::mem::transmute_copy(bytes) }
64}
65
66impl TransportBackend for StreamTransport {
67    async fn send_frame(&self, frame: Frame) -> Result<(), TransportError> {
68        if self.is_closed_inner() {
69            return Err(TransportError::Closed);
70        }
71
72        let payload = frame.payload_bytes();
73        let frame_len = DESC_SIZE + payload.len();
74        let desc_bytes = desc_to_bytes(&frame.desc);
75
76        let mut writer = self.inner.writer.lock().await;
77        writer
78            .write_all(&(frame_len as u32).to_le_bytes())
79            .await
80            .map_err(TransportError::Io)?;
81        writer
82            .write_all(&desc_bytes)
83            .await
84            .map_err(TransportError::Io)?;
85        if !payload.is_empty() {
86            writer
87                .write_all(payload)
88                .await
89                .map_err(TransportError::Io)?;
90        }
91        writer.flush().await.map_err(TransportError::Io)?;
92        Ok(())
93    }
94
95    async fn recv_frame(&self) -> Result<Frame, TransportError> {
96        if self.is_closed_inner() {
97            return Err(TransportError::Closed);
98        }
99
100        let mut reader = self.inner.reader.lock().await;
101
102        let mut len_buf = [0u8; 4];
103        reader.read_exact(&mut len_buf).await.map_err(|e| {
104            if e.kind() == std::io::ErrorKind::UnexpectedEof {
105                TransportError::Closed
106            } else {
107                TransportError::Io(e)
108            }
109        })?;
110        let frame_len = u32::from_le_bytes(len_buf) as usize;
111        if frame_len < DESC_SIZE {
112            return Err(TransportError::Io(std::io::Error::new(
113                std::io::ErrorKind::InvalidData,
114                format!("frame too small: {} < {}", frame_len, DESC_SIZE),
115            )));
116        }
117
118        let mut desc_buf = [0u8; DESC_SIZE];
119        reader
120            .read_exact(&mut desc_buf)
121            .await
122            .map_err(TransportError::Io)?;
123        let mut desc = bytes_to_desc(&desc_buf);
124
125        let payload_len = frame_len - DESC_SIZE;
126        let payload = if payload_len > 0 {
127            let mut buf = vec![0u8; payload_len];
128            reader
129                .read_exact(&mut buf)
130                .await
131                .map_err(TransportError::Io)?;
132            buf
133        } else {
134            Vec::new()
135        };
136
137        desc.payload_len = payload_len as u32;
138
139        if payload_len <= INLINE_PAYLOAD_SIZE {
140            desc.payload_slot = INLINE_PAYLOAD_SLOT;
141            desc.payload_generation = 0;
142            desc.payload_offset = 0;
143            desc.inline_payload[..payload_len].copy_from_slice(&payload);
144            Ok(Frame {
145                desc,
146                payload: Payload::Inline,
147            })
148        } else {
149            desc.payload_slot = 0;
150            desc.payload_generation = 0;
151            desc.payload_offset = 0;
152            Ok(Frame {
153                desc,
154                payload: Payload::Owned(payload),
155            })
156        }
157    }
158
159    fn close(&self) {
160        self.inner.closed.store(true, Ordering::Release);
161    }
162
163    fn is_closed(&self) -> bool {
164        self.is_closed_inner()
165    }
166}