agent_tui/daemon/transport/
unix_socket.rs1use std::io::{BufRead, BufReader, Write};
2use std::os::unix::net::{UnixListener, UnixStream};
3use std::path::Path;
4use std::time::Duration;
5
6use crate::ipc::{RpcRequest, RpcResponse};
7
8use super::{TransportConnection, TransportError, TransportListener};
9
10const MAX_REQUEST_SIZE: usize = 1024 * 1024;
11
12struct SizeLimitedReader<R> {
13 inner: R,
14 max_size: usize,
15 read_count: usize,
16}
17
18impl<R> SizeLimitedReader<R> {
19 fn new(inner: R, max_size: usize) -> Self {
20 Self {
21 inner,
22 max_size,
23 read_count: 0,
24 }
25 }
26}
27
28impl<R: BufRead> SizeLimitedReader<R> {
29 fn read_line(&mut self) -> Result<Option<String>, TransportError> {
30 let mut line = String::new();
31 match self.inner.read_line(&mut line) {
32 Ok(0) => Ok(None),
33 Ok(n) => {
34 self.read_count += n;
35 if self.read_count > self.max_size {
36 return Err(TransportError::SizeLimit {
37 max_bytes: self.max_size,
38 });
39 }
40 if line.ends_with('\n') {
41 line.pop();
42 if line.ends_with('\r') {
43 line.pop();
44 }
45 }
46 Ok(Some(line))
47 }
48 Err(e) => Err(TransportError::from(e)),
49 }
50 }
51}
52
53pub struct UnixSocketConnection {
54 reader: SizeLimitedReader<BufReader<UnixStream>>,
55 writer: UnixStream,
56}
57
58impl UnixSocketConnection {
59 pub fn new(stream: UnixStream) -> Result<Self, TransportError> {
60 let reader_stream = stream.try_clone()?;
61 Ok(Self {
62 reader: SizeLimitedReader::new(BufReader::new(reader_stream), MAX_REQUEST_SIZE),
63 writer: stream,
64 })
65 }
66}
67
68impl TransportConnection for UnixSocketConnection {
69 fn read_request(&mut self) -> Result<RpcRequest, TransportError> {
70 loop {
71 match self.reader.read_line()? {
72 None => return Err(TransportError::ConnectionClosed),
73 Some(line) if line.trim().is_empty() => continue,
74 Some(line) => {
75 return serde_json::from_str(&line)
76 .map_err(|e| TransportError::Parse(e.to_string()));
77 }
78 }
79 }
80 }
81
82 fn write_response(&mut self, response: &RpcResponse) -> Result<(), TransportError> {
83 let json = serde_json::to_string(response)
84 .map_err(|e| TransportError::Parse(format!("Failed to serialize response: {}", e)))?;
85 writeln!(self.writer, "{}", json)?;
86 Ok(())
87 }
88
89 fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<(), TransportError> {
90 self.writer.set_read_timeout(timeout)?;
91 Ok(())
92 }
93
94 fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<(), TransportError> {
95 self.writer.set_write_timeout(timeout)?;
96 Ok(())
97 }
98}
99
100pub struct UnixSocketListener {
101 inner: UnixListener,
102}
103
104impl UnixSocketListener {
105 pub fn bind(path: &Path) -> Result<Self, TransportError> {
106 let listener = UnixListener::bind(path)?;
107 Ok(Self { inner: listener })
108 }
109
110 pub fn into_inner(self) -> UnixListener {
111 self.inner
112 }
113}
114
115impl TransportListener for UnixSocketListener {
116 type Connection = UnixSocketConnection;
117
118 fn accept(&self) -> Result<Self::Connection, TransportError> {
119 let (stream, _addr) = self.inner.accept()?;
120 UnixSocketConnection::new(stream)
121 }
122
123 fn set_nonblocking(&self, nonblocking: bool) -> Result<(), TransportError> {
124 self.inner.set_nonblocking(nonblocking)?;
125 Ok(())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use std::io::Cursor;
133
134 #[test]
135 fn test_size_limited_reader_within_limit() {
136 let data = "hello\nworld\n";
137 let cursor = Cursor::new(data);
138 let buf_reader = BufReader::new(cursor);
139 let mut reader = SizeLimitedReader::new(buf_reader, 100);
140
141 assert_eq!(reader.read_line().unwrap(), Some("hello".to_string()));
142 assert_eq!(reader.read_line().unwrap(), Some("world".to_string()));
143 assert_eq!(reader.read_line().unwrap(), None);
144 }
145
146 #[test]
147 fn test_size_limited_reader_exceeds_limit() {
148 let data = "this is a long line that exceeds the limit\n";
149 let cursor = Cursor::new(data);
150 let buf_reader = BufReader::new(cursor);
151 let mut reader = SizeLimitedReader::new(buf_reader, 10);
152
153 let result = reader.read_line();
154 assert!(matches!(result, Err(TransportError::SizeLimit { .. })));
155 }
156
157 #[test]
158 fn test_size_limited_reader_strips_newlines() {
159 let data = "line with crlf\r\n";
160 let cursor = Cursor::new(data);
161 let buf_reader = BufReader::new(cursor);
162 let mut reader = SizeLimitedReader::new(buf_reader, 100);
163
164 assert_eq!(
165 reader.read_line().unwrap(),
166 Some("line with crlf".to_string())
167 );
168 }
169
170 #[test]
171 fn test_transport_error_display() {
172 let io_err = TransportError::Io(std::io::Error::new(
173 std::io::ErrorKind::NotFound,
174 "test error",
175 ));
176 assert!(io_err.to_string().contains("I/O error"));
177
178 let parse_err = TransportError::Parse("invalid json".to_string());
179 assert!(parse_err.to_string().contains("Parse error"));
180
181 let size_err = TransportError::SizeLimit { max_bytes: 1024 };
182 assert!(size_err.to_string().contains("1024"));
183
184 let timeout_err = TransportError::Timeout;
185 assert_eq!(timeout_err.to_string(), "Connection timeout");
186
187 let closed_err = TransportError::ConnectionClosed;
188 assert_eq!(closed_err.to_string(), "Connection closed");
189 }
190
191 #[test]
192 fn test_transport_error_from_io() {
193 let timeout = std::io::Error::new(std::io::ErrorKind::TimedOut, "timed out");
194 assert!(matches!(
195 TransportError::from(timeout),
196 TransportError::Timeout
197 ));
198
199 let would_block = std::io::Error::new(std::io::ErrorKind::WouldBlock, "would block");
200 assert!(matches!(
201 TransportError::from(would_block),
202 TransportError::Timeout
203 ));
204
205 let eof = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof");
206 assert!(matches!(
207 TransportError::from(eof),
208 TransportError::ConnectionClosed
209 ));
210
211 let broken_pipe = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
212 assert!(matches!(
213 TransportError::from(broken_pipe),
214 TransportError::ConnectionClosed
215 ));
216
217 let other = std::io::Error::other("other");
218 assert!(matches!(TransportError::from(other), TransportError::Io(_)));
219 }
220
221 #[test]
222 fn test_unix_socket_roundtrip() {
223 use std::os::unix::net::UnixStream;
224 use std::thread;
225
226 let (client_stream, server_stream) = UnixStream::pair().unwrap();
227
228 let server_handle = thread::spawn(move || {
229 let mut conn = UnixSocketConnection::new(server_stream).unwrap();
230 let request = conn.read_request().unwrap();
231 assert_eq!(request.method, "test_method");
232
233 let response = RpcResponse::success(request.id, serde_json::json!({"ok": true}));
234 conn.write_response(&response).unwrap();
235 });
236
237 let mut client_stream_writer = client_stream.try_clone().unwrap();
238 let mut client_conn = UnixSocketConnection::new(client_stream).unwrap();
239
240 let request_json = r#"{"jsonrpc":"2.0","id":1,"method":"test_method"}"#;
241 writeln!(client_stream_writer, "{}", request_json).unwrap();
242
243 let response = client_conn.read_request();
244 assert!(response.is_ok() || matches!(response, Err(TransportError::Parse(_))));
245
246 server_handle.join().unwrap();
247 }
248
249 #[test]
250 fn test_unix_socket_connection_closed() {
251 use std::os::unix::net::UnixStream;
252
253 let (client_stream, server_stream) = UnixStream::pair().unwrap();
254 drop(server_stream);
255
256 let mut conn = UnixSocketConnection::new(client_stream).unwrap();
257 let result = conn.read_request();
258 assert!(matches!(result, Err(TransportError::ConnectionClosed)));
259 }
260
261 #[test]
262 fn test_size_limited_reader_cumulative_limit() {
263 let data = "aaa\nbbb\nccc\n";
264 let cursor = Cursor::new(data);
265 let buf_reader = BufReader::new(cursor);
266 let mut reader = SizeLimitedReader::new(buf_reader, 8);
267
268 assert_eq!(reader.read_line().unwrap(), Some("aaa".to_string()));
269 assert_eq!(reader.read_line().unwrap(), Some("bbb".to_string()));
270 let result = reader.read_line();
271 assert!(matches!(result, Err(TransportError::SizeLimit { .. })));
272 }
273}