algocline_app/pool/client.rs
1//! UDS client for connecting to a pool worker process.
2//!
3//! [`PoolClient`] is the MCP-side (AppService-side) handle that opens a Unix
4//! domain socket connection to a worker subprocess and exchanges
5//! [`PoolRequest`] / [`PoolResponse`] messages in JSON-line format.
6
7use std::path::Path;
8
9use tokio::{
10 io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
11 net::UnixStream,
12 sync::Mutex,
13};
14
15use crate::pool::{
16 error::PoolError,
17 protocol::{PoolRequest, PoolResponse, PoolResponseData},
18};
19
20/// Version string embedded in every handshake to prevent client/server skew.
21pub const POOL_PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
22
23/// Maximum time to wait for the worker's handshake response.
24///
25/// If the worker does not respond within this window, `PoolClient::connect`
26/// returns `Err(PoolError::Handshake("handshake recv timeout (10s)"))` and the
27/// connection is dropped. This prevents `RunningService::cancel` from hanging
28/// indefinitely when a worker fails to send the handshake reply.
29pub(crate) const HANDSHAKE_RECV_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
30
31// ─── Internal state ───────────────────────────────────────────────────────────
32
33#[derive(Debug)]
34struct Inner {
35 writer: BufWriter<tokio::net::unix::OwnedWriteHalf>,
36 reader: BufReader<tokio::net::unix::OwnedReadHalf>,
37}
38
39// ─── PoolClient ───────────────────────────────────────────────────────────────
40
41/// A thin UDS client for communicating with a pool worker process.
42///
43/// Each `PoolClient` instance owns a single Unix domain socket connection to
44/// one worker. Messages are JSON lines (`\n`-terminated).
45///
46/// # Lifecycle
47///
48/// 1. Call [`PoolClient::connect`] — opens the socket and runs the version
49/// handshake atomically. If the handshake fails the connection is dropped
50/// and an error is returned; no `PoolClient` instance is produced.
51/// 2. Call [`PoolClient::send_request`] for each message.
52/// 3. Drop the `PoolClient` when done; the underlying socket is closed.
53#[derive(Debug)]
54pub struct PoolClient {
55 inner: Mutex<Inner>,
56}
57
58impl PoolClient {
59 /// Open a connection to a worker at `sock_path` and verify protocol version.
60 ///
61 /// The handshake (`PoolRequest::Handshake`) is performed inside this call.
62 /// If the worker reports a different version, `Err(PoolError::VersionMismatch)`
63 /// is returned and no `PoolClient` is constructed.
64 ///
65 /// # Concurrency
66 ///
67 /// **Cancel safety**: this function is **not** cancel safe. If dropped before
68 /// `recv_line` completes, the internal `BufReader` may hold a partial line;
69 /// the partial connection is dropped and must not be reused.
70 ///
71 /// **Timeout**: the handshake recv is bounded by `HANDSHAKE_RECV_TIMEOUT`
72 /// (10 s). If the worker does not respond within this window, the function
73 /// returns `Err(PoolError::Handshake("handshake recv timeout (10s)"))` and the
74 /// connection is dropped. This prevents `RunningService::cancel` from hanging
75 /// when a worker fails to send the handshake.
76 ///
77 /// **Send + Sync**: `PoolClient` is `Send` (all fields are `Send`). It is
78 /// **not** `Sync` — callers sharing across tasks must wrap in
79 /// `Arc<tokio::sync::Mutex<PoolClient>>`.
80 ///
81 /// # Errors
82 ///
83 /// - `PoolError::Connect` — socket connect failed (wraps `std::io::Error`).
84 /// - `PoolError::Handshake` — response could not be parsed as valid JSON, or
85 /// the handshake recv timed out after 10 s.
86 /// - `PoolError::VersionMismatch` — worker version differs from client version.
87 pub async fn connect(sock_path: &Path) -> Result<Self, PoolError> {
88 let stream = UnixStream::connect(sock_path).await?;
89 let (read_half, write_half) = stream.into_split();
90
91 let mut inner = Inner {
92 writer: BufWriter::new(write_half),
93 reader: BufReader::new(read_half),
94 };
95
96 // Perform version handshake before returning.
97 let handshake_req = PoolRequest::Handshake {
98 version: POOL_PROTOCOL_VERSION.to_string(),
99 };
100 send_line(&mut inner, &handshake_req).await?;
101
102 let resp = match tokio::time::timeout(HANDSHAKE_RECV_TIMEOUT, recv_line(&mut inner)).await {
103 Ok(Ok(r)) => r,
104 Ok(Err(e)) => return Err(e),
105 Err(_elapsed) => {
106 return Err(PoolError::Handshake(
107 "handshake recv timeout (10s)".to_string(),
108 ));
109 }
110 };
111
112 match &resp.data {
113 Some(PoolResponseData::Handshake { version }) => {
114 if version != POOL_PROTOCOL_VERSION {
115 return Err(PoolError::VersionMismatch {
116 client: POOL_PROTOCOL_VERSION.to_string(),
117 server: version.clone(),
118 });
119 }
120 }
121 _ => {
122 return Err(PoolError::Handshake(
123 "unexpected handshake response".to_string(),
124 ));
125 }
126 }
127
128 Ok(Self {
129 inner: Mutex::new(inner),
130 })
131 }
132
133 /// Send a [`PoolRequest`] over the Unix domain socket and await the response.
134 ///
135 /// Serialises the request to a JSON line (`\n`-terminated), writes it via
136 /// `tokio::io::AsyncWriteExt::write_all`, then reads the response with
137 /// `tokio::io::AsyncBufReadExt::read_line`.
138 ///
139 /// # Concurrency
140 ///
141 /// **Cancel safety**: this function is **not** cancel safe.
142 /// `AsyncBufReadExt::read_line` is not cancel safe per tokio documentation:
143 /// if this future is dropped before `read_line` completes, the internal buffer
144 /// may hold a partial line. After cancellation the connection is no longer
145 /// usable; callers must drop this `PoolClient` and reconnect.
146 ///
147 /// **Mutex serialisation**: the internal `tokio::sync::Mutex<Inner>` serialises
148 /// concurrent callers. Cancelling a `lock().await` call loses queue position
149 /// (tokio docs: "Cancelling a call to `lock` makes you lose your place in the
150 /// queue"). Only one request can be in-flight per `PoolClient` instance at a
151 /// time. Holding the guard across `.await` (write + flush + read_line) is
152 /// intentional and correct with `tokio::sync::Mutex`.
153 ///
154 /// **Send + Sync**: `PoolClient` is `Send` (all fields are `Send`). It is
155 /// **not** `Sync` — callers sharing across tasks must wrap in
156 /// `Arc<tokio::sync::Mutex<PoolClient>>`.
157 ///
158 /// # Panics
159 ///
160 /// Does not panic.
161 pub async fn send_request(&mut self, req: PoolRequest) -> Result<PoolResponse, PoolError> {
162 let inner = self.inner.get_mut();
163 send_line(inner, &req).await?;
164 recv_line(inner).await
165 }
166}
167
168// ─── helpers ─────────────────────────────────────────────────────────────────
169
170/// Write a JSON-line for `req` and flush.
171async fn send_line(inner: &mut Inner, req: &PoolRequest) -> Result<(), PoolError> {
172 let mut line =
173 serde_json::to_string(req).map_err(|e| PoolError::ResponseParse(e.to_string()))?;
174 line.push('\n');
175 inner
176 .writer
177 .write_all(line.as_bytes())
178 .await
179 .map_err(|e| PoolError::IoWrite(e.to_string()))?;
180 inner
181 .writer
182 .flush()
183 .await
184 .map_err(|e| PoolError::IoWrite(e.to_string()))?;
185 Ok(())
186}
187
188/// Read one JSON line and deserialise it as [`PoolResponse`].
189async fn recv_line(inner: &mut Inner) -> Result<PoolResponse, PoolError> {
190 let mut buf = String::new();
191 inner
192 .reader
193 .read_line(&mut buf)
194 .await
195 .map_err(|e| PoolError::IoRead(e.to_string()))?;
196 serde_json::from_str(buf.trim_end_matches('\n'))
197 .map_err(|e| PoolError::ResponseParse(e.to_string()))
198}
199
200// ─── Tests ───────────────────────────────────────────────────────────────────
201
202#[cfg(test)]
203mod tests {
204 use std::path::PathBuf;
205
206 use tokio::net::UnixListener;
207
208 use super::*;
209 use crate::pool::protocol::PoolResponseData;
210
211 // ── helpers ──────────────────────────────────────────────────────────────
212
213 /// Return a temporary directory and a socket path within it.
214 fn temp_sock() -> (tempfile::TempDir, PathBuf) {
215 let dir = tempfile::tempdir().expect("tempdir");
216 let sock = dir.path().join("worker.sock");
217 (dir, sock)
218 }
219
220 /// Spawn a mock server that processes one connection.
221 ///
222 /// `handler` receives a mutable reference to an Inner and must process
223 /// exactly the messages the test expects, then return.
224 async fn spawn_server<F, Fut>(listener: UnixListener, handler: F) -> tokio::task::JoinHandle<()>
225 where
226 F: FnOnce(Inner) -> Fut + Send + 'static,
227 Fut: std::future::Future<Output = ()> + Send + 'static,
228 {
229 tokio::spawn(async move {
230 let (stream, _) = listener.accept().await.expect("accept");
231 let (r, w) = stream.into_split();
232 let inner = Inner {
233 writer: BufWriter::new(w),
234 reader: BufReader::new(r),
235 };
236 handler(inner).await;
237 })
238 }
239
240 /// Write a `PoolResponse` as a JSON line to `inner`.
241 async fn server_send(inner: &mut Inner, resp: &PoolResponse) {
242 let mut line = serde_json::to_string(resp).expect("serialize");
243 line.push('\n');
244 inner
245 .writer
246 .write_all(line.as_bytes())
247 .await
248 .expect("write");
249 inner.writer.flush().await.expect("flush");
250 }
251
252 /// Read one request line from `inner`.
253 async fn server_recv(inner: &mut Inner) -> PoolRequest {
254 let mut buf = String::new();
255 inner
256 .reader
257 .read_line(&mut buf)
258 .await
259 .expect("server read_line");
260 serde_json::from_str(buf.trim_end_matches('\n')).expect("server deserialize")
261 }
262
263 // ── test: happy-path round-trip ───────────────────────────────────────────
264
265 /// Full round-trip: handshake → run (paused) → continue → shutdown.
266 ///
267 /// The mock server matches the protocol version and returns canned responses
268 /// for each op. Verifies that `PoolClient` correctly forwards every response
269 /// back to the caller without re-interpreting the payload.
270 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
271 async fn round_trip_handshake_run_pause_continue_shutdown() {
272 let (_dir, sock_path) = temp_sock();
273 // Clean up any leftover socket from a previous run (EADDRINUSE guard).
274 let _ = std::fs::remove_file(&sock_path);
275
276 let listener = UnixListener::bind(&sock_path).expect("bind");
277
278 let server_handle = spawn_server(listener, |mut inner| async move {
279 // 1. Handshake
280 let req = server_recv(&mut inner).await;
281 assert!(matches!(req, PoolRequest::Handshake { .. }));
282 server_send(
283 &mut inner,
284 &PoolResponse::success(PoolResponseData::Handshake {
285 version: POOL_PROTOCOL_VERSION.to_string(),
286 }),
287 )
288 .await;
289
290 // 2. Run → Paused (FeedResult represented as raw JSON value)
291 let req = server_recv(&mut inner).await;
292 assert!(matches!(req, PoolRequest::Run { .. }));
293 let feed_result = serde_json::json!({
294 "type": "paused",
295 "session_id": "test-sid",
296 "prompt": "hi",
297 "query_id": "q1"
298 });
299 server_send(
300 &mut inner,
301 &PoolResponse::success(PoolResponseData::Feed {
302 session_id: "test-sid".to_string(),
303 feed_result,
304 }),
305 )
306 .await;
307
308 // 3. Continue → Finished
309 let req = server_recv(&mut inner).await;
310 assert!(matches!(req, PoolRequest::Continue { .. }));
311 let feed_result = serde_json::json!({"type": "finished", "output": "done"});
312 server_send(
313 &mut inner,
314 &PoolResponse::success(PoolResponseData::Feed {
315 session_id: "test-sid".to_string(),
316 feed_result,
317 }),
318 )
319 .await;
320
321 // 4. Shutdown
322 let req = server_recv(&mut inner).await;
323 assert!(matches!(req, PoolRequest::Shutdown));
324 server_send(
325 &mut inner,
326 &PoolResponse::success(PoolResponseData::Shutdown),
327 )
328 .await;
329 })
330 .await;
331
332 // --- Client side ---
333 let mut client = PoolClient::connect(&sock_path).await.expect("connect");
334
335 // Run
336 let resp = client
337 .send_request(PoolRequest::Run {
338 code: "return alc.llm('hi')".to_string(),
339 ctx: None,
340 lib_paths: vec![],
341 })
342 .await
343 .expect("run");
344 assert!(resp.ok);
345 assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
346
347 // Continue
348 let resp = client
349 .send_request(PoolRequest::Continue {
350 sid: "test-sid".to_string(),
351 response: "ok".to_string(),
352 query_id: Some("q1".to_string()),
353 usage: None,
354 })
355 .await
356 .expect("continue");
357 assert!(resp.ok);
358 assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
359
360 // Shutdown
361 let resp = client
362 .send_request(PoolRequest::Shutdown)
363 .await
364 .expect("shutdown");
365 assert!(resp.ok);
366 assert!(matches!(resp.data, Some(PoolResponseData::Shutdown)));
367
368 server_handle.await.expect("server task");
369 }
370
371 // ── test: version mismatch ────────────────────────────────────────────────
372
373 /// When the worker responds with a different version, `PoolClient::connect`
374 /// must return `Err(PoolError::VersionMismatch)` and no `PoolClient` is
375 /// constructed.
376 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
377 async fn version_mismatch_returns_pool_error() {
378 let (_dir, sock_path) = temp_sock();
379 let _ = std::fs::remove_file(&sock_path);
380
381 let listener = UnixListener::bind(&sock_path).expect("bind");
382
383 let server_handle = spawn_server(listener, |mut inner| async move {
384 // Consume the handshake request.
385 let _ = server_recv(&mut inner).await;
386 // Reply with a deliberately wrong version.
387 server_send(
388 &mut inner,
389 &PoolResponse::success(PoolResponseData::Handshake {
390 version: "999.0.0".to_string(),
391 }),
392 )
393 .await;
394 })
395 .await;
396
397 let err = PoolClient::connect(&sock_path)
398 .await
399 .expect_err("should fail with version mismatch");
400
401 assert!(
402 matches!(
403 err,
404 PoolError::VersionMismatch {
405 ref client,
406 ref server
407 } if client == POOL_PROTOCOL_VERSION && server == "999.0.0"
408 ),
409 "unexpected error: {err:?}"
410 );
411
412 server_handle.await.expect("server task");
413 }
414
415 // ── test G3: handshake timeout finite ────────────────────────────────────
416
417 /// Verify that `PoolClient::connect` returns `Err(PoolError::Handshake(_))`
418 /// within a finite wall-clock bound when the worker never sends the handshake
419 /// response.
420 ///
421 /// The mock server accepts the connection but does not send anything (sleeps
422 /// for 30 s). The client must time out within `HANDSHAKE_RECV_TIMEOUT` and
423 /// return an error rather than blocking indefinitely.
424 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
425 async fn test_connect_handshake_timeout_finite() {
426 let (_dir, sock_path) = temp_sock();
427 let _ = std::fs::remove_file(&sock_path);
428 let listener = UnixListener::bind(&sock_path).expect("bind");
429
430 // Fake server: accept then do nothing (sleep longer than the timeout).
431 // Hold `inner` across the sleep so the server-side socket stays open;
432 // otherwise async-move would not capture the unused param and the writer
433 // would drop immediately, causing the client to read EOF instead of timing
434 // out and producing a `ResponseParse` error rather than the expected
435 // `Handshake` timeout error.
436 let _server = spawn_server(listener, |inner| async move {
437 let _hold = inner;
438 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
439 })
440 .await;
441
442 let start = tokio::time::Instant::now();
443 let err = PoolClient::connect(&sock_path)
444 .await
445 .expect_err("should time out and return an error");
446 let elapsed = start.elapsed();
447
448 assert!(
449 matches!(err, PoolError::Handshake(_)),
450 "expected PoolError::Handshake, got {err:?}"
451 );
452 assert!(
453 elapsed.as_secs() < HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
454 "connect must complete within {}s, took {:?}",
455 HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
456 elapsed
457 );
458 }
459
460 // ── test G4: concurrent two clients serialised via Mutex ─────────────────
461
462 /// Verify that two tasks sharing a single `Arc<tokio::sync::Mutex<PoolClient>>`
463 /// can concurrently call `send_request` without deadlocking or mixing up
464 /// responses.
465 ///
466 /// The mock server handles one connection and echoes back a unique
467 /// `session_id` for each request (incrementing counter). Two tasks each
468 /// send 10 requests; we assert that all 20 responses are received with no
469 /// duplicates and no gaps.
470 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
471 async fn test_connect_handshake_concurrent_two_clients() {
472 use std::sync::Arc;
473
474 let (_dir, sock_path) = temp_sock();
475 let _ = std::fs::remove_file(&sock_path);
476 let listener = UnixListener::bind(&sock_path).expect("bind");
477
478 // Fake server: handshake then serve N Run requests with unique session_ids.
479 let server_handle = spawn_server(listener, |mut inner| async move {
480 // Handshake
481 let _ = server_recv(&mut inner).await;
482 server_send(
483 &mut inner,
484 &PoolResponse::success(PoolResponseData::Handshake {
485 version: POOL_PROTOCOL_VERSION.to_string(),
486 }),
487 )
488 .await;
489
490 // Serve requests sequentially (client Mutex ensures serial delivery).
491 let mut counter: u32 = 0;
492 loop {
493 let req = server_recv(&mut inner).await;
494 match req {
495 PoolRequest::Shutdown => {
496 server_send(
497 &mut inner,
498 &PoolResponse::success(PoolResponseData::Shutdown),
499 )
500 .await;
501 break;
502 }
503 _ => {
504 let sid = format!("sid-{counter}");
505 counter += 1;
506 let feed_result = serde_json::json!({
507 "type": "finished",
508 "session_id": sid,
509 });
510 server_send(
511 &mut inner,
512 &PoolResponse::success(PoolResponseData::Feed {
513 session_id: sid,
514 feed_result,
515 }),
516 )
517 .await;
518 }
519 }
520 }
521 })
522 .await;
523
524 let client = Arc::new(tokio::sync::Mutex::new(
525 PoolClient::connect(&sock_path).await.expect("connect"),
526 ));
527
528 const REQS_PER_TASK: usize = 10;
529
530 let client_a = Arc::clone(&client);
531 let task_a = tokio::spawn(async move {
532 let mut results = Vec::with_capacity(REQS_PER_TASK);
533 for _ in 0..REQS_PER_TASK {
534 let mut guard = client_a.lock().await;
535 let resp = guard
536 .send_request(PoolRequest::Run {
537 code: String::new(),
538 ctx: None,
539 lib_paths: vec![],
540 })
541 .await
542 .expect("send_request failed");
543 if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
544 results.push(session_id);
545 }
546 }
547 results
548 });
549
550 let client_b = Arc::clone(&client);
551 let task_b = tokio::spawn(async move {
552 let mut results = Vec::with_capacity(REQS_PER_TASK);
553 for _ in 0..REQS_PER_TASK {
554 let mut guard = client_b.lock().await;
555 let resp = guard
556 .send_request(PoolRequest::Run {
557 code: String::new(),
558 ctx: None,
559 lib_paths: vec![],
560 })
561 .await
562 .expect("send_request failed");
563 if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
564 results.push(session_id);
565 }
566 }
567 results
568 });
569
570 let mut results_a = task_a.await.expect("task_a panicked");
571 let results_b = task_b.await.expect("task_b panicked");
572 results_a.extend(results_b);
573
574 // Shutdown cleanly.
575 {
576 let mut guard = client.lock().await;
577 let _ = guard.send_request(PoolRequest::Shutdown).await;
578 }
579
580 // All 20 session_ids must be present with no duplicates.
581 assert_eq!(
582 results_a.len(),
583 REQS_PER_TASK * 2,
584 "expected {} responses, got {}",
585 REQS_PER_TASK * 2,
586 results_a.len()
587 );
588 let mut sorted = results_a.clone();
589 sorted.sort();
590 sorted.dedup();
591 assert_eq!(
592 sorted.len(),
593 results_a.len(),
594 "duplicate session_ids detected: {results_a:?}"
595 );
596
597 server_handle.await.expect("server task");
598 }
599}