1use crate::daemon::event::{DaemonEvent, DaemonEventSender};
7use runtime::host::Host;
8use std::{
9 collections::HashMap,
10 path::PathBuf,
11 sync::{
12 Arc,
13 atomic::{AtomicU64, Ordering},
14 },
15 time::Duration,
16};
17use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
18use wcore::{
19 AgentEvent,
20 protocol::message::{
21 AgentEventKind, AgentEventMsg, ClientMessage, SendMsg, ToolCallInfo, server_message,
22 },
23};
24
25const MAX_TOOL_OUTPUT_BROADCAST: usize = 2048;
29
30const ASK_USER_TIMEOUT: Duration = Duration::from_secs(300);
32
33#[derive(Clone)]
35pub struct DaemonHost {
36 pub(crate) event_tx: DaemonEventSender,
38 pub(crate) pending_asks: Arc<Mutex<HashMap<u64, oneshot::Sender<String>>>>,
40 pub(crate) conversation_cwds: Arc<Mutex<HashMap<u64, PathBuf>>>,
42 pub(crate) events_tx: broadcast::Sender<AgentEventMsg>,
44}
45
46impl Host for DaemonHost {
47 async fn dispatch_ask_user(
48 &self,
49 args: &str,
50 conversation_id: Option<u64>,
51 ) -> Result<String, String> {
52 let input: runtime::ask_user::AskUser =
53 serde_json::from_str(args).map_err(|e| format!("invalid arguments: {e}"))?;
54
55 let conversation_id =
56 conversation_id.ok_or("ask_user is only available in streaming mode")?;
57
58 let (tx, rx) = oneshot::channel();
59 self.pending_asks.lock().await.insert(conversation_id, tx);
60
61 match tokio::time::timeout(ASK_USER_TIMEOUT, rx).await {
62 Ok(Ok(reply)) => Ok(reply),
63 Ok(Err(_)) => {
64 self.pending_asks.lock().await.remove(&conversation_id);
65 Err("ask_user cancelled: reply channel closed".to_owned())
66 }
67 Err(_) => {
68 self.pending_asks.lock().await.remove(&conversation_id);
69 let headers: Vec<&str> =
70 input.questions.iter().map(|q| q.header.as_str()).collect();
71 Err(format!(
72 "ask_user timed out after {}s: no reply received for: {}",
73 ASK_USER_TIMEOUT.as_secs(),
74 headers.join("; "),
75 ))
76 }
77 }
78 }
79
80 async fn dispatch_delegate(&self, args: &str, _agent: &str) -> Result<String, String> {
81 let input: runtime::task::Delegate =
82 serde_json::from_str(args).map_err(|e| format!("invalid arguments: {e}"))?;
83
84 let mut ephemeral_names = Vec::new();
86 let mut tasks = Vec::with_capacity(input.tasks.len());
87 for task in input.tasks {
88 let agent_name = if let Some(prompt) = task.system_prompt {
89 let name = if task.agent.is_empty() {
90 ephemeral_agent_name()
91 } else {
92 task.agent
93 };
94 let mut config = wcore::AgentConfig::new(&name);
95 config.system_prompt = prompt;
96 let (tx, rx) = oneshot::channel();
97 let _ = self
98 .event_tx
99 .send(DaemonEvent::AddEphemeral { config, reply: tx });
100 let _ = rx.await;
101 ephemeral_names.push(name.clone());
102 name
103 } else {
104 task.agent
105 };
106
107 let sender = delegate_sender();
108 let handle = spawn_agent_task(
109 agent_name.clone(),
110 task.message,
111 task.cwd,
112 sender.clone(),
113 self.event_tx.clone(),
114 );
115 tasks.push((agent_name, sender, handle));
116 }
117
118 if input.background {
119 let mut json_results = Vec::with_capacity(tasks.len());
120 let mut handles = Vec::with_capacity(tasks.len());
121 for (agent, sender, handle) in tasks {
122 json_results.push(serde_json::json!({ "agent": agent, "task_id": sender }));
123 handles.push(handle);
124 }
125 if !ephemeral_names.is_empty() {
127 let event_tx = self.event_tx.clone();
128 tokio::spawn(async move {
129 for h in handles {
130 let _ = h.await;
131 }
132 for name in ephemeral_names {
133 let _ = event_tx.send(DaemonEvent::RemoveEphemeral { name });
134 }
135 });
136 }
137 return serde_json::to_string(&json_results)
138 .map_err(|e| format!("serialization error: {e}"));
139 }
140
141 let mut results = Vec::with_capacity(tasks.len());
142 for (agent_name, _sender, handle) in tasks {
143 let (result, error) = match handle.await {
144 Ok((r, e)) => (r, e),
145 Err(e) => (None, Some(format!("task panicked: {e}"))),
146 };
147 results.push(serde_json::json!({
148 "agent": agent_name,
149 "result": result,
150 "error": error,
151 }));
152 }
153
154 for name in ephemeral_names {
156 let _ = self.event_tx.send(DaemonEvent::RemoveEphemeral { name });
157 }
158
159 serde_json::to_string(&results).map_err(|e| format!("serialization error: {e}"))
160 }
161
162 fn conversation_cwd(&self, conversation_id: u64) -> Option<PathBuf> {
163 self.conversation_cwds
164 .try_lock()
165 .ok()
166 .and_then(|m| m.get(&conversation_id).cloned())
167 }
168
169 fn on_agent_event(&self, agent: &str, conversation_id: u64, event: &AgentEvent) {
170 struct Payload {
174 kind: AgentEventKind,
175 content: String,
176 tool_calls: Vec<ToolCallInfo>,
177 tool_output: String,
178 tool_is_error: bool,
179 }
180
181 impl Payload {
182 fn of(kind: AgentEventKind) -> Self {
183 Self {
184 kind,
185 content: String::new(),
186 tool_calls: Vec::new(),
187 tool_output: String::new(),
188 tool_is_error: false,
189 }
190 }
191 }
192
193 let p = match event {
194 AgentEvent::TextStart => Payload::of(AgentEventKind::TextStart),
195 AgentEvent::TextDelta(text) => {
196 tracing::trace!(%agent, text_len = text.len(), "agent text delta");
197 Payload {
198 content: text.clone(),
199 ..Payload::of(AgentEventKind::TextDelta)
200 }
201 }
202 AgentEvent::TextEnd => Payload::of(AgentEventKind::TextEnd),
203 AgentEvent::ThinkingStart => Payload::of(AgentEventKind::ThinkingStart),
204 AgentEvent::ThinkingDelta(text) => {
205 tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
206 Payload {
207 content: text.clone(),
208 ..Payload::of(AgentEventKind::ThinkingDelta)
209 }
210 }
211 AgentEvent::ThinkingEnd => Payload::of(AgentEventKind::ThinkingEnd),
212 AgentEvent::ToolCallsBegin(_) => return,
213 AgentEvent::ToolCallsStart(calls) => {
214 tracing::debug!(%agent, count = calls.len(), "agent tool calls");
215 let mut labels = Vec::with_capacity(calls.len());
218 let mut structured = Vec::with_capacity(calls.len());
219 for c in calls {
220 labels.push(tool_call_label(c));
221 structured.push(ToolCallInfo {
222 name: c.function.name.to_string(),
223 arguments: c.function.arguments.clone(),
224 });
225 }
226 Payload {
227 content: labels.join(", "),
228 tool_calls: structured,
229 ..Payload::of(AgentEventKind::ToolStart)
230 }
231 }
232 AgentEvent::ToolResult {
233 call_id,
234 output,
235 duration_ms,
236 } => {
237 let is_error = output.is_err();
238 let text: &str = match output {
239 Ok(s) | Err(s) => s,
240 };
241 tracing::debug!(%agent, %call_id, %duration_ms, is_error, "agent tool result");
242 Payload {
243 content: format!("{duration_ms}ms"),
244 tool_output: truncate_for_broadcast(text, MAX_TOOL_OUTPUT_BROADCAST),
245 tool_is_error: is_error,
246 ..Payload::of(AgentEventKind::ToolResult)
247 }
248 }
249 AgentEvent::ToolCallsComplete => {
250 tracing::debug!(%agent, "agent tool calls complete");
251 Payload::of(AgentEventKind::ToolsComplete)
252 }
253 AgentEvent::Compact { summary } => {
254 tracing::info!(%agent, summary_len = summary.len(), "context compacted");
255 return;
256 }
257 AgentEvent::UserSteered { content } => {
258 tracing::info!(%agent, content_len = content.len(), "user steered session");
259 return;
260 }
261 AgentEvent::Done(response) => {
262 tracing::info!(
263 %agent,
264 iterations = response.iterations,
265 stop_reason = %response.stop_reason,
266 "agent run complete"
267 );
268 Payload {
269 content: format_usage(response),
270 ..Payload::of(AgentEventKind::Done)
271 }
272 }
273 };
274 let _ = self.events_tx.send(AgentEventMsg {
279 agent: agent.to_string(),
280 sender: conversation_id.to_string(),
281 kind: p.kind.into(),
282 content: p.content,
283 timestamp: chrono::Utc::now().to_rfc3339(),
284 tool_calls: p.tool_calls,
285 tool_output: p.tool_output,
286 tool_is_error: p.tool_is_error,
287 });
288
289 if let AgentEvent::Done(response) = event {
291 let payload = response.final_response.clone().unwrap_or_default();
292 let _ = self.event_tx.send(DaemonEvent::PublishEvent {
293 source: format!("agent:{}:done", agent),
294 payload,
295 });
296 }
297 }
298
299 async fn reply_to_ask(&self, session: u64, content: String) -> anyhow::Result<bool> {
300 if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
301 let _ = tx.send(content);
302 return Ok(true);
303 }
304 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
305 if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
306 let _ = tx.send(content);
307 return Ok(true);
308 }
309 Ok(false)
310 }
311
312 async fn set_conversation_cwd(&self, conversation: u64, cwd: std::path::PathBuf) {
313 self.conversation_cwds
314 .lock()
315 .await
316 .insert(conversation, cwd);
317 }
318
319 async fn clear_conversation_state(&self, conversation: u64) {
320 self.pending_asks.lock().await.remove(&conversation);
321 self.conversation_cwds.lock().await.remove(&conversation);
322 }
323
324 fn subscribe_events(&self) -> Option<broadcast::Receiver<AgentEventMsg>> {
325 Some(self.events_tx.subscribe())
326 }
327}
328
329fn delegate_sender() -> String {
331 static COUNTER: AtomicU64 = AtomicU64::new(0);
332 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
333 format!("delegate:{id}")
334}
335
336fn ephemeral_agent_name() -> String {
338 static COUNTER: AtomicU64 = AtomicU64::new(0);
339 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
340 format!("_ephemeral:{id}")
341}
342
343fn spawn_agent_task(
345 agent: String,
346 message: String,
347 cwd: Option<String>,
348 delegate_sender: String,
349 event_tx: DaemonEventSender,
350) -> tokio::task::JoinHandle<(Option<String>, Option<String>)> {
351 tokio::spawn(async move {
352 let (reply_tx, mut reply_rx) = mpsc::channel(transport::REPLY_CHANNEL_CAPACITY);
353 let msg = ClientMessage::from(SendMsg {
354 agent: agent.clone(),
355 content: message,
356 sender: Some(delegate_sender.clone()),
357 cwd,
358 guest: None,
359 tool_choice: None,
360 });
361 if event_tx
362 .send(DaemonEvent::Message {
363 msg,
364 reply: reply_tx,
365 })
366 .is_err()
367 {
368 return (None, Some("event channel closed".to_owned()));
369 }
370
371 let mut result_content: Option<String> = None;
372 let mut error_msg: Option<String> = None;
373
374 while let Some(msg) = reply_rx.recv().await {
375 match msg.msg {
376 Some(server_message::Msg::Response(resp)) => {
377 result_content = Some(resp.content);
378 }
379 Some(server_message::Msg::Error(err)) => {
380 error_msg = Some(err.message);
381 }
382 _ => {}
383 }
384 }
385
386 let (reply_tx, _) = mpsc::channel(1);
388 let _ = event_tx.send(DaemonEvent::Message {
389 msg: ClientMessage {
390 msg: Some(wcore::protocol::message::client_message::Msg::Kill(
391 wcore::protocol::message::KillMsg {
392 agent,
393 sender: delegate_sender,
394 },
395 )),
396 },
397 reply: reply_tx,
398 });
399
400 (result_content, error_msg)
401 })
402}
403
404fn format_usage(response: &wcore::AgentResponse) -> String {
405 if response.steps.is_empty() {
406 return String::new();
407 }
408 let mut prompt = 0u32;
409 let mut completion = 0u32;
410 let mut cache_hit = 0u32;
411 for step in &response.steps {
412 let u = &step.usage;
413 prompt += u.prompt_tokens;
414 completion += u.completion_tokens;
415 if let Some(v) = u.prompt_cache_hit_tokens {
416 cache_hit += v;
417 }
418 }
419 let model = &response.model;
420 if cache_hit > 0 {
421 format!(
422 "{model} {} in ({} cached) / {} out",
423 human_tokens(prompt),
424 human_tokens(cache_hit),
425 human_tokens(completion),
426 )
427 } else {
428 format!(
429 "{model} {} in / {} out",
430 human_tokens(prompt),
431 human_tokens(completion),
432 )
433 }
434}
435
436fn human_tokens(n: u32) -> String {
437 if n >= 1_000_000 {
438 format!("{:.1}M", n as f64 / 1_000_000.0)
439 } else if n >= 1_000 {
440 format!("{:.1}k", n as f64 / 1_000.0)
441 } else {
442 n.to_string()
443 }
444}
445
446fn tool_call_label(c: &wcore::model::ToolCall) -> String {
451 if c.function.name == "bash"
452 && let Ok(v) = serde_json::from_str::<serde_json::Value>(&c.function.arguments)
453 && let Some(cmd) = v.get("command").and_then(|c| c.as_str())
454 {
455 return format!("bash({})", cmd.lines().next().unwrap_or(""));
456 }
457 c.function.name.clone()
458}
459
460fn truncate_for_broadcast(s: &str, max: usize) -> String {
469 if s.len() <= max {
470 return s.to_owned();
471 }
472 let marker = "…[truncated]";
473 if max <= marker.len() {
474 return marker.to_owned();
475 }
476 let mut end = max - marker.len();
477 while end > 0 && !s.is_char_boundary(end) {
478 end -= 1;
479 }
480 format!("{}{marker}", &s[..end])
481}