use std::path::Path;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
net::UnixStream,
sync::Mutex,
};
use crate::pool::{
error::PoolError,
protocol::{PoolRequest, PoolResponse, PoolResponseData},
};
pub const POOL_PROTOCOL_VERSION: &str = env!("CARGO_PKG_VERSION");
pub(crate) const HANDSHAKE_RECV_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[derive(Debug)]
struct Inner {
writer: BufWriter<tokio::net::unix::OwnedWriteHalf>,
reader: BufReader<tokio::net::unix::OwnedReadHalf>,
}
#[derive(Debug)]
pub struct PoolClient {
inner: Mutex<Inner>,
}
impl PoolClient {
pub async fn connect(sock_path: &Path) -> Result<Self, PoolError> {
let stream = UnixStream::connect(sock_path).await?;
let (read_half, write_half) = stream.into_split();
let mut inner = Inner {
writer: BufWriter::new(write_half),
reader: BufReader::new(read_half),
};
let handshake_req = PoolRequest::Handshake {
version: POOL_PROTOCOL_VERSION.to_string(),
};
send_line(&mut inner, &handshake_req).await?;
let resp = match tokio::time::timeout(HANDSHAKE_RECV_TIMEOUT, recv_line(&mut inner)).await {
Ok(Ok(r)) => r,
Ok(Err(e)) => return Err(e),
Err(_elapsed) => {
return Err(PoolError::Handshake(
"handshake recv timeout (10s)".to_string(),
));
}
};
match &resp.data {
Some(PoolResponseData::Handshake { version }) => {
if version != POOL_PROTOCOL_VERSION {
return Err(PoolError::VersionMismatch {
client: POOL_PROTOCOL_VERSION.to_string(),
server: version.clone(),
});
}
}
_ => {
return Err(PoolError::Handshake(
"unexpected handshake response".to_string(),
));
}
}
Ok(Self {
inner: Mutex::new(inner),
})
}
pub async fn send_request(&mut self, req: PoolRequest) -> Result<PoolResponse, PoolError> {
let inner = self.inner.get_mut();
send_line(inner, &req).await?;
recv_line(inner).await
}
}
async fn send_line(inner: &mut Inner, req: &PoolRequest) -> Result<(), PoolError> {
let mut line =
serde_json::to_string(req).map_err(|e| PoolError::ResponseParse(e.to_string()))?;
line.push('\n');
inner
.writer
.write_all(line.as_bytes())
.await
.map_err(|e| PoolError::IoWrite(e.to_string()))?;
inner
.writer
.flush()
.await
.map_err(|e| PoolError::IoWrite(e.to_string()))?;
Ok(())
}
async fn recv_line(inner: &mut Inner) -> Result<PoolResponse, PoolError> {
let mut buf = String::new();
inner
.reader
.read_line(&mut buf)
.await
.map_err(|e| PoolError::IoRead(e.to_string()))?;
serde_json::from_str(buf.trim_end_matches('\n'))
.map_err(|e| PoolError::ResponseParse(e.to_string()))
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use tokio::net::UnixListener;
use super::*;
use crate::pool::protocol::PoolResponseData;
fn temp_sock() -> (tempfile::TempDir, PathBuf) {
let dir = tempfile::tempdir().expect("tempdir");
let sock = dir.path().join("worker.sock");
(dir, sock)
}
async fn spawn_server<F, Fut>(listener: UnixListener, handler: F) -> tokio::task::JoinHandle<()>
where
F: FnOnce(Inner) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let (r, w) = stream.into_split();
let inner = Inner {
writer: BufWriter::new(w),
reader: BufReader::new(r),
};
handler(inner).await;
})
}
async fn server_send(inner: &mut Inner, resp: &PoolResponse) {
let mut line = serde_json::to_string(resp).expect("serialize");
line.push('\n');
inner
.writer
.write_all(line.as_bytes())
.await
.expect("write");
inner.writer.flush().await.expect("flush");
}
async fn server_recv(inner: &mut Inner) -> PoolRequest {
let mut buf = String::new();
inner
.reader
.read_line(&mut buf)
.await
.expect("server read_line");
serde_json::from_str(buf.trim_end_matches('\n')).expect("server deserialize")
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn round_trip_handshake_run_pause_continue_shutdown() {
let (_dir, sock_path) = temp_sock();
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).expect("bind");
let server_handle = spawn_server(listener, |mut inner| async move {
let req = server_recv(&mut inner).await;
assert!(matches!(req, PoolRequest::Handshake { .. }));
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Handshake {
version: POOL_PROTOCOL_VERSION.to_string(),
}),
)
.await;
let req = server_recv(&mut inner).await;
assert!(matches!(req, PoolRequest::Run { .. }));
let feed_result = serde_json::json!({
"type": "paused",
"session_id": "test-sid",
"prompt": "hi",
"query_id": "q1"
});
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Feed {
session_id: "test-sid".to_string(),
feed_result,
}),
)
.await;
let req = server_recv(&mut inner).await;
assert!(matches!(req, PoolRequest::Continue { .. }));
let feed_result = serde_json::json!({"type": "finished", "output": "done"});
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Feed {
session_id: "test-sid".to_string(),
feed_result,
}),
)
.await;
let req = server_recv(&mut inner).await;
assert!(matches!(req, PoolRequest::Shutdown));
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Shutdown),
)
.await;
})
.await;
let mut client = PoolClient::connect(&sock_path).await.expect("connect");
let resp = client
.send_request(PoolRequest::Run {
code: "return alc.llm('hi')".to_string(),
ctx: None,
lib_paths: vec![],
})
.await
.expect("run");
assert!(resp.ok);
assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
let resp = client
.send_request(PoolRequest::Continue {
sid: "test-sid".to_string(),
response: "ok".to_string(),
query_id: Some("q1".to_string()),
usage: None,
})
.await
.expect("continue");
assert!(resp.ok);
assert!(matches!(resp.data, Some(PoolResponseData::Feed { .. })));
let resp = client
.send_request(PoolRequest::Shutdown)
.await
.expect("shutdown");
assert!(resp.ok);
assert!(matches!(resp.data, Some(PoolResponseData::Shutdown)));
server_handle.await.expect("server task");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn version_mismatch_returns_pool_error() {
let (_dir, sock_path) = temp_sock();
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).expect("bind");
let server_handle = spawn_server(listener, |mut inner| async move {
let _ = server_recv(&mut inner).await;
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Handshake {
version: "999.0.0".to_string(),
}),
)
.await;
})
.await;
let err = PoolClient::connect(&sock_path)
.await
.expect_err("should fail with version mismatch");
assert!(
matches!(
err,
PoolError::VersionMismatch {
ref client,
ref server
} if client == POOL_PROTOCOL_VERSION && server == "999.0.0"
),
"unexpected error: {err:?}"
);
server_handle.await.expect("server task");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_connect_handshake_timeout_finite() {
let (_dir, sock_path) = temp_sock();
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).expect("bind");
let _server = spawn_server(listener, |inner| async move {
let _hold = inner;
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
})
.await;
let start = tokio::time::Instant::now();
let err = PoolClient::connect(&sock_path)
.await
.expect_err("should time out and return an error");
let elapsed = start.elapsed();
assert!(
matches!(err, PoolError::Handshake(_)),
"expected PoolError::Handshake, got {err:?}"
);
assert!(
elapsed.as_secs() < HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
"connect must complete within {}s, took {:?}",
HANDSHAKE_RECV_TIMEOUT.as_secs() + 1,
elapsed
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_connect_handshake_concurrent_two_clients() {
use std::sync::Arc;
let (_dir, sock_path) = temp_sock();
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).expect("bind");
let server_handle = spawn_server(listener, |mut inner| async move {
let _ = server_recv(&mut inner).await;
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Handshake {
version: POOL_PROTOCOL_VERSION.to_string(),
}),
)
.await;
let mut counter: u32 = 0;
loop {
let req = server_recv(&mut inner).await;
match req {
PoolRequest::Shutdown => {
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Shutdown),
)
.await;
break;
}
_ => {
let sid = format!("sid-{counter}");
counter += 1;
let feed_result = serde_json::json!({
"type": "finished",
"session_id": sid,
});
server_send(
&mut inner,
&PoolResponse::success(PoolResponseData::Feed {
session_id: sid,
feed_result,
}),
)
.await;
}
}
}
})
.await;
let client = Arc::new(tokio::sync::Mutex::new(
PoolClient::connect(&sock_path).await.expect("connect"),
));
const REQS_PER_TASK: usize = 10;
let client_a = Arc::clone(&client);
let task_a = tokio::spawn(async move {
let mut results = Vec::with_capacity(REQS_PER_TASK);
for _ in 0..REQS_PER_TASK {
let mut guard = client_a.lock().await;
let resp = guard
.send_request(PoolRequest::Run {
code: String::new(),
ctx: None,
lib_paths: vec![],
})
.await
.expect("send_request failed");
if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
results.push(session_id);
}
}
results
});
let client_b = Arc::clone(&client);
let task_b = tokio::spawn(async move {
let mut results = Vec::with_capacity(REQS_PER_TASK);
for _ in 0..REQS_PER_TASK {
let mut guard = client_b.lock().await;
let resp = guard
.send_request(PoolRequest::Run {
code: String::new(),
ctx: None,
lib_paths: vec![],
})
.await
.expect("send_request failed");
if let Some(PoolResponseData::Feed { session_id, .. }) = resp.data {
results.push(session_id);
}
}
results
});
let mut results_a = task_a.await.expect("task_a panicked");
let results_b = task_b.await.expect("task_b panicked");
results_a.extend(results_b);
{
let mut guard = client.lock().await;
let _ = guard.send_request(PoolRequest::Shutdown).await;
}
assert_eq!(
results_a.len(),
REQS_PER_TASK * 2,
"expected {} responses, got {}",
REQS_PER_TASK * 2,
results_a.len()
);
let mut sorted = results_a.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
results_a.len(),
"duplicate session_ids detected: {results_a:?}"
);
server_handle.await.expect("server task");
}
}