Skip to main content

a3s_common/transport/
mod.rs

1//! Shared transport abstraction for the A3S ecosystem.
2//!
3//! This module provides:
4//! - [`Transport`] trait — async send/recv abstraction over any byte stream
5//! - [`Frame`] — unified wire format: `[type:u8][length:u32][payload]`
6//! - [`FrameReader`] / [`FrameWriter`] — async buffered frame I/O
7//! - [`UnixTransport`] — Unix domain socket transport (cross-platform)
8//! - [`MockTransport`] — in-memory transport for testing
9//! - TEE protocol types for secure communication
10
11use async_trait::async_trait;
12use std::collections::VecDeque;
13use tokio::sync::Mutex;
14
15pub mod codec;
16pub mod frame;
17pub mod tee;
18pub mod unix;
19
20// Re-exports for convenience
21pub use codec::{FrameCodec, FrameReader, FrameWriter};
22pub use frame::{Frame, FrameType, MAX_PAYLOAD_SIZE};
23pub use tee::{TeeMessage, TeeRequest, TeeRequestType, TeeResponse, TeeResponseStatus};
24pub use unix::{UnixListener, UnixTransport};
25
26/// Well-known vsock port assignments
27pub mod ports {
28    /// gRPC agent control channel
29    pub const GRPC_AGENT: u32 = 4088;
30    /// Exec server (command execution in guest)
31    pub const EXEC_SERVER: u32 = 4089;
32    /// PTY server (interactive terminal)
33    pub const PTY_SERVER: u32 = 4090;
34    /// TEE secure channel (SafeClaw <-> a3s-code)
35    pub const TEE_CHANNEL: u32 = 4091;
36}
37
38// ---------------------------------------------------------------------------
39// Transport trait
40// ---------------------------------------------------------------------------
41
42/// Error type for transport operations
43#[derive(Debug, thiserror::Error)]
44pub enum TransportError {
45    #[error("Connection failed: {0}")]
46    ConnectionFailed(String),
47    #[error("Not connected")]
48    NotConnected,
49    #[error("Send failed: {0}")]
50    SendFailed(String),
51    #[error("Receive failed: {0}")]
52    RecvFailed(String),
53    #[error("Connection closed")]
54    Closed,
55    #[error("Operation timed out")]
56    Timeout,
57    #[error("Frame error: {0}")]
58    FrameError(String),
59    #[error("Protocol error: {0}")]
60    Protocol(String),
61}
62
63/// Async transport trait for sending and receiving framed messages.
64#[async_trait]
65pub trait Transport: Send + Sync + std::fmt::Debug {
66    /// Establish the connection.
67    async fn connect(&mut self) -> Result<(), TransportError>;
68
69    /// Send raw bytes as a data frame.
70    async fn send(&mut self, data: &[u8]) -> Result<(), TransportError>;
71
72    /// Send a typed frame.
73    async fn send_frame(&mut self, frame: &Frame) -> Result<(), TransportError> {
74        // Default: encode frame and send payload only (for backward compat)
75        self.send(&frame.encode()?).await
76    }
77
78    /// Receive raw bytes (payload of the next data frame).
79    async fn recv(&mut self) -> Result<Vec<u8>, TransportError>;
80
81    /// Receive a typed frame. Returns `None` on clean EOF.
82    async fn recv_frame(&mut self) -> Result<Option<Frame>, TransportError> {
83        // Default: wrap recv() in a data frame
84        match self.recv().await {
85            Ok(data) => Ok(Some(Frame::data(data))),
86            Err(TransportError::Closed) => Ok(None),
87            Err(e) => Err(e),
88        }
89    }
90
91    /// Close the connection.
92    async fn close(&mut self) -> Result<(), TransportError>;
93
94    /// Check if connected.
95    fn is_connected(&self) -> bool;
96}
97
98// ---------------------------------------------------------------------------
99// MockTransport
100// ---------------------------------------------------------------------------
101
102/// Handler function type for mock responses
103type ResponseHandler = Box<dyn Fn(&[u8]) -> Vec<u8> + Send + Sync>;
104
105/// In-memory transport for testing.
106pub struct MockTransport {
107    connected: bool,
108    recv_queue: Mutex<VecDeque<Vec<u8>>>,
109    sent: Mutex<Vec<Vec<u8>>>,
110    handler: Option<ResponseHandler>,
111}
112
113impl std::fmt::Debug for MockTransport {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("MockTransport")
116            .field("connected", &self.connected)
117            .finish_non_exhaustive()
118    }
119}
120
121impl MockTransport {
122    /// Create a new disconnected mock transport
123    pub fn new() -> Self {
124        Self {
125            connected: false,
126            recv_queue: Mutex::new(VecDeque::new()),
127            sent: Mutex::new(Vec::new()),
128            handler: None,
129        }
130    }
131
132    /// Create a mock transport with an auto-response handler.
133    pub fn with_handler<F>(handler: F) -> Self
134    where
135        F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
136    {
137        Self {
138            connected: false,
139            recv_queue: Mutex::new(VecDeque::new()),
140            sent: Mutex::new(Vec::new()),
141            handler: Some(Box::new(handler)),
142        }
143    }
144
145    /// Push a message into the recv queue
146    pub fn push_recv(&self, data: Vec<u8>) {
147        if let Ok(mut queue) = self.recv_queue.try_lock() {
148            queue.push_back(data);
149        }
150    }
151
152    /// Get all sent messages
153    pub async fn sent_messages(&self) -> Vec<Vec<u8>> {
154        self.sent.lock().await.clone()
155    }
156}
157
158impl Default for MockTransport {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164#[async_trait]
165impl Transport for MockTransport {
166    async fn connect(&mut self) -> Result<(), TransportError> {
167        self.connected = true;
168        Ok(())
169    }
170
171    async fn send(&mut self, data: &[u8]) -> Result<(), TransportError> {
172        if !self.connected {
173            return Err(TransportError::NotConnected);
174        }
175        self.sent.lock().await.push(data.to_vec());
176        if let Some(ref handler) = self.handler {
177            let response = handler(data);
178            self.recv_queue.lock().await.push_back(response);
179        }
180        Ok(())
181    }
182
183    async fn recv(&mut self) -> Result<Vec<u8>, TransportError> {
184        if !self.connected {
185            return Err(TransportError::NotConnected);
186        }
187        match self.recv_queue.lock().await.pop_front() {
188            Some(data) => Ok(data),
189            None => Err(TransportError::Closed),
190        }
191    }
192
193    async fn close(&mut self) -> Result<(), TransportError> {
194        self.connected = false;
195        Ok(())
196    }
197
198    fn is_connected(&self) -> bool {
199        self.connected
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[tokio::test]
208    async fn test_mock_transport_handler() {
209        let mut transport = MockTransport::with_handler(|data| {
210            let mut resp = b"echo: ".to_vec();
211            resp.extend_from_slice(data);
212            resp
213        });
214        transport.connect().await.unwrap();
215        transport.send(b"ping").await.unwrap();
216        let response = transport.recv().await.unwrap();
217        assert_eq!(response, b"echo: ping");
218    }
219
220    #[tokio::test]
221    async fn test_mock_transport_not_connected() {
222        let mut transport = MockTransport::new();
223        assert!(!transport.is_connected());
224        assert!(transport.send(b"data").await.is_err());
225        assert!(transport.recv().await.is_err());
226    }
227
228    #[tokio::test]
229    async fn test_mock_transport_push_recv() {
230        let mut transport = MockTransport::new();
231        transport.connect().await.unwrap();
232        transport.push_recv(b"queued".to_vec());
233        let data = transport.recv().await.unwrap();
234        assert_eq!(data, b"queued");
235    }
236
237    #[tokio::test]
238    async fn test_mock_transport_close() {
239        let mut transport = MockTransport::new();
240        transport.connect().await.unwrap();
241        assert!(transport.is_connected());
242        transport.close().await.unwrap();
243        assert!(!transport.is_connected());
244    }
245
246    #[tokio::test]
247    async fn test_mock_sent_messages() {
248        let mut transport = MockTransport::new();
249        transport.connect().await.unwrap();
250        transport.send(b"msg1").await.unwrap();
251        transport.send(b"msg2").await.unwrap();
252        let sent = transport.sent_messages().await;
253        assert_eq!(sent.len(), 2);
254        assert_eq!(sent[0], b"msg1");
255        assert_eq!(sent[1], b"msg2");
256    }
257}