Skip to main content

bob_chat/
thread.rs

1//! Thread handle for interacting with a conversation thread.
2//!
3//! [`ThreadHandle`] provides a scoped reference to a specific thread within a
4//! chat adapter.  Handler functions receive a `ThreadHandle` so they can post
5//! replies, manage subscriptions, and stream responses back to the originating
6//! thread.
7
8use std::sync::Arc;
9
10use crate::{
11    adapter::ChatAdapter,
12    error::ChatError,
13    message::{AdapterPostableMessage, EphemeralMessage, PostableMessage, SentMessage},
14};
15
16/// A scoped handle to a conversation thread.
17///
18/// Holds a reference to the originating adapter, the thread identifier, and the
19/// shared subscription map so handlers can interact with the thread without
20/// needing access to the full [`ChatBot`](crate::bot::ChatBot).
21pub struct ThreadHandle {
22    /// Platform-specific thread/channel identifier.
23    pub(crate) thread_id: String,
24    /// The adapter that owns this thread.
25    pub(crate) adapter: Arc<dyn ChatAdapter>,
26    /// Shared subscription map (thread_id → ()).
27    pub(crate) subscriptions: Arc<scc::HashMap<String, ()>>,
28}
29
30impl ThreadHandle {
31    /// Return the thread identifier.
32    #[must_use]
33    pub fn thread_id(&self) -> &str {
34        &self.thread_id
35    }
36
37    /// Return the name of the adapter backing this handle.
38    #[must_use]
39    pub fn adapter_name(&self) -> &str {
40        self.adapter.name()
41    }
42
43    /// Post a message to this thread.
44    ///
45    /// For `PostableMessage::Text` and `PostableMessage::Markdown`, the message
46    /// is forwarded to the adapter via [`ChatAdapter::post_message`].
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if the adapter fails to post.
51    pub async fn post(
52        &self,
53        message: impl Into<PostableMessage>,
54    ) -> Result<SentMessage, ChatError> {
55        let postable = message.into();
56        let adapter_msg = match postable {
57            PostableMessage::Text(t) => AdapterPostableMessage::Text(t),
58            PostableMessage::Markdown(m) => AdapterPostableMessage::Markdown(m),
59        };
60        self.adapter.post_message(&self.thread_id, &adapter_msg).await
61    }
62
63    /// Post an ephemeral message visible only to the specified user.
64    ///
65    /// If the adapter does not support ephemeral messages and
66    /// `fallback_to_dm` is `true`, falls back to opening a direct-message
67    /// channel with the user and posting there.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if both the ephemeral post and DM fallback fail.
72    pub async fn post_ephemeral(
73        &self,
74        user_id: &str,
75        message: impl Into<AdapterPostableMessage>,
76        fallback_to_dm: bool,
77    ) -> Result<Option<EphemeralMessage>, ChatError> {
78        let msg = message.into();
79        match self.adapter.post_ephemeral(&self.thread_id, user_id, &msg).await {
80            Ok(eph) => Ok(Some(eph)),
81            Err(ChatError::NotSupported(_)) if fallback_to_dm => {
82                let dm_thread = self.adapter.open_dm(user_id).await?;
83                self.adapter.post_message(&dm_thread, &msg).await?;
84                Ok(None)
85            }
86            Err(e) => Err(e),
87        }
88    }
89
90    /// Show a typing / status indicator in this thread.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the adapter fails.
95    pub async fn start_typing(&self, status: Option<&str>) -> Result<(), ChatError> {
96        self.adapter.start_typing(&self.thread_id, status).await
97    }
98
99    /// Subscribe to follow-up messages in this thread.
100    ///
101    /// Subsequent `ChatEvent::Message` events for this thread will be
102    /// routed to `on_subscribed_message` handlers instead of `on_message`.
103    pub async fn subscribe(&self) {
104        let _ = self.subscriptions.insert_async(self.thread_id.clone(), ()).await;
105    }
106
107    /// Unsubscribe from this thread.
108    pub async fn unsubscribe(&self) {
109        let _ = self.subscriptions.remove_async(&self.thread_id).await;
110    }
111
112    /// Format a platform-agnostic mention string for a user.
113    ///
114    /// This is a simple helper that returns `<@user_id>`.  Adapters that
115    /// need a different format should handle it at the rendering layer.
116    #[must_use]
117    pub fn mention_user(&self, user_id: &str) -> String {
118        format!("<@{user_id}>")
119    }
120}
121
122impl std::fmt::Debug for ThreadHandle {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("ThreadHandle")
125            .field("thread_id", &self.thread_id)
126            .field("adapter", &self.adapter.name())
127            .finish_non_exhaustive()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::atomic::{AtomicUsize, Ordering};
134
135    use super::*;
136    use crate::{card::CardElement, event::ChatEvent};
137
138    // -----------------------------------------------------------------
139    // Mock adapter for thread handle tests
140    // -----------------------------------------------------------------
141
142    struct MockThreadAdapter {
143        post_count: Arc<AtomicUsize>,
144        edit_count: Arc<AtomicUsize>,
145        ephemeral_supported: bool,
146        dm_opened: Arc<AtomicUsize>,
147    }
148
149    impl MockThreadAdapter {
150        fn new(ephemeral_supported: bool) -> Self {
151            Self {
152                post_count: Arc::new(AtomicUsize::new(0)),
153                edit_count: Arc::new(AtomicUsize::new(0)),
154                ephemeral_supported,
155                dm_opened: Arc::new(AtomicUsize::new(0)),
156            }
157        }
158    }
159
160    #[async_trait::async_trait]
161    impl ChatAdapter for MockThreadAdapter {
162        fn name(&self) -> &str {
163            "mock-thread"
164        }
165
166        async fn post_message(
167            &self,
168            _thread_id: &str,
169            _message: &AdapterPostableMessage,
170        ) -> Result<SentMessage, ChatError> {
171            self.post_count.fetch_add(1, Ordering::SeqCst);
172            Ok(SentMessage {
173                id: "m1".into(),
174                thread_id: "t1".into(),
175                adapter_name: "mock-thread".into(),
176                raw: None,
177            })
178        }
179
180        async fn edit_message(
181            &self,
182            _thread_id: &str,
183            _message_id: &str,
184            _message: &AdapterPostableMessage,
185        ) -> Result<SentMessage, ChatError> {
186            self.edit_count.fetch_add(1, Ordering::SeqCst);
187            Ok(SentMessage {
188                id: "m1".into(),
189                thread_id: "t1".into(),
190                adapter_name: "mock-thread".into(),
191                raw: None,
192            })
193        }
194
195        async fn delete_message(
196            &self,
197            _thread_id: &str,
198            _message_id: &str,
199        ) -> Result<(), ChatError> {
200            Ok(())
201        }
202
203        fn render_card(&self, _card: &CardElement) -> String {
204            String::new()
205        }
206
207        fn render_message(&self, _message: &AdapterPostableMessage) -> String {
208            String::new()
209        }
210
211        async fn recv_event(&mut self) -> Option<ChatEvent> {
212            None
213        }
214
215        async fn post_ephemeral(
216            &self,
217            _thread_id: &str,
218            _user_id: &str,
219            _message: &AdapterPostableMessage,
220        ) -> Result<EphemeralMessage, ChatError> {
221            if self.ephemeral_supported {
222                Ok(EphemeralMessage {
223                    id: "e1".into(),
224                    thread_id: "t1".into(),
225                    used_fallback: false,
226                })
227            } else {
228                Err(ChatError::NotSupported("ephemeral messages".into()))
229            }
230        }
231
232        async fn open_dm(&self, _user_id: &str) -> Result<String, ChatError> {
233            self.dm_opened.fetch_add(1, Ordering::SeqCst);
234            Ok("dm-thread".into())
235        }
236    }
237
238    fn make_handle(adapter: MockThreadAdapter) -> ThreadHandle {
239        ThreadHandle {
240            thread_id: "t1".into(),
241            adapter: Arc::new(adapter),
242            subscriptions: Arc::new(scc::HashMap::new()),
243        }
244    }
245
246    // -----------------------------------------------------------------
247    // Tests
248    // -----------------------------------------------------------------
249
250    #[test]
251    fn thread_id_accessor() {
252        let handle = make_handle(MockThreadAdapter::new(true));
253        assert_eq!(handle.thread_id(), "t1");
254    }
255
256    #[test]
257    fn adapter_name_accessor() {
258        let handle = make_handle(MockThreadAdapter::new(true));
259        assert_eq!(handle.adapter_name(), "mock-thread");
260    }
261
262    #[tokio::test]
263    async fn post_text_message() {
264        let adapter = MockThreadAdapter::new(true);
265        let post_count = Arc::clone(&adapter.post_count);
266        let handle = make_handle(adapter);
267
268        let result = handle.post(PostableMessage::Text("hello".into())).await;
269        assert!(result.is_ok());
270        assert_eq!(post_count.load(Ordering::SeqCst), 1);
271    }
272
273    #[tokio::test]
274    async fn post_markdown_message() {
275        let adapter = MockThreadAdapter::new(true);
276        let post_count = Arc::clone(&adapter.post_count);
277        let handle = make_handle(adapter);
278
279        let result = handle.post(PostableMessage::Markdown("**bold**".into())).await;
280        assert!(result.is_ok());
281        assert_eq!(post_count.load(Ordering::SeqCst), 1);
282    }
283
284    #[tokio::test]
285    async fn post_ephemeral_supported() {
286        let handle = make_handle(MockThreadAdapter::new(true));
287        let result =
288            handle.post_ephemeral("u1", AdapterPostableMessage::Text("secret".into()), false).await;
289        assert!(result.is_ok());
290        let eph = result.expect("ephemeral msg");
291        assert!(eph.is_some());
292    }
293
294    #[tokio::test]
295    async fn post_ephemeral_fallback_to_dm() {
296        let adapter = MockThreadAdapter::new(false);
297        let dm_count = Arc::clone(&adapter.dm_opened);
298        let post_count = Arc::clone(&adapter.post_count);
299        let handle = make_handle(adapter);
300
301        let result =
302            handle.post_ephemeral("u1", AdapterPostableMessage::Text("secret".into()), true).await;
303        assert!(result.is_ok());
304        // Should have opened a DM and posted there
305        assert_eq!(dm_count.load(Ordering::SeqCst), 1);
306        assert_eq!(post_count.load(Ordering::SeqCst), 1);
307        // Returns None when fallback was used
308        assert!(result.expect("should be Ok").is_none());
309    }
310
311    #[tokio::test]
312    async fn post_ephemeral_no_fallback_returns_error() {
313        let handle = make_handle(MockThreadAdapter::new(false));
314        let result =
315            handle.post_ephemeral("u1", AdapterPostableMessage::Text("secret".into()), false).await;
316        assert!(result.is_err());
317    }
318
319    #[tokio::test]
320    async fn subscribe_and_unsubscribe() {
321        let handle = make_handle(MockThreadAdapter::new(true));
322
323        assert!(!handle.subscriptions.contains_sync("t1"));
324        handle.subscribe().await;
325        assert!(handle.subscriptions.contains_sync("t1"));
326        handle.unsubscribe().await;
327        assert!(!handle.subscriptions.contains_sync("t1"));
328    }
329
330    #[test]
331    fn mention_user_formatting() {
332        let handle = make_handle(MockThreadAdapter::new(true));
333        assert_eq!(handle.mention_user("U123"), "<@U123>");
334    }
335
336    #[test]
337    fn debug_impl() {
338        let handle = make_handle(MockThreadAdapter::new(true));
339        let dbg = format!("{handle:?}");
340        assert!(dbg.contains("t1"));
341        assert!(dbg.contains("mock-thread"));
342    }
343}