Skip to main content

walrus_channel/
spawn.rs

1//! Channel spawn logic.
2//!
3//! Connects configured platform bots (Telegram, Discord) and routes all
4//! messages through a single `on_message` callback that accepts a
5//! `ClientMessage` and returns a `ServerMessage` stream.
6
7use crate::command::parse_command;
8use crate::config::ChannelConfig;
9use crate::message::ChannelMessage;
10use compact_str::CompactString;
11use serenity::model::id::ChannelId;
12use std::{
13    collections::{HashMap, HashSet},
14    future::Future,
15    sync::Arc,
16};
17use teloxide::prelude::*;
18use tokio::sync::{RwLock, mpsc};
19use wcore::protocol::message::{client::ClientMessage, server::ServerMessage};
20
21/// Shared set of sender IDs belonging to sibling Walrus bots.
22///
23/// Built incrementally as each bot connects. Channel loops check this set
24/// before dispatching messages — senders in this set are silently dropped
25/// to prevent agent-to-agent loops.
26type KnownBots = Arc<RwLock<HashSet<CompactString>>>;
27
28/// Connect configured channels and spawn message loops.
29///
30/// Iterates all channel entries and spawns a transport for each one.
31/// `default_agent` is used when an entry does not specify an agent.
32/// `on_message` dispatches any `ClientMessage` and returns a receiver for
33/// streamed `ServerMessage` results.
34pub async fn spawn_channels<C, CFut>(
35    config: &ChannelConfig,
36    default_agent: CompactString,
37    on_message: Arc<C>,
38) where
39    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
40    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
41{
42    let known_bots: KnownBots = Arc::new(RwLock::new(HashSet::new()));
43
44    if let Some(tg) = &config.telegram {
45        if tg.token.is_empty() {
46            tracing::warn!(platform = "telegram", "token is empty, skipping");
47        } else {
48            spawn_telegram(
49                &tg.token,
50                default_agent.clone(),
51                on_message.clone(),
52                known_bots.clone(),
53            )
54            .await;
55        }
56    }
57
58    if let Some(dc) = &config.discord {
59        if dc.token.is_empty() {
60            tracing::warn!(platform = "discord", "token is empty, skipping");
61        } else {
62            spawn_discord(&dc.token, default_agent, on_message, known_bots).await;
63        }
64    }
65}
66
67async fn spawn_telegram<C, CFut>(
68    token: &str,
69    agent: CompactString,
70    on_message: Arc<C>,
71    known_bots: KnownBots,
72) where
73    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
74    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
75{
76    let bot = Bot::new(token);
77
78    // Resolve our own user ID and register it in the known-bot set.
79    match bot.get_me().await {
80        Ok(me) => {
81            let bot_sender: CompactString = format!("tg:{}", me.id.0).into();
82            tracing::info!(platform = "telegram", %bot_sender, "registered bot identity");
83            known_bots.write().await.insert(bot_sender);
84        }
85        Err(e) => {
86            tracing::warn!(platform = "telegram", "failed to resolve bot identity: {e}");
87        }
88    }
89
90    let (tx, rx) = mpsc::unbounded_channel::<ChannelMessage>();
91
92    let poll_bot = bot.clone();
93    tokio::spawn(async move {
94        crate::telegram::poll_loop(poll_bot, tx).await;
95    });
96
97    tokio::spawn(telegram_loop(rx, bot, agent, on_message, known_bots));
98    tracing::info!(platform = "telegram", "channel transport started");
99}
100
101async fn spawn_discord<C, CFut>(
102    token: &str,
103    agent: CompactString,
104    on_message: Arc<C>,
105    known_bots: KnownBots,
106) where
107    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
108    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
109{
110    let (msg_tx, msg_rx) = mpsc::unbounded_channel::<ChannelMessage>();
111    let (http_tx, http_rx) = tokio::sync::oneshot::channel();
112
113    let token = token.to_owned();
114    let kb = known_bots.clone();
115    tokio::spawn(async move {
116        crate::discord::event_loop(&token, msg_tx, http_tx, kb).await;
117    });
118
119    tokio::spawn(async move {
120        match http_rx.await {
121            Ok(http) => {
122                discord_loop(msg_rx, http, agent, on_message, known_bots).await;
123            }
124            Err(_) => {
125                tracing::error!("discord gateway failed to send http client");
126            }
127        }
128    });
129
130    tracing::info!(platform = "discord", "channel transport started");
131}
132
133/// Telegram message loop: routes incoming messages to agents or bot commands.
134///
135/// Maintains a `chat_id → session_id` mapping so consecutive messages from the
136/// same chat reuse the same session. If a session is killed externally, the
137/// error triggers a retry with `session: None` to create a fresh session.
138async fn telegram_loop<C, CFut>(
139    mut rx: mpsc::UnboundedReceiver<ChannelMessage>,
140    bot: Bot,
141    agent: CompactString,
142    on_message: Arc<C>,
143    known_bots: KnownBots,
144) where
145    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
146    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
147{
148    let mut sessions: HashMap<i64, u64> = HashMap::new();
149
150    while let Some(msg) = rx.recv().await {
151        let chat_id = msg.chat_id;
152        let content = msg.content.clone();
153        let sender: CompactString = format!("tg:{}", msg.sender_id).into();
154
155        // Drop messages from sibling Walrus bots.
156        if known_bots.read().await.contains(&sender) {
157            tracing::debug!(%sender, chat_id, "dropping message from known bot");
158            continue;
159        }
160
161        tracing::info!(%agent, chat_id, "telegram dispatch");
162
163        // Bot command path.
164        if content.starts_with('/') {
165            match parse_command(&content) {
166                Some(cmd) => {
167                    let b = bot.clone();
168                    let om = on_message.clone();
169                    tokio::spawn(async move {
170                        crate::telegram::command::dispatch_command(cmd, om, b, chat_id).await;
171                    });
172                }
173                None => {
174                    tracing::warn!(chat_id, content, "unrecognised bot command");
175                    let hint = "Unknown command. Available: /hub install <pkg>, /hub uninstall <pkg>, /model download <model>";
176                    if let Err(e) = bot.send_message(ChatId(chat_id), hint).await {
177                        tracing::warn!("failed to send command hint: {e}");
178                    }
179                }
180            }
181            continue;
182        }
183
184        // Normal agent chat path with session mapping.
185        let session = sessions.get(&chat_id).copied();
186
187        // Group chat: evaluate whether the agent should respond.
188        if msg.is_group && !should_respond(&on_message, &agent, &content, session, &sender).await {
189            tracing::debug!(%agent, chat_id, "agent declined to respond in group");
190            continue;
191        }
192        let client_msg = ClientMessage::Send {
193            agent: agent.clone(),
194            content: content.clone(),
195            session,
196            sender: Some(sender.clone()),
197        };
198        let mut reply_rx = on_message(client_msg).await;
199        let mut retry = false;
200        while let Some(server_msg) = reply_rx.recv().await {
201            match server_msg {
202                ServerMessage::Response(resp) => {
203                    sessions.insert(chat_id, resp.session);
204                    if let Err(e) = bot.send_message(ChatId(chat_id), resp.content).await {
205                        tracing::warn!(%agent, "failed to send channel reply: {e}");
206                    }
207                }
208                ServerMessage::Error { ref message, .. } if session.is_some() => {
209                    tracing::warn!(%agent, chat_id, "session error, retrying: {message}");
210                    sessions.remove(&chat_id);
211                    retry = true;
212                }
213                ServerMessage::Error { message, .. } => {
214                    tracing::warn!(%agent, chat_id, "dispatch error: {message}");
215                }
216                _ => {}
217            }
218        }
219
220        // Retry with a fresh session if the previous one was stale.
221        if retry {
222            let client_msg = ClientMessage::Send {
223                agent: agent.clone(),
224                content,
225                session: None,
226                sender: Some(sender),
227            };
228            let mut reply_rx = on_message(client_msg).await;
229            while let Some(server_msg) = reply_rx.recv().await {
230                match server_msg {
231                    ServerMessage::Response(resp) => {
232                        sessions.insert(chat_id, resp.session);
233                        if let Err(e) = bot.send_message(ChatId(chat_id), resp.content).await {
234                            tracing::warn!(%agent, "failed to send channel reply: {e}");
235                        }
236                    }
237                    ServerMessage::Error { message, .. } => {
238                        tracing::warn!(%agent, chat_id, "dispatch error on retry: {message}");
239                    }
240                    _ => {}
241                }
242            }
243        }
244    }
245
246    tracing::info!(platform = "telegram", "channel loop ended");
247}
248
249/// Discord message loop: routes incoming messages to agents or bot commands.
250///
251/// Maintains a `chat_id → session_id` mapping so consecutive messages from the
252/// same chat reuse the same session. Same stale-session retry logic as Telegram.
253async fn discord_loop<C, CFut>(
254    mut rx: mpsc::UnboundedReceiver<ChannelMessage>,
255    http: Arc<serenity::http::Http>,
256    agent: CompactString,
257    on_message: Arc<C>,
258    known_bots: KnownBots,
259) where
260    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
261    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
262{
263    let mut sessions: HashMap<i64, u64> = HashMap::new();
264
265    while let Some(msg) = rx.recv().await {
266        let chat_id = msg.chat_id;
267        let channel_id = ChannelId::new(chat_id as u64);
268        let content = msg.content.clone();
269        let sender: CompactString = format!("dc:{}", msg.sender_id).into();
270
271        // Drop messages from sibling Walrus bots.
272        if known_bots.read().await.contains(&sender) {
273            tracing::debug!(%sender, chat_id, "dropping message from known bot");
274            continue;
275        }
276
277        tracing::info!(%agent, chat_id, "discord dispatch");
278
279        // Bot command path.
280        if content.starts_with('/') {
281            match parse_command(&content) {
282                Some(cmd) => {
283                    let h = http.clone();
284                    let om = on_message.clone();
285                    tokio::spawn(async move {
286                        crate::discord::command::dispatch_command(cmd, om, h, channel_id).await;
287                    });
288                }
289                None => {
290                    tracing::warn!(chat_id, content, "unrecognised bot command");
291                    let hint = "Unknown command. Available: /hub install <pkg>, /hub uninstall <pkg>, /model download <model>";
292                    crate::discord::send_text(&http, channel_id, hint.to_owned()).await;
293                }
294            }
295            continue;
296        }
297
298        // Normal agent chat path with session mapping.
299        let session = sessions.get(&chat_id).copied();
300
301        // Group chat: evaluate whether the agent should respond.
302        if msg.is_group && !should_respond(&on_message, &agent, &content, session, &sender).await {
303            tracing::debug!(%agent, chat_id, "agent declined to respond in group");
304            continue;
305        }
306
307        let client_msg = ClientMessage::Send {
308            agent: agent.clone(),
309            content: content.clone(),
310            session,
311            sender: Some(sender.clone()),
312        };
313        let mut reply_rx = on_message(client_msg).await;
314        let mut retry = false;
315        while let Some(server_msg) = reply_rx.recv().await {
316            match server_msg {
317                ServerMessage::Response(resp) => {
318                    sessions.insert(chat_id, resp.session);
319                    crate::discord::send_text(&http, channel_id, resp.content).await;
320                }
321                ServerMessage::Error { ref message, .. } if session.is_some() => {
322                    tracing::warn!(%agent, chat_id, "session error, retrying: {message}");
323                    sessions.remove(&chat_id);
324                    retry = true;
325                }
326                ServerMessage::Error { message, .. } => {
327                    tracing::warn!(%agent, chat_id, "dispatch error: {message}");
328                }
329                _ => {}
330            }
331        }
332
333        // Retry with a fresh session if the previous one was stale.
334        if retry {
335            let client_msg = ClientMessage::Send {
336                agent: agent.clone(),
337                content,
338                session: None,
339                sender: Some(sender),
340            };
341            let mut reply_rx = on_message(client_msg).await;
342            while let Some(server_msg) = reply_rx.recv().await {
343                match server_msg {
344                    ServerMessage::Response(resp) => {
345                        sessions.insert(chat_id, resp.session);
346                        crate::discord::send_text(&http, channel_id, resp.content).await;
347                    }
348                    ServerMessage::Error { message, .. } => {
349                        tracing::warn!(%agent, chat_id, "dispatch error on retry: {message}");
350                    }
351                    _ => {}
352                }
353            }
354        }
355    }
356
357    tracing::info!(platform = "discord", "channel loop ended");
358}
359
360/// Ask the daemon whether the agent should respond to a group message.
361///
362/// Dispatches `ClientMessage::Evaluate` and checks for
363/// `ServerMessage::Evaluation { respond }`. Falls back to `true` on any
364/// unexpected response or error so the agent still responds if evaluation
365/// fails.
366async fn should_respond<C, CFut>(
367    on_message: &Arc<C>,
368    agent: &CompactString,
369    content: &str,
370    session: Option<u64>,
371    sender: &CompactString,
372) -> bool
373where
374    C: Fn(ClientMessage) -> CFut + Send + Sync + 'static,
375    CFut: Future<Output = mpsc::UnboundedReceiver<ServerMessage>> + Send + 'static,
376{
377    let eval_msg = ClientMessage::Evaluate {
378        agent: agent.clone(),
379        content: content.to_owned(),
380        session,
381        sender: Some(sender.clone()),
382    };
383    let mut rx = on_message(eval_msg).await;
384    match rx.recv().await {
385        Some(ServerMessage::Evaluation { respond }) => respond,
386        _ => {
387            tracing::warn!(%agent, "evaluate returned unexpected response, defaulting to respond");
388            true
389        }
390    }
391}