1use super::traits::{Transport, TransportError, TransportReader, TransportWriter};
8use crate::protocol::{Request, Response};
9use tokio::sync::mpsc;
10
11pub fn channel() -> (InProcessClient, InProcessServer) {
29 let (request_tx, request_rx) = mpsc::unbounded_channel();
31 let (response_tx, response_rx) = mpsc::unbounded_channel();
33
34 let client = InProcessClient { request_tx, response_rx, closed: false };
35
36 let server = InProcessServer { request_rx, response_tx, closed: false };
37
38 (client, server)
39}
40
41#[derive(Debug)]
45pub struct InProcessClient {
46 request_tx: mpsc::UnboundedSender<Request>,
47 response_rx: mpsc::UnboundedReceiver<Response>,
48 closed: bool,
49}
50
51#[async_trait::async_trait]
52impl Transport for InProcessClient {
53 async fn send_request(&mut self, request: &Request) -> Result<(), TransportError> {
54 if self.closed {
55 return Err(TransportError::ConnectionClosed);
56 }
57 self.request_tx.send(request.clone()).map_err(|_| TransportError::ConnectionClosed)
58 }
59
60 async fn send_response(&mut self, _response: &Response) -> Result<(), TransportError> {
61 Err(TransportError::Io(std::io::Error::new(
63 std::io::ErrorKind::Unsupported,
64 "Clients cannot send responses",
65 )))
66 }
67
68 async fn recv_request(&mut self) -> Result<Request, TransportError> {
69 Err(TransportError::Io(std::io::Error::new(
71 std::io::ErrorKind::Unsupported,
72 "Clients cannot receive requests",
73 )))
74 }
75
76 async fn recv_response(&mut self) -> Result<Response, TransportError> {
77 if self.closed {
78 return Err(TransportError::ConnectionClosed);
79 }
80 self.response_rx.recv().await.ok_or(TransportError::ConnectionClosed)
81 }
82
83 async fn close(&mut self) -> Result<(), TransportError> {
84 self.closed = true;
85 Ok(())
86 }
87}
88
89#[derive(Debug)]
93pub struct InProcessServer {
94 request_rx: mpsc::UnboundedReceiver<Request>,
95 response_tx: mpsc::UnboundedSender<Response>,
96 closed: bool,
97}
98
99#[async_trait::async_trait]
100impl Transport for InProcessServer {
101 async fn send_request(&mut self, _request: &Request) -> Result<(), TransportError> {
102 Err(TransportError::Io(std::io::Error::new(
104 std::io::ErrorKind::Unsupported,
105 "Servers cannot send requests",
106 )))
107 }
108
109 async fn send_response(&mut self, response: &Response) -> Result<(), TransportError> {
110 if self.closed {
111 return Err(TransportError::ConnectionClosed);
112 }
113 self.response_tx.send(response.clone()).map_err(|_| TransportError::ConnectionClosed)
114 }
115
116 async fn recv_request(&mut self) -> Result<Request, TransportError> {
117 if self.closed {
118 return Err(TransportError::ConnectionClosed);
119 }
120 self.request_rx.recv().await.ok_or(TransportError::ConnectionClosed)
121 }
122
123 async fn recv_response(&mut self) -> Result<Response, TransportError> {
124 Err(TransportError::Io(std::io::Error::new(
126 std::io::ErrorKind::Unsupported,
127 "Servers cannot receive responses",
128 )))
129 }
130
131 async fn close(&mut self) -> Result<(), TransportError> {
132 self.closed = true;
133 Ok(())
134 }
135}
136
137pub struct InProcessClientReader {
139 response_rx: mpsc::UnboundedReceiver<Response>,
140}
141
142pub struct InProcessClientWriter {
143 request_tx: mpsc::UnboundedSender<Request>,
144}
145
146#[async_trait::async_trait]
147impl TransportReader for InProcessClientReader {
148 async fn recv_request(&mut self) -> Result<Request, TransportError> {
149 Err(TransportError::Io(std::io::Error::new(
150 std::io::ErrorKind::Unsupported,
151 "Client reader cannot receive requests",
152 )))
153 }
154
155 async fn recv_response(&mut self) -> Result<Response, TransportError> {
156 self.response_rx.recv().await.ok_or(TransportError::ConnectionClosed)
157 }
158}
159
160#[async_trait::async_trait]
161impl TransportWriter for InProcessClientWriter {
162 async fn send_request(&mut self, request: &Request) -> Result<(), TransportError> {
163 self.request_tx.send(request.clone()).map_err(|_| TransportError::ConnectionClosed)
164 }
165
166 async fn send_response(&mut self, _response: &Response) -> Result<(), TransportError> {
167 Err(TransportError::Io(std::io::Error::new(
168 std::io::ErrorKind::Unsupported,
169 "Client writer cannot send responses",
170 )))
171 }
172
173 async fn flush(&mut self) -> Result<(), TransportError> {
174 Ok(())
176 }
177}
178
179pub struct InProcessServerReader {
181 request_rx: mpsc::UnboundedReceiver<Request>,
182}
183
184pub struct InProcessServerWriter {
185 response_tx: mpsc::UnboundedSender<Response>,
186}
187
188#[async_trait::async_trait]
189impl TransportReader for InProcessServerReader {
190 async fn recv_request(&mut self) -> Result<Request, TransportError> {
191 self.request_rx.recv().await.ok_or(TransportError::ConnectionClosed)
192 }
193
194 async fn recv_response(&mut self) -> Result<Response, TransportError> {
195 Err(TransportError::Io(std::io::Error::new(
196 std::io::ErrorKind::Unsupported,
197 "Server reader cannot receive responses",
198 )))
199 }
200}
201
202#[async_trait::async_trait]
203impl TransportWriter for InProcessServerWriter {
204 async fn send_request(&mut self, _request: &Request) -> Result<(), TransportError> {
205 Err(TransportError::Io(std::io::Error::new(
206 std::io::ErrorKind::Unsupported,
207 "Server writer cannot send requests",
208 )))
209 }
210
211 async fn send_response(&mut self, response: &Response) -> Result<(), TransportError> {
212 self.response_tx.send(response.clone()).map_err(|_| TransportError::ConnectionClosed)
213 }
214
215 async fn flush(&mut self) -> Result<(), TransportError> {
216 Ok(())
218 }
219}
220
221impl InProcessClient {
222 pub fn split(self) -> (InProcessClientReader, InProcessClientWriter) {
224 (
225 InProcessClientReader { response_rx: self.response_rx },
226 InProcessClientWriter { request_tx: self.request_tx },
227 )
228 }
229}
230
231impl InProcessServer {
232 pub fn split(self) -> (InProcessServerReader, InProcessServerWriter) {
234 (
235 InProcessServerReader { request_rx: self.request_rx },
236 InProcessServerWriter { response_tx: self.response_tx },
237 )
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::protocol::{MessageId, Operation, OperationResult, ReplMode, SessionId, Status};
245
246 #[tokio::test]
247 async fn test_channel_creation() {
248 let (_client, _server) = channel();
249 }
250
251 #[tokio::test]
252 async fn test_send_request_recv_request() {
253 let (mut client, mut server) = channel();
254
255 let request = Request {
256 id: MessageId::new(1),
257 session_id: SessionId::new("test-session"),
258 operation: Operation::Eval { code: "(+ 1 2)".to_string(), mode: ReplMode::Lisp },
259 };
260
261 client.send_request(&request).await.unwrap();
263
264 let received = server.recv_request().await.unwrap();
266 assert_eq!(received.id, MessageId::new(1));
267 assert_eq!(received.session_id, SessionId::new("test-session"));
268 }
269
270 #[tokio::test]
271 async fn test_send_response_recv_response() {
272 let (mut client, mut server) = channel();
273
274 let response = Response {
275 request_id: MessageId::new(42),
276 session_id: SessionId::new("test-session"),
277 result: OperationResult::Success {
278 status: Status { tier: 1, cached: false, duration_ms: 5 },
279 value: Some("3".to_string()),
280 stdout: None,
281 stderr: None,
282 },
283 };
284
285 server.send_response(&response).await.unwrap();
287
288 let received = client.recv_response().await.unwrap();
290 assert_eq!(received.request_id, MessageId::new(42));
291 }
292
293 #[tokio::test]
294 async fn test_bidirectional_communication() {
295 let (mut client, mut server) = channel();
296
297 let request = Request {
298 id: MessageId::new(1),
299 session_id: SessionId::new("session-1"),
300 operation: Operation::Eval { code: "(+ 2 3)".to_string(), mode: ReplMode::Lisp },
301 };
302
303 let response = Response {
304 request_id: MessageId::new(1),
305 session_id: SessionId::new("session-1"),
306 result: OperationResult::Success {
307 status: Status { tier: 1, cached: false, duration_ms: 2 },
308 value: Some("5".to_string()),
309 stdout: None,
310 stderr: None,
311 },
312 };
313
314 client.send_request(&request).await.unwrap();
316
317 let recv_request = server.recv_request().await.unwrap();
319 assert_eq!(recv_request.id, MessageId::new(1));
320
321 server.send_response(&response).await.unwrap();
323
324 let recv_response = client.recv_response().await.unwrap();
326 assert_eq!(recv_response.request_id, MessageId::new(1));
327 if let OperationResult::Success { value, .. } = recv_response.result {
328 assert_eq!(value, Some("5".to_string()));
329 } else {
330 panic!("Expected success result");
331 }
332 }
333
334 #[tokio::test]
335 async fn test_multiple_requests() {
336 let (mut client, mut server) = channel();
337
338 for i in 1..=5 {
340 let request = Request {
341 id: MessageId::new(i),
342 session_id: SessionId::new("test"),
343 operation: Operation::Eval {
344 code: format!("(+ {} {})", i, i),
345 mode: ReplMode::Lisp,
346 },
347 };
348 client.send_request(&request).await.unwrap();
349 }
350
351 for i in 1..=5 {
353 let received = server.recv_request().await.unwrap();
354 assert_eq!(received.id, MessageId::new(i));
355 }
356 }
357
358 #[tokio::test]
359 async fn test_connection_closed_on_drop() {
360 let (client, mut server) = channel();
361
362 drop(client);
364
365 let result = server.recv_request().await;
367 assert!(matches!(result, Err(TransportError::ConnectionClosed)));
368 }
369
370 #[tokio::test]
371 async fn test_client_cannot_send_response() {
372 let (mut client, _server) = channel();
373
374 let response = Response {
375 request_id: MessageId::new(1),
376 session_id: SessionId::new("test"),
377 result: OperationResult::Success {
378 status: Status { tier: 1, cached: false, duration_ms: 0 },
379 value: None,
380 stdout: None,
381 stderr: None,
382 },
383 };
384
385 let result = client.send_response(&response).await;
386 assert!(result.is_err());
387 }
388
389 #[tokio::test]
390 async fn test_server_cannot_send_request() {
391 let (_client, mut server) = channel();
392
393 let request = Request {
394 id: MessageId::new(1),
395 session_id: SessionId::new("test"),
396 operation: Operation::LsSessions,
397 };
398
399 let result = server.send_request(&request).await;
400 assert!(result.is_err());
401 }
402
403 #[tokio::test]
404 async fn test_close() {
405 let (mut client, mut server) = channel();
406
407 client.close().await.unwrap();
409
410 let request = Request {
412 id: MessageId::new(1),
413 session_id: SessionId::new("test"),
414 operation: Operation::LsSessions,
415 };
416 let result = client.send_request(&request).await;
417 assert!(matches!(result, Err(TransportError::ConnectionClosed)));
418
419 server.close().await.unwrap();
421 let result = server.recv_request().await;
422 assert!(matches!(result, Err(TransportError::ConnectionClosed)));
423 }
424
425 #[tokio::test]
426 async fn test_split_client() {
427 let (client, mut server) = channel();
428 let (mut reader, mut writer) = client.split();
429
430 let request = Request {
431 id: MessageId::new(99),
432 session_id: SessionId::new("split-test"),
433 operation: Operation::LsSessions,
434 };
435
436 let response = Response {
437 request_id: MessageId::new(99),
438 session_id: SessionId::new("split-test"),
439 result: OperationResult::Success {
440 status: Status { tier: 1, cached: false, duration_ms: 1 },
441 value: None,
442 stdout: None,
443 stderr: None,
444 },
445 };
446
447 writer.send_request(&request).await.unwrap();
449
450 let _ = server.recv_request().await.unwrap();
452 server.send_response(&response).await.unwrap();
453
454 let recv = reader.recv_response().await.unwrap();
456 assert_eq!(recv.request_id, MessageId::new(99));
457 }
458
459 #[tokio::test]
460 async fn test_split_server() {
461 let (mut client, server) = channel();
462 let (mut reader, mut writer) = server.split();
463
464 let request = Request {
465 id: MessageId::new(77),
466 session_id: SessionId::new("split-server-test"),
467 operation: Operation::CreateSession { mode: ReplMode::Lisp },
468 };
469
470 let response = Response {
471 request_id: MessageId::new(77),
472 session_id: SessionId::new("split-server-test"),
473 result: OperationResult::Success {
474 status: Status { tier: 1, cached: false, duration_ms: 0 },
475 value: None,
476 stdout: None,
477 stderr: None,
478 },
479 };
480
481 client.send_request(&request).await.unwrap();
483
484 let recv_req = reader.recv_request().await.unwrap();
486 assert_eq!(recv_req.id, MessageId::new(77));
487
488 writer.send_response(&response).await.unwrap();
490
491 let recv_resp = client.recv_response().await.unwrap();
493 assert_eq!(recv_resp.request_id, MessageId::new(77));
494 }
495
496 #[tokio::test]
497 async fn test_concurrent_usage() {
498 let (client, server) = channel();
499 let (mut client_reader, mut client_writer) = client.split();
500 let (mut server_reader, mut server_writer) = server.split();
501
502 let client_write_handle = tokio::spawn(async move {
504 for i in 1..=10 {
505 let request = Request {
506 id: MessageId::new(i),
507 session_id: SessionId::new("concurrent"),
508 operation: Operation::Eval {
509 code: format!("(+ {} 1)", i),
510 mode: ReplMode::Lisp,
511 },
512 };
513 client_writer.send_request(&request).await.unwrap();
514 }
515 });
516
517 let server_handle = tokio::spawn(async move {
519 for _ in 1..=10 {
520 let req = server_reader.recv_request().await.unwrap();
521 let response = Response {
522 request_id: req.id,
523 session_id: req.session_id.clone(),
524 result: OperationResult::Success {
525 status: Status { tier: 1, cached: false, duration_ms: 1 },
526 value: Some("ok".to_string()),
527 stdout: None,
528 stderr: None,
529 },
530 };
531 server_writer.send_response(&response).await.unwrap();
532 }
533 });
534
535 let client_read_handle = tokio::spawn(async move {
537 let mut count = 0;
538 for _ in 1..=10 {
539 let _resp = client_reader.recv_response().await.unwrap();
540 count += 1;
541 }
542 count
543 });
544
545 client_write_handle.await.unwrap();
546 server_handle.await.unwrap();
547 let received_count = client_read_handle.await.unwrap();
548
549 assert_eq!(received_count, 10);
551 }
552}