Skip to main content

crabtalk_daemon/hook/
host.rs

1//! DaemonHost — server-specific Host implementation.
2//!
3//! Provides `ask_user` and `delegate` dispatch using daemon event channels,
4//! per-session CWD resolution, and agent event broadcasting.
5
6use crate::daemon::event::{DaemonEvent, DaemonEventSender};
7use runtime::host::Host;
8use std::{
9    collections::HashMap,
10    path::PathBuf,
11    sync::{
12        Arc,
13        atomic::{AtomicU64, Ordering},
14    },
15    time::Duration,
16};
17use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
18use wcore::{
19    AgentEvent,
20    protocol::message::{
21        AgentEventKind, AgentEventMsg, ClientMessage, SendMsg, ToolCallInfo, server_message,
22    },
23};
24
25/// Tool result output is truncated to this many bytes in the broadcast.
26/// Keeps the firehose lightweight while still giving rich UIs enough
27/// content to render meaningful previews.
28const MAX_TOOL_OUTPUT_BROADCAST: usize = 2048;
29
30/// Timeout for waiting on user reply (5 minutes).
31const ASK_USER_TIMEOUT: Duration = Duration::from_secs(300);
32
33/// Server-specific host for the daemon. Owns event channels and session state.
34#[derive(Clone)]
35pub struct DaemonHost {
36    /// Event channel for task delegation.
37    pub(crate) event_tx: DaemonEventSender,
38    /// Pending `ask_user` oneshots, keyed by conversation_id.
39    pub(crate) pending_asks: Arc<Mutex<HashMap<u64, oneshot::Sender<String>>>>,
40    /// Per-conversation working directory overrides.
41    pub(crate) conversation_cwds: Arc<Mutex<HashMap<u64, PathBuf>>>,
42    /// Broadcast channel for agent events (console subscription).
43    pub(crate) events_tx: broadcast::Sender<AgentEventMsg>,
44}
45
46impl Host for DaemonHost {
47    async fn dispatch_ask_user(
48        &self,
49        args: &str,
50        conversation_id: Option<u64>,
51    ) -> Result<String, String> {
52        let input: runtime::ask_user::AskUser =
53            serde_json::from_str(args).map_err(|e| format!("invalid arguments: {e}"))?;
54
55        let conversation_id =
56            conversation_id.ok_or("ask_user is only available in streaming mode")?;
57
58        let (tx, rx) = oneshot::channel();
59        self.pending_asks.lock().await.insert(conversation_id, tx);
60
61        match tokio::time::timeout(ASK_USER_TIMEOUT, rx).await {
62            Ok(Ok(reply)) => Ok(reply),
63            Ok(Err(_)) => {
64                self.pending_asks.lock().await.remove(&conversation_id);
65                Err("ask_user cancelled: reply channel closed".to_owned())
66            }
67            Err(_) => {
68                self.pending_asks.lock().await.remove(&conversation_id);
69                let headers: Vec<&str> =
70                    input.questions.iter().map(|q| q.header.as_str()).collect();
71                Err(format!(
72                    "ask_user timed out after {}s: no reply received for: {}",
73                    ASK_USER_TIMEOUT.as_secs(),
74                    headers.join("; "),
75                ))
76            }
77        }
78    }
79
80    async fn dispatch_delegate(&self, args: &str, _agent: &str) -> Result<String, String> {
81        let input: runtime::task::Delegate =
82            serde_json::from_str(args).map_err(|e| format!("invalid arguments: {e}"))?;
83
84        // Register ephemeral agents and resolve agent names.
85        let mut ephemeral_names = Vec::new();
86        let mut tasks = Vec::with_capacity(input.tasks.len());
87        for task in input.tasks {
88            let agent_name = if let Some(prompt) = task.system_prompt {
89                let name = if task.agent.is_empty() {
90                    ephemeral_agent_name()
91                } else {
92                    task.agent
93                };
94                let mut config = wcore::AgentConfig::new(&name);
95                config.system_prompt = prompt;
96                let (tx, rx) = oneshot::channel();
97                let _ = self
98                    .event_tx
99                    .send(DaemonEvent::AddEphemeral { config, reply: tx });
100                let _ = rx.await;
101                ephemeral_names.push(name.clone());
102                name
103            } else {
104                task.agent
105            };
106
107            let sender = delegate_sender();
108            let handle = spawn_agent_task(
109                agent_name.clone(),
110                task.message,
111                task.cwd,
112                sender.clone(),
113                self.event_tx.clone(),
114            );
115            tasks.push((agent_name, sender, handle));
116        }
117
118        if input.background {
119            let mut json_results = Vec::with_capacity(tasks.len());
120            let mut handles = Vec::with_capacity(tasks.len());
121            for (agent, sender, handle) in tasks {
122                json_results.push(serde_json::json!({ "agent": agent, "task_id": sender }));
123                handles.push(handle);
124            }
125            // Spawn cleanup that waits for all delegates to finish.
126            if !ephemeral_names.is_empty() {
127                let event_tx = self.event_tx.clone();
128                tokio::spawn(async move {
129                    for h in handles {
130                        let _ = h.await;
131                    }
132                    for name in ephemeral_names {
133                        let _ = event_tx.send(DaemonEvent::RemoveEphemeral { name });
134                    }
135                });
136            }
137            return serde_json::to_string(&json_results)
138                .map_err(|e| format!("serialization error: {e}"));
139        }
140
141        let mut results = Vec::with_capacity(tasks.len());
142        for (agent_name, _sender, handle) in tasks {
143            let (result, error) = match handle.await {
144                Ok((r, e)) => (r, e),
145                Err(e) => (None, Some(format!("task panicked: {e}"))),
146            };
147            results.push(serde_json::json!({
148                "agent": agent_name,
149                "result": result,
150                "error": error,
151            }));
152        }
153
154        // Clean up ephemeral agents after foreground tasks complete.
155        for name in ephemeral_names {
156            let _ = self.event_tx.send(DaemonEvent::RemoveEphemeral { name });
157        }
158
159        serde_json::to_string(&results).map_err(|e| format!("serialization error: {e}"))
160    }
161
162    fn conversation_cwd(&self, conversation_id: u64) -> Option<PathBuf> {
163        self.conversation_cwds
164            .try_lock()
165            .ok()
166            .and_then(|m| m.get(&conversation_id).cloned())
167    }
168
169    fn on_agent_event(&self, agent: &str, conversation_id: u64, event: &AgentEvent) {
170        /// Kind-specific payload built per match arm. `kind` is required —
171        /// no `Default` impl, so the compiler forces every arm to set it.
172        /// The other fields default to empty via struct update syntax.
173        struct Payload {
174            kind: AgentEventKind,
175            content: String,
176            tool_calls: Vec<ToolCallInfo>,
177            tool_output: String,
178            tool_is_error: bool,
179        }
180
181        impl Payload {
182            fn of(kind: AgentEventKind) -> Self {
183                Self {
184                    kind,
185                    content: String::new(),
186                    tool_calls: Vec::new(),
187                    tool_output: String::new(),
188                    tool_is_error: false,
189                }
190            }
191        }
192
193        let p = match event {
194            AgentEvent::TextStart => Payload::of(AgentEventKind::TextStart),
195            AgentEvent::TextDelta(text) => {
196                tracing::trace!(%agent, text_len = text.len(), "agent text delta");
197                Payload {
198                    content: text.clone(),
199                    ..Payload::of(AgentEventKind::TextDelta)
200                }
201            }
202            AgentEvent::TextEnd => Payload::of(AgentEventKind::TextEnd),
203            AgentEvent::ThinkingStart => Payload::of(AgentEventKind::ThinkingStart),
204            AgentEvent::ThinkingDelta(text) => {
205                tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
206                Payload {
207                    content: text.clone(),
208                    ..Payload::of(AgentEventKind::ThinkingDelta)
209                }
210            }
211            AgentEvent::ThinkingEnd => Payload::of(AgentEventKind::ThinkingEnd),
212            AgentEvent::ToolCallsBegin(_) => return,
213            AgentEvent::ToolCallsStart(calls) => {
214                tracing::debug!(%agent, count = calls.len(), "agent tool calls");
215                // Single pass over `calls` builds both the human label and
216                // the structured copy.
217                let mut labels = Vec::with_capacity(calls.len());
218                let mut structured = Vec::with_capacity(calls.len());
219                for c in calls {
220                    labels.push(tool_call_label(c));
221                    structured.push(ToolCallInfo {
222                        name: c.function.name.to_string(),
223                        arguments: c.function.arguments.clone(),
224                    });
225                }
226                Payload {
227                    content: labels.join(", "),
228                    tool_calls: structured,
229                    ..Payload::of(AgentEventKind::ToolStart)
230                }
231            }
232            AgentEvent::ToolResult {
233                call_id,
234                output,
235                duration_ms,
236            } => {
237                let is_error = output.is_err();
238                let text: &str = match output {
239                    Ok(s) | Err(s) => s,
240                };
241                tracing::debug!(%agent, %call_id, %duration_ms, is_error, "agent tool result");
242                Payload {
243                    content: format!("{duration_ms}ms"),
244                    tool_output: truncate_for_broadcast(text, MAX_TOOL_OUTPUT_BROADCAST),
245                    tool_is_error: is_error,
246                    ..Payload::of(AgentEventKind::ToolResult)
247                }
248            }
249            AgentEvent::ToolCallsComplete => {
250                tracing::debug!(%agent, "agent tool calls complete");
251                Payload::of(AgentEventKind::ToolsComplete)
252            }
253            AgentEvent::Compact { summary } => {
254                tracing::info!(%agent, summary_len = summary.len(), "context compacted");
255                return;
256            }
257            AgentEvent::UserSteered { content } => {
258                tracing::info!(%agent, content_len = content.len(), "user steered session");
259                return;
260            }
261            AgentEvent::Done(response) => {
262                tracing::info!(
263                    %agent,
264                    iterations = response.iterations,
265                    stop_reason = %response.stop_reason,
266                    "agent run complete"
267                );
268                Payload {
269                    content: format_usage(response),
270                    ..Payload::of(AgentEventKind::Done)
271                }
272            }
273        };
274        // The sender field is derived from the conversation's created_by.
275        // Since we don't have access to conversation state here, we use
276        // conversation_id as a string placeholder — subscribers correlate
277        // by agent name.
278        let _ = self.events_tx.send(AgentEventMsg {
279            agent: agent.to_string(),
280            sender: conversation_id.to_string(),
281            kind: p.kind.into(),
282            content: p.content,
283            timestamp: chrono::Utc::now().to_rfc3339(),
284            tool_calls: p.tool_calls,
285            tool_output: p.tool_output,
286            tool_is_error: p.tool_is_error,
287        });
288
289        // Publish agent completion to the event bus.
290        if let AgentEvent::Done(response) = event {
291            let payload = response.final_response.clone().unwrap_or_default();
292            let _ = self.event_tx.send(DaemonEvent::PublishEvent {
293                source: format!("agent:{}:done", agent),
294                payload,
295            });
296        }
297    }
298
299    async fn reply_to_ask(&self, session: u64, content: String) -> anyhow::Result<bool> {
300        if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
301            let _ = tx.send(content);
302            return Ok(true);
303        }
304        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
305        if let Some(tx) = self.pending_asks.lock().await.remove(&session) {
306            let _ = tx.send(content);
307            return Ok(true);
308        }
309        Ok(false)
310    }
311
312    async fn set_conversation_cwd(&self, conversation: u64, cwd: std::path::PathBuf) {
313        self.conversation_cwds
314            .lock()
315            .await
316            .insert(conversation, cwd);
317    }
318
319    async fn clear_conversation_state(&self, conversation: u64) {
320        self.pending_asks.lock().await.remove(&conversation);
321        self.conversation_cwds.lock().await.remove(&conversation);
322    }
323
324    fn subscribe_events(&self) -> Option<broadcast::Receiver<AgentEventMsg>> {
325        Some(self.events_tx.subscribe())
326    }
327}
328
329/// Generate a unique delegate sender identity.
330fn delegate_sender() -> String {
331    static COUNTER: AtomicU64 = AtomicU64::new(0);
332    let id = COUNTER.fetch_add(1, Ordering::Relaxed);
333    format!("delegate:{id}")
334}
335
336/// Generate a unique ephemeral agent name.
337fn ephemeral_agent_name() -> String {
338    static COUNTER: AtomicU64 = AtomicU64::new(0);
339    let id = COUNTER.fetch_add(1, Ordering::Relaxed);
340    format!("_ephemeral:{id}")
341}
342
343/// Spawn an agent task via the event channel and collect its response.
344fn spawn_agent_task(
345    agent: String,
346    message: String,
347    cwd: Option<String>,
348    delegate_sender: String,
349    event_tx: DaemonEventSender,
350) -> tokio::task::JoinHandle<(Option<String>, Option<String>)> {
351    tokio::spawn(async move {
352        let (reply_tx, mut reply_rx) = mpsc::channel(transport::REPLY_CHANNEL_CAPACITY);
353        let msg = ClientMessage::from(SendMsg {
354            agent: agent.clone(),
355            content: message,
356            sender: Some(delegate_sender.clone()),
357            cwd,
358            guest: None,
359            tool_choice: None,
360        });
361        if event_tx
362            .send(DaemonEvent::Message {
363                msg,
364                reply: reply_tx,
365            })
366            .is_err()
367        {
368            return (None, Some("event channel closed".to_owned()));
369        }
370
371        let mut result_content: Option<String> = None;
372        let mut error_msg: Option<String> = None;
373
374        while let Some(msg) = reply_rx.recv().await {
375            match msg.msg {
376                Some(server_message::Msg::Response(resp)) => {
377                    result_content = Some(resp.content);
378                }
379                Some(server_message::Msg::Error(err)) => {
380                    error_msg = Some(err.message);
381                }
382                _ => {}
383            }
384        }
385
386        // Kill the delegate conversation after completion.
387        let (reply_tx, _) = mpsc::channel(1);
388        let _ = event_tx.send(DaemonEvent::Message {
389            msg: ClientMessage {
390                msg: Some(wcore::protocol::message::client_message::Msg::Kill(
391                    wcore::protocol::message::KillMsg {
392                        agent,
393                        sender: delegate_sender,
394                    },
395                )),
396            },
397            reply: reply_tx,
398        });
399
400        (result_content, error_msg)
401    })
402}
403
404fn format_usage(response: &wcore::AgentResponse) -> String {
405    if response.steps.is_empty() {
406        return String::new();
407    }
408    let mut prompt = 0u32;
409    let mut completion = 0u32;
410    let mut cache_hit = 0u32;
411    for step in &response.steps {
412        let u = &step.usage;
413        prompt += u.prompt_tokens;
414        completion += u.completion_tokens;
415        if let Some(v) = u.prompt_cache_hit_tokens {
416            cache_hit += v;
417        }
418    }
419    let model = &response.model;
420    if cache_hit > 0 {
421        format!(
422            "{model} {} in ({} cached) / {} out",
423            human_tokens(prompt),
424            human_tokens(cache_hit),
425            human_tokens(completion),
426        )
427    } else {
428        format!(
429            "{model} {} in / {} out",
430            human_tokens(prompt),
431            human_tokens(completion),
432        )
433    }
434}
435
436fn human_tokens(n: u32) -> String {
437    if n >= 1_000_000 {
438        format!("{:.1}M", n as f64 / 1_000_000.0)
439    } else if n >= 1_000 {
440        format!("{:.1}k", n as f64 / 1_000.0)
441    } else {
442        n.to_string()
443    }
444}
445
446/// Build the human-readable label for a single tool call. Bash gets a
447/// special preview of its first line; everything else falls back to the
448/// function name. Used by the legacy `content` field for display-only
449/// consumers — rich UIs should read `tool_calls` directly.
450fn tool_call_label(c: &wcore::model::ToolCall) -> String {
451    if c.function.name == "bash"
452        && let Ok(v) = serde_json::from_str::<serde_json::Value>(&c.function.arguments)
453        && let Some(cmd) = v.get("command").and_then(|c| c.as_str())
454    {
455        return format!("bash({})", cmd.lines().next().unwrap_or(""));
456    }
457    c.function.name.clone()
458}
459
460/// Truncate a tool output to at most `max` bytes for the event broadcast,
461/// snapping back to a UTF-8 char boundary and appending an elision marker
462/// if anything was dropped. Keeps the firehose lightweight.
463///
464/// If `max` is smaller than the marker itself, returns just the marker
465/// (which may exceed `max`). Caller is expected to size `max` generously
466/// — the helper exists to cap pathological multi-MB tool outputs, not
467/// to enforce a precise byte budget.
468fn truncate_for_broadcast(s: &str, max: usize) -> String {
469    if s.len() <= max {
470        return s.to_owned();
471    }
472    let marker = "…[truncated]";
473    if max <= marker.len() {
474        return marker.to_owned();
475    }
476    let mut end = max - marker.len();
477    while end > 0 && !s.is_char_boundary(end) {
478        end -= 1;
479    }
480    format!("{}{marker}", &s[..end])
481}