matrixcode_core/matrixrpc/transport/
tcp.rs1use std::io;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}};
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21
22use crate::matrixrpc::protocol::JsonRpcMessage;
23use super::{Transport, TransportConfig};
24
25const FRAME_HEADER_SIZE: usize = 4;
28
29pub struct TcpTransport {
35 reader: Arc<Mutex<Option<OwnedReadHalf>>>,
37 writer: Arc<Mutex<Option<OwnedWriteHalf>>>,
39 config: TransportConfig,
41 remote_addr: Option<SocketAddr>,
43 is_closed: bool,
45}
46
47impl TcpTransport {
48 pub async fn connect(addr: &str) -> io::Result<Self> {
50 Self::connect_with_config(addr, TransportConfig::default()).await
51 }
52
53 pub async fn connect_with_config(addr: &str, config: TransportConfig) -> io::Result<Self> {
55 let stream = TokioTcpStream::connect(addr).await?;
56 let remote_addr = stream.peer_addr().ok();
57 let (reader, writer) = stream.into_split();
58
59 Ok(Self {
60 reader: Arc::new(Mutex::new(Some(reader))),
61 writer: Arc::new(Mutex::new(Some(writer))),
62 config,
63 remote_addr,
64 is_closed: false,
65 })
66 }
67
68 pub fn from_stream(stream: TokioTcpStream, config: TransportConfig) -> Self {
70 let remote_addr = stream.peer_addr().ok();
71 let (reader, writer) = stream.into_split();
72
73 Self {
74 reader: Arc::new(Mutex::new(Some(reader))),
75 writer: Arc::new(Mutex::new(Some(writer))),
76 config,
77 remote_addr,
78 is_closed: false,
79 }
80 }
81
82 pub fn remote_addr(&self) -> Option<SocketAddr> {
84 self.remote_addr
85 }
86
87 fn encode_frame(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
89 let json = message.to_json().map_err(|e| {
90 io::Error::new(
91 io::ErrorKind::InvalidData,
92 format!("JSON encode error: {}", e),
93 )
94 })?;
95
96 let json_bytes = json.into_bytes();
97 let length = json_bytes.len() as u32;
98
99 let mut frame = Vec::with_capacity(FRAME_HEADER_SIZE + json_bytes.len());
101 frame.extend_from_slice(&length.to_be_bytes());
102 frame.extend(json_bytes);
103
104 Ok(frame)
105 }
106
107 async fn decode_frame(
109 reader: &mut OwnedReadHalf,
110 max_size: usize,
111 ) -> io::Result<Option<JsonRpcMessage>> {
112 let mut header_buf = [0u8; FRAME_HEADER_SIZE];
114 match reader.read_exact(&mut header_buf).await {
115 Ok(_) => {}
116 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
117 Err(e) => return Err(e),
118 }
119
120 let length = u32::from_be_bytes(header_buf) as usize;
121
122 if length > max_size {
124 return Err(io::Error::new(
125 io::ErrorKind::InvalidData,
126 format!("Frame size {} exceeds maximum {}", length, max_size),
127 ));
128 }
129
130 if length == 0 {
131 return Ok(None);
132 }
133
134 let mut payload_buf = vec![0u8; length];
136 reader.read_exact(&mut payload_buf).await?;
137
138 let json_str = String::from_utf8(payload_buf).map_err(|e| {
140 io::Error::new(
141 io::ErrorKind::InvalidData,
142 format!("UTF-8 decode error: {}", e),
143 )
144 })?;
145
146 let message = JsonRpcMessage::from_json(&json_str).map_err(|e| {
147 io::Error::new(
148 io::ErrorKind::InvalidData,
149 format!("JSON parse error: {}", e),
150 )
151 })?;
152
153 Ok(Some(message))
154 }
155}
156
157#[async_trait]
158impl Transport for TcpTransport {
159 async fn send(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
160 if self.is_closed {
161 return Err(io::Error::new(
162 io::ErrorKind::BrokenPipe,
163 "Transport is closed",
164 ));
165 }
166
167 let writer_guard = self.writer.lock().await;
168 let _writer = writer_guard.as_ref().ok_or_else(|| {
169 io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
170 })?;
171
172 let frame = Self::encode_frame(message)?;
173
174 let result = timeout(
177 Duration::from_millis(self.config.write_timeout_ms),
178 async {
179 drop(writer_guard);
181 let mut writer_guard = self.writer.lock().await;
182 let writer = writer_guard.as_mut().ok_or_else(|| {
183 io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
184 })?;
185 writer.write_all(&frame).await
186 }
187 )
188 .await;
189
190 match result {
191 Ok(Ok(_)) => Ok(()),
192 Ok(Err(e)) => Err(e),
193 Err(_) => Err(io::Error::new(
194 io::ErrorKind::TimedOut,
195 "Write timeout",
196 )),
197 }
198 }
199
200 async fn receive(&mut self) -> io::Result<Option<JsonRpcMessage>> {
201 if self.is_closed {
202 return Ok(None);
203 }
204
205 let read_result = timeout(
207 Duration::from_millis(self.config.read_timeout_ms),
208 async {
209 let mut reader_guard = self.reader.lock().await;
210 let reader = reader_guard.as_mut().ok_or_else(|| {
211 io::Error::new(io::ErrorKind::BrokenPipe, "No stream available")
212 })?;
213 Self::decode_frame(reader, self.config.max_message_size).await
214 }
215 )
216 .await;
217
218 match read_result {
219 Ok(Ok(message)) => Ok(message),
220 Ok(Err(e)) => Err(e),
221 Err(_) => Err(io::Error::new(
222 io::ErrorKind::TimedOut,
223 "Read timeout",
224 )),
225 }
226 }
227
228 async fn close(&mut self) -> io::Result<()> {
229 if self.is_closed {
230 return Ok(());
231 }
232
233 self.is_closed = true;
234
235 let mut reader_guard = self.reader.lock().await;
237 let mut writer_guard = self.writer.lock().await;
238 reader_guard.take();
239 writer_guard.take();
240
241 Ok(())
242 }
243
244 fn is_closed(&self) -> bool {
245 self.is_closed
246 }
247}
248
249pub struct TcpListener {
254 listener: TokioTcpListener,
256 local_addr: SocketAddr,
258 config: TransportConfig,
260}
261
262impl TcpListener {
263 pub async fn bind(port: u16) -> io::Result<Self> {
265 Self::bind_with_config(port, TransportConfig::default()).await
266 }
267
268 pub async fn bind_with_config(port: u16, config: TransportConfig) -> io::Result<Self> {
270 let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
271 let listener = TokioTcpListener::bind(addr).await?;
272 let local_addr = listener.local_addr()?;
273
274 Ok(Self {
275 listener,
276 local_addr,
277 config,
278 })
279 }
280
281 pub fn local_addr(&self) -> SocketAddr {
283 self.local_addr
284 }
285
286 pub async fn accept(&self) -> io::Result<TcpTransport> {
290 let (stream, _addr) = self.listener.accept().await?;
291 Ok(TcpTransport::from_stream(stream, self.config.clone()))
292 }
293
294 pub fn port(&self) -> u16 {
296 self.local_addr.port()
297 }
298}
299
300pub const REGISTRY_PORT: u16 = 9527;
302
303pub const CALLBACK_PORT: u16 = 9528;
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_encode_frame_simple() {
312 use crate::matrixrpc::protocol::{JsonRpcRequest, JsonRpcId};
313
314 let request = JsonRpcRequest::with_id("test.method", JsonRpcId::String("test-1".to_string()))
315 .params(serde_json::json!({"param": "value"}));
316 let message = JsonRpcMessage::Request(request);
317
318 let frame = TcpTransport::encode_frame(&message).unwrap();
319
320 assert!(frame.len() > FRAME_HEADER_SIZE);
322
323 let length = u32::from_be_bytes([
325 frame[0], frame[1], frame[2], frame[3],
326 ]);
327 assert!(length > 0);
328 assert_eq!(frame.len(), FRAME_HEADER_SIZE + length as usize);
329 }
330
331 #[test]
332 fn test_tcp_config() {
333 let config = TransportConfig::new()
334 .max_message_size(1024)
335 .read_timeout(5000);
336
337 assert_eq!(config.max_message_size, 1024);
338 assert_eq!(config.read_timeout_ms, 5000);
339 }
340
341 #[test]
342 fn test_frame_header_size() {
343 assert_eq!(FRAME_HEADER_SIZE, 4);
344 }
345
346 #[test]
347 fn test_port_constants() {
348 assert_eq!(REGISTRY_PORT, 9527);
349 assert_eq!(CALLBACK_PORT, 9528);
350 }
351}