Skip to main content

bob_chat/
stream.rs

1//! Streaming text types for progressive message delivery.
2
3use std::pin::Pin;
4
5use futures_core::Stream;
6use tokio_stream::StreamExt as _;
7
8use crate::{
9    adapter::ChatAdapter,
10    error::ChatError,
11    message::{AdapterPostableMessage, SentMessage},
12};
13
14/// A stream of text chunks from an async source (e.g. an LLM response).
15pub type TextStream = Pin<Box<dyn Stream<Item = String> + Send>>;
16
17/// Options that control how streamed text is posted and updated.
18#[derive(Debug, Clone)]
19pub struct StreamOptions {
20    /// Minimum interval (in milliseconds) between edit calls while
21    /// streaming accumulated text to a thread.
22    pub update_interval_ms: u64,
23    /// Placeholder text shown in the initial message before any chunks
24    /// arrive.  `None` means no placeholder is sent until the first
25    /// chunk is received.
26    pub placeholder_text: Option<String>,
27}
28
29impl Default for StreamOptions {
30    fn default() -> Self {
31        Self { update_interval_ms: 500, placeholder_text: Some("...".into()) }
32    }
33}
34
35/// Stream text progressively using a post-then-edit loop.
36///
37/// This is the default fallback streaming strategy for adapters that do
38/// not provide a native streaming mechanism:
39///
40/// 1. Post an initial placeholder message.
41/// 2. Consume chunks from `text_stream`, accumulating them.
42/// 3. Every `options.update_interval_ms` milliseconds, edit the message with the accumulated text
43///    so far.
44/// 4. After the stream ends, perform a final edit with the complete text.
45pub async fn fallback_stream<A: ChatAdapter + ?Sized>(
46    adapter: &A,
47    thread_id: &str,
48    text_stream: TextStream,
49    options: &StreamOptions,
50) -> Result<SentMessage, ChatError> {
51    let placeholder = options.placeholder_text.clone().unwrap_or_else(|| String::from("\u{200B}")); // zero-width space
52
53    let initial =
54        adapter.post_message(thread_id, &AdapterPostableMessage::Text(placeholder)).await?;
55
56    let message_id = initial.id.clone();
57    let mut accumulated = String::new();
58    let interval = tokio::time::Duration::from_millis(options.update_interval_ms);
59    let mut last_edit = tokio::time::Instant::now();
60
61    let mut stream = text_stream;
62
63    while let Some(chunk) = stream.next().await {
64        accumulated.push_str(&chunk);
65
66        if last_edit.elapsed() >= interval {
67            let _interim = adapter
68                .edit_message(
69                    thread_id,
70                    &message_id,
71                    &AdapterPostableMessage::Text(accumulated.clone()),
72                )
73                .await?;
74            last_edit = tokio::time::Instant::now();
75        }
76    }
77
78    // Always perform a final edit so the message contains the complete text.
79    let final_sent = adapter
80        .edit_message(thread_id, &message_id, &AdapterPostableMessage::Text(accumulated))
81        .await?;
82
83    Ok(final_sent)
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::{Arc, Mutex};
89
90    use super::*;
91    use crate::{card::CardElement, event::ChatEvent};
92
93    // -----------------------------------------------------------------
94    // Mock adapter that records post/edit calls
95    // -----------------------------------------------------------------
96
97    #[derive(Debug, Clone)]
98    #[expect(dead_code, reason = "fields read in test assertions via pattern matching")]
99    enum Call {
100        Post(String),
101        Edit { message_id: String, text: String },
102    }
103
104    struct MockStreamAdapter {
105        calls: Arc<Mutex<Vec<Call>>>,
106        next_id: Arc<Mutex<u64>>,
107    }
108
109    impl MockStreamAdapter {
110        fn new() -> Self {
111            Self { calls: Arc::new(Mutex::new(Vec::new())), next_id: Arc::new(Mutex::new(0)) }
112        }
113
114        fn take_calls(&self) -> Vec<Call> {
115            let Ok(mut guard) = self.calls.lock() else {
116                return Vec::new();
117            };
118            std::mem::take(&mut *guard)
119        }
120    }
121
122    #[async_trait::async_trait]
123    impl ChatAdapter for MockStreamAdapter {
124        fn name(&self) -> &'static str {
125            "mock-stream"
126        }
127
128        async fn post_message(
129            &self,
130            _thread_id: &str,
131            message: &AdapterPostableMessage,
132        ) -> Result<SentMessage, ChatError> {
133            let text = match message {
134                AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
135            };
136            let id = {
137                let Ok(mut id) = self.next_id.lock() else {
138                    return Err(ChatError::Adapter("lock poisoned".into()));
139                };
140                *id += 1;
141                format!("msg-{id}")
142            };
143            {
144                let Ok(mut calls) = self.calls.lock() else {
145                    return Err(ChatError::Adapter("lock poisoned".into()));
146                };
147                calls.push(Call::Post(text));
148            }
149            Ok(SentMessage {
150                id,
151                thread_id: "t1".into(),
152                adapter_name: "mock-stream".into(),
153                raw: None,
154            })
155        }
156
157        async fn edit_message(
158            &self,
159            _thread_id: &str,
160            message_id: &str,
161            message: &AdapterPostableMessage,
162        ) -> Result<SentMessage, ChatError> {
163            let text = match message {
164                AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
165            };
166            {
167                let Ok(mut calls) = self.calls.lock() else {
168                    return Err(ChatError::Adapter("lock poisoned".into()));
169                };
170                calls.push(Call::Edit { message_id: message_id.to_owned(), text });
171            }
172            Ok(SentMessage {
173                id: message_id.to_owned(),
174                thread_id: "t1".into(),
175                adapter_name: "mock-stream".into(),
176                raw: None,
177            })
178        }
179
180        async fn delete_message(
181            &self,
182            _thread_id: &str,
183            _message_id: &str,
184        ) -> Result<(), ChatError> {
185            Ok(())
186        }
187
188        fn render_card(&self, _card: &CardElement) -> String {
189            String::new()
190        }
191
192        fn render_message(&self, _msg: &AdapterPostableMessage) -> String {
193            String::new()
194        }
195
196        async fn recv_event(&mut self) -> Option<ChatEvent> {
197            None
198        }
199    }
200
201    // -----------------------------------------------------------------
202    // Tests
203    // -----------------------------------------------------------------
204
205    #[test]
206    fn default_stream_options() {
207        let opts = StreamOptions::default();
208        assert_eq!(opts.update_interval_ms, 500);
209        assert_eq!(opts.placeholder_text.as_deref(), Some("..."));
210    }
211
212    #[tokio::test]
213    async fn fallback_stream_posts_then_edits() {
214        tokio::time::pause();
215
216        let adapter = MockStreamAdapter::new();
217        let chunks = vec!["Hello".to_owned(), " ".into(), "world".into()];
218        let stream: TextStream = Box::pin(tokio_stream::iter(chunks));
219
220        let options =
221            StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
222
223        let result = fallback_stream(&adapter, "t1", stream, &options).await;
224        assert!(result.is_ok());
225
226        let calls = adapter.take_calls();
227
228        // First call must be a Post (the placeholder).
229        assert!(matches!(&calls[0], Call::Post(t) if t == "..."));
230
231        // Last call must be an Edit containing the full text.
232        let last = calls.last();
233        assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "Hello world"));
234    }
235
236    #[tokio::test]
237    async fn fallback_stream_intermediate_edits_with_time_advance() {
238        tokio::time::pause();
239
240        let adapter = MockStreamAdapter::new();
241
242        // Build a stream that yields chunks with delays between them so
243        // the interval-based logic triggers intermediate edits.
244        let stream: TextStream = Box::pin(async_stream::stream! {
245            yield "A".to_owned();
246            tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
247            yield "B".to_owned();
248            tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
249            yield "C".to_owned();
250        });
251
252        let options =
253            StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
254
255        let result = fallback_stream(&adapter, "t1", stream, &options).await;
256        assert!(result.is_ok());
257
258        let calls = adapter.take_calls();
259
260        // Expect: Post("..."), Edit("AB") [after 300ms ≥ 200ms interval],
261        //         Edit("ABC") [after another 300ms], final Edit("ABC").
262        // At minimum we should see 1 post + ≥2 edits.
263        let post_count = calls.iter().filter(|c| matches!(c, Call::Post(_))).count();
264        let edit_count = calls.iter().filter(|c| matches!(c, Call::Edit { .. })).count();
265
266        assert_eq!(post_count, 1, "exactly one post_message call");
267        assert!(
268            edit_count >= 2,
269            "at least two edit calls (intermediate + final), got {edit_count}"
270        );
271
272        // Final edit must contain the full accumulated text.
273        let last = calls.last();
274        assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "ABC"));
275    }
276
277    #[tokio::test]
278    async fn fallback_stream_empty_stream_still_edits() {
279        tokio::time::pause();
280
281        let adapter = MockStreamAdapter::new();
282        let stream: TextStream = Box::pin(tokio_stream::iter(Vec::<String>::new()));
283
284        let options = StreamOptions::default();
285
286        let result = fallback_stream(&adapter, "t1", stream, &options).await;
287        assert!(result.is_ok());
288
289        let calls = adapter.take_calls();
290
291        // 1 post + 1 final edit (with empty text).
292        assert_eq!(calls.len(), 2);
293        assert!(matches!(&calls[0], Call::Post(_)));
294        assert!(matches!(&calls[1], Call::Edit { text, .. } if text.is_empty()));
295    }
296}