Skip to main content

algocline_app/pool/
client.rs

1//! UDS client for connecting to a pool worker process.
2//!
3//! [`PoolClient`] is the MCP-side (AppService-side) handle that opens a Unix
4//! domain socket connection to a worker subprocess and exchanges
5//! [`PoolRequest`] / [`PoolResponse`] messages in JSON-line format.
6
7use std::path::Path;
8
9use tokio::{
10    io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
11    net::UnixStream,
12    sync::Mutex,
13};
14
15use crate::pool::{
16    error::PoolError,
17    protocol::{PoolRequest, PoolResponse, PoolResponseData},
18};
19
20/// Version string embedded in every handshake to prevent client/server skew.
21pub const POOL_PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
22
23/// Maximum time to wait for the worker's handshake response.
24///
25/// If the worker does not respond within this window, `PoolClient::connect`
26/// returns `Err(PoolError::Handshake("handshake recv timeout (10s)"))` and the
27/// connection is dropped.  This prevents `RunningService::cancel` from hanging
28/// indefinitely when a worker fails to send the handshake reply.
29pub(crate) const HANDSHAKE_RECV_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
30
31// ─── Internal state ───────────────────────────────────────────────────────────
32
33#[derive(Debug)]
34struct Inner {
35    writer: BufWriter<tokio::net::unix::OwnedWriteHalf>,
36    reader: BufReader<tokio::net::unix::OwnedReadHalf>,
37}
38
39// ─── PoolClient ───────────────────────────────────────────────────────────────
40
41/// A thin UDS client for communicating with a pool worker process.
42///
43/// Each `PoolClient` instance owns a single Unix domain socket connection to
44/// one worker. Messages are JSON lines (`\n`-terminated).
45///
46/// # Lifecycle
47///
48/// 1. Call [`PoolClient::connect`] — opens the socket and runs the version
49///    handshake atomically.  If the handshake fails the connection is dropped
50///    and an error is returned; no `PoolClient` instance is produced.
51/// 2. Call [`PoolClient::send_request`] for each message.
52/// 3. Drop the `PoolClient` when done; the underlying socket is closed.
53#[derive(Debug)]
54pub struct PoolClient {
55    inner: Mutex<Inner>,
56}
57
58impl PoolClient {
59    /// Open a connection to a worker at `sock_path` and verify protocol version.
60    ///
61    /// The handshake (`PoolRequest::Handshake`) is performed inside this call.
62    /// If the worker reports a different version, `Err(PoolError::VersionMismatch)`
63    /// is returned and no `PoolClient` is constructed.
64    ///
65    /// # Concurrency
66    ///
67    /// **Cancel safety**: this function is **not** cancel safe. If dropped before
68    /// `recv_line` completes, the internal `BufReader` may hold a partial line;
69    /// the partial connection is dropped and must not be reused.
70    ///
71    /// **Timeout**: the handshake recv is bounded by `HANDSHAKE_RECV_TIMEOUT`
72    /// (10 s). If the worker does not respond within this window, the function
73    /// returns `Err(PoolError::Handshake("handshake recv timeout (10s)"))` and the
74    /// connection is dropped. This prevents `RunningService::cancel` from hanging
75    /// when a worker fails to send the handshake.
76    ///
77    /// **Send + Sync**: `PoolClient` is `Send` (all fields are `Send`). It is
78    /// **not** `Sync` — callers sharing across tasks must wrap in
79    /// `Arc<tokio::sync::Mutex<PoolClient>>`.
80    ///
81    /// # Errors
82    ///
83    /// - `PoolError::Connect` — socket connect failed (wraps `std::io::Error`).
84    /// - `PoolError::Handshake` — response could not be parsed as valid JSON, or
85    ///   the handshake recv timed out after 10 s.
86    /// - `PoolError::VersionMismatch` — worker version differs from client version.
87    pub async fn connect(sock_path: &Path) -> Result<Self, PoolError> {
88        let stream = UnixStream::connect(sock_path).await?;
89        let (read_half, write_half) = stream.into_split();
90
91        let mut inner = Inner {
92            writer: BufWriter::new(write_half),
93            reader: BufReader::new(read_half),
94        };
95
96        // Perform version handshake before returning.
97        let handshake_req = PoolRequest::Handshake {
98            version: POOL_PROTOCOL_VERSION.to_string(),
99        };
100        send_line(&mut inner, &handshake_req).await?;
101
102        let resp = match tokio::time::timeout(HANDSHAKE_RECV_TIMEOUT, recv_line(&mut inner)).await {
103            Ok(Ok(r)) => r,
104            Ok(Err(e)) => return Err(e),
105            Err(_elapsed) => {
106                return Err(PoolError::Handshake(
107                    "handshake recv timeout (10s)".to_string(),
108                ));
109            }
110        };
111
112        match &resp.data {
113            Some(PoolResponseData::Handshake { version }) => {
114                if version != POOL_PROTOCOL_VERSION {
115                    return Err(PoolError::VersionMismatch {
116                        client: POOL_PROTOCOL_VERSION.to_string(),
117                        server: version.clone(),
118                    });
119                }
120            }
121            _ => {
122                return Err(PoolError::Handshake(
123                    "unexpected handshake response".to_string(),
124                ));
125            }
126        }
127
128        Ok(Self {
129            inner: Mutex::new(inner),
130        })
131    }
132
133    /// Send a [`PoolRequest`] over the Unix domain socket and await the response.
134    ///
135    /// Serialises the request to a JSON line (`\n`-terminated), writes it via
136    /// `tokio::io::AsyncWriteExt::write_all`, then reads the response with
137    /// `tokio::io::AsyncBufReadExt::read_line`.
138    ///
139    /// # Concurrency
140    ///
141    /// **Cancel safety**: this function is **not** cancel safe.
142    /// `AsyncBufReadExt::read_line` is not cancel safe per tokio documentation:
143    /// if this future is dropped before `read_line` completes, the internal buffer
144    /// may hold a partial line. After cancellation the connection is no longer
145    /// usable; callers must drop this `PoolClient` and reconnect.
146    ///
147    /// **Mutex serialisation**: the internal `tokio::sync::Mutex<Inner>` serialises
148    /// concurrent callers. Cancelling a `lock().await` call loses queue position
149    /// (tokio docs: "Cancelling a call to `lock` makes you lose your place in the
150    /// queue"). Only one request can be in-flight per `PoolClient` instance at a
151    /// time. Holding the guard across `.await` (write + flush + read_line) is
152    /// intentional and correct with `tokio::sync::Mutex`.
153    ///
154    /// **Send + Sync**: `PoolClient` is `Send` (all fields are `Send`). It is
155    /// **not** `Sync` — callers sharing across tasks must wrap in
156    /// `Arc<tokio::sync::Mutex<PoolClient>>`.
157    ///
158    /// # Panics
159    ///
160    /// Does not panic.
161    pub async fn send_request(&mut self, req: PoolRequest) -> Result<PoolResponse, PoolError> {
162        let inner = self.inner.get_mut();
163        send_line(inner, &req).await?;
164        recv_line(inner).await
165    }
166}
167
168// ─── helpers ─────────────────────────────────────────────────────────────────
169
170/// Write a JSON-line for `req` and flush.
171async fn send_line(inner: &mut Inner, req: &PoolRequest) -> Result<(), PoolError> {
172    let mut line =
173        serde_json::to_string(req).map_err(|e| PoolError::ResponseParse(e.to_string()))?;
174    line.push('\n');
175    inner
176        .writer
177        .write_all(line.as_bytes())
178        .await
179        .map_err(|e| PoolError::IoWrite(e.to_string()))?;
180    inner
181        .writer
182        .flush()
183        .await
184        .map_err(|e| PoolError::IoWrite(e.to_string()))?;
185    Ok(())
186}
187
188/// Read one JSON line and deserialise it as [`PoolResponse`].
189async fn recv_line(inner: &mut Inner) -> Result<PoolResponse, PoolError> {
190    let mut buf = String::new();
191    inner
192        .reader
193        .read_line(&mut buf)
194        .await
195        .map_err(|e| PoolError::IoRead(e.to_string()))?;
196    serde_json::from_str(buf.trim_end_matches('\n'))
197        .map_err(|e| PoolError::ResponseParse(e.to_string()))
198}
199
200// ─── Tests ───────────────────────────────────────────────────────────────────
201
202#[cfg(test)]
203mod tests {
204    use std::path::PathBuf;
205
206    use tokio::net::UnixListener;
207
208    use super::*;
209    use crate::pool::protocol::PoolResponseData;
210
211    // ── helpers ──────────────────────────────────────────────────────────────
212
213    /// Return a temporary directory and a socket path within it.
214    fn temp_sock() -> (tempfile::TempDir, PathBuf) {
215        let dir = tempfile::tempdir().expect("tempdir");
216        let sock = dir.path().join("worker.sock");
217        (dir, sock)
218    }
219
220    /// Spawn a mock server that processes one connection.
221    ///
222    /// `handler` receives a mutable reference to an Inner and must process
223    /// exactly the messages the test expects, then return.
224    async fn spawn_server<F, Fut>(listener: UnixListener, handler: F) -> tokio::task::JoinHandle<()>
225    where
226        F: FnOnce(Inner) -> Fut + Send + 'static,
227        Fut: std::future::Future<Output = ()> + Send + 'static,
228    {
229        tokio::spawn(async move {
230            let (stream, _) = listener.accept().await.expect("accept");
231            let (r, w) = stream.into_split();
232            let inner = Inner {
233                writer: BufWriter::new(w),
234                reader: BufReader::new(r),
235            };
236            handler(inner).await;
237        })
238    }
239
240    /// Write a `PoolResponse` as a JSON line to `inner`.
241    async fn server_send(inner: &mut Inner, resp: &PoolResponse) {
242        let mut line = serde_json::to_string(resp).expect("serialize");
243        line.push('\n');
244        inner
245            .writer
246            .write_all(line.as_bytes())
247            .await
248            .expect("write");
249        inner.writer.flush().await.expect("flush");
250    }
251
252    /// Read one request line from `inner`.
253    async fn server_recv(inner: &mut Inner) -> PoolRequest {
254        let mut buf = String::new();
255        inner
256            .reader
257            .read_line(&mut buf)
258            .await
259            .expect("server read_line");
260        serde_json::from_str(buf.trim_end_matches('\n')).expect("server deserialize")
261    }
262
263    // ── test: happy-path round-trip ───────────────────────────────────────────
264
265    /// Full round-trip: handshake → run (paused) → continue → shutdown.
266    ///
267    /// The mock server matches the protocol version and returns canned responses
268    /// for each op. Verifies that `PoolClient` correctly forwards every response
269    /// back to the caller without re-interpreting the payload.
270    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
271    async fn round_trip_handshake_run_pause_continue_shutdown() {
272        let (_dir, sock_path) = temp_sock();
273        // Clean up any leftover socket from a previous run (EADDRINUSE guard).
274        let _ = std::fs::remove_file(&sock_path);
275
276        let listener = UnixListener::bind(&sock_path).expect("bind");
277
278        let server_handle = spawn_server(listener, |mut inner| async move {
279            // 1. Handshake
280            let req = server_recv(&mut inner).await;
281            assert!(matches!(req, PoolRequest::Handshake { .. }));
282            server_send(
283                &mut inner,
284                &PoolResponse::success(PoolResponseData::Handshake {
285                    version: POOL_PROTOCOL_VERSION.to_string(),
286                }),
287            )
288            .await;
289
290            // 2. Run → Paused (FeedResult represented as raw JSON value)
291            let req = server_recv(&mut inner).await;
292            assert!(matches!(req, PoolRequest::Run { .. }));
293            let feed_result = serde_json::json!({
294                "type": "paused",
295                "session_id": "test-sid",
296                "prompt": "hi",
297                "query_id": "q1"
298            });
299            server_send(
300                &mut inner,
301                &PoolResponse::success(PoolResponseData::Feed {
302                    session_id: "test-sid".to_string(),
303                    feed_result,
304                }),
305            )
306            .await;
307
308            // 3. Continue → Finished
309            let req = server_recv(&mut inner).await;
310            assert!(matches!(req, PoolRequest::Continue { .. }));
311            let feed_result = serde_json::json!({"type": "finished", "output": "done"});
312            server_send(
313                &mut inner,
314                &PoolResponse::success(PoolResponseData::Feed {
315                    session_id: "test-sid".to_string(),
316                    feed_result,
317                }),
318            )
319            .await;
320
321            // 4. Shutdown
322            let req = server_recv(&mut inner).await;
323            assert!(matches!(req, PoolRequest::Shutdown));
324            server_send(
325                &mut inner,
326                &PoolResponse::success(PoolResponseData::Shutdown),
327            )
328            .await;
329        })
330        .await;
331
332        // --- Client side ---
333        let mut client = PoolClient::connect(&sock_path).await.expect("connect");
334
335        // Run
336        let resp = client
337            .send_request(PoolRequest::Run {
338                code: "return alc.llm('hi')".to_string(),
339                ctx: None,
340                lib_paths: vec![],
341            })
342            .await
343            .expect("run");
344        assert!(resp.ok);
345        assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
346
347        // Continue
348        let resp = client
349            .send_request(PoolRequest::Continue {
350                sid: "test-sid".to_string(),
351                response: "ok".to_string(),
352                query_id: Some("q1".to_string()),
353                usage: None,
354            })
355            .await
356            .expect("continue");
357        assert!(resp.ok);
358        assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
359
360        // Shutdown
361        let resp = client
362            .send_request(PoolRequest::Shutdown)
363            .await
364            .expect("shutdown");
365        assert!(resp.ok);
366        assert!(matches!(resp.data, Some(PoolResponseData::Shutdown)));
367
368        server_handle.await.expect("server task");
369    }
370
371    // ── test: version mismatch ────────────────────────────────────────────────
372
373    /// When the worker responds with a different version, `PoolClient::connect`
374    /// must return `Err(PoolError::VersionMismatch)` and no `PoolClient` is
375    /// constructed.
376    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
377    async fn version_mismatch_returns_pool_error() {
378        let (_dir, sock_path) = temp_sock();
379        let _ = std::fs::remove_file(&sock_path);
380
381        let listener = UnixListener::bind(&sock_path).expect("bind");
382
383        let server_handle = spawn_server(listener, |mut inner| async move {
384            // Consume the handshake request.
385            let _ = server_recv(&mut inner).await;
386            // Reply with a deliberately wrong version.
387            server_send(
388                &mut inner,
389                &PoolResponse::success(PoolResponseData::Handshake {
390                    version: "999.0.0".to_string(),
391                }),
392            )
393            .await;
394        })
395        .await;
396
397        let err = PoolClient::connect(&sock_path)
398            .await
399            .expect_err("should fail with version mismatch");
400
401        assert!(
402            matches!(
403                err,
404                PoolError::VersionMismatch {
405                    ref client,
406                    ref server
407                } if client == POOL_PROTOCOL_VERSION && server == "999.0.0"
408            ),
409            "unexpected error: {err:?}"
410        );
411
412        server_handle.await.expect("server task");
413    }
414
415    // ── test G3: handshake timeout finite ────────────────────────────────────
416
417    /// Verify that `PoolClient::connect` returns `Err(PoolError::Handshake(_))`
418    /// within a finite wall-clock bound when the worker never sends the handshake
419    /// response.
420    ///
421    /// The mock server accepts the connection but does not send anything (sleeps
422    /// for 30 s).  The client must time out within `HANDSHAKE_RECV_TIMEOUT` and
423    /// return an error rather than blocking indefinitely.
424    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
425    async fn test_connect_handshake_timeout_finite() {
426        let (_dir, sock_path) = temp_sock();
427        let _ = std::fs::remove_file(&sock_path);
428        let listener = UnixListener::bind(&sock_path).expect("bind");
429
430        // Fake server: accept then do nothing (sleep longer than the timeout).
431        // Hold `inner` across the sleep so the server-side socket stays open;
432        // otherwise async-move would not capture the unused param and the writer
433        // would drop immediately, causing the client to read EOF instead of timing
434        // out and producing a `ResponseParse` error rather than the expected
435        // `Handshake` timeout error.
436        let _server = spawn_server(listener, |inner| async move {
437            let _hold = inner;
438            tokio::time::sleep(std::time::Duration::from_secs(30)).await;
439        })
440        .await;
441
442        let start = tokio::time::Instant::now();
443        let err = PoolClient::connect(&sock_path)
444            .await
445            .expect_err("should time out and return an error");
446        let elapsed = start.elapsed();
447
448        assert!(
449            matches!(err, PoolError::Handshake(_)),
450            "expected PoolError::Handshake, got {err:?}"
451        );
452        assert!(
453            elapsed.as_secs() < HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
454            "connect must complete within {}s, took {:?}",
455            HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
456            elapsed
457        );
458    }
459
460    // ── test G4: concurrent two clients serialised via Mutex ─────────────────
461
462    /// Verify that two tasks sharing a single `Arc<tokio::sync::Mutex<PoolClient>>`
463    /// can concurrently call `send_request` without deadlocking or mixing up
464    /// responses.
465    ///
466    /// The mock server handles one connection and echoes back a unique
467    /// `session_id` for each request (incrementing counter).  Two tasks each
468    /// send 10 requests; we assert that all 20 responses are received with no
469    /// duplicates and no gaps.
470    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
471    async fn test_connect_handshake_concurrent_two_clients() {
472        use std::sync::Arc;
473
474        let (_dir, sock_path) = temp_sock();
475        let _ = std::fs::remove_file(&sock_path);
476        let listener = UnixListener::bind(&sock_path).expect("bind");
477
478        // Fake server: handshake then serve N Run requests with unique session_ids.
479        let server_handle = spawn_server(listener, |mut inner| async move {
480            // Handshake
481            let _ = server_recv(&mut inner).await;
482            server_send(
483                &mut inner,
484                &PoolResponse::success(PoolResponseData::Handshake {
485                    version: POOL_PROTOCOL_VERSION.to_string(),
486                }),
487            )
488            .await;
489
490            // Serve requests sequentially (client Mutex ensures serial delivery).
491            let mut counter: u32 = 0;
492            loop {
493                let req = server_recv(&mut inner).await;
494                match req {
495                    PoolRequest::Shutdown => {
496                        server_send(
497                            &mut inner,
498                            &PoolResponse::success(PoolResponseData::Shutdown),
499                        )
500                        .await;
501                        break;
502                    }
503                    _ => {
504                        let sid = format!("sid-{counter}");
505                        counter += 1;
506                        let feed_result = serde_json::json!({
507                            "type": "finished",
508                            "session_id": sid,
509                        });
510                        server_send(
511                            &mut inner,
512                            &PoolResponse::success(PoolResponseData::Feed {
513                                session_id: sid,
514                                feed_result,
515                            }),
516                        )
517                        .await;
518                    }
519                }
520            }
521        })
522        .await;
523
524        let client = Arc::new(tokio::sync::Mutex::new(
525            PoolClient::connect(&sock_path).await.expect("connect"),
526        ));
527
528        const REQS_PER_TASK: usize = 10;
529
530        let client_a = Arc::clone(&client);
531        let task_a = tokio::spawn(async move {
532            let mut results = Vec::with_capacity(REQS_PER_TASK);
533            for _ in 0..REQS_PER_TASK {
534                let mut guard = client_a.lock().await;
535                let resp = guard
536                    .send_request(PoolRequest::Run {
537                        code: String::new(),
538                        ctx: None,
539                        lib_paths: vec![],
540                    })
541                    .await
542                    .expect("send_request failed");
543                if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
544                    results.push(session_id);
545                }
546            }
547            results
548        });
549
550        let client_b = Arc::clone(&client);
551        let task_b = tokio::spawn(async move {
552            let mut results = Vec::with_capacity(REQS_PER_TASK);
553            for _ in 0..REQS_PER_TASK {
554                let mut guard = client_b.lock().await;
555                let resp = guard
556                    .send_request(PoolRequest::Run {
557                        code: String::new(),
558                        ctx: None,
559                        lib_paths: vec![],
560                    })
561                    .await
562                    .expect("send_request failed");
563                if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
564                    results.push(session_id);
565                }
566            }
567            results
568        });
569
570        let mut results_a = task_a.await.expect("task_a panicked");
571        let results_b = task_b.await.expect("task_b panicked");
572        results_a.extend(results_b);
573
574        // Shutdown cleanly.
575        {
576            let mut guard = client.lock().await;
577            let _ = guard.send_request(PoolRequest::Shutdown).await;
578        }
579
580        // All 20 session_ids must be present with no duplicates.
581        assert_eq!(
582            results_a.len(),
583            REQS_PER_TASK * 2,
584            "expected {} responses, got {}",
585            REQS_PER_TASK * 2,
586            results_a.len()
587        );
588        let mut sorted = results_a.clone();
589        sorted.sort();
590        sorted.dedup();
591        assert_eq!(
592            sorted.len(),
593            results_a.len(),
594            "duplicate session_ids detected: {results_a:?}"
595        );
596
597        server_handle.await.expect("server task");
598    }
599}