Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::{daemon::Daemon, ext::hub};
4use anyhow::Result;
5use compact_str::CompactString;
6use futures_util::{StreamExt, pin_mut};
7use std::sync::Arc;
8use wcore::AgentEvent;
9use wcore::protocol::{
10    api::Server,
11    message::{
12        DownloadEvent, DownloadRequest, HubAction, HubEvent, SendRequest, SendResponse,
13        StreamEvent, StreamRequest,
14        server::{SessionInfo, TaskInfo, ToolCallInfo},
15    },
16};
17
18impl Server for Daemon {
19    async fn send(&self, req: SendRequest) -> Result<SendResponse> {
20        let rt: Arc<_> = self.runtime.read().await.clone();
21        let sender = req.sender.as_deref().unwrap_or("");
22        let created_by = if sender.is_empty() { "user" } else { sender };
23        let (session_id, is_new) = match req.session {
24            Some(id) => (id, false),
25            None => (rt.create_session(&req.agent, created_by).await?, true),
26        };
27        let response = rt.send_to(session_id, &req.content, sender).await?;
28        if is_new {
29            rt.close_session(session_id).await;
30        }
31        Ok(SendResponse {
32            agent: req.agent,
33            content: response.final_response.unwrap_or_default(),
34            session: session_id,
35        })
36    }
37
38    fn stream(
39        &self,
40        req: StreamRequest,
41    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
42        let runtime = self.runtime.clone();
43        let agent = req.agent;
44        let content = req.content;
45        let req_session = req.session;
46        let sender = req.sender.unwrap_or_default();
47        async_stream::try_stream! {
48            let rt: Arc<_> = runtime.read().await.clone();
49            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
50            let (session_id, is_new) = match req_session {
51                Some(id) => (id, false),
52                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
53            };
54
55            yield StreamEvent::Start { agent: agent.clone(), session: session_id };
56
57            let stream = rt.stream_to(session_id, &content, &sender);
58            pin_mut!(stream);
59            while let Some(event) = stream.next().await {
60                match event {
61                    AgentEvent::TextDelta(text) => {
62                        yield StreamEvent::Chunk { content: text };
63                    }
64                    AgentEvent::ThinkingDelta(text) => {
65                        yield StreamEvent::Thinking { content: text };
66                    }
67                    AgentEvent::ToolCallsStart(calls) => {
68                        yield StreamEvent::ToolStart {
69                            calls: calls.into_iter().map(|c| ToolCallInfo {
70                                name: CompactString::from(c.function.name.as_str()),
71                                arguments: c.function.arguments,
72                            }).collect(),
73                        };
74                    }
75                    AgentEvent::ToolResult { call_id, output } => {
76                        yield StreamEvent::ToolResult { call_id, output };
77                    }
78                    AgentEvent::ToolCallsComplete => {
79                        yield StreamEvent::ToolsComplete;
80                    }
81                    AgentEvent::Done(resp) => {
82                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
83                            if is_new {
84                                rt.close_session(session_id).await;
85                            }
86                            Err(anyhow::anyhow!("{e}"))?;
87                        }
88                        break;
89                    }
90                }
91            }
92            if is_new {
93                rt.close_session(session_id).await;
94            }
95
96            yield StreamEvent::End { agent: agent.clone() };
97        }
98    }
99
100    fn download(
101        &self,
102        req: DownloadRequest,
103    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
104        #[cfg(feature = "local")]
105        {
106            use tokio::sync::mpsc;
107            async_stream::try_stream! {
108                // Only registry models are supported.
109                let entry = model::local::registry::find(&req.model)
110                    .ok_or_else(|| anyhow::anyhow!(
111                        "model '{}' is not in the registry", req.model
112                    ))?;
113
114                if !entry.fits() {
115                    let required = entry.memory_requirement();
116                    let actual = model::local::system_memory() / (1024 * 1024 * 1024);
117                    Err(anyhow::anyhow!(
118                        "model '{}' requires at least {} RAM, your system has {}GB",
119                        entry.name, required, actual
120                    ))?;
121                }
122
123                yield DownloadEvent::Start { model: req.model.clone() };
124
125                let (dtx, mut drx) = mpsc::unbounded_channel();
126                let model_str = req.model.to_string();
127                let download_handle = tokio::spawn(async move {
128                    model::local::download::download_model(&model_str, dtx).await
129                });
130
131                while let Some(event) = drx.recv().await {
132                    let dl_event = match event {
133                        model::local::download::DownloadEvent::FileStart { filename, size } => {
134                            DownloadEvent::FileStart { model: req.model.clone(), filename, size }
135                        }
136                        model::local::download::DownloadEvent::Progress { bytes } => {
137                            DownloadEvent::Progress { model: req.model.clone(), bytes }
138                        }
139                        model::local::download::DownloadEvent::FileEnd { filename } => {
140                            DownloadEvent::FileEnd { model: req.model.clone(), filename }
141                        }
142                    };
143                    yield dl_event;
144                }
145
146                match download_handle.await {
147                    Ok(Ok(())) => {
148                        yield DownloadEvent::End { model: req.model };
149                    }
150                    Ok(Err(e)) => {
151                        Err(anyhow::anyhow!("download failed: {e}"))?;
152                    }
153                    Err(e) => {
154                        Err(anyhow::anyhow!("download task panicked: {e}"))?;
155                    }
156                }
157            }
158        }
159        #[cfg(not(feature = "local"))]
160        {
161            let _ = req;
162            async_stream::stream! {
163                yield Err(anyhow::anyhow!("this daemon was built without local model support"));
164            }
165        }
166    }
167
168    async fn ping(&self) -> Result<()> {
169        Ok(())
170    }
171
172    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
173        let rt = self.runtime.read().await.clone();
174        let sessions = rt.sessions().await;
175        let mut infos = Vec::with_capacity(sessions.len());
176        for s in sessions {
177            let s = s.lock().await;
178            infos.push(SessionInfo {
179                id: s.id,
180                agent: s.agent.clone(),
181                created_by: s.created_by.clone(),
182                message_count: s.history.len(),
183                alive_secs: s.created_at.elapsed().as_secs(),
184            });
185        }
186        Ok(infos)
187    }
188
189    async fn kill_session(&self, session: u64) -> Result<bool> {
190        let rt = self.runtime.read().await.clone();
191        Ok(rt.close_session(session).await)
192    }
193
194    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
195        let rt = self.runtime.read().await.clone();
196        let registry = rt.hook.tasks.lock().await;
197        let tasks = registry.list(None, None, None);
198        Ok(tasks
199            .into_iter()
200            .map(|t| TaskInfo {
201                id: t.id,
202                parent_id: t.parent_id,
203                agent: t.agent.clone(),
204                status: t.status.to_string(),
205                description: t.description.clone(),
206                result: t.result.clone(),
207                error: t.error.clone(),
208                created_by: t.created_by.clone(),
209                prompt_tokens: t.prompt_tokens,
210                completion_tokens: t.completion_tokens,
211                alive_secs: t.created_at.elapsed().as_secs(),
212                blocked_on: t.blocked_on.as_ref().map(|i| i.question.clone()),
213            })
214            .collect())
215    }
216
217    async fn kill_task(&self, task_id: u64) -> Result<bool> {
218        let rt = self.runtime.read().await.clone();
219        let tasks = rt.hook.tasks.clone();
220        let mut registry = tasks.lock().await;
221        let Some(task) = registry.get(task_id) else {
222            return Ok(false);
223        };
224        match task.status {
225            crate::hook::task::TaskStatus::InProgress | crate::hook::task::TaskStatus::Blocked => {
226                if let Some(handle) = &task.abort_handle {
227                    handle.abort();
228                }
229                registry.set_status(task_id, crate::hook::task::TaskStatus::Failed);
230                if let Some(task) = registry.get_mut(task_id) {
231                    task.error = Some("killed by user".into());
232                }
233                // Close associated session.
234                if let Some(sid) = registry.get(task_id).and_then(|t| t.session_id) {
235                    drop(registry);
236                    rt.close_session(sid).await;
237                    let mut registry = tasks.lock().await;
238                    registry.promote_next(tasks.clone());
239                } else {
240                    registry.promote_next(tasks.clone());
241                }
242                Ok(true)
243            }
244            crate::hook::task::TaskStatus::Queued => {
245                registry.remove(task_id);
246                Ok(true)
247            }
248            _ => Ok(false),
249        }
250    }
251
252    async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
253        let rt = self.runtime.read().await.clone();
254        let mut registry = rt.hook.tasks.lock().await;
255        Ok(registry.approve(task_id, response))
256    }
257
258    fn hub(
259        &self,
260        package: CompactString,
261        action: HubAction,
262    ) -> impl futures_core::Stream<Item = Result<HubEvent>> + Send {
263        async_stream::try_stream! {
264            match action {
265                HubAction::Install => {
266                    let s = hub::install(package);
267                    pin_mut!(s);
268                    while let Some(event) = s.next().await {
269                        yield event?;
270                    }
271                }
272                HubAction::Uninstall => {
273                    let s = hub::uninstall(package);
274                    pin_mut!(s);
275                    while let Some(event) = s.next().await {
276                        yield event?;
277                    }
278                }
279            }
280        }
281    }
282}