use crate::daemon::event::{DaemonEvent, DaemonEventSender};
use runtime::host::Host;
use std::{
collections::HashMap,
path::PathBuf,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use wcore::{
AgentEvent,
protocol::message::{AgentEventKind, AgentEventMsg, ClientMessage, SendMsg, server_message},
};
const ASK_USER_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Clone)]
pub struct DaemonHost {
pub(crate) event_tx: DaemonEventSender,
pub(crate) pending_asks: Arc<Mutex<HashMap<u64, oneshot::Sender<String>>>>,
pub(crate) conversation_cwds: Arc<Mutex<HashMap<u64, PathBuf>>>,
pub(crate) events_tx: broadcast::Sender<AgentEventMsg>,
}
impl Host for DaemonHost {
async fn dispatch_ask_user(&self, args: &str, conversation_id: Option<u64>) -> String {
let input: runtime::ask_user::AskUser = match serde_json::from_str(args) {
Ok(v) => v,
Err(e) => return format!("invalid arguments: {e}"),
};
let conversation_id = match conversation_id {
Some(id) => id,
None => return "ask_user is only available in streaming mode".to_owned(),
};
let (tx, rx) = oneshot::channel();
self.pending_asks.lock().await.insert(conversation_id, tx);
match tokio::time::timeout(ASK_USER_TIMEOUT, rx).await {
Ok(Ok(reply)) => reply,
Ok(Err(_)) => {
self.pending_asks.lock().await.remove(&conversation_id);
"ask_user cancelled: reply channel closed".to_owned()
}
Err(_) => {
self.pending_asks.lock().await.remove(&conversation_id);
let headers: Vec<&str> =
input.questions.iter().map(|q| q.header.as_str()).collect();
format!(
"ask_user timed out after {}s: no reply received for: {}",
ASK_USER_TIMEOUT.as_secs(),
headers.join("; "),
)
}
}
}
async fn dispatch_delegate(&self, args: &str, _agent: &str) -> String {
let input: runtime::task::Delegate = match serde_json::from_str(args) {
Ok(v) => v,
Err(e) => return format!("invalid arguments: {e}"),
};
let mut tasks = Vec::with_capacity(input.tasks.len());
for task in input.tasks {
let sender = delegate_sender();
let handle = spawn_agent_task(
task.agent.clone(),
task.message,
sender.clone(),
self.event_tx.clone(),
);
tasks.push((task.agent, sender, handle));
}
if input.background {
let results: Vec<_> = tasks
.into_iter()
.map(|(agent, sender, _handle)| {
serde_json::json!({ "agent": agent, "task_id": sender })
})
.collect();
return serde_json::to_string(&results)
.unwrap_or_else(|e| format!("serialization error: {e}"));
}
let mut results = Vec::with_capacity(tasks.len());
for (agent_name, _sender, handle) in tasks {
let (result, error) = match handle.await {
Ok((r, e)) => (r, e),
Err(e) => (None, Some(format!("task panicked: {e}"))),
};
results.push(serde_json::json!({
"agent": agent_name,
"result": result,
"error": error,
}));
}
serde_json::to_string(&results).unwrap_or_else(|e| format!("serialization error: {e}"))
}
fn conversation_cwd(&self, conversation_id: u64) -> Option<PathBuf> {
self.conversation_cwds
.try_lock()
.ok()
.and_then(|m| m.get(&conversation_id).cloned())
}
fn on_agent_event(&self, agent: &str, conversation_id: u64, event: &AgentEvent) {
let (kind, content) = match event {
AgentEvent::TextDelta(text) => {
tracing::trace!(%agent, text_len = text.len(), "agent text delta");
(AgentEventKind::TextDelta, String::new())
}
AgentEvent::ThinkingDelta(text) => {
tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
(AgentEventKind::ThinkingDelta, String::new())
}
AgentEvent::ToolCallsBegin(_) => return,
AgentEvent::ToolCallsStart(calls) => {
tracing::debug!(%agent, count = calls.len(), "agent tool calls");
let labels: Vec<String> = calls
.iter()
.map(|c| {
if c.function.name == "bash"
&& let Ok(v) =
serde_json::from_str::<serde_json::Value>(&c.function.arguments)
&& let Some(cmd) = v.get("command").and_then(|c| c.as_str())
{
return format!("bash({})", cmd.lines().next().unwrap_or(""));
}
c.function.name.clone()
})
.collect();
(AgentEventKind::ToolStart, labels.join(", "))
}
AgentEvent::ToolResult {
call_id,
duration_ms,
..
} => {
tracing::debug!(%agent, %call_id, %duration_ms, "agent tool result");
(AgentEventKind::ToolResult, format!("{duration_ms}ms"))
}
AgentEvent::ToolCallsComplete => {
tracing::debug!(%agent, "agent tool calls complete");
(AgentEventKind::ToolsComplete, String::new())
}
AgentEvent::Compact { summary } => {
tracing::info!(%agent, summary_len = summary.len(), "context compacted");
return;
}
AgentEvent::Done(response) => {
tracing::info!(
%agent,
iterations = response.iterations,
stop_reason = %response.stop_reason,
"agent run complete"
);
let content = format_usage(response);
(AgentEventKind::Done, content)
}
};
let _ = self.events_tx.send(AgentEventMsg {
agent: agent.to_string(),
sender: conversation_id.to_string(),
kind: kind.into(),
content,
timestamp: chrono::Utc::now().to_rfc3339(),
});
if let AgentEvent::Done(response) = event {
let payload = response.final_response.clone().unwrap_or_default();
let _ = self.event_tx.send(DaemonEvent::PublishEvent {
source: format!("agent:{}:done", agent),
payload,
});
}
}
async fn reply_to_ask(&self, session: u64, content: String) -> anyhow::Result<bool> {
if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
let _ = tx.send(content);
return Ok(true);
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
let _ = tx.send(content);
return Ok(true);
}
Ok(false)
}
async fn set_conversation_cwd(&self, conversation: u64, cwd: std::path::PathBuf) {
self.conversation_cwds
.lock()
.await
.insert(conversation, cwd);
}
async fn clear_conversation_state(&self, conversation: u64) {
self.pending_asks.lock().await.remove(&conversation);
self.conversation_cwds.lock().await.remove(&conversation);
}
fn subscribe_events(&self) -> Option<broadcast::Receiver<AgentEventMsg>> {
Some(self.events_tx.subscribe())
}
}
fn delegate_sender() -> String {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
format!("delegate:{id}")
}
fn spawn_agent_task(
agent: String,
message: String,
delegate_sender: String,
event_tx: DaemonEventSender,
) -> tokio::task::JoinHandle<(Option<String>, Option<String>)> {
tokio::spawn(async move {
let (reply_tx, mut reply_rx) = mpsc::channel(transport::REPLY_CHANNEL_CAPACITY);
let msg = ClientMessage::from(SendMsg {
agent: agent.clone(),
content: message,
sender: Some(delegate_sender.clone()),
cwd: None,
guest: None,
});
if event_tx
.send(DaemonEvent::Message {
msg,
reply: reply_tx,
})
.is_err()
{
return (None, Some("event channel closed".to_owned()));
}
let mut result_content: Option<String> = None;
let mut error_msg: Option<String> = None;
while let Some(msg) = reply_rx.recv().await {
match msg.msg {
Some(server_message::Msg::Response(resp)) => {
result_content = Some(resp.content);
}
Some(server_message::Msg::Error(err)) => {
error_msg = Some(err.message);
}
_ => {}
}
}
let (reply_tx, _) = mpsc::channel(1);
let _ = event_tx.send(DaemonEvent::Message {
msg: ClientMessage {
msg: Some(wcore::protocol::message::client_message::Msg::Kill(
wcore::protocol::message::KillMsg {
agent,
sender: delegate_sender,
},
)),
},
reply: reply_tx,
});
(result_content, error_msg)
})
}
fn format_usage(response: &wcore::AgentResponse) -> String {
if response.steps.is_empty() {
return String::new();
}
let mut prompt = 0u32;
let mut completion = 0u32;
let mut cache_hit = 0u32;
for step in &response.steps {
let u = &step.response.usage;
prompt += u.prompt_tokens;
completion += u.completion_tokens;
if let Some(v) = u.prompt_cache_hit_tokens {
cache_hit += v;
}
}
let model = &response.model;
if cache_hit > 0 {
format!(
"{model} {} in ({} cached) / {} out",
human_tokens(prompt),
human_tokens(cache_hit),
human_tokens(completion),
)
} else {
format!(
"{model} {} in / {} out",
human_tokens(prompt),
human_tokens(completion),
)
}
}
fn human_tokens(n: u32) -> String {
if n >= 1_000_000 {
format!("{:.1}M", n as f64 / 1_000_000.0)
} else if n >= 1_000 {
format!("{:.1}k", n as f64 / 1_000.0)
} else {
n.to_string()
}
}