Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9    api::Server,
10    message::{
11        DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12        StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13        ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14    },
15};
16
17impl Server for Daemon {
18    async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19        let rt: Arc<_> = self.runtime.read().await.clone();
20        let sender = req.sender.as_deref().unwrap_or("");
21        let created_by = if sender.is_empty() { "user" } else { sender };
22        let (session_id, is_new) = match req.session {
23            Some(id) => (id, false),
24            None => (rt.create_session(&req.agent, created_by).await?, true),
25        };
26        let response = rt.send_to(session_id, &req.content, sender).await?;
27        if is_new {
28            rt.close_session(session_id).await;
29        }
30        Ok(SendResponse {
31            agent: req.agent,
32            content: response.final_response.unwrap_or_default(),
33            session: session_id,
34        })
35    }
36
37    fn stream(
38        &self,
39        req: StreamMsg,
40    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41        let runtime = self.runtime.clone();
42        let agent = req.agent;
43        let content = req.content;
44        let req_session = req.session;
45        let sender = req.sender.unwrap_or_default();
46        async_stream::try_stream! {
47            let rt: Arc<_> = runtime.read().await.clone();
48            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49            let (session_id, is_new) = match req_session {
50                Some(id) => (id, false),
51                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52            };
53
54            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56            let stream = rt.stream_to(session_id, &content, &sender);
57            pin_mut!(stream);
58            while let Some(event) = stream.next().await {
59                match event {
60                    AgentEvent::TextDelta(text) => {
61                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62                    }
63                    AgentEvent::ThinkingDelta(text) => {
64                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65                    }
66                    AgentEvent::ToolCallsStart(calls) => {
67                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68                            calls: calls.into_iter().map(|c| ToolCallInfo {
69                                name: c.function.name.to_string(),
70                                arguments: c.function.arguments,
71                            }).collect(),
72                        })) };
73                    }
74                    AgentEvent::ToolResult { call_id, output } => {
75                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76                    }
77                    AgentEvent::ToolCallsComplete => {
78                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79                    }
80                    AgentEvent::Done(resp) => {
81                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
82                            if is_new {
83                                rt.close_session(session_id).await;
84                            }
85                            Err(anyhow::anyhow!("{e}"))?;
86                        }
87                        break;
88                    }
89                }
90            }
91            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
92        }
93    }
94
95    async fn ping(&self) -> Result<()> {
96        Ok(())
97    }
98
99    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
100        let rt = self.runtime.read().await.clone();
101        let sessions = rt.sessions().await;
102        let mut infos = Vec::with_capacity(sessions.len());
103        for s in sessions {
104            let s = s.lock().await;
105            infos.push(SessionInfo {
106                id: s.id,
107                agent: s.agent.to_string(),
108                created_by: s.created_by.to_string(),
109                message_count: s.history.len() as u64,
110                alive_secs: s.created_at.elapsed().as_secs(),
111            });
112        }
113        Ok(infos)
114    }
115
116    async fn kill_session(&self, session: u64) -> Result<bool> {
117        let rt = self.runtime.read().await.clone();
118        Ok(rt.close_session(session).await)
119    }
120
121    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
122        let rt = self.runtime.read().await.clone();
123        let registry = rt.hook.tasks.lock().await;
124        let tasks = registry.list(None, None, None);
125        Ok(tasks
126            .into_iter()
127            .map(|t| TaskInfo {
128                id: t.id,
129                parent_id: t.parent_id,
130                agent: t.agent.to_string(),
131                status: t.status.to_string(),
132                description: t.description.clone(),
133                result: t.result.clone(),
134                error: t.error.clone(),
135                created_by: t.created_by.to_string(),
136                prompt_tokens: t.prompt_tokens,
137                completion_tokens: t.completion_tokens,
138                alive_secs: t.created_at.elapsed().as_secs(),
139                blocked_on: t.blocked_on.as_ref().map(|i| i.question.clone()),
140            })
141            .collect())
142    }
143
144    async fn kill_task(&self, task_id: u64) -> Result<bool> {
145        let rt = self.runtime.read().await.clone();
146        let tasks = rt.hook.tasks.clone();
147        let mut registry = tasks.lock().await;
148        let Some(task) = registry.get(task_id) else {
149            return Ok(false);
150        };
151        match task.status {
152            crate::hook::task::TaskStatus::InProgress | crate::hook::task::TaskStatus::Blocked => {
153                if let Some(handle) = &task.abort_handle {
154                    handle.abort();
155                }
156                registry.set_status(task_id, crate::hook::task::TaskStatus::Failed);
157                if let Some(task) = registry.get_mut(task_id) {
158                    task.error = Some("killed by user".into());
159                }
160                // Close associated session.
161                if let Some(sid) = registry.get(task_id).and_then(|t| t.session_id) {
162                    drop(registry);
163                    rt.close_session(sid).await;
164                    let mut registry = tasks.lock().await;
165                    registry.promote_next(tasks.clone());
166                } else {
167                    registry.promote_next(tasks.clone());
168                }
169                Ok(true)
170            }
171            crate::hook::task::TaskStatus::Queued => {
172                registry.remove(task_id);
173                Ok(true)
174            }
175            _ => Ok(false),
176        }
177    }
178
179    async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
180        let rt = self.runtime.read().await.clone();
181        let mut registry = rt.hook.tasks.lock().await;
182        Ok(registry.approve(task_id, response))
183    }
184
185    fn hub(
186        &self,
187        package: String,
188        action: HubAction,
189        filters: Vec<String>,
190    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
191        let runtime = self.runtime.clone();
192        async_stream::try_stream! {
193            let rt = runtime.read().await.clone();
194            let registry = rt.hook.downloads.clone();
195            let package = compact_str::CompactString::from(package.as_str());
196            match action {
197                HubAction::Install => {
198                    let s = crate::ext::hub::package::install(package, registry, filters);
199                    pin_mut!(s);
200                    while let Some(event) = s.next().await {
201                        yield event?;
202                    }
203                }
204                HubAction::Uninstall => {
205                    let s = crate::ext::hub::package::uninstall(package, registry, filters);
206                    pin_mut!(s);
207                    while let Some(event) = s.next().await {
208                        yield event?;
209                    }
210                }
211            }
212        }
213    }
214
215    fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
216        let runtime = self.runtime.clone();
217        async_stream::try_stream! {
218            let rt = runtime.read().await.clone();
219            let mut rx = rt.hook.tasks.lock().await.subscribe();
220            loop {
221                match rx.recv().await {
222                    Ok(event) => yield event,
223                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
224                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
225                }
226            }
227        }
228    }
229
230    async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
231        let rt = self.runtime.read().await.clone();
232        let registry = rt.hook.downloads.lock().await;
233        Ok(registry.list())
234    }
235
236    fn subscribe_downloads(
237        &self,
238    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
239        let runtime = self.runtime.clone();
240        async_stream::try_stream! {
241            let rt = runtime.read().await.clone();
242            let mut rx = rt.hook.downloads.lock().await.subscribe();
243            loop {
244                match rx.recv().await {
245                    Ok(event) => yield event,
246                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
247                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
248                }
249            }
250        }
251    }
252
253    async fn get_config(&self) -> Result<String> {
254        let config = self.load_config()?;
255        serde_json::to_string(&config).context("failed to serialize config")
256    }
257
258    async fn set_config(&self, config: String) -> Result<()> {
259        let parsed: crate::DaemonConfig =
260            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
261        let toml_str =
262            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
263        let config_path = self.config_dir.join("walrus.toml");
264        std::fs::write(&config_path, toml_str)
265            .with_context(|| format!("failed to write {}", config_path.display()))?;
266        self.reload().await
267    }
268
269    async fn service_query(&self, service: String, query: String) -> Result<String> {
270        let rt = self.runtime.read().await.clone();
271        let registry = rt
272            .hook
273            .registry
274            .as_ref()
275            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
276        let handle = registry
277            .query
278            .get(&service)
279            .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
280        let req = wcore::protocol::ext::ExtRequest {
281            msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
282                wcore::protocol::ext::ExtServiceQuery { query },
283            )),
284        };
285        let resp = handle.request(&req).await?;
286        match resp.msg {
287            Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
288                Ok(result.result)
289            }
290            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
291                anyhow::bail!("service '{}' error: {}", service, e.message)
292            }
293            other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
294        }
295    }
296
297    async fn get_service_schema(&self, service: String) -> Result<String> {
298        let rt = self.runtime.read().await.clone();
299        let registry = rt
300            .hook
301            .registry
302            .as_ref()
303            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
304        let handle = registry
305            .query
306            .get(&service)
307            .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
308            .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
309        let req = wcore::protocol::ext::ExtRequest {
310            msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
311                wcore::protocol::ext::ExtGetSchema {},
312            )),
313        };
314        let resp = handle.request(&req).await?;
315        match resp.msg {
316            Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
317                Ok(result.schema)
318            }
319            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
320                anyhow::bail!("service '{}' schema error: {}", service, e.message)
321            }
322            other => anyhow::bail!(
323                "unexpected schema response from service '{}': {other:?}",
324                service
325            ),
326        }
327    }
328
329    async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
330        let rt = self.runtime.read().await.clone();
331        let registry = rt
332            .hook
333            .registry
334            .as_ref()
335            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
336        let mut schemas = std::collections::HashMap::new();
337        // Collect unique service handles from the query registry.
338        for (name, handle) in &registry.query {
339            let req = wcore::protocol::ext::ExtRequest {
340                msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
341                    wcore::protocol::ext::ExtGetSchema {},
342                )),
343            };
344            if let Ok(resp) = handle.request(&req).await
345                && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
346                    resp.msg
347            {
348                schemas.insert(name.clone(), result.schema);
349            }
350        }
351        Ok(schemas)
352    }
353
354    async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
355        let rt = self.runtime.read().await.clone();
356        let registry = rt.hook.registry.as_ref();
357        let mut services = Vec::new();
358        if let Some(reg) = registry {
359            // Collect unique service names from all capability buckets.
360            let mut seen = std::collections::HashSet::new();
361            let all_handles: Vec<_> = reg
362                .build_agent
363                .iter()
364                .chain(reg.before_run.iter())
365                .chain(reg.compact.iter())
366                .chain(reg.event_observer.iter())
367                .chain(reg.query.values())
368                .chain(reg.tools.values())
369                .collect();
370            for handle in all_handles {
371                let name = handle.name.to_string();
372                if !seen.insert(name.clone()) {
373                    continue;
374                }
375                let capabilities: Vec<String> = handle
376                    .capabilities
377                    .iter()
378                    .filter_map(|c| match &c.cap {
379                        Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
380                            Some("tools".into())
381                        }
382                        Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
383                            Some("query".into())
384                        }
385                        Some(wcore::protocol::ext::capability::Cap::BuildAgent(_)) => {
386                            Some("build_agent".into())
387                        }
388                        Some(wcore::protocol::ext::capability::Cap::BeforeRun(_)) => {
389                            Some("before_run".into())
390                        }
391                        Some(wcore::protocol::ext::capability::Cap::Compact(_)) => {
392                            Some("compact".into())
393                        }
394                        Some(wcore::protocol::ext::capability::Cap::EventObserver(_)) => {
395                            Some("event_observer".into())
396                        }
397                        Some(wcore::protocol::ext::capability::Cap::AfterRun(_)) => {
398                            Some("after_run".into())
399                        }
400                        Some(wcore::protocol::ext::capability::Cap::Infer(_)) => {
401                            Some("infer".into())
402                        }
403                        None => None,
404                    })
405                    .collect();
406                services.push(wcore::protocol::message::ServiceInfoMsg {
407                    name,
408                    kind: "extension".into(),
409                    status: "running".into(),
410                    capabilities,
411                    has_config: true,
412                });
413            }
414        }
415        Ok(services)
416    }
417
418    async fn set_service_config(&self, service: String, config: String) -> Result<()> {
419        let mut daemon_config = self.load_config()?;
420        let svc = daemon_config
421            .services
422            .get_mut(&service)
423            .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
424        let parsed: serde_json::Value =
425            serde_json::from_str(&config).context("invalid service config JSON")?;
426        svc.config = parsed;
427        let toml_str =
428            toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
429        let config_path = self.config_dir.join("walrus.toml");
430        std::fs::write(&config_path, toml_str)
431            .with_context(|| format!("failed to write {}", config_path.display()))?;
432        self.reload().await
433    }
434
435    async fn reload(&self) -> Result<()> {
436        self.reload().await
437    }
438}
439
440impl Daemon {
441    /// Load the current `DaemonConfig` from disk.
442    fn load_config(&self) -> Result<crate::DaemonConfig> {
443        crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
444    }
445}