Skip to main content

robson_core/
plugin.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use regex::Regex;
8use sea_orm::DatabaseConnection;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use crate::llm::LlmProvider;
13use crate::AppState;
14
15/// Carries events emitted by Workers back through the SensoriumLoop to Gateways.
16#[derive(Debug, Clone)]
17pub struct ProcessEvent {
18    pub id: Uuid,
19    pub kind: ProcessEventKind,
20    pub content: String,
21}
22
23#[derive(Debug, Clone)]
24pub enum ProcessEventKind {
25    Started,
26    Progress,
27    Completed,
28    Failed,
29}
30
31impl ProcessEventKind {
32    pub fn as_str(&self) -> &'static str {
33        match self {
34            ProcessEventKind::Started => "started",
35            ProcessEventKind::Progress => "progress",
36            ProcessEventKind::Completed => "completed",
37            ProcessEventKind::Failed => "failed",
38        }
39    }
40}
41
42/// Carries messages between Gateways and the SensoriumLoop.
43#[derive(Debug, Clone)]
44pub struct MessageEvent {
45    pub id: Uuid,
46    pub kind: MessageEventKind,
47    pub content: String,
48    pub channel_id: String,
49    pub user_id: String,
50    pub thread_ts: Option<String>,
51    /// The `conversations.id` that originated this event.
52    /// Set by SensoriumLoop when dispatching inbound messages; used for routing
53    /// `ProcessEvent` deliveries back to the correct gateway channel.
54    pub conversation_id: Option<i32>,
55}
56
57#[derive(Debug, Clone)]
58pub enum MessageEventKind {
59    Received,
60    Delivered,
61}
62
63impl ProcessEvent {
64    pub fn started(content: impl Into<String>) -> Self {
65        let content = content.into();
66        debug!(kind = "started", preview = %&content[..content.len().min(80)], "ProcessEvent::started");
67        Self {
68            id: Uuid::new_v4(),
69            kind: ProcessEventKind::Started,
70            content,
71        }
72    }
73    pub fn progress(content: impl Into<String>) -> Self {
74        let content = content.into();
75        debug!(kind = "progress", preview = %&content[..content.len().min(80)], "ProcessEvent::progress");
76        Self {
77            id: Uuid::new_v4(),
78            kind: ProcessEventKind::Progress,
79            content,
80        }
81    }
82    pub fn completed(content: impl Into<String>) -> Self {
83        let content = content.into();
84        debug!(kind = "completed", preview = %&content[..content.len().min(80)], "ProcessEvent::completed");
85        Self {
86            id: Uuid::new_v4(),
87            kind: ProcessEventKind::Completed,
88            content,
89        }
90    }
91    pub fn failed(content: impl Into<String>) -> Self {
92        let content = content.into();
93        debug!(kind = "failed", preview = %&content[..content.len().min(80)], "ProcessEvent::failed");
94        Self {
95            id: Uuid::new_v4(),
96            kind: ProcessEventKind::Failed,
97            content,
98        }
99    }
100}
101
102impl MessageEvent {
103    pub fn received(
104        content: impl Into<String>,
105        channel_id: impl Into<String>,
106        user_id: impl Into<String>,
107    ) -> Self {
108        Self {
109            id: Uuid::new_v4(),
110            kind: MessageEventKind::Received,
111            content: content.into(),
112            channel_id: channel_id.into(),
113            user_id: user_id.into(),
114            thread_ts: None,
115            conversation_id: None,
116        }
117    }
118
119    pub fn with_thread_ts(mut self, thread_ts: impl Into<String>) -> Self {
120        self.thread_ts = Some(thread_ts.into());
121        self
122    }
123
124    /// Set the target channel for delivery routing (gateway_channel_id from the conversation).
125    pub fn with_channel(mut self, channel_id: impl Into<String>) -> Self {
126        self.channel_id = channel_id.into();
127        self
128    }
129
130    /// Attach the originating conversation id for downstream routing.
131    pub fn with_conversation_id(mut self, id: i32) -> Self {
132        self.conversation_id = Some(id);
133        self
134    }
135
136    pub fn delivered(content: impl Into<String>) -> Self {
137        Self {
138            id: Uuid::new_v4(),
139            kind: MessageEventKind::Delivered,
140            content: content.into(),
141            channel_id: String::new(),
142            user_id: String::new(),
143            thread_ts: None,
144            conversation_id: None,
145        }
146    }
147}
148
149/// Common interface for all agent gateways (TUI, Slack, etc.).
150///
151/// Implementors provide access to their config, shared state, and optional LLM.
152/// The `on_message` method is called whenever an inbound message is received;
153/// the default implementation saves it to the `conversations` table.
154#[async_trait]
155pub trait AgentGateway: Send + Sync {
156    type Config: Send + Sync;
157
158    fn config(&self) -> &Self::Config;
159    fn state(&self) -> Option<&AppState>;
160    fn llm(&self) -> Option<&dyn LlmProvider>;
161
162    async fn on_message(&self, event: MessageEvent) -> Result<()> {
163        use crate::entities::conversation::{ConversationRole, Model as Conversation};
164        if let Some(state) = self.state() {
165            let thread = event.thread_ts.as_deref().unwrap_or("");
166            Conversation::insert(
167                &state.db,
168                None, // gateway_id unknown in the legacy AgentGateway path
169                &event.channel_id,
170                thread,
171                &event.user_id,
172                ConversationRole::User,
173                &event.content,
174            )
175            .await?;
176        }
177        Ok(())
178    }
179}
180
181/// A Worker receives a MessageEvent from the SensoriumLoop and handles it.
182///
183/// # Contract
184/// - The worker writes any `ProcessEvent`s it generates to the `process_events` table via `db`.
185/// - Return `Ok(true)` if the message was accepted and handled, `Ok(false)` to pass.
186#[async_trait]
187pub trait Worker: Send + Sync {
188    /// Human-readable name used in log output to identify this worker.
189    fn name(&self) -> &'static str;
190    /// One-line description shown in /help output.
191    fn description(&self) -> &'static str;
192    /// Example invocation shown in /help output (e.g. "`/task list`").
193    fn example(&self) -> &'static str;
194    async fn handle(
195        &self,
196        db: DatabaseConnection,
197        msg: MessageEvent,
198        args: HashMap<String, String>,
199    ) -> Result<bool>;
200}
201
202/// Parse key=value pairs from a Slack message content string.
203/// Strips a leading bot-mention (`<@Uxxxx>`) and the command token before parsing.
204/// Keys are lowercased. Quoted values (`key="value with spaces"`) are supported.
205pub fn parse_kv(content: &str) -> HashMap<String, String> {
206    let s = content.trim();
207    let s = if s.starts_with("<@") {
208        s.find('>').map(|i| s[i + 1..].trim_start()).unwrap_or(s)
209    } else {
210        s
211    };
212    let s = if s.starts_with('/') {
213        match s.split_once(|c: char| c.is_whitespace()) {
214            Some((_, rest)) => rest.trim(),
215            None => "",
216        }
217    } else {
218        s
219    };
220    let s = {
221        let trimmed = s.trim();
222        match trimmed.split_once(|c: char| c.is_whitespace()) {
223            Some((first, rest)) if !first.contains('=') => rest.trim(),
224            _ => trimmed,
225        }
226    };
227    parse_kv_pairs(s)
228}
229
230fn parse_kv_pairs(input: &str) -> HashMap<String, String> {
231    let mut map = HashMap::new();
232    let mut chars = input.chars().peekable();
233    loop {
234        // skip whitespace
235        while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
236            chars.next();
237        }
238        if chars.peek().is_none() {
239            break;
240        }
241        // collect key
242        let mut key = String::new();
243        let mut has_eq = false;
244        loop {
245            match chars.peek() {
246                Some(&'=') => {
247                    chars.next();
248                    has_eq = true;
249                    break;
250                }
251                Some(&c) if c.is_whitespace() => break,
252                Some(&c) => {
253                    key.push(c);
254                    chars.next();
255                }
256                None => break,
257            }
258        }
259        if key.is_empty() {
260            break;
261        }
262        if !has_eq {
263            // bare token with no '=', skip it
264            continue;
265        }
266        // collect value
267        let value = match chars.peek() {
268            Some(&'"') => {
269                chars.next();
270                let mut v = String::new();
271                loop {
272                    match chars.next() {
273                        Some('"') => break,
274                        Some(c) => v.push(c),
275                        None => break,
276                    }
277                }
278                v
279            }
280            _ => {
281                let mut v = String::new();
282                while chars.peek().map(|c| !c.is_whitespace()).unwrap_or(false) {
283                    v.push(chars.next().unwrap());
284                }
285                v
286            }
287        };
288        map.insert(key.to_lowercase(), value);
289    }
290    map
291}
292
293/// Builds the formatted /help response from the registered workers list.
294pub fn build_help_response(workers: &[WorkerRegistration]) -> String {
295    let mut out = String::from(":book: *Robson \u{2014} Available Commands*\n\n");
296    for reg in workers {
297        out.push_str(&format!(
298            "{} \u{2014} {}\n",
299            reg.worker.example(),
300            reg.worker.description()
301        ));
302    }
303    out
304}
305
306pub struct WorkerRegistration {
307    pub pattern: Regex,
308    pub worker: Arc<dyn Worker>,
309}
310
311/// A Gateway bridges the sensorium loop with an external communication channel
312/// (Slack, HTTP webhook, chat TUI, etc).
313#[async_trait]
314pub trait Gateway: Send + Sync {
315    /// Unique, stable identifier for this gateway (e.g. "slack", "tui").
316    /// Used as the key in `process_event_deliveries` to track per-gateway delivery state.
317    fn name(&self) -> &'static str;
318
319    /// Deliver a message to the external system (called for each ProcessEvent).
320    async fn send(&self, msg: MessageEvent) -> Result<()>;
321
322    /// Start listening for inbound messages, writing them directly to the `conversations` table.
323    /// This is expected to run until the gateway shuts down.
324    async fn start(&self, db: DatabaseConnection) -> Result<()>;
325}
326
327/// The SensoriumLoop orchestrates workers and gateways.
328///
329/// Flow:
330///   Gateway.start(db) → writes to conversations (processed=false)
331///   → Loop 1: polls conversations → matches message content against registered worker patterns
332///             in registration order; first match wins and is dispatched via tokio::spawn
333///   → Workers write ProcessEvents to process_events table
334///   → Loop 2: polls process_event_deliveries → retries delivery per gateway
335pub struct SensoriumLoop {
336    workers: Vec<WorkerRegistration>,
337    gateways: Vec<Arc<dyn Gateway>>,
338    conversations_poll_interval_secs: u64,
339    process_event_deliveries_poll_interval_secs: u64,
340    help_re: Regex,
341}
342
343impl Default for SensoriumLoop {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349impl SensoriumLoop {
350    pub fn new() -> Self {
351        info!("SensoriumLoop created");
352        Self {
353            workers: Vec::new(),
354            gateways: Vec::new(),
355            conversations_poll_interval_secs: 1,
356            process_event_deliveries_poll_interval_secs: 1,
357            help_re: Regex::new(r"(?i)(^<@\S+>\s*)?/help\b").expect("help regex is valid"),
358        }
359    }
360
361    pub fn with_conversations_poll_interval(mut self, secs: u64) -> Self {
362        info!(
363            conversations_poll_interval_secs = secs,
364            "SensoriumLoop conversations poll interval set"
365        );
366        self.conversations_poll_interval_secs = secs;
367        self
368    }
369
370    pub fn with_process_event_deliveries_poll_interval(mut self, secs: u64) -> Self {
371        info!(
372            process_event_deliveries_poll_interval_secs = secs,
373            "SensoriumLoop deliveries poll interval set"
374        );
375        self.process_event_deliveries_poll_interval_secs = secs;
376        self
377    }
378
379    pub fn register_worker(&mut self, pattern: &str, worker: Arc<dyn Worker>) -> Result<()> {
380        let compiled = Regex::new(pattern)
381            .map_err(|e| anyhow::anyhow!("invalid worker pattern {:?}: {}", pattern, e))?;
382        let help_probes = ["/help", "/HELP", "<@U1234> /help", "<@U1234> /HELP"];
383        if help_probes.iter().any(|s| compiled.is_match(s)) {
384            return Err(anyhow::anyhow!(
385                "worker pattern {:?} collides with the reserved /help command",
386                pattern
387            ));
388        }
389        if let Some(existing) = self.workers.iter().find(|r| r.pattern.as_str() == pattern) {
390            return Err(anyhow::anyhow!(
391                "worker pattern {:?} already registered by worker {:?}",
392                pattern,
393                existing.worker.name()
394            ));
395        }
396        debug!(
397            worker = worker.name(),
398            pattern,
399            total_workers = self.workers.len() + 1,
400            "Worker registered"
401        );
402        self.workers.push(WorkerRegistration {
403            pattern: compiled,
404            worker,
405        });
406        Ok(())
407    }
408
409    pub fn workers(&self) -> &[WorkerRegistration] {
410        &self.workers
411    }
412
413    pub fn register_gateway(&mut self, gateway: Arc<dyn Gateway>) {
414        let idx = self.gateways.len();
415        debug!(
416            gateway = gateway.name(),
417            gateway_index = idx,
418            total_gateways = idx + 1,
419            "Gateway registered"
420        );
421        self.gateways.push(gateway);
422    }
423
424    pub async fn run(self, db: DatabaseConnection) -> Result<()> {
425        use crate::entities::conversation::Model as Conversation;
426        use crate::entities::process_event::Model as ProcessEventModel;
427        use crate::entities::process_event_deliveries::Model as Delivery;
428
429        let workers = Arc::new(self.workers);
430        let gateways = Arc::new(self.gateways);
431        let help_re = Arc::new(self.help_re);
432        let conversations_poll_interval =
433            Duration::from_secs(self.conversations_poll_interval_secs);
434        let deliveries_poll_interval =
435            Duration::from_secs(self.process_event_deliveries_poll_interval_secs);
436
437        info!(
438            worker_count = workers.len(),
439            gateway_count = gateways.len(),
440            conversations_poll_interval_secs = self.conversations_poll_interval_secs,
441            process_event_deliveries_poll_interval_secs =
442                self.process_event_deliveries_poll_interval_secs,
443            "SensoriumLoop starting"
444        );
445
446        // Start all gateways — each writes inbound messages to the conversations table
447        for (idx, gateway) in gateways.iter().enumerate() {
448            let gateway = gateway.clone();
449            let db_clone = db.clone();
450            debug!(
451                gateway = gateway.name(),
452                gateway_index = idx,
453                "Spawning gateway listener"
454            );
455            tokio::spawn(async move {
456                if let Err(e) = gateway.start(db_clone).await {
457                    error!(error = %e, gateway_index = idx, "gateway stopped with error");
458                }
459            });
460        }
461
462        info!(
463            gateway_count = gateways.len(),
464            "All gateways spawned, starting poll loops"
465        );
466
467        // Loop 2: delivery poll — forwards process_events to gateways with per-gateway retry/backoff
468        let gateways_delivery = gateways.clone();
469        let db_delivery = db.clone();
470        tokio::spawn(async move {
471            loop {
472                // Ensure every undelivered process_event has a delivery row per registered gateway,
473                // then attempt delivery for rows whose retry window has elapsed.
474                let undelivered_events = match ProcessEventModel::find_undelivered(&db_delivery)
475                    .await
476                {
477                    Ok(rows) => rows,
478                    Err(e) => {
479                        warn!(error = %e, "delivery poll: failed to query undelivered process_events");
480                        tokio::time::sleep(deliveries_poll_interval).await;
481                        continue;
482                    }
483                };
484
485                for event in &undelivered_events {
486                    // Resolve the originating conversation and its gateway
487                    let conv = match Conversation::find_by_id(&db_delivery, event.conversation_id)
488                        .await
489                    {
490                        Ok(Some(c)) => c,
491                        Ok(None) => {
492                            debug!(
493                                process_event_id = event.id,
494                                conversation_id = event.conversation_id,
495                                "delivery poll: conversation not found, skipping"
496                            );
497                            continue;
498                        }
499                        Err(e) => {
500                            warn!(error = %e, process_event_id = event.id, "delivery poll: failed to load conversation");
501                            continue;
502                        }
503                    };
504
505                    // Determine target gateway: use conversation's gateway_id if set,
506                    // otherwise broadcast to all registered gateways (legacy path).
507                    let target_gateways: Vec<_> = match conv.gateway_id {
508                        Some(gw_id) => {
509                            use crate::entities::gateway::Model as GatewayModel;
510                            match GatewayModel::find_by_id(&db_delivery, gw_id).await {
511                                Ok(Some(gw_row)) => gateways_delivery
512                                    .iter()
513                                    .filter(|g| g.name() == gw_row.name)
514                                    .cloned()
515                                    .collect(),
516                                _ => {
517                                    warn!(
518                                        process_event_id = event.id,
519                                        gateway_id = gw_id,
520                                        "delivery poll: gateway row not found, skipping"
521                                    );
522                                    continue;
523                                }
524                            }
525                        }
526                        None => gateways_delivery.iter().cloned().collect(),
527                    };
528
529                    for gateway in target_gateways.iter() {
530                        let gateway_name = gateway.name();
531
532                        // Ensure a delivery row exists for this (event, gateway) pair
533                        let delivery_id =
534                            match Delivery::upsert_pending(&db_delivery, event.id, gateway_name)
535                                .await
536                            {
537                                Ok(id) => id,
538                                Err(e) => {
539                                    warn!(
540                                        error = %e,
541                                        process_event_id = event.id,
542                                        gateway = gateway_name,
543                                        "delivery poll: failed to upsert delivery record"
544                                    );
545                                    continue;
546                                }
547                            };
548
549                        // Check eligibility: undelivered and retry window elapsed
550                        let pending =
551                            match Delivery::find_pending_for_gateway(&db_delivery, gateway_name)
552                                .await
553                            {
554                                Ok(rows) => rows,
555                                Err(e) => {
556                                    warn!(
557                                        error = %e,
558                                        gateway = gateway_name,
559                                        "delivery poll: failed to find pending deliveries"
560                                    );
561                                    continue;
562                                }
563                            };
564
565                        let is_eligible = pending.iter().any(|r| r.id == delivery_id);
566                        if !is_eligible {
567                            debug!(
568                                process_event_id = event.id,
569                                gateway = gateway_name,
570                                delivery_id,
571                                "delivery poll: skipping — within backoff window"
572                            );
573                            continue;
574                        }
575
576                        // Attempt delivery; pass gateway_channel_id so the gateway
577                        // routes to the originating channel rather than a default.
578                        let msg = MessageEvent::delivered(event.content.clone())
579                            .with_channel(conv.gateway_channel_id.clone());
580                        match gateway.send(msg).await {
581                            Ok(_) => {
582                                debug!(
583                                    process_event_id = event.id,
584                                    gateway = gateway_name,
585                                    "delivery poll: delivered successfully"
586                                );
587                                if let Err(e) =
588                                    Delivery::mark_delivered(&db_delivery, delivery_id).await
589                                {
590                                    warn!(error = %e, delivery_id, "delivery poll: failed to mark delivery record");
591                                }
592
593                                // Check if all gateways have now delivered this event
594                                match Delivery::count_pending_for_event(&db_delivery, event.id)
595                                    .await
596                                {
597                                    Ok(0) => {
598                                        if let Err(e) = ProcessEventModel::mark_delivered(
599                                            &db_delivery,
600                                            event.id,
601                                        )
602                                        .await
603                                        {
604                                            warn!(
605                                                error = %e,
606                                                process_event_id = event.id,
607                                                "delivery poll: failed to mark process_event delivered"
608                                            );
609                                        } else {
610                                            info!(
611                                                process_event_id = event.id,
612                                                "delivery poll: all gateways delivered — process_event marked done"
613                                            );
614                                        }
615                                    }
616                                    Ok(pending_count) => {
617                                        debug!(
618                                            process_event_id = event.id,
619                                            pending_count,
620                                            "delivery poll: still waiting on other gateways"
621                                        );
622                                    }
623                                    Err(e) => {
624                                        warn!(
625                                            error = %e,
626                                            process_event_id = event.id,
627                                            "delivery poll: failed to count pending deliveries"
628                                        );
629                                    }
630                                }
631                            }
632                            Err(e) => {
633                                warn!(
634                                    error = %e,
635                                    process_event_id = event.id,
636                                    gateway = gateway_name,
637                                    "delivery poll: delivery failed, scheduling retry"
638                                );
639                                if let Err(re) = Delivery::record_failure(
640                                    &db_delivery,
641                                    delivery_id,
642                                    &format!("{:#}", e),
643                                )
644                                .await
645                                {
646                                    warn!(error = %re, delivery_id, "delivery poll: failed to record failure");
647                                }
648                            }
649                        }
650                    }
651                }
652
653                tokio::time::sleep(deliveries_poll_interval).await;
654            }
655        });
656
657        // Loop 1: conversations poll — dispatches inbound messages to workers
658        loop {
659            let unprocessed = match Conversation::find_unprocessed(&db).await {
660                Ok(rows) => rows,
661                Err(e) => {
662                    warn!(error = %e, "conversations poll: failed to query unprocessed conversations");
663                    tokio::time::sleep(conversations_poll_interval).await;
664                    continue;
665                }
666            };
667
668            if !unprocessed.is_empty() {
669                debug!(
670                    count = unprocessed.len(),
671                    "conversations poll: found unprocessed rows"
672                );
673            }
674
675            for row in unprocessed {
676                let conversation_id = row.id;
677                let msg = MessageEvent {
678                    id: Uuid::new_v4(),
679                    kind: MessageEventKind::Received,
680                    content: row.content.clone(),
681                    channel_id: row.gateway_channel_id.clone(),
682                    user_id: row.user_id.clone(),
683                    thread_ts: if row.thread_ts.is_empty() {
684                        None
685                    } else {
686                        Some(row.thread_ts.clone())
687                    },
688                    conversation_id: Some(conversation_id),
689                };
690
691                info!(
692                    conversation_id,
693                    channel_id = %row.gateway_channel_id,
694                    user_id = %row.user_id,
695                    thread_ts = ?msg.thread_ts,
696                    preview = %&row.content[..row.content.len().min(80)],
697                    "Conversation picked up for dispatch"
698                );
699
700                // Intercept /help before dispatching to workers
701                if help_re.is_match(&msg.content) {
702                    let response = build_help_response(&workers);
703                    if let Some(conv_id) = msg.conversation_id {
704                        let db_h = db.clone();
705                        tokio::spawn(async move {
706                            if let Err(e) = crate::entities::process_event::Model::insert(
707                                &db_h,
708                                conv_id,
709                                ProcessEventKind::Completed.as_str(),
710                                &response,
711                            )
712                            .await
713                            {
714                                warn!(error = %e, conversation_id = conv_id, "help: failed to persist process event");
715                            }
716                        });
717                    }
718                    if let Err(e) = Conversation::mark_processed(&db, conversation_id).await {
719                        warn!(error = %e, conversation_id, "conversations poll: failed to mark conversation as processed");
720                    }
721                    continue;
722                }
723
724                // Mark processed before dispatching so a restart won't re-deliver
725                if let Err(e) = Conversation::mark_processed(&db, conversation_id).await {
726                    warn!(error = %e, conversation_id, "conversations poll: failed to mark conversation as processed");
727                    continue;
728                }
729
730                // Find first worker whose pattern matches message content
731                let matched = workers.iter().find(|r| r.pattern.is_match(&msg.content));
732
733                match matched {
734                    Some(registration) => {
735                        let args = parse_kv(&msg.content);
736                        let worker = registration.worker.clone();
737                        let worker_name = worker.name();
738                        debug!(
739                            worker = worker_name,
740                            conversation_id,
741                            pattern = registration.pattern.as_str(),
742                            "Dispatching to matched worker"
743                        );
744                        let db_w = db.clone();
745                        let msg = msg.clone();
746                        tokio::spawn(async move {
747                            match worker.handle(db_w, msg, args).await {
748                                Ok(_) => {
749                                    info!(worker = worker_name, conversation_id, "Worker completed")
750                                }
751                                Err(e) => {
752                                    error!(error = %e, worker = worker_name, conversation_id, "Worker failed")
753                                }
754                            }
755                        });
756                    }
757                    None => {
758                        warn!(conversation_id, preview = %&row.content[..row.content.len().min(80)], "No worker matched message, skipping");
759                    }
760                }
761            }
762
763            tokio::time::sleep(conversations_poll_interval).await;
764        }
765    }
766}
767
768#[cfg(test)]
769#[path = "plugin_tests.rs"]
770mod plugin_tests;