1use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9 api::Server,
10 message::{
11 DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12 StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13 ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14 },
15};
16
17impl Server for Daemon {
18 async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19 let rt: Arc<_> = self.runtime.read().await.clone();
20 let sender = req.sender.as_deref().unwrap_or("");
21 let created_by = if sender.is_empty() { "user" } else { sender };
22 let (session_id, is_new) = match req.session {
23 Some(id) => (id, false),
24 None => (rt.create_session(&req.agent, created_by).await?, true),
25 };
26 let response = rt.send_to(session_id, &req.content, sender).await?;
27 if is_new {
28 rt.close_session(session_id).await;
29 }
30 Ok(SendResponse {
31 agent: req.agent,
32 content: response.final_response.unwrap_or_default(),
33 session: session_id,
34 })
35 }
36
37 fn stream(
38 &self,
39 req: StreamMsg,
40 ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41 let runtime = self.runtime.clone();
42 let agent = req.agent;
43 let content = req.content;
44 let req_session = req.session;
45 let sender = req.sender.unwrap_or_default();
46 async_stream::try_stream! {
47 let rt: Arc<_> = runtime.read().await.clone();
48 let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49 let (session_id, is_new) = match req_session {
50 Some(id) => (id, false),
51 None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52 };
53
54 yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56 let stream = rt.stream_to(session_id, &content, &sender);
57 pin_mut!(stream);
58 while let Some(event) = stream.next().await {
59 match event {
60 AgentEvent::TextDelta(text) => {
61 yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62 }
63 AgentEvent::ThinkingDelta(text) => {
64 yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65 }
66 AgentEvent::ToolCallsStart(calls) => {
67 yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68 calls: calls.into_iter().map(|c| ToolCallInfo {
69 name: c.function.name.to_string(),
70 arguments: c.function.arguments,
71 }).collect(),
72 })) };
73 }
74 AgentEvent::ToolResult { call_id, output } => {
75 yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76 }
77 AgentEvent::ToolCallsComplete => {
78 yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79 }
80 AgentEvent::Done(resp) => {
81 if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
82 if is_new {
83 rt.close_session(session_id).await;
84 }
85 Err(anyhow::anyhow!("{e}"))?;
86 }
87 break;
88 }
89 }
90 }
91 if is_new {
92 rt.close_session(session_id).await;
93 }
94
95 yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
96 }
97 }
98
99 async fn ping(&self) -> Result<()> {
100 Ok(())
101 }
102
103 async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
104 let rt = self.runtime.read().await.clone();
105 let sessions = rt.sessions().await;
106 let mut infos = Vec::with_capacity(sessions.len());
107 for s in sessions {
108 let s = s.lock().await;
109 infos.push(SessionInfo {
110 id: s.id,
111 agent: s.agent.to_string(),
112 created_by: s.created_by.to_string(),
113 message_count: s.history.len() as u64,
114 alive_secs: s.created_at.elapsed().as_secs(),
115 });
116 }
117 Ok(infos)
118 }
119
120 async fn kill_session(&self, session: u64) -> Result<bool> {
121 let rt = self.runtime.read().await.clone();
122 Ok(rt.close_session(session).await)
123 }
124
125 async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
126 let rt = self.runtime.read().await.clone();
127 let registry = rt.hook.tasks.lock().await;
128 let tasks = registry.list(None, None, None);
129 Ok(tasks
130 .into_iter()
131 .map(|t| TaskInfo {
132 id: t.id,
133 parent_id: t.parent_id,
134 agent: t.agent.to_string(),
135 status: t.status.to_string(),
136 description: t.description.clone(),
137 result: t.result.clone(),
138 error: t.error.clone(),
139 created_by: t.created_by.to_string(),
140 prompt_tokens: t.prompt_tokens,
141 completion_tokens: t.completion_tokens,
142 alive_secs: t.created_at.elapsed().as_secs(),
143 blocked_on: t.blocked_on.as_ref().map(|i| i.question.clone()),
144 })
145 .collect())
146 }
147
148 async fn kill_task(&self, task_id: u64) -> Result<bool> {
149 let rt = self.runtime.read().await.clone();
150 let tasks = rt.hook.tasks.clone();
151 let mut registry = tasks.lock().await;
152 let Some(task) = registry.get(task_id) else {
153 return Ok(false);
154 };
155 match task.status {
156 crate::hook::task::TaskStatus::InProgress | crate::hook::task::TaskStatus::Blocked => {
157 if let Some(handle) = &task.abort_handle {
158 handle.abort();
159 }
160 registry.set_status(task_id, crate::hook::task::TaskStatus::Failed);
161 if let Some(task) = registry.get_mut(task_id) {
162 task.error = Some("killed by user".into());
163 }
164 if let Some(sid) = registry.get(task_id).and_then(|t| t.session_id) {
166 drop(registry);
167 rt.close_session(sid).await;
168 let mut registry = tasks.lock().await;
169 registry.promote_next(tasks.clone());
170 } else {
171 registry.promote_next(tasks.clone());
172 }
173 Ok(true)
174 }
175 crate::hook::task::TaskStatus::Queued => {
176 registry.remove(task_id);
177 Ok(true)
178 }
179 _ => Ok(false),
180 }
181 }
182
183 async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
184 let rt = self.runtime.read().await.clone();
185 let mut registry = rt.hook.tasks.lock().await;
186 Ok(registry.approve(task_id, response))
187 }
188
189 fn hub(
190 &self,
191 package: String,
192 action: HubAction,
193 filters: Vec<String>,
194 ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
195 let runtime = self.runtime.clone();
196 async_stream::try_stream! {
197 let rt = runtime.read().await.clone();
198 let registry = rt.hook.downloads.clone();
199 let package = compact_str::CompactString::from(package.as_str());
200 match action {
201 HubAction::Install => {
202 let s = crate::ext::hub::package::install(package, registry, filters);
203 pin_mut!(s);
204 while let Some(event) = s.next().await {
205 yield event?;
206 }
207 }
208 HubAction::Uninstall => {
209 let s = crate::ext::hub::package::uninstall(package, registry, filters);
210 pin_mut!(s);
211 while let Some(event) = s.next().await {
212 yield event?;
213 }
214 }
215 }
216 }
217 }
218
219 fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
220 let runtime = self.runtime.clone();
221 async_stream::try_stream! {
222 let rt = runtime.read().await.clone();
223 let mut rx = rt.hook.tasks.lock().await.subscribe();
224 loop {
225 match rx.recv().await {
226 Ok(event) => yield event,
227 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
228 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
229 }
230 }
231 }
232 }
233
234 async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
235 let rt = self.runtime.read().await.clone();
236 let registry = rt.hook.downloads.lock().await;
237 Ok(registry.list())
238 }
239
240 fn subscribe_downloads(
241 &self,
242 ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
243 let runtime = self.runtime.clone();
244 async_stream::try_stream! {
245 let rt = runtime.read().await.clone();
246 let mut rx = rt.hook.downloads.lock().await.subscribe();
247 loop {
248 match rx.recv().await {
249 Ok(event) => yield event,
250 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
251 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
252 }
253 }
254 }
255 }
256
257 async fn get_config(&self) -> Result<String> {
258 let config = self.load_config()?;
259 serde_json::to_string(&config).context("failed to serialize config")
260 }
261
262 async fn set_config(&self, config: String) -> Result<()> {
263 let parsed: crate::DaemonConfig =
264 serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
265 let toml_str =
266 toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
267 let config_path = self.config_dir.join("walrus.toml");
268 std::fs::write(&config_path, toml_str)
269 .with_context(|| format!("failed to write {}", config_path.display()))?;
270 self.reload().await
271 }
272
273 async fn service_query(&self, service: String, query: String) -> Result<String> {
274 let rt = self.runtime.read().await.clone();
275 let registry = rt
276 .hook
277 .registry
278 .as_ref()
279 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
280 let handle = registry
281 .query
282 .get(&service)
283 .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
284 let req = wcore::protocol::ext::ExtRequest {
285 msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
286 wcore::protocol::ext::ExtServiceQuery { query },
287 )),
288 };
289 let resp = handle.request(&req).await?;
290 match resp.msg {
291 Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
292 Ok(result.result)
293 }
294 Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
295 anyhow::bail!("service '{}' error: {}", service, e.message)
296 }
297 other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
298 }
299 }
300
301 async fn get_service_schema(&self, service: String) -> Result<String> {
302 let rt = self.runtime.read().await.clone();
303 let registry = rt
304 .hook
305 .registry
306 .as_ref()
307 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
308 let handle = registry
309 .query
310 .get(&service)
311 .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
312 .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
313 let req = wcore::protocol::ext::ExtRequest {
314 msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
315 wcore::protocol::ext::ExtGetSchema {},
316 )),
317 };
318 let resp = handle.request(&req).await?;
319 match resp.msg {
320 Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
321 Ok(result.schema)
322 }
323 Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
324 anyhow::bail!("service '{}' schema error: {}", service, e.message)
325 }
326 other => anyhow::bail!(
327 "unexpected schema response from service '{}': {other:?}",
328 service
329 ),
330 }
331 }
332
333 async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
334 let rt = self.runtime.read().await.clone();
335 let registry = rt
336 .hook
337 .registry
338 .as_ref()
339 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
340 let mut schemas = std::collections::HashMap::new();
341 for (name, handle) in ®istry.query {
343 let req = wcore::protocol::ext::ExtRequest {
344 msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
345 wcore::protocol::ext::ExtGetSchema {},
346 )),
347 };
348 if let Ok(resp) = handle.request(&req).await
349 && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
350 resp.msg
351 {
352 schemas.insert(name.clone(), result.schema);
353 }
354 }
355 Ok(schemas)
356 }
357
358 async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
359 let rt = self.runtime.read().await.clone();
360 let registry = rt.hook.registry.as_ref();
361 let mut services = Vec::new();
362 if let Some(reg) = registry {
363 let mut seen = std::collections::HashSet::new();
365 let all_handles: Vec<_> = reg
366 .build_agent
367 .iter()
368 .chain(reg.before_run.iter())
369 .chain(reg.compact.iter())
370 .chain(reg.event_observer.iter())
371 .chain(reg.query.values())
372 .chain(reg.tools.values())
373 .collect();
374 for handle in all_handles {
375 let name = handle.name.to_string();
376 if !seen.insert(name.clone()) {
377 continue;
378 }
379 let capabilities: Vec<String> = handle
380 .capabilities
381 .iter()
382 .filter_map(|c| match &c.cap {
383 Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
384 Some("tools".into())
385 }
386 Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
387 Some("query".into())
388 }
389 Some(wcore::protocol::ext::capability::Cap::BuildAgent(_)) => {
390 Some("build_agent".into())
391 }
392 Some(wcore::protocol::ext::capability::Cap::BeforeRun(_)) => {
393 Some("before_run".into())
394 }
395 Some(wcore::protocol::ext::capability::Cap::Compact(_)) => {
396 Some("compact".into())
397 }
398 Some(wcore::protocol::ext::capability::Cap::EventObserver(_)) => {
399 Some("event_observer".into())
400 }
401 Some(wcore::protocol::ext::capability::Cap::AfterRun(_)) => {
402 Some("after_run".into())
403 }
404 Some(wcore::protocol::ext::capability::Cap::Infer(_)) => {
405 Some("infer".into())
406 }
407 None => None,
408 })
409 .collect();
410 services.push(wcore::protocol::message::ServiceInfoMsg {
411 name,
412 kind: "extension".into(),
413 status: "running".into(),
414 capabilities,
415 has_config: true,
416 });
417 }
418 }
419 Ok(services)
420 }
421
422 async fn set_service_config(&self, service: String, config: String) -> Result<()> {
423 let mut daemon_config = self.load_config()?;
424 let svc = daemon_config
425 .services
426 .get_mut(&service)
427 .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
428 let parsed: serde_json::Value =
429 serde_json::from_str(&config).context("invalid service config JSON")?;
430 svc.config = parsed;
431 let toml_str =
432 toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
433 let config_path = self.config_dir.join("walrus.toml");
434 std::fs::write(&config_path, toml_str)
435 .with_context(|| format!("failed to write {}", config_path.display()))?;
436 self.reload().await
437 }
438
439 async fn reload(&self) -> Result<()> {
440 self.reload().await
441 }
442}
443
444impl Daemon {
445 fn load_config(&self) -> Result<crate::DaemonConfig> {
447 crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
448 }
449}