use std::pin::Pin;
use futures_core::Stream;
use tokio_stream::StreamExt as _;
use crate::{
adapter::ChatAdapter,
error::ChatError,
message::{AdapterPostableMessage, SentMessage},
};
pub type TextStream = Pin<Box<dyn Stream<Item = String> + Send>>;
#[derive(Debug, Clone)]
pub struct StreamOptions {
pub update_interval_ms: u64,
pub placeholder_text: Option<String>,
}
impl Default for StreamOptions {
fn default() -> Self {
Self { update_interval_ms: 500, placeholder_text: Some("...".into()) }
}
}
pub async fn fallback_stream<A: ChatAdapter + ?Sized>(
adapter: &A,
thread_id: &str,
text_stream: TextStream,
options: &StreamOptions,
) -> Result<SentMessage, ChatError> {
let placeholder = options.placeholder_text.clone().unwrap_or_else(|| String::from("\u{200B}"));
let initial =
adapter.post_message(thread_id, &AdapterPostableMessage::Text(placeholder)).await?;
let message_id = initial.id.clone();
let mut accumulated = String::new();
let interval = tokio::time::Duration::from_millis(options.update_interval_ms);
let mut last_edit = tokio::time::Instant::now();
let mut stream = text_stream;
while let Some(chunk) = stream.next().await {
accumulated.push_str(&chunk);
if last_edit.elapsed() >= interval {
let _interim = adapter
.edit_message(
thread_id,
&message_id,
&AdapterPostableMessage::Text(accumulated.clone()),
)
.await?;
last_edit = tokio::time::Instant::now();
}
}
let final_sent = adapter
.edit_message(thread_id, &message_id, &AdapterPostableMessage::Text(accumulated))
.await?;
Ok(final_sent)
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use super::*;
use crate::{card::CardElement, event::ChatEvent};
#[derive(Debug, Clone)]
#[expect(dead_code, reason = "fields read in test assertions via pattern matching")]
enum Call {
Post(String),
Edit { message_id: String, text: String },
}
struct MockStreamAdapter {
calls: Arc<Mutex<Vec<Call>>>,
next_id: Arc<Mutex<u64>>,
}
impl MockStreamAdapter {
fn new() -> Self {
Self { calls: Arc::new(Mutex::new(Vec::new())), next_id: Arc::new(Mutex::new(0)) }
}
fn take_calls(&self) -> Vec<Call> {
let Ok(mut guard) = self.calls.lock() else {
return Vec::new();
};
std::mem::take(&mut *guard)
}
}
#[async_trait::async_trait]
impl ChatAdapter for MockStreamAdapter {
fn name(&self) -> &'static str {
"mock-stream"
}
async fn post_message(
&self,
_thread_id: &str,
message: &AdapterPostableMessage,
) -> Result<SentMessage, ChatError> {
let text = match message {
AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
};
let id = {
let Ok(mut id) = self.next_id.lock() else {
return Err(ChatError::Adapter("lock poisoned".into()));
};
*id += 1;
format!("msg-{id}")
};
{
let Ok(mut calls) = self.calls.lock() else {
return Err(ChatError::Adapter("lock poisoned".into()));
};
calls.push(Call::Post(text));
}
Ok(SentMessage {
id,
thread_id: "t1".into(),
adapter_name: "mock-stream".into(),
raw: None,
})
}
async fn edit_message(
&self,
_thread_id: &str,
message_id: &str,
message: &AdapterPostableMessage,
) -> Result<SentMessage, ChatError> {
let text = match message {
AdapterPostableMessage::Text(t) | AdapterPostableMessage::Markdown(t) => t.clone(),
};
{
let Ok(mut calls) = self.calls.lock() else {
return Err(ChatError::Adapter("lock poisoned".into()));
};
calls.push(Call::Edit { message_id: message_id.to_owned(), text });
}
Ok(SentMessage {
id: message_id.to_owned(),
thread_id: "t1".into(),
adapter_name: "mock-stream".into(),
raw: None,
})
}
async fn delete_message(
&self,
_thread_id: &str,
_message_id: &str,
) -> Result<(), ChatError> {
Ok(())
}
fn render_card(&self, _card: &CardElement) -> String {
String::new()
}
fn render_message(&self, _msg: &AdapterPostableMessage) -> String {
String::new()
}
async fn recv_event(&mut self) -> Option<ChatEvent> {
None
}
}
#[test]
fn default_stream_options() {
let opts = StreamOptions::default();
assert_eq!(opts.update_interval_ms, 500);
assert_eq!(opts.placeholder_text.as_deref(), Some("..."));
}
#[tokio::test]
async fn fallback_stream_posts_then_edits() {
tokio::time::pause();
let adapter = MockStreamAdapter::new();
let chunks = vec!["Hello".to_owned(), " ".into(), "world".into()];
let stream: TextStream = Box::pin(tokio_stream::iter(chunks));
let options =
StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
let result = fallback_stream(&adapter, "t1", stream, &options).await;
assert!(result.is_ok());
let calls = adapter.take_calls();
assert!(matches!(&calls[0], Call::Post(t) if t == "..."));
let last = calls.last();
assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "Hello world"));
}
#[tokio::test]
async fn fallback_stream_intermediate_edits_with_time_advance() {
tokio::time::pause();
let adapter = MockStreamAdapter::new();
let stream: TextStream = Box::pin(async_stream::stream! {
yield "A".to_owned();
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
yield "B".to_owned();
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
yield "C".to_owned();
});
let options =
StreamOptions { update_interval_ms: 200, placeholder_text: Some("...".into()) };
let result = fallback_stream(&adapter, "t1", stream, &options).await;
assert!(result.is_ok());
let calls = adapter.take_calls();
let post_count = calls.iter().filter(|c| matches!(c, Call::Post(_))).count();
let edit_count = calls.iter().filter(|c| matches!(c, Call::Edit { .. })).count();
assert_eq!(post_count, 1, "exactly one post_message call");
assert!(
edit_count >= 2,
"at least two edit calls (intermediate + final), got {edit_count}"
);
let last = calls.last();
assert!(matches!(last, Some(Call::Edit { text, .. }) if text == "ABC"));
}
#[tokio::test]
async fn fallback_stream_empty_stream_still_edits() {
tokio::time::pause();
let adapter = MockStreamAdapter::new();
let stream: TextStream = Box::pin(tokio_stream::iter(Vec::<String>::new()));
let options = StreamOptions::default();
let result = fallback_stream(&adapter, "t1", stream, &options).await;
assert!(result.is_ok());
let calls = adapter.take_calls();
assert_eq!(calls.len(), 2);
assert!(matches!(&calls[0], Call::Post(_)));
assert!(matches!(&calls[1], Call::Edit { text, .. } if text.is_empty()));
}
}