Skip to main content

mur_chat/
adapter.rs

1//! Chat Adapter — bridges SSE workflow events to chat platform messages.
2//!
3//! Subscribes to the EventBroadcaster and converts WorkflowEvents into
4//! chat messages sent via the ChatPlatform trait. Each workflow execution
5//! gets its own thread for organized conversation.
6
7use crate::platform::{
8    ChatPlatform, ProgressStatus, ProgressUpdate, WorkflowNotification,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::broadcast;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info};
15
16/// SSE event types that the adapter consumes.
17/// Mirrors mur_daemon::sse::WorkflowEvent but decoupled to avoid circular deps.
18#[derive(Debug, Clone)]
19pub enum WorkflowEvent {
20    Started {
21        execution_id: String,
22        workflow_id: String,
23        total_steps: usize,
24    },
25    StepStarted {
26        execution_id: String,
27        step_index: usize,
28        step_name: String,
29    },
30    StepCompleted {
31        execution_id: String,
32        step_index: usize,
33        step_name: String,
34        success: bool,
35        output: String,
36        duration_ms: u64,
37    },
38    BreakpointHit {
39        execution_id: String,
40        step_index: usize,
41        step_name: String,
42        message: String,
43    },
44    Completed {
45        execution_id: String,
46        success: bool,
47        steps_completed: usize,
48        total_steps: usize,
49        duration_ms: u64,
50        error: Option<String>,
51    },
52}
53
54/// Tracks per-execution thread state.
55#[derive(Debug, Clone)]
56struct ExecutionThread {
57    thread_id: String,
58    channel_id: String,
59    workflow_id: String,
60    total_steps: usize,
61}
62
63/// Chat adapter that converts workflow events to chat messages.
64pub struct ChatAdapter<P: ChatPlatform> {
65    platform: Arc<P>,
66    channel_id: String,
67    /// Maps execution_id → thread state.
68    threads: Arc<RwLock<HashMap<String, ExecutionThread>>>,
69    shadow_default: bool,
70}
71
72impl<P: ChatPlatform + 'static> ChatAdapter<P> {
73    pub fn new(platform: Arc<P>, channel_id: String) -> Self {
74        Self {
75            platform,
76            channel_id,
77            threads: Arc::new(RwLock::new(HashMap::new())),
78            shadow_default: false,
79        }
80    }
81
82    pub fn with_shadow_default(mut self, shadow: bool) -> Self {
83        self.shadow_default = shadow;
84        self
85    }
86
87    /// Start listening to a broadcast channel of workflow events.
88    /// Runs until the channel is closed or an error occurs.
89    pub async fn run(&self, mut rx: broadcast::Receiver<WorkflowEvent>) {
90        info!("ChatAdapter started, forwarding events to chat");
91
92        loop {
93            match rx.recv().await {
94                Ok(event) => {
95                    if let Err(e) = self.handle_event(event).await {
96                        error!("ChatAdapter error: {}", e);
97                    }
98                }
99                Err(broadcast::error::RecvError::Lagged(n)) => {
100                    debug!("ChatAdapter lagged by {} events", n);
101                }
102                Err(broadcast::error::RecvError::Closed) => {
103                    info!("ChatAdapter: event channel closed");
104                    break;
105                }
106            }
107        }
108    }
109
110    async fn handle_event(&self, event: WorkflowEvent) -> anyhow::Result<()> {
111        match event {
112            WorkflowEvent::Started {
113                execution_id,
114                workflow_id,
115                total_steps,
116            } => {
117                let thread_id = self
118                    .platform
119                    .start_thread(
120                        &self.channel_id,
121                        &execution_id,
122                        &workflow_id,
123                        total_steps,
124                        self.shadow_default,
125                    )
126                    .await?;
127
128                let mut threads = self.threads.write().await;
129                threads.insert(
130                    execution_id,
131                    ExecutionThread {
132                        thread_id,
133                        channel_id: self.channel_id.clone(),
134                        workflow_id,
135                        total_steps,
136                    },
137                );
138            }
139
140            WorkflowEvent::StepStarted {
141                execution_id,
142                step_index,
143                step_name,
144            } => {
145                let threads = self.threads.read().await;
146                if let Some(thread) = threads.get(&execution_id) {
147                    let progress = ProgressUpdate {
148                        execution_id: execution_id.clone(),
149                        workflow_id: thread.workflow_id.clone(),
150                        step_index,
151                        total_steps: thread.total_steps,
152                        step_name,
153                        status: ProgressStatus::StepRunning,
154                        output: None,
155                        duration_ms: None,
156                    };
157                    self.platform
158                        .send_progress(&thread.channel_id, &thread.thread_id, &progress)
159                        .await?;
160                }
161            }
162
163            WorkflowEvent::StepCompleted {
164                execution_id,
165                step_index,
166                step_name,
167                success,
168                output,
169                duration_ms,
170            } => {
171                let threads = self.threads.read().await;
172                if let Some(thread) = threads.get(&execution_id) {
173                    let status = if success {
174                        ProgressStatus::StepDone
175                    } else {
176                        ProgressStatus::StepFailed
177                    };
178                    let trimmed = if output.len() > 200 {
179                        format!("{}...", &output[..output.char_indices().take_while(|(i, _)| *i < 200).last().map(|(i, c)| i + c.len_utf8()).unwrap_or(0)])
180                    } else {
181                        output
182                    };
183                    let progress = ProgressUpdate {
184                        execution_id: execution_id.clone(),
185                        workflow_id: thread.workflow_id.clone(),
186                        step_index,
187                        total_steps: thread.total_steps,
188                        step_name,
189                        status,
190                        output: Some(trimmed),
191                        duration_ms: Some(duration_ms),
192                    };
193                    self.platform
194                        .send_progress(&thread.channel_id, &thread.thread_id, &progress)
195                        .await?;
196                }
197            }
198
199            WorkflowEvent::BreakpointHit {
200                execution_id,
201                step_name,
202                message,
203                ..
204            } => {
205                let threads = self.threads.read().await;
206                if let Some(thread) = threads.get(&execution_id) {
207                    let approval = crate::platform::ApprovalRequest {
208                        execution_id,
209                        step_name,
210                        description: message,
211                        action: "breakpoint".into(),
212                        allowed_approvers: Vec::new(),
213                    };
214                    self.platform
215                        .send_approval(&thread.channel_id, &approval)
216                        .await?;
217                }
218            }
219
220            WorkflowEvent::Completed {
221                execution_id,
222                success,
223                steps_completed,
224                total_steps,
225                duration_ms,
226                error,
227            } => {
228                // Clone data while holding read lock, then drop it before network calls
229                let thread_info = {
230                    let threads = self.threads.read().await;
231                    threads.get(&execution_id).cloned()
232                };
233
234                if let Some(thread) = thread_info {
235                    let notification = WorkflowNotification {
236                        execution_id: execution_id.clone(),
237                        workflow_id: thread.workflow_id.clone(),
238                        success,
239                        steps_completed,
240                        total_steps,
241                        duration_ms,
242                        error,
243                    };
244                    self.platform
245                        .send_notification(
246                            &thread.channel_id,
247                            Some(&thread.thread_id),
248                            &notification,
249                        )
250                        .await?;
251
252                    self.platform
253                        .send_notification(&thread.channel_id, None, &notification)
254                        .await?;
255
256                    // Clean up thread tracking
257                    let mut threads = self.threads.write().await;
258                    threads.remove(&execution_id);
259                }
260            }
261        }
262        Ok(())
263    }
264
265    /// Get the thread ID for an active execution.
266    pub async fn thread_for_execution(&self, execution_id: &str) -> Option<String> {
267        self.threads
268            .read()
269            .await
270            .get(execution_id)
271            .map(|t| t.thread_id.clone())
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::platform::{ApprovalRequest, OutgoingMessage};
279    use std::sync::Mutex;
280
281    /// Mock chat platform for testing.
282    struct MockPlatform {
283        messages: Arc<Mutex<Vec<String>>>,
284    }
285
286    impl MockPlatform {
287        fn new() -> Self {
288            Self {
289                messages: Arc::new(Mutex::new(Vec::new())),
290            }
291        }
292    }
293
294    impl ChatPlatform for MockPlatform {
295        async fn send_message(&self, msg: &OutgoingMessage) -> anyhow::Result<String> {
296            self.messages.lock().unwrap().push(msg.text.clone());
297            Ok("mock-ts".into())
298        }
299
300        async fn send_approval(
301            &self,
302            _channel_id: &str,
303            _request: &ApprovalRequest,
304        ) -> anyhow::Result<String> {
305            Ok("mock-ts".into())
306        }
307
308        async fn update_message(
309            &self,
310            _channel_id: &str,
311            _message_id: &str,
312            _text: &str,
313        ) -> anyhow::Result<()> {
314            Ok(())
315        }
316
317        async fn add_reaction(
318            &self,
319            _channel_id: &str,
320            _message_id: &str,
321            _emoji: &str,
322        ) -> anyhow::Result<()> {
323            Ok(())
324        }
325
326        async fn send_progress(
327            &self,
328            _channel_id: &str,
329            _thread_id: &str,
330            progress: &ProgressUpdate,
331        ) -> anyhow::Result<String> {
332            self.messages
333                .lock()
334                .unwrap()
335                .push(format!("progress:{}", progress.step_name));
336            Ok("mock-ts".into())
337        }
338
339        async fn send_notification(
340            &self,
341            _channel_id: &str,
342            _thread_id: Option<&str>,
343            notification: &WorkflowNotification,
344        ) -> anyhow::Result<String> {
345            self.messages.lock().unwrap().push(format!(
346                "notification:{}:{}",
347                notification.workflow_id, notification.success
348            ));
349            Ok("mock-ts".into())
350        }
351
352        async fn start_thread(
353            &self,
354            _channel_id: &str,
355            _execution_id: &str,
356            workflow_id: &str,
357            _total_steps: usize,
358            _shadow: bool,
359        ) -> anyhow::Result<String> {
360            self.messages
361                .lock()
362                .unwrap()
363                .push(format!("thread:{}", workflow_id));
364            Ok("mock-thread-ts".into())
365        }
366    }
367
368    #[tokio::test]
369    async fn test_adapter_full_lifecycle() {
370        let platform = Arc::new(MockPlatform::new());
371        let _adapter = ChatAdapter::new(platform.clone(), "#test".into());
372        let (tx, rx) = broadcast::channel(16);
373
374        // Run adapter in background
375        let adapter_handle = tokio::spawn({
376            let adapter = ChatAdapter::new(platform.clone(), "#test".into());
377            async move { adapter.run(rx).await }
378        });
379
380        // Simulate workflow lifecycle
381        tx.send(WorkflowEvent::Started {
382            execution_id: "e1".into(),
383            workflow_id: "deploy".into(),
384            total_steps: 2,
385        })
386        .unwrap();
387
388        tx.send(WorkflowEvent::StepStarted {
389            execution_id: "e1".into(),
390            step_index: 0,
391            step_name: "build".into(),
392        })
393        .unwrap();
394
395        tx.send(WorkflowEvent::StepCompleted {
396            execution_id: "e1".into(),
397            step_index: 0,
398            step_name: "build".into(),
399            success: true,
400            output: "OK".into(),
401            duration_ms: 100,
402        })
403        .unwrap();
404
405        tx.send(WorkflowEvent::Completed {
406            execution_id: "e1".into(),
407            success: true,
408            steps_completed: 2,
409            total_steps: 2,
410            duration_ms: 500,
411            error: None,
412        })
413        .unwrap();
414
415        // Give adapter time to process
416        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
417
418        // Drop sender to close channel
419        drop(tx);
420        let _ = tokio::time::timeout(
421            std::time::Duration::from_millis(200),
422            adapter_handle,
423        )
424        .await;
425
426        let messages = platform.messages.lock().unwrap();
427        assert!(
428            messages.iter().any(|m| m.contains("thread:deploy")),
429            "Should have started a thread"
430        );
431        assert!(
432            messages.iter().any(|m| m.contains("progress:build")),
433            "Should have sent progress"
434        );
435        assert!(
436            messages.iter().any(|m| m.contains("notification:deploy:true")),
437            "Should have sent completion notification"
438        );
439    }
440
441    #[tokio::test]
442    async fn test_adapter_thread_tracking() {
443        let platform = Arc::new(MockPlatform::new());
444        let adapter = ChatAdapter::new(platform.clone(), "#test".into());
445
446        // No thread initially
447        assert!(adapter.thread_for_execution("e1").await.is_none());
448    }
449}