1use std::pin::Pin;
4
5use futures_core::Stream;
6use tokio_stream::StreamExt as _;
7
8use crate::{
9 adapter::ChatAdapter,
10 error::ChatError,
11 message::{AdapterPostableMessage, SentMessage},
12};
13
14pub type TextStream = Pin<Box<dyn Stream<Item = String> + Send>>;
16
17#[derive(Debug, Clone)]
19pub struct StreamOptions {
20 pub update_interval_ms: u64,
23 pub placeholder_text: Option<String>,
27}
28
29impl Default for StreamOptions {
30 fn default() -> Self {
31 Self { update_interval_ms: 500, placeholder_text: Some("...".into()) }
32 }
33}
34
35pub async fn fallback_stream<A: ChatAdapter + ?Sized>(
46 adapter: &A,
47 thread_id: &str,
48 text_stream: TextStream,
49 options: &StreamOptions,
50) -> Result<SentMessage, ChatError> {
51 let placeholder = options.placeholder_text.clone().unwrap_or_else(|| String::from("\u{200B}")); let initial =
54 adapter.post_message(thread_id, &AdapterPostableMessage::Text(placeholder)).await?;
55
56 let message_id = initial.id.clone();
57 let mut accumulated = String::new();
58 let interval = tokio::time::Duration::from_millis(options.update_interval_ms);
59 let mut last_edit = tokio::time::Instant::now();
60
61 let mut stream = text_stream;
62
63 while let Some(chunk) = stream.next().await {
64 accumulated.push_str(&chunk);
65
66 if last_edit.elapsed() >= interval {
67 let _interim = adapter
68 .edit_message(
69 thread_id,
70 &message_id,
71 &AdapterPostableMessage::Text(accumulated.clone()),
72 )
73 .await?;
74 last_edit = tokio::time::Instant::now();
75 }
76 }
77
78 let final_sent = adapter
80 .edit_message(thread_id, &message_id, &AdapterPostableMessage::Text(accumulated))
81 .await?;
82
83 Ok(final_sent)
84}
85
86#[cfg(test)]
87mod tests {
88 use std::sync::{Arc, Mutex};
89
90 use super::*;
91 use crate::{card::CardElement, event::ChatEvent};
92
93 #[derive(Debug, Clone)]
98 #[expect(dead_code, reason = "fields read in test assertions via pattern matching")]
99 enum Call {
100 Post(String),
101 Edit { message_id: String, text: String },
102 }
103
104 struct MockStreamAdapter {
105 calls: Arc<Mutex<Vec<Call>>>,
106 next_id: Arc<Mutex<u64>>,
107 }
108
109 impl MockStreamAdapter {
110 fn new() -> Self {
111 Self { calls: Arc::new(Mutex::new(Vec::new())), next_id: Arc::new(Mutex::new(0)) }
112 }
113
114 fn take_calls(&self) -> Vec<Call> {
115 let Ok(mut guard) = self.calls.lock() else {
116 return Vec::new();
117 };
118 std::mem::take(&mut *guard)
119 }
120 }
121
122 #[async_trait::async_trait]
123 impl ChatAdapter for MockStreamAdapter {
124 fn name(&self) -> &'static str {
125 "mock-stream"
126 }
127
128 async fn post_message(
129 &self,
130 _thread_id: &str,
131 message: &AdapterPostableMessage,
132 ) -> Result<SentMessage, ChatError> {
133 let text = match message {
134 AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
135 };
136 let id = {
137 let Ok(mut id) = self.next_id.lock() else {
138 return Err(ChatError::Adapter("lock poisoned".into()));
139 };
140 *id += 1;
141 format!("msg-{id}")
142 };
143 {
144 let Ok(mut calls) = self.calls.lock() else {
145 return Err(ChatError::Adapter("lock poisoned".into()));
146 };
147 calls.push(Call::Post(text));
148 }
149 Ok(SentMessage {
150 id,
151 thread_id: "t1".into(),
152 adapter_name: "mock-stream".into(),
153 raw: None,
154 })
155 }
156
157 async fn edit_message(
158 &self,
159 _thread_id: &str,
160 message_id: &str,
161 message: &AdapterPostableMessage,
162 ) -> Result<SentMessage, ChatError> {
163 let text = match message {
164 AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
165 };
166 {
167 let Ok(mut calls) = self.calls.lock() else {
168 return Err(ChatError::Adapter("lock poisoned".into()));
169 };
170 calls.push(Call::Edit { message_id: message_id.to_owned(), text });
171 }
172 Ok(SentMessage {
173 id: message_id.to_owned(),
174 thread_id: "t1".into(),
175 adapter_name: "mock-stream".into(),
176 raw: None,
177 })
178 }
179
180 async fn delete_message(
181 &self,
182 _thread_id: &str,
183 _message_id: &str,
184 ) -> Result<(), ChatError> {
185 Ok(())
186 }
187
188 fn render_card(&self, _card: &CardElement) -> String {
189 String::new()
190 }
191
192 fn render_message(&self, _msg: &AdapterPostableMessage) -> String {
193 String::new()
194 }
195
196 async fn recv_event(&mut self) -> Option<ChatEvent> {
197 None
198 }
199 }
200
201 #[test]
206 fn default_stream_options() {
207 let opts = StreamOptions::default();
208 assert_eq!(opts.update_interval_ms, 500);
209 assert_eq!(opts.placeholder_text.as_deref(), Some("..."));
210 }
211
212 #[tokio::test]
213 async fn fallback_stream_posts_then_edits() {
214 tokio::time::pause();
215
216 let adapter = MockStreamAdapter::new();
217 let chunks = vec!["Hello".to_owned(), " ".into(), "world".into()];
218 let stream: TextStream = Box::pin(tokio_stream::iter(chunks));
219
220 let options =
221 StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
222
223 let result = fallback_stream(&adapter, "t1", stream, &options).await;
224 assert!(result.is_ok());
225
226 let calls = adapter.take_calls();
227
228 assert!(matches!(&calls[0], Call::Post(t) if t == "..."));
230
231 let last = calls.last();
233 assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "Hello world"));
234 }
235
236 #[tokio::test]
237 async fn fallback_stream_intermediate_edits_with_time_advance() {
238 tokio::time::pause();
239
240 let adapter = MockStreamAdapter::new();
241
242 let stream: TextStream = Box::pin(async_stream::stream! {
245 yield "A".to_owned();
246 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
247 yield "B".to_owned();
248 tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
249 yield "C".to_owned();
250 });
251
252 let options =
253 StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
254
255 let result = fallback_stream(&adapter, "t1", stream, &options).await;
256 assert!(result.is_ok());
257
258 let calls = adapter.take_calls();
259
260 let post_count = calls.iter().filter(|c| matches!(c, Call::Post(_))).count();
264 let edit_count = calls.iter().filter(|c| matches!(c, Call::Edit { .. })).count();
265
266 assert_eq!(post_count, 1, "exactly one post_message call");
267 assert!(
268 edit_count >= 2,
269 "at least two edit calls (intermediate + final), got {edit_count}"
270 );
271
272 let last = calls.last();
274 assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "ABC"));
275 }
276
277 #[tokio::test]
278 async fn fallback_stream_empty_stream_still_edits() {
279 tokio::time::pause();
280
281 let adapter = MockStreamAdapter::new();
282 let stream: TextStream = Box::pin(tokio_stream::iter(Vec::<String>::new()));
283
284 let options = StreamOptions::default();
285
286 let result = fallback_stream(&adapter, "t1", stream, &options).await;
287 assert!(result.is_ok());
288
289 let calls = adapter.take_calls();
290
291 assert_eq!(calls.len(), 2);
293 assert!(matches!(&calls[0], Call::Post(_)));
294 assert!(matches!(&calls[1], Call::Edit { text, .. } if text.is_empty()));
295 }
296}