agent_tui/daemon/transport/
unix_socket.rs

1use std::io::{BufRead, BufReader, Write};
2use std::os::unix::net::{UnixListener, UnixStream};
3use std::path::Path;
4use std::time::Duration;
5
6use crate::ipc::{RpcRequest, RpcResponse};
7
8use super::{TransportConnection, TransportError, TransportListener};
9
10const MAX_REQUEST_SIZE: usize = 1024 * 1024;
11
12struct SizeLimitedReader<R> {
13    inner: R,
14    max_size: usize,
15    read_count: usize,
16}
17
18impl<R> SizeLimitedReader<R> {
19    fn new(inner: R, max_size: usize) -> Self {
20        Self {
21            inner,
22            max_size,
23            read_count: 0,
24        }
25    }
26}
27
28impl<R: BufRead> SizeLimitedReader<R> {
29    fn read_line(&mut self) -> Result<Option<String>, TransportError> {
30        let mut line = String::new();
31        match self.inner.read_line(&mut line) {
32            Ok(0) => Ok(None),
33            Ok(n) => {
34                self.read_count += n;
35                if self.read_count > self.max_size {
36                    return Err(TransportError::SizeLimit {
37                        max_bytes: self.max_size,
38                    });
39                }
40                if line.ends_with('\n') {
41                    line.pop();
42                    if line.ends_with('\r') {
43                        line.pop();
44                    }
45                }
46                Ok(Some(line))
47            }
48            Err(e) => Err(TransportError::from(e)),
49        }
50    }
51}
52
53pub struct UnixSocketConnection {
54    reader: SizeLimitedReader<BufReader<UnixStream>>,
55    writer: UnixStream,
56}
57
58impl UnixSocketConnection {
59    pub fn new(stream: UnixStream) -> Result<Self, TransportError> {
60        let reader_stream = stream.try_clone()?;
61        Ok(Self {
62            reader: SizeLimitedReader::new(BufReader::new(reader_stream), MAX_REQUEST_SIZE),
63            writer: stream,
64        })
65    }
66}
67
68impl TransportConnection for UnixSocketConnection {
69    fn read_request(&mut self) -> Result<RpcRequest, TransportError> {
70        loop {
71            match self.reader.read_line()? {
72                None => return Err(TransportError::ConnectionClosed),
73                Some(line) if line.trim().is_empty() => continue,
74                Some(line) => {
75                    return serde_json::from_str(&line)
76                        .map_err(|e| TransportError::Parse(e.to_string()));
77                }
78            }
79        }
80    }
81
82    fn write_response(&mut self, response: &RpcResponse) -> Result<(), TransportError> {
83        let json = serde_json::to_string(response)
84            .map_err(|e| TransportError::Parse(format!("Failed to serialize response: {}", e)))?;
85        writeln!(self.writer, "{}", json)?;
86        Ok(())
87    }
88
89    fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<(), TransportError> {
90        self.writer.set_read_timeout(timeout)?;
91        Ok(())
92    }
93
94    fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<(), TransportError> {
95        self.writer.set_write_timeout(timeout)?;
96        Ok(())
97    }
98}
99
100pub struct UnixSocketListener {
101    inner: UnixListener,
102}
103
104impl UnixSocketListener {
105    pub fn bind(path: &Path) -> Result<Self, TransportError> {
106        let listener = UnixListener::bind(path)?;
107        Ok(Self { inner: listener })
108    }
109
110    pub fn into_inner(self) -> UnixListener {
111        self.inner
112    }
113}
114
115impl TransportListener for UnixSocketListener {
116    type Connection = UnixSocketConnection;
117
118    fn accept(&self) -> Result<Self::Connection, TransportError> {
119        let (stream, _addr) = self.inner.accept()?;
120        UnixSocketConnection::new(stream)
121    }
122
123    fn set_nonblocking(&self, nonblocking: bool) -> Result<(), TransportError> {
124        self.inner.set_nonblocking(nonblocking)?;
125        Ok(())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use std::io::Cursor;
133
134    #[test]
135    fn test_size_limited_reader_within_limit() {
136        let data = "hello\nworld\n";
137        let cursor = Cursor::new(data);
138        let buf_reader = BufReader::new(cursor);
139        let mut reader = SizeLimitedReader::new(buf_reader, 100);
140
141        assert_eq!(reader.read_line().unwrap(), Some("hello".to_string()));
142        assert_eq!(reader.read_line().unwrap(), Some("world".to_string()));
143        assert_eq!(reader.read_line().unwrap(), None);
144    }
145
146    #[test]
147    fn test_size_limited_reader_exceeds_limit() {
148        let data = "this is a long line that exceeds the limit\n";
149        let cursor = Cursor::new(data);
150        let buf_reader = BufReader::new(cursor);
151        let mut reader = SizeLimitedReader::new(buf_reader, 10);
152
153        let result = reader.read_line();
154        assert!(matches!(result, Err(TransportError::SizeLimit { .. })));
155    }
156
157    #[test]
158    fn test_size_limited_reader_strips_newlines() {
159        let data = "line with crlf\r\n";
160        let cursor = Cursor::new(data);
161        let buf_reader = BufReader::new(cursor);
162        let mut reader = SizeLimitedReader::new(buf_reader, 100);
163
164        assert_eq!(
165            reader.read_line().unwrap(),
166            Some("line with crlf".to_string())
167        );
168    }
169
170    #[test]
171    fn test_transport_error_display() {
172        let io_err = TransportError::Io(std::io::Error::new(
173            std::io::ErrorKind::NotFound,
174            "test error",
175        ));
176        assert!(io_err.to_string().contains("I/O error"));
177
178        let parse_err = TransportError::Parse("invalid json".to_string());
179        assert!(parse_err.to_string().contains("Parse error"));
180
181        let size_err = TransportError::SizeLimit { max_bytes: 1024 };
182        assert!(size_err.to_string().contains("1024"));
183
184        let timeout_err = TransportError::Timeout;
185        assert_eq!(timeout_err.to_string(), "Connection timeout");
186
187        let closed_err = TransportError::ConnectionClosed;
188        assert_eq!(closed_err.to_string(), "Connection closed");
189    }
190
191    #[test]
192    fn test_transport_error_from_io() {
193        let timeout = std::io::Error::new(std::io::ErrorKind::TimedOut, "timed out");
194        assert!(matches!(
195            TransportError::from(timeout),
196            TransportError::Timeout
197        ));
198
199        let would_block = std::io::Error::new(std::io::ErrorKind::WouldBlock, "would block");
200        assert!(matches!(
201            TransportError::from(would_block),
202            TransportError::Timeout
203        ));
204
205        let eof = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof");
206        assert!(matches!(
207            TransportError::from(eof),
208            TransportError::ConnectionClosed
209        ));
210
211        let broken_pipe = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
212        assert!(matches!(
213            TransportError::from(broken_pipe),
214            TransportError::ConnectionClosed
215        ));
216
217        let other = std::io::Error::other("other");
218        assert!(matches!(TransportError::from(other), TransportError::Io(_)));
219    }
220
221    #[test]
222    fn test_unix_socket_roundtrip() {
223        use std::os::unix::net::UnixStream;
224        use std::thread;
225
226        let (client_stream, server_stream) = UnixStream::pair().unwrap();
227
228        let server_handle = thread::spawn(move || {
229            let mut conn = UnixSocketConnection::new(server_stream).unwrap();
230            let request = conn.read_request().unwrap();
231            assert_eq!(request.method, "test_method");
232
233            let response = RpcResponse::success(request.id, serde_json::json!({"ok": true}));
234            conn.write_response(&response).unwrap();
235        });
236
237        let mut client_stream_writer = client_stream.try_clone().unwrap();
238        let mut client_conn = UnixSocketConnection::new(client_stream).unwrap();
239
240        let request_json = r#"{"jsonrpc":"2.0","id":1,"method":"test_method"}"#;
241        writeln!(client_stream_writer, "{}", request_json).unwrap();
242
243        let response = client_conn.read_request();
244        assert!(response.is_ok() || matches!(response, Err(TransportError::Parse(_))));
245
246        server_handle.join().unwrap();
247    }
248
249    #[test]
250    fn test_unix_socket_connection_closed() {
251        use std::os::unix::net::UnixStream;
252
253        let (client_stream, server_stream) = UnixStream::pair().unwrap();
254        drop(server_stream);
255
256        let mut conn = UnixSocketConnection::new(client_stream).unwrap();
257        let result = conn.read_request();
258        assert!(matches!(result, Err(TransportError::ConnectionClosed)));
259    }
260
261    #[test]
262    fn test_size_limited_reader_cumulative_limit() {
263        let data = "aaa\nbbb\nccc\n";
264        let cursor = Cursor::new(data);
265        let buf_reader = BufReader::new(cursor);
266        let mut reader = SizeLimitedReader::new(buf_reader, 8);
267
268        assert_eq!(reader.read_line().unwrap(), Some("aaa".to_string()));
269        assert_eq!(reader.read_line().unwrap(), Some("bbb".to_string()));
270        let result = reader.read_line();
271        assert!(matches!(result, Err(TransportError::SizeLimit { .. })));
272    }
273}