use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use car_engine::Runtime;
use car_inference::tasks::generate::{ContentBlock, Message};
use serde_json::{json, Value};
use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex};
use super::agent_loop::{
run_assistant_loop_cancellable, ApprovalDecision, ApprovalGate, AssistantEvent,
};
use super::AssistantConfig;
use crate::coder::native_loop::TurnGenerator;
const APPROVAL_TIMEOUT: Duration = Duration::from_secs(300);
pub struct AssistantService {
generator: Arc<dyn TurnGenerator>,
runtime: Arc<Runtime>,
cfg: AssistantConfig,
system: String,
threads: AsyncMutex<HashMap<String, Vec<Message>>>,
cancels: StdMutex<HashMap<String, Arc<AtomicBool>>>,
approvals: Arc<StdMutex<HashMap<String, oneshot::Sender<bool>>>>,
}
impl AssistantService {
pub fn new(
generator: Arc<dyn TurnGenerator>,
runtime: Arc<Runtime>,
cfg: AssistantConfig,
system: String,
) -> Self {
Self {
generator,
runtime,
cfg,
system,
threads: AsyncMutex::new(HashMap::new()),
cancels: StdMutex::new(HashMap::new()),
approvals: Arc::new(StdMutex::new(HashMap::new())),
}
}
pub fn cancel(&self, session_id: &str) {
if let Ok(g) = self.cancels.lock() {
if let Some(flag) = g.get(session_id) {
flag.store(true, Ordering::Relaxed);
}
}
}
pub fn resolve_approval(&self, approval_id: &str, approved: bool) -> bool {
let tx = self
.approvals
.lock()
.ok()
.and_then(|mut g| g.remove(approval_id));
match tx {
Some(tx) => tx.send(approved).is_ok(),
None => false,
}
}
pub async fn handle_turn<E, Fut>(
&self,
session_id: &str,
prompt: &str,
attachments: Option<Vec<Value>>,
emit: E,
) where
E: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let images: Vec<ContentBlock> = attachments
.unwrap_or_default()
.into_iter()
.filter_map(|a| serde_json::from_value::<ContentBlock>(a).ok())
.filter(|c| {
matches!(
c,
ContentBlock::ImageBase64 { .. } | ContentBlock::ImageUrl { .. }
)
})
.collect();
let cancel = Arc::new(AtomicBool::new(false));
if let Ok(mut g) = self.cancels.lock() {
g.insert(session_id.to_string(), cancel.clone());
}
let mut messages = {
let mut g = self.threads.lock().await;
g.entry(session_id.to_string())
.or_insert_with(|| {
vec![Message::System {
content: self.system.clone(),
}]
})
.clone()
};
messages.push(Message::User {
content: prompt.to_string(),
});
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Value>();
let drain = tokio::spawn(async move {
while let Some(v) = rx.recv().await {
emit(v).await;
}
});
let sid = session_id.to_string();
let tx_term = tx.clone();
let gate = ChatApprovalGate {
session_id: sid.clone(),
tx: tx.clone(),
approvals: self.approvals.clone(),
counter: AtomicU64::new(0),
};
let outcome = run_assistant_loop_cancellable(
&*self.generator,
&self.runtime,
&self.cfg,
&mut messages,
&cancel,
Some(&gate),
if images.is_empty() {
None
} else {
Some(images.as_slice())
},
{
let sid = sid.clone();
move |ev| {
if let Some(mut payload) = event_to_wire(ev) {
payload["session_id"] = json!(sid);
let _ = tx.send(payload);
}
}
},
)
.await;
drop(gate);
match outcome.status {
"max_turns" => {
let _ = tx_term
.send(json!({ "kind": "done", "text": outcome.summary, "session_id": sid }));
}
"cancelled" => {
let _ = tx_term
.send(json!({ "kind": "error", "error": "cancelled", "session_id": sid }));
}
_ => {}
}
drop(tx_term);
let _ = drain.await;
if outcome.status != "cancelled" {
let mut g = self.threads.lock().await;
g.insert(session_id.to_string(), messages);
}
if let Ok(mut g) = self.cancels.lock() {
g.remove(session_id);
}
}
}
struct ChatApprovalGate {
session_id: String,
tx: mpsc::UnboundedSender<Value>,
approvals: Arc<StdMutex<HashMap<String, oneshot::Sender<bool>>>>,
counter: AtomicU64,
}
#[async_trait::async_trait]
impl ApprovalGate for ChatApprovalGate {
async fn request(&self, tool: &str, params: &Value) -> ApprovalDecision {
let n = self.counter.fetch_add(1, Ordering::Relaxed);
let approval_id = format!("{}-appr-{n}", self.session_id);
let (otx, orx) = oneshot::channel();
if let Ok(mut g) = self.approvals.lock() {
g.insert(approval_id.clone(), otx);
}
let _ = self.tx.send(json!({
"kind": "approval_pending",
"approval_id": approval_id,
"tool": tool,
"params": params,
"session_id": self.session_id,
}));
match tokio::time::timeout(APPROVAL_TIMEOUT, orx).await {
Ok(Ok(true)) => ApprovalDecision::Approved,
Ok(Ok(false)) => ApprovalDecision::Denied("declined by user".into()),
_ => {
if let Ok(mut g) = self.approvals.lock() {
g.remove(&approval_id);
}
ApprovalDecision::Denied("approval timed out".into())
}
}
}
}
fn event_to_wire(ev: AssistantEvent) -> Option<Value> {
match ev {
AssistantEvent::Text(t) => Some(json!({ "kind": "token", "delta": t })),
AssistantEvent::ToolCall { name, params } => {
Some(json!({ "kind": "tool_call", "tool": name, "params": params }))
}
AssistantEvent::ToolResult { .. } => None,
AssistantEvent::Done { text } => Some(json!({ "kind": "done", "text": text })),
AssistantEvent::Error(e) => Some(json!({ "kind": "error", "error": e })),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assistant::executor::GeneralExecutor;
use async_trait::async_trait;
use car_engine::{LocalSubstrate, Substrate, ToolExecutor};
use car_inference::{GenerateRequest, InferenceEngine, InferenceResult};
use std::sync::atomic::AtomicUsize;
fn turn(text: &str, tool_calls: Value) -> InferenceResult {
serde_json::from_value(json!({
"text": text, "tool_calls": tool_calls,
"trace_id": "t", "model_used": "scripted", "latency_ms": 0,
}))
.unwrap()
}
struct Script {
turns: Vec<InferenceResult>,
cursor: AtomicUsize,
}
#[async_trait]
impl TurnGenerator for Script {
async fn generate(&self, _r: GenerateRequest) -> Result<InferenceResult, String> {
let i = self.cursor.fetch_add(1, Ordering::SeqCst);
self.turns.get(i).cloned().ok_or("exhausted".into())
}
}
async fn runtime(dir: &std::path::Path) -> Arc<Runtime> {
let substrate: Arc<dyn Substrate> = Arc::new(LocalSubstrate::new());
let exec: Arc<dyn ToolExecutor> =
Arc::new(GeneralExecutor::new(substrate.clone(), dir, true));
let rt = Runtime::new()
.with_inference(Arc::new(InferenceEngine::new(Default::default())))
.with_executor(exec)
.with_substrate(substrate);
rt.register_agent_basics().await;
Arc::new(rt)
}
#[tokio::test]
async fn chat_turn_streams_tokens_and_done() {
let dir = tempfile::tempdir().unwrap();
let rt = runtime(dir.path()).await;
let generator: Arc<dyn TurnGenerator> = Arc::new(Script {
turns: vec![
turn(
"let me compute",
json!([{ "id": "c1", "name": "calculate", "arguments": { "expression": "2+2" } }]),
),
turn("It's 4.", json!([])),
],
cursor: AtomicUsize::new(0),
});
let cfg = AssistantConfig {
model: Some("scripted".into()),
max_turns: 4,
tools: GeneralExecutor::tool_defs(),
gated_tools: Vec::new(),
approval_policy: None,
};
let svc = AssistantService::new(generator, rt, cfg, "sys".into());
let events = Arc::new(StdMutex::new(Vec::<Value>::new()));
let ev2 = events.clone();
svc.handle_turn("s1", "what is 2+2?", None, move |v| {
let ev = ev2.clone();
async move {
ev.lock().unwrap().push(v);
}
})
.await;
let got = events.lock().unwrap().clone();
assert!(got.iter().all(|e| e["session_id"] == "s1"));
assert!(got
.iter()
.any(|e| e["kind"] == "tool_call" && e["tool"] == "calculate"));
let last = got.last().unwrap();
assert_eq!(last["kind"], "done");
assert_eq!(last["text"], "It's 4.");
let thread_len = svc.threads.lock().await.get("s1").map(|m| m.len()).unwrap();
assert!(thread_len >= 3, "thread should persist across the turn");
}
#[tokio::test]
async fn every_chat_event_is_forwardable_by_the_daemon() {
let dir = tempfile::tempdir().unwrap();
let rt = runtime(dir.path()).await;
let generator: Arc<dyn TurnGenerator> = Arc::new(Script {
turns: vec![
turn(
"let me compute",
json!([{ "id": "c1", "name": "calculate", "arguments": { "expression": "1+1" } }]),
),
turn("It's 2.", json!([])),
],
cursor: AtomicUsize::new(0),
});
let cfg = AssistantConfig {
model: Some("scripted".into()),
max_turns: 4,
tools: GeneralExecutor::tool_defs(),
gated_tools: Vec::new(),
approval_policy: None,
};
let svc = AssistantService::new(generator, rt, cfg, "sys".into());
let events = Arc::new(StdMutex::new(Vec::<Value>::new()));
let ev2 = events.clone();
svc.handle_turn("sess-42", "1+1?", None, move |v| {
let ev = ev2.clone();
async move {
ev.lock().unwrap().push(v);
}
})
.await;
const KNOWN_KINDS: [&str; 5] = ["token", "tool_call", "approval_pending", "done", "error"];
let got = events.lock().unwrap().clone();
assert!(!got.is_empty());
for e in &got {
assert_eq!(
e.get("session_id").and_then(Value::as_str),
Some("sess-42"),
"every event must carry its session_id: {e}"
);
let kind = e.get("kind").and_then(Value::as_str).unwrap_or("");
assert!(KNOWN_KINDS.contains(&kind), "unknown event kind: {e}");
}
assert_eq!(got.last().unwrap()["kind"], "done");
}
#[tokio::test]
async fn chat_gated_write_emits_approval_and_resumes_on_approve() {
let dir = tempfile::tempdir().unwrap();
let rt = runtime(dir.path()).await;
let generator: Arc<dyn TurnGenerator> = Arc::new(Script {
turns: vec![
turn(
"",
json!([{ "id": "w1", "name": "write_file", "arguments": { "path": "z.txt", "content": "zephyr" } }]),
),
turn("done", json!([])),
],
cursor: AtomicUsize::new(0),
});
let cfg = AssistantConfig {
model: Some("scripted".into()),
max_turns: 4,
tools: GeneralExecutor::tool_defs(),
gated_tools: vec!["write_file".into()],
approval_policy: None,
};
let svc = Arc::new(AssistantService::new(generator, rt, cfg, "sys".into()));
let events = Arc::new(StdMutex::new(Vec::<Value>::new()));
let ev2 = events.clone();
let svc_run = svc.clone();
let turn_task = tokio::spawn(async move {
svc_run
.handle_turn("s1", "write z.txt", None, move |v| {
let ev = ev2.clone();
async move {
ev.lock().unwrap().push(v);
}
})
.await;
});
let approved = {
let mut ok = false;
for _ in 0..200 {
let id = events
.lock()
.unwrap()
.iter()
.find(|e| e["kind"] == "approval_pending")
.and_then(|e| e["approval_id"].as_str().map(String::from));
if let Some(id) = id {
assert!(svc.resolve_approval(&id, true));
ok = true;
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
ok
};
assert!(approved, "an approval_pending event should have been emitted");
turn_task.await.unwrap();
assert_eq!(
std::fs::read_to_string(dir.path().join("z.txt")).unwrap(),
"zephyr"
);
}
}