Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Result, bail};
5use futures_util::{StreamExt, pin_mut};
6use memory::Memory;
7use protocol::{
8    api::Server,
9    message::{
10        AgentDetail, AgentInfoRequest, AgentList, AgentSummary, ClearSessionRequest, DownloadEvent,
11        DownloadRequest, GetMemoryRequest, McpAddRequest, McpAdded, McpReloaded, McpRemoveRequest,
12        McpRemoved, McpServerList, McpServerSummary, MemoryEntry, MemoryList, SendRequest,
13        SendResponse, SessionCleared, SkillsReloaded, StreamEvent, StreamRequest,
14    },
15};
16use wcore::AgentEvent;
17
18impl Server for Daemon {
19    async fn send(&self, req: SendRequest) -> Result<SendResponse> {
20        let response = self.runtime.send_to(&req.agent, &req.content).await?;
21        Ok(SendResponse {
22            agent: req.agent,
23            content: response.final_response.unwrap_or_default(),
24        })
25    }
26
27    fn stream(
28        &self,
29        req: StreamRequest,
30    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
31        let runtime = self.runtime.clone();
32        let agent = req.agent;
33        let content = req.content;
34        async_stream::try_stream! {
35            yield StreamEvent::Start { agent: agent.clone() };
36
37            let stream = runtime.stream_to(&agent, &content);
38            pin_mut!(stream);
39            while let Some(event) = stream.next().await {
40                match event {
41                    AgentEvent::TextDelta(text) => {
42                        yield StreamEvent::Chunk { content: text };
43                    }
44                    AgentEvent::Done(_) => break,
45                    _ => {}
46                }
47            }
48
49            yield StreamEvent::End { agent: agent.clone() };
50        }
51    }
52
53    async fn clear_session(&self, req: ClearSessionRequest) -> Result<SessionCleared> {
54        self.runtime.clear_session(&req.agent).await;
55        Ok(SessionCleared { agent: req.agent })
56    }
57
58    async fn list_agents(&self) -> Result<AgentList> {
59        let agents = self
60            .runtime
61            .agents()
62            .await
63            .into_iter()
64            .map(|a| AgentSummary {
65                name: a.name.clone(),
66                description: a.description.clone(),
67            })
68            .collect();
69        Ok(AgentList { agents })
70    }
71
72    async fn agent_info(&self, req: AgentInfoRequest) -> Result<AgentDetail> {
73        match self.runtime.agent(&req.agent).await {
74            Some(a) => Ok(AgentDetail {
75                name: a.name.clone(),
76                description: a.description.clone(),
77                tools: a.tools.to_vec(),
78                skill_tags: a.skill_tags.to_vec(),
79                system_prompt: a.system_prompt.clone(),
80            }),
81            None => bail!("agent not found: {}", req.agent),
82        }
83    }
84
85    async fn list_memory(&self) -> Result<MemoryList> {
86        let entries = self.runtime.hook.memory.entries();
87        Ok(MemoryList { entries })
88    }
89
90    async fn get_memory(&self, req: GetMemoryRequest) -> Result<MemoryEntry> {
91        let value = self.runtime.hook.memory.get(&req.key);
92        Ok(MemoryEntry {
93            key: req.key,
94            value,
95        })
96    }
97
98    fn download(
99        &self,
100        req: DownloadRequest,
101    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
102        #[cfg(feature = "local")]
103        {
104            use tokio::sync::mpsc;
105            async_stream::try_stream! {
106                yield DownloadEvent::Start { model: req.model.clone() };
107
108                let (dtx, mut drx) = mpsc::unbounded_channel();
109                let model_str = req.model.to_string();
110                let download_handle = tokio::spawn(async move {
111                    model::local::download::download_model(&model_str, dtx).await
112                });
113
114                while let Some(event) = drx.recv().await {
115                    let dl_event = match event {
116                        model::local::download::DownloadEvent::FileStart { filename, size } => {
117                            DownloadEvent::FileStart { filename, size }
118                        }
119                        model::local::download::DownloadEvent::Progress { bytes } => {
120                            DownloadEvent::Progress { bytes }
121                        }
122                        model::local::download::DownloadEvent::FileEnd { filename } => {
123                            DownloadEvent::FileEnd { filename }
124                        }
125                    };
126                    yield dl_event;
127                }
128
129                match download_handle.await {
130                    Ok(Ok(())) => {
131                        yield DownloadEvent::End { model: req.model };
132                    }
133                    Ok(Err(e)) => {
134                        Err(anyhow::anyhow!("download failed: {e}"))?;
135                    }
136                    Err(e) => {
137                        Err(anyhow::anyhow!("download task panicked: {e}"))?;
138                    }
139                }
140            }
141        }
142        #[cfg(not(feature = "local"))]
143        {
144            let _ = req;
145            async_stream::stream! {
146                yield Err(anyhow::anyhow!("this daemon was built without local model support"));
147            }
148        }
149    }
150
151    async fn reload_skills(&self) -> Result<SkillsReloaded> {
152        let count = self.runtime.hook.skills.reload().await?;
153        tracing::info!("reloaded {count} skill(s)");
154        Ok(SkillsReloaded { count })
155    }
156
157    async fn mcp_add(&self, req: McpAddRequest) -> Result<McpAdded> {
158        let config = mcp::McpServerConfig {
159            name: req.name.clone(),
160            command: req.command,
161            args: req.args,
162            env: req.env,
163            auto_restart: true,
164        };
165        let tools = self.runtime.hook.mcp.add(config).await?;
166
167        // Register newly added MCP tools on Runtime's registry.
168        for (tool, handler) in self.runtime.hook.mcp.tool_handlers().await {
169            if tools.iter().any(|t| t == &*tool.name) {
170                self.runtime.register_tool(tool, handler).await;
171            }
172        }
173
174        Ok(McpAdded {
175            name: req.name,
176            tools,
177        })
178    }
179
180    async fn mcp_remove(&self, req: McpRemoveRequest) -> Result<McpRemoved> {
181        let tools = self.runtime.hook.mcp.remove(&req.name).await?;
182
183        // Unregister removed MCP tools from Runtime's registry.
184        for tool_name in &tools {
185            self.runtime.unregister_tool(tool_name).await;
186        }
187
188        Ok(McpRemoved {
189            name: req.name,
190            tools,
191        })
192    }
193
194    async fn mcp_reload(&self) -> Result<McpReloaded> {
195        // Collect old tool names before reload.
196        let old_tool_names: Vec<compact_str::CompactString> = self
197            .runtime
198            .hook
199            .mcp
200            .tool_handlers()
201            .await
202            .into_iter()
203            .map(|(t, _)| t.name)
204            .collect();
205
206        let servers = self
207            .runtime
208            .hook
209            .mcp
210            .reload(|path| {
211                let config = crate::DaemonConfig::load(path)?;
212                Ok(config.mcp_servers)
213            })
214            .await?;
215
216        // Atomically swap old MCP tools for new ones on Runtime.
217        let new_tools = self.runtime.hook.mcp.tool_handlers().await;
218        self.runtime.replace_tools(&old_tool_names, new_tools).await;
219
220        let servers = servers
221            .into_iter()
222            .map(|(name, tools)| McpServerSummary { name, tools })
223            .collect();
224        Ok(McpReloaded { servers })
225    }
226
227    async fn mcp_list(&self) -> Result<McpServerList> {
228        let servers = self
229            .runtime
230            .hook
231            .mcp
232            .list()
233            .await
234            .into_iter()
235            .map(|(name, tools)| McpServerSummary { name, tools })
236            .collect();
237        Ok(McpServerList { servers })
238    }
239
240    async fn ping(&self) -> Result<()> {
241        Ok(())
242    }
243}