use std::path::{Path, PathBuf};
use tokio::time::{sleep, timeout, Duration};
use crate::pool::{
registry::with_registry_lock, PoolClient, PoolError, PoolRegistry, PoolRequest, PoolResponse,
PoolResponseData, PoolSessionEntry,
};
async fn spawn_worker(pool_dir: &Path, sid: &str) -> Result<(u32, PathBuf), PoolError> {
let sock = pool_dir.join(format!("{sid}.sock"));
let exe = std::env::current_exe()
.map_err(|e| PoolError::Spawn(format!("current_exe failed: {e}")))?;
let mut cmd = tokio::process::Command::new(&exe);
cmd.args([
"pool-worker",
"--sid",
sid,
"--sock",
&sock.to_string_lossy(),
]);
#[cfg(unix)]
{
unsafe {
cmd.pre_exec(|| {
libc::setsid();
Ok(())
});
}
}
let mut child = cmd
.spawn()
.map_err(|e| PoolError::Spawn(format!("worker spawn failed: {e}")))?;
let pid = child.id().ok_or_else(|| {
PoolError::Spawn("child.id() returned None — process already exited".to_string())
})?;
let sid_owned = sid.to_string();
tokio::spawn(async move {
match child.wait().await {
Ok(status) => tracing::debug!(sid = %sid_owned, ?status, "pool worker reaped"),
Err(e) => tracing::warn!(sid = %sid_owned, error = %e, "pool worker wait error"),
}
});
Ok((pid, sock))
}
fn gen_pool_sid() -> String {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
use std::time::{SystemTime, UNIX_EPOCH};
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let random: u64 = {
let s = RandomState::new();
let mut h = s.build_hasher();
h.write_u128(ts);
h.finish()
};
format!("p-{ts:x}-{random:016x}")
}
pub async fn run_via_pool(
pool_dir: &Path,
reg_path: &Path,
lock_path: &Path,
extra_lib_paths: Vec<PathBuf>,
code: String,
ctx: serde_json::Value,
) -> Result<(String, String, Option<String>), PoolError> {
let sid = gen_pool_sid();
let (pid, sock) = spawn_worker(pool_dir, &sid).await?;
{
let sock_clone = sock.clone();
timeout(Duration::from_secs(5), async {
loop {
if sock_clone.exists() {
break;
}
sleep(Duration::from_millis(50)).await;
}
})
.await
.map_err(|_| {
PoolError::Handshake(format!(
"timeout waiting for worker socket at {}",
sock.display()
))
})?;
}
let mut client = PoolClient::connect(&sock).await?;
let resp = client
.send_request(PoolRequest::Run {
code,
ctx: Some(ctx),
lib_paths: extra_lib_paths,
})
.await?;
let (worker_sid, feed_result_json) = extract_feed_response(resp)?;
let mcp_json = feed_result_to_mcp_json(&worker_sid, &feed_result_json);
let pool_save_error = persist_entry(
reg_path.to_path_buf(),
lock_path.to_path_buf(),
PoolSessionEntry::new(&worker_sid, pid, sock, env!("CARGO_PKG_VERSION")),
)
.await;
Ok((worker_sid, mcp_json.to_string(), pool_save_error))
}
pub async fn continue_via_pool(
entry: &PoolSessionEntry,
sid: &str,
response: String,
query_id: Option<String>,
usage: Option<algocline_core::TokenUsage>,
) -> Result<String, PoolError> {
let mut client = PoolClient::connect(&entry.sock).await?;
let resp = client
.send_request(PoolRequest::Continue {
sid: sid.to_string(),
response,
query_id,
usage,
})
.await?;
let (session_id, feed_result_json) = extract_feed_response(resp)?;
let mcp_json = feed_result_to_mcp_json(&session_id, &feed_result_json);
Ok(mcp_json.to_string())
}
fn feed_result_to_mcp_json(session_id: &str, feed_result: &serde_json::Value) -> serde_json::Value {
use serde_json::json;
if let Some(paused) = feed_result.get("Paused") {
let queries = paused.get("queries").and_then(|q| q.as_array());
match queries {
Some(qs) if qs.len() == 1 => {
let q = &qs[0];
let mut obj = json!({
"status": "needs_response",
"session_id": session_id,
"query_id": q.get("id").and_then(|v| v.as_str()).unwrap_or("q-0"),
"prompt": q.get("prompt").cloned().unwrap_or(serde_json::Value::Null),
"system": q.get("system").cloned().unwrap_or(serde_json::Value::Null),
"max_tokens": q.get("max_tokens").cloned().unwrap_or(json!(1024)),
});
if q.get("grounded").and_then(|v| v.as_bool()).unwrap_or(false) {
obj["grounded"] = json!(true);
}
if q.get("underspecified")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
obj["underspecified"] = json!(true);
}
obj
}
Some(qs) => {
let mapped: Vec<serde_json::Value> = qs
.iter()
.map(|q| {
let mut obj = json!({
"id": q.get("id").cloned().unwrap_or(json!("q-0")),
"prompt": q.get("prompt").cloned().unwrap_or(serde_json::Value::Null),
"system": q.get("system").cloned().unwrap_or(serde_json::Value::Null),
"max_tokens": q.get("max_tokens").cloned().unwrap_or(json!(1024)),
});
if q.get("grounded").and_then(|v| v.as_bool()).unwrap_or(false) {
obj["grounded"] = json!(true);
}
if q.get("underspecified")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
obj["underspecified"] = json!(true);
}
obj
})
.collect();
json!({
"status": "needs_response",
"session_id": session_id,
"queries": mapped,
})
}
None => json!({
"status": "needs_response",
"session_id": session_id,
}),
}
} else if let Some(finished) = feed_result.get("Finished") {
let state = finished.get("state");
let metrics = finished.get("metrics");
if let Some(completed) = state.and_then(|s| s.get("Completed")) {
json!({
"status": "completed",
"result": completed.get("result").cloned().unwrap_or(serde_json::Value::Null),
"stats": metrics.cloned().unwrap_or(serde_json::Value::Null),
})
} else if let Some(failed) = state.and_then(|s| s.get("Failed")) {
json!({
"status": "error",
"error": failed.get("error").and_then(|v| v.as_str()).unwrap_or("execution failed"),
})
} else {
json!({
"status": "cancelled",
"stats": metrics.cloned().unwrap_or(serde_json::Value::Null),
})
}
} else if let Some(accepted) = feed_result.get("Accepted") {
json!({
"status": "accepted",
"remaining": accepted.get("remaining").cloned().unwrap_or(json!(0)),
})
} else {
json!({
"status": "error",
"error": format!("unrecognised FeedResult shape from worker: {feed_result}"),
})
}
}
fn extract_feed_response(resp: PoolResponse) -> Result<(String, serde_json::Value), PoolError> {
match resp.data {
Some(PoolResponseData::Feed {
session_id,
feed_result,
}) => Ok((session_id, feed_result)),
Some(other) => Err(PoolError::ResponseParse(format!(
"expected Feed response, got {other:?}"
))),
None => {
let err = resp.error.unwrap_or_else(|| "unknown error".to_string());
Err(PoolError::ResponseParse(format!(
"worker returned error: {err}"
)))
}
}
}
async fn persist_entry(
reg_path: PathBuf,
lock_path: PathBuf,
entry: PoolSessionEntry,
) -> Option<String> {
match tokio::task::spawn_blocking(move || {
with_registry_lock(&lock_path, || {
let mut reg = PoolRegistry::load_or_default(®_path)?;
reg.add(entry);
reg.save(®_path)
})
})
.await
{
Ok(Ok(())) => None,
Ok(Err(e)) => Some(e.to_string()),
Err(e) => Some(format!("spawn_blocking join error: {e}")),
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::*;
use crate::pool::{protocol::PoolResponseData, PoolResponse};
#[test]
fn extract_feed_response_ok_on_feed() {
let resp = PoolResponse::success(PoolResponseData::Feed {
session_id: "test-sid".to_string(),
feed_result: serde_json::json!({"status": "needs_response"}),
});
let (sid, json) = extract_feed_response(resp).expect("should extract feed");
assert_eq!(sid, "test-sid");
assert_eq!(json["status"], "needs_response");
}
#[test]
fn gen_pool_sid_is_unique() {
let ids: Vec<_> = (0..20).map(|_| gen_pool_sid()).collect();
let unique: std::collections::HashSet<_> = ids.iter().collect();
assert_eq!(
unique.len(),
ids.len(),
"all generated session IDs must be distinct"
);
}
#[test]
fn gen_pool_sid_has_prefix() {
let sid = gen_pool_sid();
assert!(sid.starts_with("p-"), "sid must start with 'p-', got {sid}");
}
#[test]
fn extract_feed_response_error_on_non_feed() {
let resp = PoolResponse::success(PoolResponseData::Shutdown);
let err = extract_feed_response(resp).expect_err("should fail on Shutdown response");
assert!(
matches!(err, PoolError::ResponseParse(_)),
"expected ResponseParse, got {err:?}"
);
}
#[test]
fn extract_feed_response_error_on_failure_response() {
let resp = PoolResponse::failure("something went wrong");
let err = extract_feed_response(resp).expect_err("should fail on error response");
match err {
PoolError::ResponseParse(msg) => {
assert!(
msg.contains("something went wrong"),
"error must include worker message, got: {msg}"
);
}
other => panic!("expected ResponseParse, got {other:?}"),
}
}
#[tokio::test]
async fn persist_entry_returns_some_on_io_error() {
let dir = tempfile::tempdir().expect("tempdir");
let blocker = dir.path().join("blocker");
std::fs::write(&blocker, b"not a dir").expect("write blocker");
let reg_path = blocker.join("registry.json"); let lock_path = blocker.join("registry.lock");
let entry = PoolSessionEntry::new(
"test-sid",
std::process::id(),
PathBuf::from("/tmp/test.sock"),
"0.30.0",
);
let result = persist_entry(reg_path, lock_path, entry).await;
assert!(
result.is_some(),
"persist_entry must return Some(error) on I/O failure"
);
}
#[test]
fn feed_result_to_mcp_json_paused_single_query() {
let feed = serde_json::json!({
"Paused": {
"queries": [{
"id": "q-0",
"prompt": "What is 1+1?",
"system": null,
"max_tokens": 1024,
"grounded": false,
"underspecified": false
}]
}
});
let mcp = feed_result_to_mcp_json("sid-abc", &feed);
assert_eq!(mcp["status"], "needs_response");
assert_eq!(mcp["session_id"], "sid-abc");
assert_eq!(mcp["query_id"], "q-0");
assert_eq!(mcp["prompt"], "What is 1+1?");
}
#[test]
fn feed_result_to_mcp_json_finished_completed() {
let feed = serde_json::json!({
"Finished": {
"state": {
"Completed": { "result": {"answer": 42} }
},
"metrics": {}
}
});
let mcp = feed_result_to_mcp_json("sid-xyz", &feed);
assert_eq!(mcp["status"], "completed");
assert_eq!(mcp["result"]["answer"], 42);
}
#[test]
fn feed_result_to_mcp_json_paused_multi_query() {
let feed = serde_json::json!({
"Paused": {
"queries": [
{"id": "q-0", "prompt": "P1", "system": null, "max_tokens": 512},
{"id": "q-1", "prompt": "P2", "system": null, "max_tokens": 512}
]
}
});
let mcp = feed_result_to_mcp_json("sid-multi", &feed);
assert_eq!(mcp["status"], "needs_response");
assert_eq!(mcp["session_id"], "sid-multi");
let qs = mcp["queries"].as_array().expect("queries array");
assert_eq!(qs.len(), 2);
assert_eq!(qs[0]["id"], "q-0");
}
#[test]
fn feed_result_to_mcp_json_unknown_shape_is_error() {
let feed = serde_json::json!({"Unknown": {}});
let mcp = feed_result_to_mcp_json("sid-bad", &feed);
assert_eq!(mcp["status"], "error");
assert!(
mcp["error"].as_str().unwrap_or("").contains("unrecognised"),
"error message must mention 'unrecognised', got: {}",
mcp["error"]
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_spawn_worker_reaps_child_no_zombie() {
let mut cmd = tokio::process::Command::new("true");
let mut child = cmd.spawn().expect("spawn true");
let pid = child.id().expect("child.id() must be Some before wait");
tokio::spawn(async move {
let _ = child.wait().await;
});
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let pid_i32 = i32::try_from(pid).expect("pid fits i32");
let rc = unsafe { libc::kill(pid_i32, 0) };
assert_eq!(
rc, -1,
"process should be gone (kill(pid,0) must return -1)"
);
let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
assert_eq!(
errno,
libc::ESRCH,
"errno must be ESRCH (no such process), got {errno}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_spawn_worker_child_id_none_returns_pool_error() {
let mut child = tokio::process::Command::new("false")
.spawn()
.expect("spawn false");
let _status = child.wait().await.expect("wait");
let id = child.id();
assert!(
id.is_none(),
"child.id() must be None after wait(): got {:?}",
id
);
let result: Result<u32, crate::pool::PoolError> = id.ok_or_else(|| {
crate::pool::PoolError::Spawn(
"child.id() returned None — process already exited".to_string(),
)
});
assert!(
matches!(result, Err(crate::pool::PoolError::Spawn(_))),
"expected Err(PoolError::Spawn), got {:?}",
result
);
}
}