Skip to main content

sparrow/gateway/
mod.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tokio::sync::mpsc;
4use tokio::task::AbortHandle;
5
6use crate::engine::{Engine, Task};
7use crate::event::Event;
8use crate::runtime::recorder::{FsRecorder, Recorder, RunInputs};
9
10/// Active-run registry. Keyed by run_id, holds the `AbortHandle` of the
11/// spawned gateway task so `sparrow gateway abort <run>` can actually cancel
12/// it instead of just writing a signal file.
13#[derive(Default, Clone)]
14pub struct RunRegistry {
15    inner: Arc<Mutex<HashMap<String, AbortHandle>>>,
16}
17
18impl RunRegistry {
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    pub fn insert(&self, run_id: String, handle: AbortHandle) {
24        if let Ok(mut g) = self.inner.lock() {
25            g.insert(run_id, handle);
26        }
27    }
28
29    pub fn remove(&self, run_id: &str) {
30        if let Ok(mut g) = self.inner.lock() {
31            g.remove(run_id);
32        }
33    }
34
35    /// Cancel the run if known. Returns true when an abort was issued.
36    pub fn abort(&self, run_id: &str) -> bool {
37        if let Ok(mut g) = self.inner.lock() {
38            if let Some(h) = g.remove(run_id) {
39                h.abort();
40                return true;
41            }
42        }
43        false
44    }
45
46    pub fn active_run_ids(&self) -> Vec<String> {
47        self.inner
48            .lock()
49            .map(|g| g.keys().cloned().collect())
50            .unwrap_or_default()
51    }
52}
53
54#[cfg(test)]
55mod registry_tests {
56    use super::*;
57
58    #[tokio::test]
59    async fn abort_unknown_run_returns_false() {
60        let reg = RunRegistry::new();
61        assert!(!reg.abort("does-not-exist"));
62    }
63
64    #[tokio::test]
65    async fn abort_cancels_a_registered_task() {
66        let reg = RunRegistry::new();
67        let handle = tokio::spawn(async {
68            tokio::time::sleep(std::time::Duration::from_secs(30)).await;
69        });
70        reg.insert("r1".into(), handle.abort_handle());
71        assert!(reg.abort("r1"));
72        // After abort, awaiting yields a JoinError with `is_cancelled()`.
73        let res = handle.await;
74        assert!(res.is_err() && res.unwrap_err().is_cancelled());
75    }
76}
77
78pub mod discord;
79pub mod email;
80pub mod extra_transports;
81pub mod slack;
82pub mod telegram;
83pub mod ws;
84
85// ─── Gateway message types ──────────────────────────────────────────────────────
86
87#[derive(Debug, Clone)]
88pub struct GatewayMessage {
89    pub surface: String,
90    pub user_id: String,
91    pub chat_id: String,
92    pub text: String,
93    pub message_id: Option<String>,
94}
95
96#[derive(Debug, Clone)]
97pub struct GatewayResponse {
98    pub surface: String,
99    pub chat_id: String,
100    pub text: String,
101    pub reply_to: Option<String>,
102    pub buttons: Vec<Vec<String>>,
103}
104
105// ─── THE GATEWAY TRAIT ──────────────────────────────────────────────────────────
106
107#[async_trait::async_trait]
108pub trait GatewayTransport: Send + Sync {
109    fn name(&self) -> &str;
110    async fn start(&self, tx: mpsc::UnboundedSender<GatewayMessage>) -> anyhow::Result<()>;
111    async fn send(&self, response: GatewayResponse) -> anyhow::Result<()>;
112    async fn stop(&self) -> anyhow::Result<()>;
113}
114
115// ─── Message router: maps incoming messages to engine tasks ─────────────────────
116
117pub struct MessageRouter {
118    engine: Arc<Engine>,
119    recorder: Arc<FsRecorder>,
120    event_bus_tx: tokio::sync::broadcast::Sender<Event>,
121    allowed_users: Vec<String>,
122    /// Cross-surface session continuity (§8). Keyed by user identity so the same
123    /// user resumes the same conversation/context regardless of surface.
124    sessions: Option<Arc<crate::runtime::session::SessionStore>>,
125    /// Tracks every spawned gateway run so `gateway abort` can cancel it.
126    pub run_registry: RunRegistry,
127}
128
129impl MessageRouter {
130    pub fn new(
131        engine: Arc<Engine>,
132        recorder: Arc<FsRecorder>,
133        event_bus_tx: tokio::sync::broadcast::Sender<Event>,
134        allowed_users: Vec<String>,
135    ) -> Self {
136        Self {
137            engine,
138            recorder,
139            event_bus_tx,
140            allowed_users,
141            sessions: None,
142            run_registry: RunRegistry::new(),
143        }
144    }
145
146    /// Attach a session store to enable cross-surface conversation continuity.
147    pub fn with_sessions(mut self, sessions: Arc<crate::runtime::session::SessionStore>) -> Self {
148        self.sessions = Some(sessions);
149        self
150    }
151
152    /// Stable gateway session key. OpenClaw-style gateway continuity is scoped
153    /// by surface + channel/account + peer, so a user can have separate sessions
154    /// in separate channels while still surviving restarts.
155    pub fn session_key(msg_user_id: &str, surface: &str, chat_id: &str) -> String {
156        let surface = session_component(surface, "surface");
157        let chat = session_component(chat_id, "channel");
158        let user = session_component(msg_user_id, "anonymous");
159        format!("gateway:{}:channel:{}:peer:{}", surface, chat, user)
160    }
161
162    /// Route an incoming message: parse command, submit to engine, return response
163    pub async fn route(
164        &self,
165        msg: GatewayMessage,
166        responses: &mpsc::UnboundedSender<GatewayResponse>,
167    ) {
168        // Check user authorization
169        if !self.allowed_users.is_empty() && !self.allowed_users.contains(&msg.user_id) {
170            let _ = responses.send(GatewayResponse {
171                surface: msg.surface.clone(),
172                chat_id: msg.chat_id.clone(),
173                text: "Unauthorized. Ask the admin to add your user ID.".into(),
174                reply_to: msg.message_id,
175                buttons: vec![],
176            });
177            return;
178        }
179
180        let text = msg.text.trim();
181        let surface = msg.surface.clone();
182        let chat_id = msg.chat_id.clone();
183        let user_id = msg.user_id.clone();
184        let reply_to = msg.message_id.clone();
185
186        if text.is_empty() {
187            return;
188        }
189
190        // Command parsing
191        if text.starts_with('/') {
192            self.handle_command(text, surface, chat_id, user_id, reply_to, responses)
193                .await;
194        } else {
195            self.handle_task(text, surface, chat_id, user_id, reply_to, responses)
196                .await;
197        }
198    }
199
200    async fn handle_command(
201        &self,
202        text: &str,
203        surface: String,
204        chat_id: String,
205        user_id: String,
206        reply_to: Option<String>,
207        responses: &mpsc::UnboundedSender<GatewayResponse>,
208    ) {
209        let parts: Vec<&str> = text.splitn(2, ' ').collect();
210        let cmd = parts[0].to_lowercase();
211        let args = parts.get(1).unwrap_or(&"");
212
213        match cmd.as_str() {
214            "/start" | "/help" => {
215                let _ = responses.send(GatewayResponse {
216                    surface,
217                    chat_id,
218                    text: format!(
219                        "Sparrow — one cli · grows with you\n\n\
220                         Commands:\n\
221                         /run <task> — Execute a task\n\
222                         /status — Show engine status\n\
223                         /models — List configured models\n\
224                         /budget — Show budget status\n\
225                         /help — This message\n\n\
226                         Or just send a message to start a task."
227                    ),
228                    reply_to,
229                    buttons: vec![vec!["/run ".into(), "/status".into()]],
230                });
231            }
232            "/run" => {
233                if args.is_empty() {
234                    let _ = responses.send(GatewayResponse {
235                        surface,
236                        chat_id,
237                        text: "Usage: /run <task description>".into(),
238                        reply_to,
239                        buttons: vec![],
240                    });
241                    return;
242                }
243                self.handle_task(args, surface, chat_id, user_id, reply_to, responses)
244                    .await;
245            }
246            "/reset" => {
247                // Clear the cross-surface session for this user
248                if let Some(sessions) = &self.sessions {
249                    let key = Self::session_key(&user_id, &surface, &chat_id);
250                    let _ = sessions.delete(&key);
251                }
252                let _ = responses.send(GatewayResponse {
253                    surface,
254                    chat_id,
255                    text: "Session cleared. Next message starts fresh.".into(),
256                    reply_to,
257                    buttons: vec![],
258                });
259            }
260            "/status" => {
261                let _ = responses.send(GatewayResponse {
262                    surface,
263                    chat_id,
264                    text: "Engine: online\nMode: headless".into(),
265                    reply_to,
266                    buttons: vec![],
267                });
268            }
269            "/models" => {
270                let _ = responses.send(GatewayResponse {
271                    surface,
272                    chat_id,
273                    text: "Use 'sparrow model --list' in CLI for model listing.".into(),
274                    reply_to,
275                    buttons: vec![],
276                });
277            }
278            "/budget" => {
279                let _ = responses.send(GatewayResponse {
280                    surface,
281                    chat_id,
282                    text: "Budget: configured in ~/.config/sparrow/config.toml".into(),
283                    reply_to,
284                    buttons: vec![],
285                });
286            }
287            _ => {
288                let _ = responses.send(GatewayResponse {
289                    surface,
290                    chat_id,
291                    text: format!("Unknown command: {}. Try /help", cmd),
292                    reply_to,
293                    buttons: vec![],
294                });
295            }
296        }
297    }
298
299    async fn handle_task(
300        &self,
301        text: &str,
302        surface: String,
303        chat_id: String,
304        user_id: String,
305        reply_to: Option<String>,
306        responses: &mpsc::UnboundedSender<GatewayResponse>,
307    ) {
308        let task_text = text.to_string();
309        let resp_tx = responses.clone();
310        let cid = chat_id.clone();
311        let surface_for_done = surface.clone();
312
313        // Clone for second spawn
314        let resp_tx2 = resp_tx.clone();
315        let cid2 = cid.clone();
316        let surface_for_stream = surface.clone();
317        let reply_to2 = reply_to.clone();
318
319        // ── Session continuity (§8) ───────────────────────────────────────────
320        // Load prior conversation for this user so context follows them across
321        // surfaces and survives gateway restarts.
322        let session_key = Self::session_key(&user_id, &surface, &chat_id);
323        let prior_msgs: Vec<crate::provider::Msg> = self
324            .sessions
325            .as_ref()
326            .and_then(|s| s.load(&session_key))
327            .and_then(|sess| serde_json::from_str(&sess.messages_json).ok())
328            .unwrap_or_default();
329        let sessions_for_save = self.sessions.clone();
330        let session_key_save = session_key.clone();
331        let prior_for_save = prior_msgs.clone();
332
333        // Create a one-shot event stream for this task
334        let (task_tx, mut task_rx) = mpsc::unbounded_channel::<Event>();
335        let event_bus = self.event_bus_tx.clone();
336        let engine = self.engine.clone();
337        let recorder = self.recorder.clone();
338
339        // Send initial "thinking" response
340        let _ = resp_tx.send(GatewayResponse {
341            surface: surface.clone(),
342            chat_id: cid.clone(),
343            text: format!("Working on: {}", &task_text[..task_text.len().min(80)]),
344            reply_to: reply_to.clone(),
345            buttons: vec![],
346        });
347
348        // Start recording
349        let run_id = uuid::Uuid::new_v4().to_string();
350        recorder.start_run(
351            run_id.clone(),
352            RunInputs {
353                task: task_text.clone(),
354                config_snapshot: serde_json::json!({}),
355                model_id: "gateway".into(),
356                repo_head: None,
357                timestamp: chrono::Utc::now().to_rfc3339(),
358                agent: "gateway".into(),
359            },
360        );
361
362        let registry = self.run_registry.clone();
363        let run_id_for_dereg = run_id.clone();
364        let drive_handle = tokio::spawn(async move {
365            let task = Task {
366                description: task_text.clone(),
367                context: prior_msgs,
368            };
369
370            match engine.drive(task, task_tx.clone()).await {
371                Ok(outcome) => {
372                    let _ = event_bus.send(Event::RunFinished {
373                        run: crate::event::RunId(run_id.clone()),
374                        outcome: outcome.clone(),
375                    });
376                    let _ = recorder.finalize(&run_id);
377                    let _ = resp_tx.send(GatewayResponse {
378                        surface: surface_for_done,
379                        chat_id: cid.clone(),
380                        text: format!(
381                            "Done.\nStatus: {}\nCost: ${:.4}\nFiles: {}",
382                            outcome.status,
383                            outcome.cost_usd,
384                            outcome.diffs.len()
385                        ),
386                        reply_to: reply_to.clone(),
387                        buttons: vec![],
388                    });
389                }
390                Err(e) => {
391                    let _ = resp_tx.send(GatewayResponse {
392                        surface: surface_for_done,
393                        chat_id: cid,
394                        text: format!("Error: {}", e),
395                        reply_to: reply_to2,
396                        buttons: vec![],
397                    });
398                }
399            }
400
401            drop(task_tx);
402        });
403        self.run_registry
404            .insert(run_id_for_dereg.clone(), drive_handle.abort_handle());
405        // Auto-deregister on completion so the registry doesn't grow unbounded.
406        {
407            let registry_for_dereg = registry.clone();
408            tokio::spawn(async move {
409                let _ = drive_handle.await;
410                registry_for_dereg.remove(&run_id_for_dereg);
411            });
412        }
413
414        // Stream intermediate updates
415        let user_task_text = text.to_string();
416        tokio::spawn(async move {
417            let mut buffer = String::new();
418            let mut full_reply = String::new();
419            let mut reasoning_reply = String::new();
420            while let Some(event) = task_rx.recv().await {
421                if let Event::ThinkingDelta { text, .. } = &event {
422                    full_reply.push_str(text);
423                }
424                if let Event::ReasoningDelta { text, .. } = &event {
425                    reasoning_reply.push_str(text);
426                }
427                match &event {
428                    Event::ThinkingDelta { text, .. } => {
429                        buffer.push_str(text);
430                        if buffer.len() > 500 || buffer.contains('\n') {
431                            let _ = resp_tx2.send(GatewayResponse {
432                                surface: surface_for_stream.clone(),
433                                chat_id: cid2.clone(),
434                                text: buffer.clone(),
435                                reply_to: None,
436                                buttons: vec![],
437                            });
438                            buffer.clear();
439                        }
440                    }
441                    Event::ToolUseProposed { name, .. } => {
442                        if !buffer.is_empty() {
443                            let _ = resp_tx2.send(GatewayResponse {
444                                surface: surface_for_stream.clone(),
445                                chat_id: cid2.clone(),
446                                text: buffer.clone(),
447                                reply_to: None,
448                                buttons: vec![],
449                            });
450                            buffer.clear();
451                        }
452                        let _ = resp_tx2.send(GatewayResponse {
453                            surface: surface_for_stream.clone(),
454                            chat_id: cid2.clone(),
455                            text: format!("[Tool: {}]", name),
456                            reply_to: None,
457                            buttons: vec![],
458                        });
459                    }
460                    Event::ModelSwitched {
461                        from, to, reason, ..
462                    } => {
463                        if !buffer.is_empty() {
464                            let _ = resp_tx2.send(GatewayResponse {
465                                surface: surface_for_stream.clone(),
466                                chat_id: cid2.clone(),
467                                text: buffer.clone(),
468                                reply_to: None,
469                                buttons: vec![],
470                            });
471                            buffer.clear();
472                        }
473                        let clean = crate::event::friendly_model_switch_reason(reason);
474                        let text = if crate::event::is_local_model_unavailable(reason) {
475                            format!("modèle local indisponible → routage modèle cloud ({})", to)
476                        } else {
477                            format!("fallback: {} → {} ({})", from, to, clean)
478                        };
479                        let _ = resp_tx2.send(GatewayResponse {
480                            surface: surface_for_stream.clone(),
481                            chat_id: cid2.clone(),
482                            text,
483                            reply_to: None,
484                            buttons: vec![],
485                        });
486                    }
487                    Event::ApprovalRequested { summary, .. } => {
488                        let _ = resp_tx2.send(GatewayResponse {
489                            surface: surface_for_stream.clone(),
490                            chat_id: cid2.clone(),
491                            text: format!("Approval needed: {}", summary),
492                            reply_to: None,
493                            buttons: vec![vec!["/approve".into(), "/deny".into()]],
494                        });
495                    }
496                    _ => {}
497                }
498            }
499            if !buffer.is_empty() {
500                let _ = resp_tx2.send(GatewayResponse {
501                    surface: surface_for_stream,
502                    chat_id: cid2.clone(),
503                    text: buffer,
504                    reply_to: None,
505                    buttons: vec![],
506                });
507            }
508
509            // ── Persist the turn to the session (§8) ──────────────────────────
510            // Append the user message and the assistant reply so the next message
511            // — on any surface — resumes with full context.
512            if let Some(sessions) = &sessions_for_save {
513                let mut updated = prior_for_save;
514                updated.push(crate::provider::Msg {
515                    role: "user".into(),
516                    content: vec![crate::provider::ContentBlock::Text {
517                        text: user_task_text,
518                    }],
519                });
520                if !full_reply.trim().is_empty() {
521                    let mut content = Vec::new();
522                    if !reasoning_reply.trim().is_empty() {
523                        content.push(crate::provider::ContentBlock::Reasoning {
524                            text: reasoning_reply,
525                        });
526                    }
527                    content.push(crate::provider::ContentBlock::Text { text: full_reply });
528                    updated.push(crate::provider::Msg {
529                        role: "assistant".into(),
530                        content,
531                    });
532                }
533                // Cap session history to the last 40 messages to bound growth.
534                let len = updated.len();
535                if len > 40 {
536                    updated.drain(..len - 40);
537                }
538                let _ = sessions.save(&session_key_save, &updated, None);
539            }
540        });
541    }
542}
543
544// ─── Event formatter: Event → human-readable message ────────────────────────────
545
546pub fn format_event(event: &Event) -> Option<String> {
547    match event {
548        Event::RunStarted { task, agent, .. } => {
549            Some(format!("Started: {} (agent: {})", task, agent))
550        }
551        Event::RunFinished { outcome, .. } => Some(format!(
552            "Finished: {} | Cost: ${:.4} | Files: {}",
553            outcome.status,
554            outcome.cost_usd,
555            outcome.diffs.len()
556        )),
557        Event::ThinkingDelta { text, .. } => Some(text.clone()),
558        Event::ReasoningDelta { .. } => None,
559        Event::ModelSwitched {
560            from, to, reason, ..
561        } => {
562            let clean = crate::event::friendly_model_switch_reason(reason);
563            if crate::event::is_local_model_unavailable(reason) {
564                Some(format!(
565                    "modèle local indisponible → routage modèle cloud ({})",
566                    to
567                ))
568            } else {
569                Some(format!("Fallback: {} → {} ({})", from, to, clean))
570            }
571        }
572        Event::ToolUseProposed { name, .. } => Some(format!("[{}]", name)),
573        Event::ApprovalRequested { summary, .. } => Some(format!("Approve: {}", summary)),
574        Event::Error { message, .. } => {
575            if crate::event::is_local_model_unavailable(message) {
576                None
577            } else {
578                Some(format!("Error: {}", message))
579            }
580        }
581        Event::CostUpdate { usd, .. } => Some(format!("Cost: ${:.4}", usd)),
582        Event::CheckpointCreated { label, .. } => Some(format!("Checkpoint: {}", label)),
583        _ => None,
584    }
585}
586
587fn session_component(value: &str, fallback: &str) -> String {
588    let cleaned = value
589        .chars()
590        .map(|ch| {
591            if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') {
592                ch
593            } else {
594                '_'
595            }
596        })
597        .collect::<String>()
598        .trim_matches('_')
599        .to_string();
600    if cleaned.is_empty() {
601        fallback.to_string()
602    } else {
603        cleaned
604    }
605}