Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9    api::Server,
10    message::{
11        DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12        StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13        ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14    },
15};
16
17impl Server for Daemon {
18    async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19        let rt: Arc<_> = self.runtime.read().await.clone();
20        let sender = req.sender.as_deref().unwrap_or("");
21        let created_by = if sender.is_empty() { "user" } else { sender };
22        let (session_id, is_new) = match req.session {
23            Some(id) => (id, false),
24            None => (rt.create_session(&req.agent, created_by).await?, true),
25        };
26        let response = rt.send_to(session_id, &req.content, sender).await?;
27        if is_new {
28            rt.close_session(session_id).await;
29        }
30        Ok(SendResponse {
31            agent: req.agent,
32            content: response.final_response.unwrap_or_default(),
33            session: session_id,
34        })
35    }
36
37    fn stream(
38        &self,
39        req: StreamMsg,
40    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41        let runtime = self.runtime.clone();
42        let agent = req.agent;
43        let content = req.content;
44        let req_session = req.session;
45        let sender = req.sender.unwrap_or_default();
46        async_stream::try_stream! {
47            let rt: Arc<_> = runtime.read().await.clone();
48            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49            let (session_id, is_new) = match req_session {
50                Some(id) => (id, false),
51                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52            };
53
54            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56            let stream = rt.stream_to(session_id, &content, &sender);
57            pin_mut!(stream);
58            while let Some(event) = stream.next().await {
59                match event {
60                    AgentEvent::TextDelta(text) => {
61                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62                    }
63                    AgentEvent::ThinkingDelta(text) => {
64                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65                    }
66                    AgentEvent::ToolCallsStart(calls) => {
67                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68                            calls: calls.into_iter().map(|c| ToolCallInfo {
69                                name: c.function.name.to_string(),
70                                arguments: c.function.arguments,
71                            }).collect(),
72                        })) };
73                    }
74                    AgentEvent::ToolResult { call_id, output } => {
75                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76                    }
77                    AgentEvent::ToolCallsComplete => {
78                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79                    }
80                    AgentEvent::Compact { .. } => {
81                        // Compact events are handled by on_event in the hook layer.
82                    }
83                    AgentEvent::Done(resp) => {
84                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
85                            if is_new {
86                                rt.close_session(session_id).await;
87                            }
88                            Err(anyhow::anyhow!("{e}"))?;
89                        }
90                        break;
91                    }
92                }
93            }
94            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
95        }
96    }
97
98    async fn ping(&self) -> Result<()> {
99        Ok(())
100    }
101
102    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
103        let rt = self.runtime.read().await.clone();
104        let sessions = rt.sessions().await;
105        let mut infos = Vec::with_capacity(sessions.len());
106        for s in sessions {
107            let s = s.lock().await;
108            infos.push(SessionInfo {
109                id: s.id,
110                agent: s.agent.to_string(),
111                created_by: s.created_by.to_string(),
112                message_count: s.history.len() as u64,
113                alive_secs: s.created_at.elapsed().as_secs(),
114            });
115        }
116        Ok(infos)
117    }
118
119    async fn kill_session(&self, session: u64) -> Result<bool> {
120        let rt = self.runtime.read().await.clone();
121        Ok(rt.close_session(session).await)
122    }
123
124    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
125        let rt = self.runtime.read().await.clone();
126        let tasks = rt.hook.tasks.lock().await;
127        Ok(tasks.list(16).iter().map(|t| t.to_info()).collect())
128    }
129
130    async fn kill_task(&self, task_id: u64) -> Result<bool> {
131        let rt = self.runtime.read().await.clone();
132        let session_id = {
133            let tasks = rt.hook.tasks.lock().await;
134            tasks.get(task_id).and_then(|t| t.session_id)
135        };
136        let killed = rt.hook.tasks.lock().await.kill(task_id);
137        if killed && let Some(sid) = session_id {
138            rt.close_session(sid).await;
139        }
140        Ok(killed)
141    }
142
143    async fn approve_task(&self, _task_id: u64, _response: String) -> Result<bool> {
144        // Approval system removed — sub-agents are autonomous.
145        Ok(false)
146    }
147
148    fn hub(
149        &self,
150        package: String,
151        action: HubAction,
152        filters: Vec<String>,
153    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
154        let runtime = self.runtime.clone();
155        async_stream::try_stream! {
156            let rt = runtime.read().await.clone();
157            let registry = rt.hook.downloads.clone();
158            match action {
159                HubAction::Install => {
160                    let s = crate::ext::hub::package::install(package, registry, filters);
161                    pin_mut!(s);
162                    while let Some(event) = s.next().await {
163                        yield event?;
164                    }
165                }
166                HubAction::Uninstall => {
167                    let s = crate::ext::hub::package::uninstall(package, registry, filters);
168                    pin_mut!(s);
169                    while let Some(event) = s.next().await {
170                        yield event?;
171                    }
172                }
173            }
174        }
175    }
176
177    fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
178        // Task subscription removed — tasks are lightweight JoinHandles now.
179        futures_util::stream::empty()
180    }
181
182    async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
183        let rt = self.runtime.read().await.clone();
184        let registry = rt.hook.downloads.lock().await;
185        Ok(registry.list())
186    }
187
188    fn subscribe_downloads(
189        &self,
190    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
191        let runtime = self.runtime.clone();
192        async_stream::try_stream! {
193            let rt = runtime.read().await.clone();
194            let mut rx = rt.hook.downloads.lock().await.subscribe();
195            loop {
196                match rx.recv().await {
197                    Ok(event) => yield event,
198                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
199                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
200                }
201            }
202        }
203    }
204
205    async fn get_config(&self) -> Result<String> {
206        let config = self.load_config()?;
207        serde_json::to_string(&config).context("failed to serialize config")
208    }
209
210    async fn set_config(&self, config: String) -> Result<()> {
211        let parsed: crate::DaemonConfig =
212            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
213        let toml_str =
214            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
215        let config_path = self.config_dir.join("walrus.toml");
216        std::fs::write(&config_path, toml_str)
217            .with_context(|| format!("failed to write {}", config_path.display()))?;
218        self.reload().await
219    }
220
221    async fn service_query(&self, service: String, query: String) -> Result<String> {
222        let rt = self.runtime.read().await.clone();
223        let registry = rt
224            .hook
225            .registry
226            .as_ref()
227            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
228        let handle = registry
229            .query
230            .get(&service)
231            .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
232        let req = wcore::protocol::ext::ExtRequest {
233            msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
234                wcore::protocol::ext::ExtServiceQuery { query },
235            )),
236        };
237        let resp = handle.request(&req).await?;
238        match resp.msg {
239            Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
240                Ok(result.result)
241            }
242            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
243                anyhow::bail!("service '{}' error: {}", service, e.message)
244            }
245            other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
246        }
247    }
248
249    async fn get_service_schema(&self, service: String) -> Result<String> {
250        let rt = self.runtime.read().await.clone();
251        let registry = rt
252            .hook
253            .registry
254            .as_ref()
255            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
256        let handle = registry
257            .query
258            .get(&service)
259            .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
260            .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
261        let req = wcore::protocol::ext::ExtRequest {
262            msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
263                wcore::protocol::ext::ExtGetSchema {},
264            )),
265        };
266        let resp = handle.request(&req).await?;
267        match resp.msg {
268            Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
269                Ok(result.schema)
270            }
271            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
272                anyhow::bail!("service '{}' schema error: {}", service, e.message)
273            }
274            other => anyhow::bail!(
275                "unexpected schema response from service '{}': {other:?}",
276                service
277            ),
278        }
279    }
280
281    async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
282        let rt = self.runtime.read().await.clone();
283        let registry = rt
284            .hook
285            .registry
286            .as_ref()
287            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
288        let mut schemas = std::collections::HashMap::new();
289        // Collect unique service handles from the query registry.
290        for (name, handle) in &registry.query {
291            let req = wcore::protocol::ext::ExtRequest {
292                msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
293                    wcore::protocol::ext::ExtGetSchema {},
294                )),
295            };
296            if let Ok(resp) = handle.request(&req).await
297                && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
298                    resp.msg
299            {
300                schemas.insert(name.clone(), result.schema);
301            }
302        }
303        Ok(schemas)
304    }
305
306    async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
307        let rt = self.runtime.read().await.clone();
308        let registry = rt.hook.registry.as_ref();
309        let mut services = Vec::new();
310        if let Some(reg) = registry {
311            // Collect unique service names from capability buckets.
312            let mut seen = std::collections::HashSet::new();
313            let all_handles: Vec<_> = reg.query.values().chain(reg.tools.values()).collect();
314            for handle in all_handles {
315                let name = handle.name.to_string();
316                if !seen.insert(name.clone()) {
317                    continue;
318                }
319                let capabilities: Vec<String> = handle
320                    .capabilities
321                    .iter()
322                    .filter_map(|c| match &c.cap {
323                        Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
324                            Some("tools".into())
325                        }
326                        Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
327                            Some("query".into())
328                        }
329                        _ => None,
330                    })
331                    .collect();
332                services.push(wcore::protocol::message::ServiceInfoMsg {
333                    name,
334                    kind: "extension".into(),
335                    status: "running".into(),
336                    capabilities,
337                    has_config: true,
338                });
339            }
340        }
341        Ok(services)
342    }
343
344    async fn set_service_config(&self, service: String, config: String) -> Result<()> {
345        let mut daemon_config = self.load_config()?;
346        let svc = daemon_config
347            .services
348            .get_mut(&service)
349            .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
350        let parsed: serde_json::Value =
351            serde_json::from_str(&config).context("invalid service config JSON")?;
352        svc.config = parsed;
353        let toml_str =
354            toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
355        let config_path = self.config_dir.join("walrus.toml");
356        std::fs::write(&config_path, toml_str)
357            .with_context(|| format!("failed to write {}", config_path.display()))?;
358        self.reload().await
359    }
360
361    async fn reload(&self) -> Result<()> {
362        self.reload().await
363    }
364}
365
366impl Daemon {
367    /// Load the current `DaemonConfig` from disk.
368    fn load_config(&self) -> Result<crate::DaemonConfig> {
369        crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
370    }
371}