Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::{daemon::Daemon, ext::hub};
4use anyhow::Result;
5use compact_str::CompactString;
6use futures_util::{StreamExt, pin_mut};
7use std::sync::Arc;
8use wcore::AgentEvent;
9use wcore::protocol::{
10    api::Server,
11    message::{
12        DownloadEvent, DownloadRequest, HubAction, HubEvent, SendRequest, SendResponse,
13        StreamEvent, StreamRequest,
14    },
15};
16
17impl Server for Daemon {
18    async fn send(&self, req: SendRequest) -> Result<SendResponse> {
19        let rt: Arc<_> = self.runtime.read().await.clone();
20        let response = rt.send_to(&req.agent, &req.content).await?;
21        Ok(SendResponse {
22            agent: req.agent,
23            content: response.final_response.unwrap_or_default(),
24        })
25    }
26
27    fn stream(
28        &self,
29        req: StreamRequest,
30    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
31        let runtime = self.runtime.clone();
32        let agent = req.agent;
33        let content = req.content;
34        async_stream::try_stream! {
35            yield StreamEvent::Start { agent: agent.clone() };
36
37            let rt: Arc<_> = runtime.read().await.clone();
38            let stream = rt.stream_to(&agent, &content);
39            pin_mut!(stream);
40            while let Some(event) = stream.next().await {
41                match event {
42                    AgentEvent::TextDelta(text) => {
43                        yield StreamEvent::Chunk { content: text };
44                    }
45                    AgentEvent::Done(_) => break,
46                    _ => {}
47                }
48            }
49
50            yield StreamEvent::End { agent: agent.clone() };
51        }
52    }
53
54    fn download(
55        &self,
56        req: DownloadRequest,
57    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
58        #[cfg(feature = "local")]
59        {
60            use tokio::sync::mpsc;
61            async_stream::try_stream! {
62                yield DownloadEvent::Start { model: req.model.clone() };
63
64                let (dtx, mut drx) = mpsc::unbounded_channel();
65                let model_str = req.model.to_string();
66                let download_handle = tokio::spawn(async move {
67                    model::local::download::download_model(&model_str, dtx).await
68                });
69
70                while let Some(event) = drx.recv().await {
71                    let dl_event = match event {
72                        model::local::download::DownloadEvent::FileStart { filename, size } => {
73                            DownloadEvent::FileStart { model: req.model.clone(), filename, size }
74                        }
75                        model::local::download::DownloadEvent::Progress { bytes } => {
76                            DownloadEvent::Progress { model: req.model.clone(), bytes }
77                        }
78                        model::local::download::DownloadEvent::FileEnd { filename } => {
79                            DownloadEvent::FileEnd { model: req.model.clone(), filename }
80                        }
81                    };
82                    yield dl_event;
83                }
84
85                match download_handle.await {
86                    Ok(Ok(())) => {
87                        yield DownloadEvent::End { model: req.model };
88                    }
89                    Ok(Err(e)) => {
90                        Err(anyhow::anyhow!("download failed: {e}"))?;
91                    }
92                    Err(e) => {
93                        Err(anyhow::anyhow!("download task panicked: {e}"))?;
94                    }
95                }
96            }
97        }
98        #[cfg(not(feature = "local"))]
99        {
100            let _ = req;
101            async_stream::stream! {
102                yield Err(anyhow::anyhow!("this daemon was built without local model support"));
103            }
104        }
105    }
106
107    async fn ping(&self) -> Result<()> {
108        Ok(())
109    }
110
111    fn hub(
112        &self,
113        package: CompactString,
114        action: HubAction,
115    ) -> impl futures_core::Stream<Item = Result<HubEvent>> + Send {
116        async_stream::try_stream! {
117            match action {
118                HubAction::Install => {
119                    let s = hub::install(package);
120                    pin_mut!(s);
121                    while let Some(event) = s.next().await {
122                        yield event?;
123                    }
124                }
125                HubAction::Uninstall => {
126                    let s = hub::uninstall(package);
127                    pin_mut!(s);
128                    while let Some(event) = s.next().await {
129                        yield event?;
130                    }
131                }
132            }
133        }
134    }
135}