tx2_link/
transport.rs

1use crate::error::{LinkError, Result};
2use crate::protocol::Message;
3use crate::serialization::{BinarySerializer, BinaryFormat};
4use bytes::Bytes;
5
6#[cfg(feature = "async")]
7use async_trait::async_trait;
8
9pub trait Transport {
10    fn send(&mut self, message: &Message) -> Result<()>;
11    fn receive(&mut self) -> Result<Option<Message>>;
12    fn close(&mut self) -> Result<()>;
13    fn is_connected(&self) -> bool;
14}
15
16#[cfg(feature = "async")]
17#[async_trait]
18pub trait AsyncTransport: Send + Sync {
19    async fn send(&mut self, message: &Message) -> Result<()>;
20    async fn receive(&mut self) -> Result<Option<Message>>;
21    async fn close(&mut self) -> Result<()>;
22    fn is_connected(&self) -> bool;
23}
24
25pub struct MemoryTransport {
26    serializer: BinarySerializer,
27    send_buffer: Vec<Bytes>,
28    receive_buffer: Vec<Bytes>,
29    connected: bool,
30}
31
32impl MemoryTransport {
33    pub fn new(format: BinaryFormat) -> Self {
34        Self {
35            serializer: BinarySerializer::new(format),
36            send_buffer: Vec::new(),
37            receive_buffer: Vec::new(),
38            connected: true,
39        }
40    }
41
42    pub fn create_pair(format: BinaryFormat) -> (Self, Self) {
43        let t1 = Self::new(format);
44        let t2 = Self::new(format);
45        (t1, t2)
46    }
47
48    pub fn connect_to(&mut self, other: &mut Self) {
49        std::mem::swap(&mut self.send_buffer, &mut other.receive_buffer);
50        std::mem::swap(&mut self.receive_buffer, &mut other.send_buffer);
51    }
52
53    pub fn get_send_buffer(&self) -> &[Bytes] {
54        &self.send_buffer
55    }
56
57    pub fn get_receive_buffer(&self) -> &[Bytes] {
58        &self.receive_buffer
59    }
60}
61
62impl Transport for MemoryTransport {
63    fn send(&mut self, message: &Message) -> Result<()> {
64        if !self.connected {
65            return Err(LinkError::ConnectionClosed);
66        }
67
68        let data = self.serializer.serialize_message(message)?;
69        self.send_buffer.push(data);
70        Ok(())
71    }
72
73    fn receive(&mut self) -> Result<Option<Message>> {
74        if !self.connected {
75            return Err(LinkError::ConnectionClosed);
76        }
77
78        if self.receive_buffer.is_empty() {
79            return Ok(None);
80        }
81
82        let data = self.receive_buffer.remove(0);
83        let message = self.serializer.deserialize_message(&data)?;
84        Ok(Some(message))
85    }
86
87    fn close(&mut self) -> Result<()> {
88        self.connected = false;
89        self.send_buffer.clear();
90        self.receive_buffer.clear();
91        Ok(())
92    }
93
94    fn is_connected(&self) -> bool {
95        self.connected
96    }
97}
98
99pub struct StdioTransport {
100    serializer: BinarySerializer,
101    connected: bool,
102}
103
104impl StdioTransport {
105    pub fn new(format: BinaryFormat) -> Self {
106        Self {
107            serializer: BinarySerializer::new(format),
108            connected: true,
109        }
110    }
111}
112
113impl Transport for StdioTransport {
114    fn send(&mut self, message: &Message) -> Result<()> {
115        if !self.connected {
116            return Err(LinkError::ConnectionClosed);
117        }
118
119        use std::io::Write;
120
121        let data = self.serializer.serialize_message(message)?;
122        let len = data.len() as u32;
123
124        let mut stdout = std::io::stdout();
125        stdout.write_all(&len.to_le_bytes())?;
126        stdout.write_all(&data)?;
127        stdout.flush()?;
128
129        Ok(())
130    }
131
132    fn receive(&mut self) -> Result<Option<Message>> {
133        if !self.connected {
134            return Err(LinkError::ConnectionClosed);
135        }
136
137        use std::io::Read;
138
139        let mut stdin = std::io::stdin();
140        let mut len_bytes = [0u8; 4];
141
142        match stdin.read_exact(&mut len_bytes) {
143            Ok(_) => {},
144            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
145                return Ok(None);
146            }
147            Err(e) => return Err(e.into()),
148        }
149
150        let len = u32::from_le_bytes(len_bytes) as usize;
151        let mut buffer = vec![0u8; len];
152
153        stdin.read_exact(&mut buffer)?;
154
155        let message = self.serializer.deserialize_message(&buffer)?;
156        Ok(Some(message))
157    }
158
159    fn close(&mut self) -> Result<()> {
160        self.connected = false;
161        Ok(())
162    }
163
164    fn is_connected(&self) -> bool {
165        self.connected
166    }
167}
168
169#[cfg(feature = "websocket")]
170pub mod websocket {
171    use super::*;
172    use tokio_tungstenite::{
173        WebSocketStream,
174        tungstenite::Message as WsMessage,
175    };
176    use tokio::net::TcpStream;
177    use futures_util::{SinkExt, StreamExt};
178
179    pub struct WebSocketTransport {
180        serializer: BinarySerializer,
181        stream: Option<WebSocketStream<TcpStream>>,
182    }
183
184    impl WebSocketTransport {
185        pub fn new(format: BinaryFormat, stream: WebSocketStream<TcpStream>) -> Self {
186            Self {
187                serializer: BinarySerializer::new(format),
188                stream: Some(stream),
189            }
190        }
191    }
192
193    #[async_trait]
194    impl AsyncTransport for WebSocketTransport {
195        async fn send(&mut self, message: &Message) -> Result<()> {
196            let stream = self.stream.as_mut()
197                .ok_or(LinkError::ConnectionClosed)?;
198
199            let data = self.serializer.serialize_message(message)?;
200            stream.send(WsMessage::Binary(data.to_vec())).await
201                .map_err(|e| LinkError::Transport(e.to_string()))?;
202
203            Ok(())
204        }
205
206        async fn receive(&mut self) -> Result<Option<Message>> {
207            let stream = self.stream.as_mut()
208                .ok_or(LinkError::ConnectionClosed)?;
209
210            match stream.next().await {
211                Some(Ok(WsMessage::Binary(data))) => {
212                    let message = self.serializer.deserialize_message(&data)?;
213                    Ok(Some(message))
214                }
215                Some(Ok(WsMessage::Close(_))) => {
216                    self.stream = None;
217                    Err(LinkError::ConnectionClosed)
218                }
219                Some(Ok(_)) => Ok(None),
220                Some(Err(e)) => Err(LinkError::Transport(e.to_string())),
221                None => {
222                    self.stream = None;
223                    Err(LinkError::ConnectionClosed)
224                }
225            }
226        }
227
228        async fn close(&mut self) -> Result<()> {
229            if let Some(mut stream) = self.stream.take() {
230                stream.close(None).await
231                    .map_err(|e| LinkError::Transport(e.to_string()))?;
232            }
233            Ok(())
234        }
235
236        fn is_connected(&self) -> bool {
237            self.stream.is_some()
238        }
239    }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum TransportError {
244    NotConnected,
245    SendFailed,
246    ReceiveFailed,
247    CloseFailed,
248}
249
250impl std::fmt::Display for TransportError {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        match self {
253            TransportError::NotConnected => write!(f, "Not connected"),
254            TransportError::SendFailed => write!(f, "Send failed"),
255            TransportError::ReceiveFailed => write!(f, "Receive failed"),
256            TransportError::CloseFailed => write!(f, "Close failed"),
257        }
258    }
259}
260
261impl std::error::Error for TransportError {}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::protocol::MessageType;
267
268    #[test]
269    fn test_memory_transport() {
270        let mut transport1 = MemoryTransport::new(BinaryFormat::MessagePack);
271        let mut transport2 = MemoryTransport::new(BinaryFormat::MessagePack);
272
273        let message = Message::ping(1);
274        transport1.send(&message).unwrap();
275
276        transport1.connect_to(&mut transport2);
277
278        let received = transport2.receive().unwrap().unwrap();
279        assert_eq!(message.header.msg_type, received.header.msg_type);
280    }
281
282    #[test]
283    fn test_transport_close() {
284        let mut transport = MemoryTransport::new(BinaryFormat::Json);
285
286        assert!(transport.is_connected());
287
288        transport.close().unwrap();
289
290        assert!(!transport.is_connected());
291
292        let message = Message::ping(1);
293        assert!(transport.send(&message).is_err());
294    }
295}