Skip to main content

matrixcode_core/matrixrpc/transport/
tcp.rs

1//! TCP Transport for MatrixRPC
2//!
3//! Provides TCP-based transport for JSON-RPC communication with external services.
4//! Uses binary frame format: [4 bytes length][JSON payload]
5//!
6//! # Ports
7//!
8//! - Registry Port (9527): Accepts external service registration
9//! - Callback Port (9528): Accepts callback requests from external services
10
11use std::io;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}};
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21
22use crate::matrixrpc::protocol::JsonRpcMessage;
23use super::{Transport, TransportConfig};
24
25/// Binary frame format: 4-byte length prefix + JSON payload
26/// More efficient than Content-Length format for TCP communication
27const FRAME_HEADER_SIZE: usize = 4;
28
29/// TCP Transport implementation
30///
31/// Supports both client and server modes:
32/// - Client mode: Connect to external service
33/// - Server mode: Accept connections from external services
34pub struct TcpTransport {
35    /// Read half of the TCP stream
36    reader: Arc<Mutex<Option<OwnedReadHalf>>>,
37    /// Write half of the TCP stream
38    writer: Arc<Mutex<Option<OwnedWriteHalf>>>,
39    /// Transport configuration
40    config: TransportConfig,
41    /// Remote address (for logging/debugging)
42    remote_addr: Option<SocketAddr>,
43    /// Connection state
44    is_closed: bool,
45}
46
47impl TcpTransport {
48    /// Create a new TCP transport by connecting to an address
49    pub async fn connect(addr: &str) -> io::Result<Self> {
50        Self::connect_with_config(addr, TransportConfig::default()).await
51    }
52
53    /// Create a new TCP transport with custom configuration
54    pub async fn connect_with_config(addr: &str, config: TransportConfig) -> io::Result<Self> {
55        let stream = TokioTcpStream::connect(addr).await?;
56        let remote_addr = stream.peer_addr().ok();
57        let (reader, writer) = stream.into_split();
58
59        Ok(Self {
60            reader: Arc::new(Mutex::new(Some(reader))),
61            writer: Arc::new(Mutex::new(Some(writer))),
62            config,
63            remote_addr,
64            is_closed: false,
65        })
66    }
67
68    /// Create a transport from an existing TcpStream (server mode)
69    pub fn from_stream(stream: TokioTcpStream, config: TransportConfig) -> Self {
70        let remote_addr = stream.peer_addr().ok();
71        let (reader, writer) = stream.into_split();
72
73        Self {
74            reader: Arc::new(Mutex::new(Some(reader))),
75            writer: Arc::new(Mutex::new(Some(writer))),
76            config,
77            remote_addr,
78            is_closed: false,
79        }
80    }
81
82    /// Get the remote address
83    pub fn remote_addr(&self) -> Option<SocketAddr> {
84        self.remote_addr
85    }
86
87    /// Encode message with binary frame format
88    fn encode_frame(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
89        let json = message.to_json().map_err(|e| {
90            io::Error::new(
91                io::ErrorKind::InvalidData,
92                format!("JSON encode error: {}", e),
93            )
94        })?;
95
96        let json_bytes = json.into_bytes();
97        let length = json_bytes.len() as u32;
98
99        // Create frame: 4-byte length + JSON
100        let mut frame = Vec::with_capacity(FRAME_HEADER_SIZE + json_bytes.len());
101        frame.extend_from_slice(&length.to_be_bytes());
102        frame.extend(json_bytes);
103
104        Ok(frame)
105    }
106
107    /// Decode message from binary frame
108    async fn decode_frame(
109        reader: &mut OwnedReadHalf,
110        max_size: usize,
111    ) -> io::Result<Option<JsonRpcMessage>> {
112        // Read 4-byte length header
113        let mut header_buf = [0u8; FRAME_HEADER_SIZE];
114        match reader.read_exact(&mut header_buf).await {
115            Ok(_) => {}
116            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
117            Err(e) => return Err(e),
118        }
119
120        let length = u32::from_be_bytes(header_buf) as usize;
121
122        // Validate size
123        if length > max_size {
124            return Err(io::Error::new(
125                io::ErrorKind::InvalidData,
126                format!("Frame size {} exceeds maximum {}", length, max_size),
127            ));
128        }
129
130        if length == 0 {
131            return Ok(None);
132        }
133
134        // Read payload
135        let mut payload_buf = vec![0u8; length];
136        reader.read_exact(&mut payload_buf).await?;
137
138        // Parse JSON
139        let json_str = String::from_utf8(payload_buf).map_err(|e| {
140            io::Error::new(
141                io::ErrorKind::InvalidData,
142                format!("UTF-8 decode error: {}", e),
143            )
144        })?;
145
146        let message = JsonRpcMessage::from_json(&json_str).map_err(|e| {
147            io::Error::new(
148                io::ErrorKind::InvalidData,
149                format!("JSON parse error: {}", e),
150            )
151        })?;
152
153        Ok(Some(message))
154    }
155}
156
157#[async_trait]
158impl Transport for TcpTransport {
159    async fn send(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
160        if self.is_closed {
161            return Err(io::Error::new(
162                io::ErrorKind::BrokenPipe,
163                "Transport is closed",
164            ));
165        }
166
167        let writer_guard = self.writer.lock().await;
168        let _writer = writer_guard.as_ref().ok_or_else(|| {
169            io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
170        })?;
171
172        let frame = Self::encode_frame(message)?;
173
174        // Write with timeout (need to clone writer for timeout usage)
175        // Since OwnedWriteHalf can't be cloned, we use a different approach
176        let result = timeout(
177            Duration::from_millis(self.config.write_timeout_ms),
178            async {
179                // We need mutable access, so we need to lock mutably
180                drop(writer_guard);
181                let mut writer_guard = self.writer.lock().await;
182                let writer = writer_guard.as_mut().ok_or_else(|| {
183                    io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
184                })?;
185                writer.write_all(&frame).await
186            }
187        )
188        .await;
189
190        match result {
191            Ok(Ok(_)) => Ok(()),
192            Ok(Err(e)) => Err(e),
193            Err(_) => Err(io::Error::new(
194                io::ErrorKind::TimedOut,
195                "Write timeout",
196            )),
197        }
198    }
199
200    async fn receive(&mut self) -> io::Result<Option<JsonRpcMessage>> {
201        if self.is_closed {
202            return Ok(None);
203        }
204
205        // Read with timeout
206        let read_result = timeout(
207            Duration::from_millis(self.config.read_timeout_ms),
208            async {
209                let mut reader_guard = self.reader.lock().await;
210                let reader = reader_guard.as_mut().ok_or_else(|| {
211                    io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
212                })?;
213                Self::decode_frame(reader, self.config.max_message_size).await
214            }
215        )
216        .await;
217
218        match read_result {
219            Ok(Ok(message)) => Ok(message),
220            Ok(Err(e)) => Err(e),
221            Err(_) => Err(io::Error::new(
222                io::ErrorKind::TimedOut,
223                "Read timeout",
224            )),
225        }
226    }
227
228    async fn close(&mut self) -> io::Result<()> {
229        if self.is_closed {
230            return Ok(());
231        }
232
233        self.is_closed = true;
234
235        // Drop the stream by taking it out
236        let mut reader_guard = self.reader.lock().await;
237        let mut writer_guard = self.writer.lock().await;
238        reader_guard.take();
239        writer_guard.take();
240
241        Ok(())
242    }
243
244    fn is_closed(&self) -> bool {
245        self.is_closed
246    }
247}
248
249/// TCP Listener for accepting incoming connections
250///
251/// Used by Extension Gateway to accept external service registrations
252/// and callback requests.
253pub struct TcpListener {
254    /// Tokio TCP listener
255    listener: TokioTcpListener,
256    /// Local address
257    local_addr: SocketAddr,
258    /// Transport config for accepted connections
259    config: TransportConfig,
260}
261
262impl TcpListener {
263    /// Create a new TCP listener on the specified port
264    pub async fn bind(port: u16) -> io::Result<Self> {
265        Self::bind_with_config(port, TransportConfig::default()).await
266    }
267
268    /// Create a new TCP listener with custom configuration
269    pub async fn bind_with_config(port: u16, config: TransportConfig) -> io::Result<Self> {
270        let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
271        let listener = TokioTcpListener::bind(addr).await?;
272        let local_addr = listener.local_addr()?;
273
274        Ok(Self {
275            listener,
276            local_addr,
277            config,
278        })
279    }
280
281    /// Get the local address
282    pub fn local_addr(&self) -> SocketAddr {
283        self.local_addr
284    }
285
286    /// Accept a new connection
287    ///
288    /// Returns a TcpTransport for the accepted connection.
289    pub async fn accept(&self) -> io::Result<TcpTransport> {
290        let (stream, _addr) = self.listener.accept().await?;
291        Ok(TcpTransport::from_stream(stream, self.config.clone()))
292    }
293
294    /// Get port number
295    pub fn port(&self) -> u16 {
296        self.local_addr.port()
297    }
298}
299
300/// Registry Port (9527) - Accepts external service registration
301pub const REGISTRY_PORT: u16 = 9527;
302
303/// Callback Port (9528) - Accepts callback requests
304pub const CALLBACK_PORT: u16 = 9528;
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_encode_frame_simple() {
312        use crate::matrixrpc::protocol::{JsonRpcRequest, JsonRpcId};
313
314        let request = JsonRpcRequest::with_id("test.method", JsonRpcId::String("test-1".to_string()))
315            .params(serde_json::json!({"param": "value"}));
316        let message = JsonRpcMessage::Request(request);
317
318        let frame = TcpTransport::encode_frame(&message).unwrap();
319
320        // Check frame structure
321        assert!(frame.len() > FRAME_HEADER_SIZE);
322
323        // Extract length from header
324        let length = u32::from_be_bytes([
325            frame[0], frame[1], frame[2], frame[3],
326        ]);
327        assert!(length > 0);
328        assert_eq!(frame.len(), FRAME_HEADER_SIZE + length as usize);
329    }
330
331    #[test]
332    fn test_tcp_config() {
333        let config = TransportConfig::new()
334            .max_message_size(1024)
335            .read_timeout(5000);
336
337        assert_eq!(config.max_message_size, 1024);
338        assert_eq!(config.read_timeout_ms, 5000);
339    }
340
341    #[test]
342    fn test_frame_header_size() {
343        assert_eq!(FRAME_HEADER_SIZE, 4);
344    }
345
346    #[test]
347    fn test_port_constants() {
348        assert_eq!(REGISTRY_PORT, 9527);
349        assert_eq!(CALLBACK_PORT, 9528);
350    }
351}