1use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use parking_lot::Mutex as SyncMutex;
23use rapace_core::{
24 DecodeError, EncodeCtx, EncodeError, Frame, FrameView, INLINE_PAYLOAD_SIZE,
25 INLINE_PAYLOAD_SLOT, MsgDescHot, Transport, TransportError,
26};
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
28use tokio::sync::Mutex as AsyncMutex;
29
30const DESC_SIZE: usize = 64;
32
33const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
35
36pub struct StreamTransport<R, W> {
41 inner: Arc<StreamInner<R, W>>,
42}
43
44struct StreamInner<R, W> {
45 reader: AsyncMutex<R>,
47 writer: AsyncMutex<W>,
49 last_frame: SyncMutex<Option<ReceivedFrame>>,
51 closed: AtomicBool,
53}
54
55struct ReceivedFrame {
57 desc: MsgDescHot,
58 payload: Vec<u8>,
59}
60
61impl<S> StreamTransport<ReadHalf<S>, WriteHalf<S>>
62where
63 S: AsyncRead + AsyncWrite + Send + 'static,
64{
65 pub fn new(stream: S) -> Self {
70 let (reader, writer) = tokio::io::split(stream);
71 Self {
72 inner: Arc::new(StreamInner {
73 reader: AsyncMutex::new(reader),
74 writer: AsyncMutex::new(writer),
75 last_frame: SyncMutex::new(None),
76 closed: AtomicBool::new(false),
77 }),
78 }
79 }
80}
81
82impl StreamTransport<ReadHalf<tokio::io::DuplexStream>, WriteHalf<tokio::io::DuplexStream>> {
83 pub fn pair() -> (Self, Self) {
87 let (a, b) = tokio::io::duplex(65536);
89 (Self::new(a), Self::new(b))
90 }
91}
92
93fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
105 unsafe { std::mem::transmute_copy(desc) }
108}
109
110fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
116 unsafe { std::mem::transmute_copy(bytes) }
118}
119
120impl<R, W> Transport for StreamTransport<R, W>
121where
122 R: AsyncRead + Unpin + Send + Sync + 'static,
123 W: AsyncWrite + Unpin + Send + Sync + 'static,
124{
125 async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
126 if self.is_closed() {
127 return Err(TransportError::Closed);
128 }
129
130 let payload = frame.payload();
131 let frame_len = DESC_SIZE + payload.len();
132
133 let desc_bytes = desc_to_bytes(&frame.desc);
135
136 let mut writer = self.inner.writer.lock().await;
138
139 writer
141 .write_all(&(frame_len as u32).to_le_bytes())
142 .await
143 .map_err(TransportError::Io)?;
144
145 writer
147 .write_all(&desc_bytes)
148 .await
149 .map_err(TransportError::Io)?;
150
151 if !payload.is_empty() {
153 writer
154 .write_all(payload)
155 .await
156 .map_err(TransportError::Io)?;
157 }
158
159 writer.flush().await.map_err(TransportError::Io)?;
161
162 Ok(())
163 }
164
165 async fn recv_frame(&self) -> Result<FrameView<'_>, TransportError> {
166 if self.is_closed() {
167 return Err(TransportError::Closed);
168 }
169
170 let mut reader = self.inner.reader.lock().await;
171
172 let mut len_buf = [0u8; 4];
174 reader.read_exact(&mut len_buf).await.map_err(|e| {
175 if e.kind() == std::io::ErrorKind::UnexpectedEof {
176 TransportError::Closed
177 } else {
178 TransportError::Io(e)
179 }
180 })?;
181 let frame_len = u32::from_le_bytes(len_buf) as usize;
182
183 if frame_len < DESC_SIZE {
185 return Err(TransportError::Io(std::io::Error::new(
186 std::io::ErrorKind::InvalidData,
187 format!("frame too small: {} < {}", frame_len, DESC_SIZE),
188 )));
189 }
190
191 let mut desc_buf = [0u8; DESC_SIZE];
193 reader
194 .read_exact(&mut desc_buf)
195 .await
196 .map_err(TransportError::Io)?;
197
198 let mut desc = bytes_to_desc(&desc_buf);
199
200 let payload_len = frame_len - DESC_SIZE;
202 let payload = if payload_len > 0 {
203 let mut buf = vec![0u8; payload_len];
204 reader
205 .read_exact(&mut buf)
206 .await
207 .map_err(TransportError::Io)?;
208 buf
209 } else {
210 Vec::new()
211 };
212
213 drop(reader);
215
216 desc.payload_len = payload_len as u32;
218
219 if payload_len <= INLINE_PAYLOAD_SIZE {
221 desc.payload_slot = INLINE_PAYLOAD_SLOT;
222 desc.inline_payload[..payload_len].copy_from_slice(&payload);
223 } else {
224 desc.payload_slot = 0;
226 }
227
228 {
230 let mut last = self.inner.last_frame.lock();
231 *last = Some(ReceivedFrame { desc, payload });
232 }
233
234 let last = self.inner.last_frame.lock();
238 let frame_ref = last.as_ref().unwrap();
239
240 let desc_ptr = &frame_ref.desc as *const MsgDescHot;
241 let payload_slice = if frame_ref.desc.is_inline() {
242 frame_ref.desc.inline_payload()
243 } else {
244 &frame_ref.payload
245 };
246 let payload_ptr = payload_slice.as_ptr();
247 let payload_len = payload_slice.len();
248
249 let desc: &MsgDescHot = unsafe { &*desc_ptr };
253 let payload: &[u8] = unsafe { std::slice::from_raw_parts(payload_ptr, payload_len) };
254
255 Ok(FrameView::new(desc, payload))
256 }
257
258 fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
259 Box::new(StreamEncoder::new())
260 }
261
262 async fn close(&self) -> Result<(), TransportError> {
263 self.inner.closed.store(true, Ordering::Release);
264 Ok(())
265 }
266}
267
268impl<R, W> StreamTransport<R, W> {
269 pub fn is_closed(&self) -> bool {
271 self.inner.closed.load(Ordering::Acquire)
272 }
273}
274
275pub struct StreamEncoder {
279 desc: MsgDescHot,
280 payload: Vec<u8>,
281}
282
283impl StreamEncoder {
284 fn new() -> Self {
285 Self {
286 desc: MsgDescHot::new(),
287 payload: Vec::new(),
288 }
289 }
290
291 pub fn set_desc(&mut self, desc: MsgDescHot) {
293 self.desc = desc;
294 }
295}
296
297impl EncodeCtx for StreamEncoder {
298 fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
299 self.payload.extend_from_slice(bytes);
300 Ok(())
301 }
302
303 fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
304 Ok(Frame::with_payload(self.desc, self.payload))
305 }
306}
307
308pub struct StreamDecoder<'a> {
310 data: &'a [u8],
311 pos: usize,
312}
313
314impl<'a> StreamDecoder<'a> {
315 pub fn new(data: &'a [u8]) -> Self {
317 Self { data, pos: 0 }
318 }
319}
320
321impl<'a> rapace_core::DecodeCtx<'a> for StreamDecoder<'a> {
322 fn decode_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
323 let result = &self.data[self.pos..];
324 self.pos = self.data.len();
325 Ok(result)
326 }
327
328 fn remaining(&self) -> &'a [u8] {
329 &self.data[self.pos..]
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use rapace_core::FrameFlags;
337
338 #[tokio::test]
339 async fn test_pair_creation() {
340 let (a, b) = StreamTransport::pair();
341 assert!(!a.is_closed());
342 assert!(!b.is_closed());
343 }
344
345 #[tokio::test]
346 async fn test_send_recv_inline() {
347 let (a, b) = StreamTransport::pair();
348
349 let mut desc = MsgDescHot::new();
351 desc.msg_id = 1;
352 desc.channel_id = 1;
353 desc.method_id = 42;
354 desc.flags = FrameFlags::DATA;
355
356 let frame = Frame::with_inline_payload(desc, b"hello").unwrap();
357
358 a.send_frame(&frame).await.unwrap();
360
361 let view = b.recv_frame().await.unwrap();
363 assert_eq!(view.desc.msg_id, 1);
364 assert_eq!(view.desc.channel_id, 1);
365 assert_eq!(view.desc.method_id, 42);
366 assert_eq!(view.payload, b"hello");
367 }
368
369 #[tokio::test]
370 async fn test_send_recv_external_payload() {
371 let (a, b) = StreamTransport::pair();
372
373 let mut desc = MsgDescHot::new();
374 desc.msg_id = 2;
375 desc.flags = FrameFlags::DATA;
376
377 let payload = vec![0u8; 1000]; let frame = Frame::with_payload(desc, payload.clone());
379
380 a.send_frame(&frame).await.unwrap();
381
382 let view = b.recv_frame().await.unwrap();
383 assert_eq!(view.desc.msg_id, 2);
384 assert_eq!(view.payload.len(), 1000);
385 }
386
387 #[tokio::test]
388 async fn test_bidirectional() {
389 let (a, b) = StreamTransport::pair();
390
391 let mut desc_a = MsgDescHot::new();
393 desc_a.msg_id = 1;
394 let frame_a = Frame::with_inline_payload(desc_a, b"from A").unwrap();
395 a.send_frame(&frame_a).await.unwrap();
396
397 let mut desc_b = MsgDescHot::new();
399 desc_b.msg_id = 2;
400 let frame_b = Frame::with_inline_payload(desc_b, b"from B").unwrap();
401 b.send_frame(&frame_b).await.unwrap();
402
403 let view_b = b.recv_frame().await.unwrap();
405 assert_eq!(view_b.payload, b"from A");
406
407 let view_a = a.recv_frame().await.unwrap();
408 assert_eq!(view_a.payload, b"from B");
409 }
410
411 #[tokio::test]
412 async fn test_concurrent_send_recv() {
413 let (a, b) = StreamTransport::pair();
415 let a = Arc::new(a);
416 let b = Arc::new(b);
417
418 let a_sender = a.clone();
420 let send_handle = tokio::spawn(async move {
421 for i in 0..10u64 {
422 let mut desc = MsgDescHot::new();
423 desc.msg_id = i;
424 let frame = Frame::with_inline_payload(desc, b"ping").unwrap();
425 a_sender.send_frame(&frame).await.unwrap();
426 }
427 });
428
429 let b_clone = b.clone();
431 let echo_handle = tokio::spawn(async move {
432 for _ in 0..10 {
433 let view = b_clone.recv_frame().await.unwrap();
434 let mut desc = MsgDescHot::new();
435 desc.msg_id = view.desc.msg_id;
436 let frame = Frame::with_inline_payload(desc, b"pong").unwrap();
437 b_clone.send_frame(&frame).await.unwrap();
438 }
439 });
440
441 let a_receiver = a.clone();
443 let recv_handle = tokio::spawn(async move {
444 for _ in 0..10 {
445 let view = a_receiver.recv_frame().await.unwrap();
446 assert_eq!(view.payload, b"pong");
447 }
448 });
449
450 send_handle.await.unwrap();
452 echo_handle.await.unwrap();
453 recv_handle.await.unwrap();
454 }
455
456 #[tokio::test]
457 async fn test_close() {
458 let (a, _b) = StreamTransport::pair();
459
460 a.close().await.unwrap();
461 assert!(a.is_closed());
462
463 let frame = Frame::new(MsgDescHot::new());
465 assert!(matches!(
466 a.send_frame(&frame).await,
467 Err(TransportError::Closed)
468 ));
469 }
470
471 #[tokio::test]
472 async fn test_encoder() {
473 let (a, _b) = StreamTransport::pair();
474
475 let mut encoder = a.encoder();
476 encoder.encode_bytes(b"test data").unwrap();
477 let frame = encoder.finish().unwrap();
478
479 assert_eq!(frame.payload(), b"test data");
480 }
481}
482
483#[cfg(test)]
485mod conformance_tests {
486 use super::*;
487 use rapace_testkit::{TestError, TransportFactory};
488 use tokio::io::{ReadHalf, WriteHalf};
489
490 struct StreamFactory;
491
492 impl TransportFactory for StreamFactory {
493 type Transport =
494 StreamTransport<ReadHalf<tokio::io::DuplexStream>, WriteHalf<tokio::io::DuplexStream>>;
495
496 async fn connect_pair() -> Result<(Self::Transport, Self::Transport), TestError> {
497 Ok(StreamTransport::pair())
498 }
499 }
500
501 #[tokio::test]
502 async fn unary_happy_path() {
503 rapace_testkit::run_unary_happy_path::<StreamFactory>().await;
504 }
505
506 #[tokio::test]
507 async fn unary_multiple_calls() {
508 rapace_testkit::run_unary_multiple_calls::<StreamFactory>().await;
509 }
510
511 #[tokio::test]
512 async fn ping_pong() {
513 rapace_testkit::run_ping_pong::<StreamFactory>().await;
514 }
515
516 #[tokio::test]
517 async fn deadline_success() {
518 rapace_testkit::run_deadline_success::<StreamFactory>().await;
519 }
520
521 #[tokio::test]
522 async fn deadline_exceeded() {
523 rapace_testkit::run_deadline_exceeded::<StreamFactory>().await;
524 }
525
526 #[tokio::test]
527 async fn cancellation() {
528 rapace_testkit::run_cancellation::<StreamFactory>().await;
529 }
530
531 #[tokio::test]
532 async fn credit_grant() {
533 rapace_testkit::run_credit_grant::<StreamFactory>().await;
534 }
535
536 #[tokio::test]
537 async fn error_response() {
538 rapace_testkit::run_error_response::<StreamFactory>().await;
539 }
540
541 #[tokio::test]
544 async fn session_credit_exhaustion() {
545 rapace_testkit::run_session_credit_exhaustion::<StreamFactory>().await;
546 }
547
548 #[tokio::test]
549 async fn session_cancelled_channel_drop() {
550 rapace_testkit::run_session_cancelled_channel_drop::<StreamFactory>().await;
551 }
552
553 #[tokio::test]
554 async fn session_cancel_control_frame() {
555 rapace_testkit::run_session_cancel_control_frame::<StreamFactory>().await;
556 }
557
558 #[tokio::test]
559 async fn session_grant_credits_control_frame() {
560 rapace_testkit::run_session_grant_credits_control_frame::<StreamFactory>().await;
561 }
562
563 #[tokio::test]
564 async fn session_deadline_check() {
565 rapace_testkit::run_session_deadline_check::<StreamFactory>().await;
566 }
567
568 #[tokio::test]
571 async fn server_streaming_happy_path() {
572 rapace_testkit::run_server_streaming_happy_path::<StreamFactory>().await;
573 }
574
575 #[tokio::test]
576 async fn client_streaming_happy_path() {
577 rapace_testkit::run_client_streaming_happy_path::<StreamFactory>().await;
578 }
579
580 #[tokio::test]
581 async fn bidirectional_streaming() {
582 rapace_testkit::run_bidirectional_streaming::<StreamFactory>().await;
583 }
584
585 #[tokio::test]
586 async fn streaming_cancellation() {
587 rapace_testkit::run_streaming_cancellation::<StreamFactory>().await;
588 }
589
590 #[tokio::test]
593 async fn macro_server_streaming() {
594 rapace_testkit::run_macro_server_streaming::<StreamFactory>().await;
595 }
596}