rapace_transport_stream/
lib.rs

1//! rapace-transport-stream: TCP/Unix socket transport for rapace.
2//!
3//! For cross-machine or cross-container communication.
4//!
5//! # Wire Format
6//!
7//! Each frame is sent as:
8//! - `u32 LE`: total frame length (64 + payload_len)
9//! - `[u8; 64]`: MsgDescHot as raw bytes (repr(C), POD)
10//! - `[u8; payload_len]`: payload bytes
11//!
12//! # Characteristics
13//!
14//! - Length-prefixed frames for easy framing
15//! - Everything is owned buffers (no zero-copy on receive)
16//! - Full-duplex: send and receive can happen concurrently
17//! - Same RPC semantics as other transports
18
19use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use parking_lot::Mutex as SyncMutex;
23use rapace_core::{
24    DecodeError, EncodeCtx, EncodeError, Frame, FrameView, INLINE_PAYLOAD_SIZE,
25    INLINE_PAYLOAD_SLOT, MsgDescHot, Transport, TransportError,
26};
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
28use tokio::sync::Mutex as AsyncMutex;
29
30/// Size of MsgDescHot in bytes (must be 64).
31const DESC_SIZE: usize = 64;
32
33// Compile-time check that MsgDescHot is exactly 64 bytes
34const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
35
36/// Stream-based transport implementation.
37///
38/// Works with any `AsyncRead + AsyncWrite` stream (TCP, Unix socket, duplex, etc.).
39/// Uses split read/write halves for full-duplex operation.
40pub struct StreamTransport<R, W> {
41    inner: Arc<StreamInner<R, W>>,
42}
43
44struct StreamInner<R, W> {
45    /// Read half of the stream (async mutex for holding across awaits).
46    reader: AsyncMutex<R>,
47    /// Write half of the stream (async mutex for holding across awaits).
48    writer: AsyncMutex<W>,
49    /// Buffer for the last received frame (for FrameView lifetime).
50    last_frame: SyncMutex<Option<ReceivedFrame>>,
51    /// Whether the transport is closed.
52    closed: AtomicBool,
53}
54
55/// Internal storage for a received frame.
56struct ReceivedFrame {
57    desc: MsgDescHot,
58    payload: Vec<u8>,
59}
60
61impl<S> StreamTransport<ReadHalf<S>, WriteHalf<S>>
62where
63    S: AsyncRead + AsyncWrite + Send + 'static,
64{
65    /// Create a new stream transport by splitting the given stream.
66    ///
67    /// The stream is split into read and write halves, allowing concurrent
68    /// send and receive operations.
69    pub fn new(stream: S) -> Self {
70        let (reader, writer) = tokio::io::split(stream);
71        Self {
72            inner: Arc::new(StreamInner {
73                reader: AsyncMutex::new(reader),
74                writer: AsyncMutex::new(writer),
75                last_frame: SyncMutex::new(None),
76                closed: AtomicBool::new(false),
77            }),
78        }
79    }
80}
81
82impl StreamTransport<ReadHalf<tokio::io::DuplexStream>, WriteHalf<tokio::io::DuplexStream>> {
83    /// Create a connected pair of stream transports for testing.
84    ///
85    /// Uses `tokio::io::duplex` internally.
86    pub fn pair() -> (Self, Self) {
87        // 64KB buffer should be plenty for testing
88        let (a, b) = tokio::io::duplex(65536);
89        (Self::new(a), Self::new(b))
90    }
91}
92
93/// Convert MsgDescHot to raw bytes.
94///
95/// # Safety
96///
97/// MsgDescHot is `#[repr(C, align(64))]` and contains only POD types
98/// (integers, bitflags which is a u32, and a byte array). This makes
99/// it safe to transmute to/from bytes on the same platform.
100///
101/// Note: This is NOT portable across platforms with different endianness
102/// or struct padding. For cross-platform wire format, use explicit
103/// field serialization instead.
104fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
105    // SAFETY: MsgDescHot is repr(C), Copy, and exactly 64 bytes.
106    // All fields are primitive types with well-defined layout.
107    unsafe { std::mem::transmute_copy(desc) }
108}
109
110/// Convert raw bytes to MsgDescHot.
111///
112/// # Safety
113///
114/// See `desc_to_bytes` for safety discussion.
115fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
116    // SAFETY: Same as desc_to_bytes - MsgDescHot is repr(C), Copy, 64 bytes.
117    unsafe { std::mem::transmute_copy(bytes) }
118}
119
120impl<R, W> Transport for StreamTransport<R, W>
121where
122    R: AsyncRead + Unpin + Send + Sync + 'static,
123    W: AsyncWrite + Unpin + Send + Sync + 'static,
124{
125    async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
126        if self.is_closed() {
127            return Err(TransportError::Closed);
128        }
129
130        let payload = frame.payload();
131        let frame_len = DESC_SIZE + payload.len();
132
133        // Serialize descriptor
134        let desc_bytes = desc_to_bytes(&frame.desc);
135
136        // Write: length prefix + descriptor + payload
137        let mut writer = self.inner.writer.lock().await;
138
139        // Length prefix (u32 LE)
140        writer
141            .write_all(&(frame_len as u32).to_le_bytes())
142            .await
143            .map_err(TransportError::Io)?;
144
145        // Descriptor (64 bytes)
146        writer
147            .write_all(&desc_bytes)
148            .await
149            .map_err(TransportError::Io)?;
150
151        // Payload
152        if !payload.is_empty() {
153            writer
154                .write_all(payload)
155                .await
156                .map_err(TransportError::Io)?;
157        }
158
159        // Flush to ensure frame is sent
160        writer.flush().await.map_err(TransportError::Io)?;
161
162        Ok(())
163    }
164
165    async fn recv_frame(&self) -> Result<FrameView<'_>, TransportError> {
166        if self.is_closed() {
167            return Err(TransportError::Closed);
168        }
169
170        let mut reader = self.inner.reader.lock().await;
171
172        // Read length prefix
173        let mut len_buf = [0u8; 4];
174        reader.read_exact(&mut len_buf).await.map_err(|e| {
175            if e.kind() == std::io::ErrorKind::UnexpectedEof {
176                TransportError::Closed
177            } else {
178                TransportError::Io(e)
179            }
180        })?;
181        let frame_len = u32::from_le_bytes(len_buf) as usize;
182
183        // Validate frame length
184        if frame_len < DESC_SIZE {
185            return Err(TransportError::Io(std::io::Error::new(
186                std::io::ErrorKind::InvalidData,
187                format!("frame too small: {} < {}", frame_len, DESC_SIZE),
188            )));
189        }
190
191        // Read descriptor
192        let mut desc_buf = [0u8; DESC_SIZE];
193        reader
194            .read_exact(&mut desc_buf)
195            .await
196            .map_err(TransportError::Io)?;
197
198        let mut desc = bytes_to_desc(&desc_buf);
199
200        // Read payload
201        let payload_len = frame_len - DESC_SIZE;
202        let payload = if payload_len > 0 {
203            let mut buf = vec![0u8; payload_len];
204            reader
205                .read_exact(&mut buf)
206                .await
207                .map_err(TransportError::Io)?;
208            buf
209        } else {
210            Vec::new()
211        };
212
213        // Drop reader lock before storing frame
214        drop(reader);
215
216        // Update desc.payload_len to match actual received payload
217        desc.payload_len = payload_len as u32;
218
219        // If payload fits inline, mark it as inline
220        if payload_len <= INLINE_PAYLOAD_SIZE {
221            desc.payload_slot = INLINE_PAYLOAD_SLOT;
222            desc.inline_payload[..payload_len].copy_from_slice(&payload);
223        } else {
224            // Mark as external payload
225            desc.payload_slot = 0;
226        }
227
228        // Store frame for FrameView lifetime
229        {
230            let mut last = self.inner.last_frame.lock();
231            *last = Some(ReceivedFrame { desc, payload });
232        }
233
234        // Create FrameView from stored frame
235        // SAFETY: The frame is stored in self.inner which lives as long as self.
236        // The returned FrameView borrows &self, preventing another recv_frame call.
237        let last = self.inner.last_frame.lock();
238        let frame_ref = last.as_ref().unwrap();
239
240        let desc_ptr = &frame_ref.desc as *const MsgDescHot;
241        let payload_slice = if frame_ref.desc.is_inline() {
242            frame_ref.desc.inline_payload()
243        } else {
244            &frame_ref.payload
245        };
246        let payload_ptr = payload_slice.as_ptr();
247        let payload_len = payload_slice.len();
248
249        // SAFETY: Extending lifetime is safe because:
250        // - Data lives in Arc<StreamInner> which outlives &self
251        // - FrameView borrows &self, preventing concurrent recv_frame
252        let desc: &MsgDescHot = unsafe { &*desc_ptr };
253        let payload: &[u8] = unsafe { std::slice::from_raw_parts(payload_ptr, payload_len) };
254
255        Ok(FrameView::new(desc, payload))
256    }
257
258    fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
259        Box::new(StreamEncoder::new())
260    }
261
262    async fn close(&self) -> Result<(), TransportError> {
263        self.inner.closed.store(true, Ordering::Release);
264        Ok(())
265    }
266}
267
268impl<R, W> StreamTransport<R, W> {
269    /// Check if the transport is closed.
270    pub fn is_closed(&self) -> bool {
271        self.inner.closed.load(Ordering::Acquire)
272    }
273}
274
275/// Encoder for stream transport.
276///
277/// Simply accumulates bytes into a Vec.
278pub struct StreamEncoder {
279    desc: MsgDescHot,
280    payload: Vec<u8>,
281}
282
283impl StreamEncoder {
284    fn new() -> Self {
285        Self {
286            desc: MsgDescHot::new(),
287            payload: Vec::new(),
288        }
289    }
290
291    /// Set the descriptor for this frame.
292    pub fn set_desc(&mut self, desc: MsgDescHot) {
293        self.desc = desc;
294    }
295}
296
297impl EncodeCtx for StreamEncoder {
298    fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
299        self.payload.extend_from_slice(bytes);
300        Ok(())
301    }
302
303    fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
304        Ok(Frame::with_payload(self.desc, self.payload))
305    }
306}
307
308/// Decoder for stream transport.
309pub struct StreamDecoder<'a> {
310    data: &'a [u8],
311    pos: usize,
312}
313
314impl<'a> StreamDecoder<'a> {
315    /// Create a new decoder from a byte slice.
316    pub fn new(data: &'a [u8]) -> Self {
317        Self { data, pos: 0 }
318    }
319}
320
321impl<'a> rapace_core::DecodeCtx<'a> for StreamDecoder<'a> {
322    fn decode_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
323        let result = &self.data[self.pos..];
324        self.pos = self.data.len();
325        Ok(result)
326    }
327
328    fn remaining(&self) -> &'a [u8] {
329        &self.data[self.pos..]
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use rapace_core::FrameFlags;
337
338    #[tokio::test]
339    async fn test_pair_creation() {
340        let (a, b) = StreamTransport::pair();
341        assert!(!a.is_closed());
342        assert!(!b.is_closed());
343    }
344
345    #[tokio::test]
346    async fn test_send_recv_inline() {
347        let (a, b) = StreamTransport::pair();
348
349        // Create a frame with inline payload
350        let mut desc = MsgDescHot::new();
351        desc.msg_id = 1;
352        desc.channel_id = 1;
353        desc.method_id = 42;
354        desc.flags = FrameFlags::DATA;
355
356        let frame = Frame::with_inline_payload(desc, b"hello").unwrap();
357
358        // Send from A
359        a.send_frame(&frame).await.unwrap();
360
361        // Receive on B
362        let view = b.recv_frame().await.unwrap();
363        assert_eq!(view.desc.msg_id, 1);
364        assert_eq!(view.desc.channel_id, 1);
365        assert_eq!(view.desc.method_id, 42);
366        assert_eq!(view.payload, b"hello");
367    }
368
369    #[tokio::test]
370    async fn test_send_recv_external_payload() {
371        let (a, b) = StreamTransport::pair();
372
373        let mut desc = MsgDescHot::new();
374        desc.msg_id = 2;
375        desc.flags = FrameFlags::DATA;
376
377        let payload = vec![0u8; 1000]; // Larger than inline
378        let frame = Frame::with_payload(desc, payload.clone());
379
380        a.send_frame(&frame).await.unwrap();
381
382        let view = b.recv_frame().await.unwrap();
383        assert_eq!(view.desc.msg_id, 2);
384        assert_eq!(view.payload.len(), 1000);
385    }
386
387    #[tokio::test]
388    async fn test_bidirectional() {
389        let (a, b) = StreamTransport::pair();
390
391        // A -> B
392        let mut desc_a = MsgDescHot::new();
393        desc_a.msg_id = 1;
394        let frame_a = Frame::with_inline_payload(desc_a, b"from A").unwrap();
395        a.send_frame(&frame_a).await.unwrap();
396
397        // B -> A
398        let mut desc_b = MsgDescHot::new();
399        desc_b.msg_id = 2;
400        let frame_b = Frame::with_inline_payload(desc_b, b"from B").unwrap();
401        b.send_frame(&frame_b).await.unwrap();
402
403        // Receive both
404        let view_b = b.recv_frame().await.unwrap();
405        assert_eq!(view_b.payload, b"from A");
406
407        let view_a = a.recv_frame().await.unwrap();
408        assert_eq!(view_a.payload, b"from B");
409    }
410
411    #[tokio::test]
412    async fn test_concurrent_send_recv() {
413        // This test verifies that send and recv can happen concurrently
414        let (a, b) = StreamTransport::pair();
415        let a = Arc::new(a);
416        let b = Arc::new(b);
417
418        // Spawn a task that sends multiple frames
419        let a_sender = a.clone();
420        let send_handle = tokio::spawn(async move {
421            for i in 0..10u64 {
422                let mut desc = MsgDescHot::new();
423                desc.msg_id = i;
424                let frame = Frame::with_inline_payload(desc, b"ping").unwrap();
425                a_sender.send_frame(&frame).await.unwrap();
426            }
427        });
428
429        // Spawn a task that receives and echoes back
430        let b_clone = b.clone();
431        let echo_handle = tokio::spawn(async move {
432            for _ in 0..10 {
433                let view = b_clone.recv_frame().await.unwrap();
434                let mut desc = MsgDescHot::new();
435                desc.msg_id = view.desc.msg_id;
436                let frame = Frame::with_inline_payload(desc, b"pong").unwrap();
437                b_clone.send_frame(&frame).await.unwrap();
438            }
439        });
440
441        // Receive the echoed frames
442        let a_receiver = a.clone();
443        let recv_handle = tokio::spawn(async move {
444            for _ in 0..10 {
445                let view = a_receiver.recv_frame().await.unwrap();
446                assert_eq!(view.payload, b"pong");
447            }
448        });
449
450        // Wait for all tasks
451        send_handle.await.unwrap();
452        echo_handle.await.unwrap();
453        recv_handle.await.unwrap();
454    }
455
456    #[tokio::test]
457    async fn test_close() {
458        let (a, _b) = StreamTransport::pair();
459
460        a.close().await.unwrap();
461        assert!(a.is_closed());
462
463        // Sending on closed transport should fail
464        let frame = Frame::new(MsgDescHot::new());
465        assert!(matches!(
466            a.send_frame(&frame).await,
467            Err(TransportError::Closed)
468        ));
469    }
470
471    #[tokio::test]
472    async fn test_encoder() {
473        let (a, _b) = StreamTransport::pair();
474
475        let mut encoder = a.encoder();
476        encoder.encode_bytes(b"test data").unwrap();
477        let frame = encoder.finish().unwrap();
478
479        assert_eq!(frame.payload(), b"test data");
480    }
481}
482
483/// Conformance tests using rapace-testkit.
484#[cfg(test)]
485mod conformance_tests {
486    use super::*;
487    use rapace_testkit::{TestError, TransportFactory};
488    use tokio::io::{ReadHalf, WriteHalf};
489
490    struct StreamFactory;
491
492    impl TransportFactory for StreamFactory {
493        type Transport =
494            StreamTransport<ReadHalf<tokio::io::DuplexStream>, WriteHalf<tokio::io::DuplexStream>>;
495
496        async fn connect_pair() -> Result<(Self::Transport, Self::Transport), TestError> {
497            Ok(StreamTransport::pair())
498        }
499    }
500
501    #[tokio::test]
502    async fn unary_happy_path() {
503        rapace_testkit::run_unary_happy_path::<StreamFactory>().await;
504    }
505
506    #[tokio::test]
507    async fn unary_multiple_calls() {
508        rapace_testkit::run_unary_multiple_calls::<StreamFactory>().await;
509    }
510
511    #[tokio::test]
512    async fn ping_pong() {
513        rapace_testkit::run_ping_pong::<StreamFactory>().await;
514    }
515
516    #[tokio::test]
517    async fn deadline_success() {
518        rapace_testkit::run_deadline_success::<StreamFactory>().await;
519    }
520
521    #[tokio::test]
522    async fn deadline_exceeded() {
523        rapace_testkit::run_deadline_exceeded::<StreamFactory>().await;
524    }
525
526    #[tokio::test]
527    async fn cancellation() {
528        rapace_testkit::run_cancellation::<StreamFactory>().await;
529    }
530
531    #[tokio::test]
532    async fn credit_grant() {
533        rapace_testkit::run_credit_grant::<StreamFactory>().await;
534    }
535
536    #[tokio::test]
537    async fn error_response() {
538        rapace_testkit::run_error_response::<StreamFactory>().await;
539    }
540
541    // Session-level tests (semantic enforcement)
542
543    #[tokio::test]
544    async fn session_credit_exhaustion() {
545        rapace_testkit::run_session_credit_exhaustion::<StreamFactory>().await;
546    }
547
548    #[tokio::test]
549    async fn session_cancelled_channel_drop() {
550        rapace_testkit::run_session_cancelled_channel_drop::<StreamFactory>().await;
551    }
552
553    #[tokio::test]
554    async fn session_cancel_control_frame() {
555        rapace_testkit::run_session_cancel_control_frame::<StreamFactory>().await;
556    }
557
558    #[tokio::test]
559    async fn session_grant_credits_control_frame() {
560        rapace_testkit::run_session_grant_credits_control_frame::<StreamFactory>().await;
561    }
562
563    #[tokio::test]
564    async fn session_deadline_check() {
565        rapace_testkit::run_session_deadline_check::<StreamFactory>().await;
566    }
567
568    // Streaming tests
569
570    #[tokio::test]
571    async fn server_streaming_happy_path() {
572        rapace_testkit::run_server_streaming_happy_path::<StreamFactory>().await;
573    }
574
575    #[tokio::test]
576    async fn client_streaming_happy_path() {
577        rapace_testkit::run_client_streaming_happy_path::<StreamFactory>().await;
578    }
579
580    #[tokio::test]
581    async fn bidirectional_streaming() {
582        rapace_testkit::run_bidirectional_streaming::<StreamFactory>().await;
583    }
584
585    #[tokio::test]
586    async fn streaming_cancellation() {
587        rapace_testkit::run_streaming_cancellation::<StreamFactory>().await;
588    }
589
590    // Macro-generated streaming tests
591
592    #[tokio::test]
593    async fn macro_server_streaming() {
594        rapace_testkit::run_macro_server_streaming::<StreamFactory>().await;
595    }
596}