Skip to main content

crabtalk_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::{cron::CronEntry, daemon::Daemon};
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::protocol::{
8    api::Server,
9    message::{
10        AgentEventMsg, AskOption, AskQuestion, AskUserEvent, CreateCronMsg, CronInfo, CronList,
11        DaemonStats, SendMsg, SendResponse, SessionInfo, StreamChunk, StreamEnd, StreamEvent,
12        StreamMsg, StreamStart, StreamThinking, TokenUsage, ToolCallInfo, ToolResultEvent,
13        ToolStartEvent, ToolsCompleteEvent, stream_event,
14    },
15};
16use wcore::{AgentEvent, AgentStep};
17
18impl Server for Daemon {
19    async fn send(&self, req: SendMsg) -> Result<SendResponse> {
20        let rt: Arc<_> = self.runtime.read().await.clone();
21        let sender = req.sender.as_deref().unwrap_or("");
22        let created_by = if sender.is_empty() { "user" } else { sender };
23        let cwd = req.cwd.map(std::path::PathBuf::from);
24        let session_id = match req.session {
25            Some(id) => id,
26            None => {
27                let id = if let Some(ref file) = req.resume_file {
28                    rt.load_specific_session(std::path::Path::new(file)).await?
29                } else if req.new_chat {
30                    rt.create_session(&req.agent, created_by).await?
31                } else {
32                    rt.get_or_create_session(&req.agent, created_by).await?
33                };
34                if let Some(ref cwd) = cwd {
35                    rt.hook
36                        .bridge
37                        .session_cwds
38                        .lock()
39                        .await
40                        .insert(id, cwd.clone());
41                }
42                id
43            }
44        };
45        let response = rt.send_to(session_id, &req.content, sender).await?;
46        let provider = rt
47            .model
48            .provider_name_for(&response.model)
49            .unwrap_or_default();
50        Ok(SendResponse {
51            agent: req.agent,
52            content: response.final_response.unwrap_or_default(),
53            session: session_id,
54            provider,
55            model: response.model,
56            usage: Some(sum_usage(&response.steps)),
57        })
58    }
59
60    fn stream(
61        &self,
62        req: StreamMsg,
63    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
64        let runtime = self.runtime.clone();
65        let agent = req.agent;
66        let content = req.content;
67        let req_session = req.session;
68        let sender = req.sender.unwrap_or_default();
69        let cwd = req.cwd.map(std::path::PathBuf::from);
70        let new_chat = req.new_chat;
71        let resume_file = req.resume_file;
72        async_stream::try_stream! {
73            let rt: Arc<_> = runtime.read().await.clone();
74            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
75            let session_id = match req_session {
76                Some(id) => id,
77                None => {
78                    let id = if let Some(ref file) = resume_file {
79                        rt.load_specific_session(std::path::Path::new(file)).await?
80                    } else if new_chat {
81                        rt.create_session(&agent, created_by.as_str()).await?
82                    } else {
83                        rt.get_or_create_session(&agent, created_by.as_str()).await?
84                    };
85                    if let Some(ref cwd) = cwd {
86                        rt.hook.bridge.session_cwds.lock().await.insert(id, cwd.clone());
87                    }
88                    id
89                }
90            };
91
92            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
93
94            let stream = rt.stream_to(session_id, &content, &sender);
95            pin_mut!(stream);
96            while let Some(event) = stream.next().await {
97                match event {
98                    AgentEvent::TextDelta(text) => {
99                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
100                    }
101                    AgentEvent::ThinkingDelta(text) => {
102                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
103                    }
104                    AgentEvent::ToolCallsBegin(calls) => {
105                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
106                            calls: calls.into_iter().map(|c| ToolCallInfo {
107                                name: c.function.name.to_string(),
108                                arguments: String::new(),
109                            }).collect(),
110                        })) };
111                    }
112                    AgentEvent::ToolCallsStart(calls) => {
113                        // Extract structured questions from ask_user calls.
114                        let ask_questions: Vec<AskQuestion> = calls
115                            .iter()
116                            .filter(|c| c.function.name == "ask_user")
117                            .filter_map(|c| {
118                                serde_json::from_str::<runtime::ask_user::AskUser>(&c.function.arguments)
119                                    .ok()
120                            })
121                            .flat_map(|a| a.questions)
122                            .map(|q| AskQuestion {
123                                question: q.question,
124                                header: q.header,
125                                options: q.options.into_iter().map(|o| AskOption {
126                                    label: o.label,
127                                    description: o.description,
128                                }).collect(),
129                                multi_select: q.multi_select,
130                            })
131                            .collect();
132
133                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
134                            calls: calls.into_iter().map(|c| ToolCallInfo {
135                                name: c.function.name.to_string(),
136                                arguments: c.function.arguments,
137                            }).collect(),
138                        })) };
139
140                        if !ask_questions.is_empty() {
141                            yield StreamEvent { event: Some(stream_event::Event::AskUser(AskUserEvent { questions: ask_questions })) };
142                        }
143                    }
144                    AgentEvent::ToolResult { call_id, output, duration_ms } => {
145                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output, duration_ms })) };
146                    }
147                    AgentEvent::ToolCallsComplete => {
148                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
149                    }
150                    AgentEvent::Compact { .. } => {
151                    }
152                    AgentEvent::Done(resp) => {
153                        let error = if let wcore::AgentStopReason::Error(ref e) = resp.stop_reason {
154                            e.clone()
155                        } else {
156                            String::new()
157                        };
158                        let provider = rt
159                            .model
160                            .provider_name_for(&resp.model)
161                            .unwrap_or_default();
162                        yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd {
163                            agent: agent.clone(),
164                            error,
165                            provider,
166                            model: resp.model,
167                            usage: Some(sum_usage(&resp.steps)),
168                        })) };
169                        return;
170                    }
171                }
172            }
173            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd {
174                agent: agent.clone(),
175                error: String::new(),
176                provider: String::new(),
177                model: String::new(),
178                usage: None,
179            })) };
180        }
181    }
182
183    async fn compact_session(&self, session: u64) -> Result<String> {
184        let rt = self.runtime.read().await.clone();
185        rt.compact_session(session)
186            .await
187            .ok_or_else(|| anyhow::anyhow!("compact failed for session {session}"))
188    }
189
190    async fn ping(&self) -> Result<()> {
191        Ok(())
192    }
193
194    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
195        let rt = self.runtime.read().await.clone();
196        let sessions = rt.sessions().await;
197        let mut infos = Vec::with_capacity(sessions.len());
198        for s in sessions {
199            let s = s.lock().await;
200            let active = rt.is_active(s.id).await;
201            infos.push(SessionInfo {
202                id: s.id,
203                agent: s.agent.to_string(),
204                created_by: s.created_by.to_string(),
205                message_count: s.history.len() as u64,
206                alive_secs: s.uptime_secs,
207                active,
208                title: s.title.clone(),
209            });
210        }
211        Ok(infos)
212    }
213
214    async fn kill_session(&self, session: u64) -> Result<bool> {
215        let rt = self.runtime.read().await.clone();
216        rt.hook.bridge.pending_asks.lock().await.remove(&session);
217        rt.hook.bridge.session_cwds.lock().await.remove(&session);
218        Ok(rt.close_session(session).await)
219    }
220
221    fn subscribe_events(&self) -> impl futures_core::Stream<Item = Result<AgentEventMsg>> + Send {
222        let runtime = self.runtime.clone();
223        async_stream::try_stream! {
224            let rt = runtime.read().await.clone();
225            let mut rx = rt.hook.bridge.subscribe_events();
226            loop {
227                match rx.recv().await {
228                    Ok(event) => yield event,
229                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
230                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
231                }
232            }
233        }
234    }
235
236    async fn get_config(&self) -> Result<String> {
237        let config = self.load_config()?;
238        serde_json::to_string(&config).context("failed to serialize config")
239    }
240
241    async fn set_config(&self, config: String) -> Result<()> {
242        let parsed: crate::DaemonConfig =
243            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
244        let toml_str =
245            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
246        let config_path = self.config_dir.join(wcore::paths::CONFIG_FILE);
247        std::fs::write(&config_path, toml_str)
248            .with_context(|| format!("failed to write {}", config_path.display()))?;
249        self.reload().await
250    }
251
252    async fn reload(&self) -> Result<()> {
253        self.reload().await
254    }
255
256    async fn get_stats(&self) -> Result<DaemonStats> {
257        let rt = self.runtime.read().await.clone();
258        let active = rt.active_session_count().await;
259        let agents = rt.agents().len() as u32;
260        let uptime = self.started_at.elapsed().as_secs();
261        Ok(DaemonStats {
262            uptime_secs: uptime,
263            active_sessions: active as u32,
264            registered_agents: agents,
265        })
266    }
267
268    async fn create_cron(&self, req: CreateCronMsg) -> Result<CronInfo> {
269        // Validate the target session exists.
270        let rt = self.runtime.read().await.clone();
271        if rt.session(req.session).await.is_none() {
272            anyhow::bail!("session {} not found", req.session);
273        }
274        let entry = CronEntry {
275            id: 0, // assigned by store
276            schedule: req.schedule,
277            skill: req.skill,
278            session: req.session,
279            quiet_start: req.quiet_start,
280            quiet_end: req.quiet_end,
281            once: req.once,
282        };
283        // Schedule validation happens inside CronStore::create.
284        let created = self
285            .crons
286            .lock()
287            .await
288            .create(entry, self.crons.clone())
289            .map_err(|e| anyhow::anyhow!("{e}"))?;
290        Ok(cron_entry_to_info(&created))
291    }
292
293    async fn delete_cron(&self, id: u64) -> Result<bool> {
294        Ok(self.crons.lock().await.delete(id))
295    }
296
297    async fn list_crons(&self) -> Result<CronList> {
298        let entries = self.crons.lock().await.list();
299        Ok(CronList {
300            crons: entries.iter().map(cron_entry_to_info).collect(),
301        })
302    }
303
304    async fn reply_to_ask(&self, session: u64, content: String) -> Result<()> {
305        let rt = self.runtime.read().await.clone();
306        if let Some(tx) = rt.hook.bridge.pending_asks.lock().await.remove(&session) {
307            let _ = tx.send(content);
308            return Ok(());
309        }
310        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
311        if let Some(tx) = rt.hook.bridge.pending_asks.lock().await.remove(&session) {
312            let _ = tx.send(content);
313            return Ok(());
314        }
315        anyhow::bail!("no pending ask_user for session {session}")
316    }
317}
318
319impl Daemon {
320    /// Load the current `DaemonConfig` from disk.
321    fn load_config(&self) -> Result<crate::DaemonConfig> {
322        crate::DaemonConfig::load(&self.config_dir.join(wcore::paths::CONFIG_FILE))
323    }
324}
325
326fn cron_entry_to_info(e: &CronEntry) -> CronInfo {
327    CronInfo {
328        id: e.id,
329        schedule: e.schedule.clone(),
330        skill: e.skill.clone(),
331        session: e.session,
332        quiet_start: e.quiet_start.clone().unwrap_or_default(),
333        quiet_end: e.quiet_end.clone().unwrap_or_default(),
334        once: e.once,
335    }
336}
337
338fn sum_usage(steps: &[AgentStep]) -> TokenUsage {
339    let mut prompt = 0u32;
340    let mut completion = 0u32;
341    let mut total = 0u32;
342    let mut cache_hit = 0u32;
343    let mut cache_miss = 0u32;
344    let mut reasoning = 0u32;
345    let mut has_cache_hit = false;
346    let mut has_cache_miss = false;
347    let mut has_reasoning = false;
348
349    for step in steps {
350        let u = &step.response.usage;
351        prompt += u.prompt_tokens;
352        completion += u.completion_tokens;
353        total += u.total_tokens;
354        if let Some(v) = u.prompt_cache_hit_tokens {
355            cache_hit += v;
356            has_cache_hit = true;
357        }
358        if let Some(v) = u.prompt_cache_miss_tokens {
359            cache_miss += v;
360            has_cache_miss = true;
361        }
362        if let Some(ref d) = u.completion_tokens_details
363            && let Some(v) = d.reasoning_tokens
364        {
365            reasoning += v;
366            has_reasoning = true;
367        }
368    }
369
370    TokenUsage {
371        prompt_tokens: prompt,
372        completion_tokens: completion,
373        total_tokens: total,
374        cache_hit_tokens: has_cache_hit.then_some(cache_hit),
375        cache_miss_tokens: has_cache_miss.then_some(cache_miss),
376        reasoning_tokens: has_reasoning.then_some(reasoning),
377    }
378}