rapace_core/transport/
stream.rs1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use tokio::sync::Mutex as AsyncMutex;
6
7use crate::{Frame, INLINE_PAYLOAD_SIZE, INLINE_PAYLOAD_SLOT, MsgDescHot, Payload, TransportError};
8
9use super::TransportBackend;
10
11const DESC_SIZE: usize = 64;
13
14const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
15
16#[derive(Clone)]
17pub struct StreamTransport {
18 inner: Arc<StreamInner>,
19}
20
21impl std::fmt::Debug for StreamTransport {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("StreamTransport").finish_non_exhaustive()
24 }
25}
26
27struct StreamInner {
28 reader: AsyncMutex<Box<dyn AsyncRead + Unpin + Send + Sync>>,
29 writer: AsyncMutex<Box<dyn AsyncWrite + Unpin + Send + Sync>>,
30 closed: AtomicBool,
31}
32
33impl StreamTransport {
34 pub fn new<S>(stream: S) -> Self
35 where
36 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
37 {
38 let (reader, writer) = tokio::io::split(stream);
39 Self {
40 inner: Arc::new(StreamInner {
41 reader: AsyncMutex::new(Box::new(reader)),
42 writer: AsyncMutex::new(Box::new(writer)),
43 closed: AtomicBool::new(false),
44 }),
45 }
46 }
47
48 pub fn pair() -> (Self, Self) {
49 let (a, b) = tokio::io::duplex(65536);
50 (Self::new(a), Self::new(b))
51 }
52
53 fn is_closed_inner(&self) -> bool {
54 self.inner.closed.load(Ordering::Acquire)
55 }
56}
57
58fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
59 unsafe { std::mem::transmute_copy(desc) }
60}
61
62fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
63 unsafe { std::mem::transmute_copy(bytes) }
64}
65
66impl TransportBackend for StreamTransport {
67 async fn send_frame(&self, frame: Frame) -> Result<(), TransportError> {
68 if self.is_closed_inner() {
69 return Err(TransportError::Closed);
70 }
71
72 let payload = frame.payload_bytes();
73 let frame_len = DESC_SIZE + payload.len();
74 let desc_bytes = desc_to_bytes(&frame.desc);
75
76 let mut writer = self.inner.writer.lock().await;
77 writer
78 .write_all(&(frame_len as u32).to_le_bytes())
79 .await
80 .map_err(TransportError::Io)?;
81 writer
82 .write_all(&desc_bytes)
83 .await
84 .map_err(TransportError::Io)?;
85 if !payload.is_empty() {
86 writer
87 .write_all(payload)
88 .await
89 .map_err(TransportError::Io)?;
90 }
91 writer.flush().await.map_err(TransportError::Io)?;
92 Ok(())
93 }
94
95 async fn recv_frame(&self) -> Result<Frame, TransportError> {
96 if self.is_closed_inner() {
97 return Err(TransportError::Closed);
98 }
99
100 let mut reader = self.inner.reader.lock().await;
101
102 let mut len_buf = [0u8; 4];
103 reader.read_exact(&mut len_buf).await.map_err(|e| {
104 if e.kind() == std::io::ErrorKind::UnexpectedEof {
105 TransportError::Closed
106 } else {
107 TransportError::Io(e)
108 }
109 })?;
110 let frame_len = u32::from_le_bytes(len_buf) as usize;
111 if frame_len < DESC_SIZE {
112 return Err(TransportError::Io(std::io::Error::new(
113 std::io::ErrorKind::InvalidData,
114 format!("frame too small: {} < {}", frame_len, DESC_SIZE),
115 )));
116 }
117
118 let mut desc_buf = [0u8; DESC_SIZE];
119 reader
120 .read_exact(&mut desc_buf)
121 .await
122 .map_err(TransportError::Io)?;
123 let mut desc = bytes_to_desc(&desc_buf);
124
125 let payload_len = frame_len - DESC_SIZE;
126 let payload = if payload_len > 0 {
127 let mut buf = vec![0u8; payload_len];
128 reader
129 .read_exact(&mut buf)
130 .await
131 .map_err(TransportError::Io)?;
132 buf
133 } else {
134 Vec::new()
135 };
136
137 desc.payload_len = payload_len as u32;
138
139 if payload_len <= INLINE_PAYLOAD_SIZE {
140 desc.payload_slot = INLINE_PAYLOAD_SLOT;
141 desc.payload_generation = 0;
142 desc.payload_offset = 0;
143 desc.inline_payload[..payload_len].copy_from_slice(&payload);
144 Ok(Frame {
145 desc,
146 payload: Payload::Inline,
147 })
148 } else {
149 desc.payload_slot = 0;
150 desc.payload_generation = 0;
151 desc.payload_offset = 0;
152 Ok(Frame {
153 desc,
154 payload: Payload::Owned(payload),
155 })
156 }
157 }
158
159 fn close(&self) {
160 self.inner.closed.store(true, Ordering::Release);
161 }
162
163 fn is_closed(&self) -> bool {
164 self.is_closed_inner()
165 }
166}