use crate::error::{AixError, AixResult};
use crate::types::StreamChunk;
use futures_core::Stream;
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
pub type TokenStream = Pin<Box<dyn Stream<Item = AixResult<StreamChunk>> + Send>>;
pub trait StreamExt: Stream {
fn collect_text(self) -> CollectText<Self>
where
Self: Sized,
{
CollectText::new(self)
}
fn filter_empty(self) -> FilterEmpty<Self>
where
Self: Sized,
{
FilterEmpty::new(self)
}
fn buffer_chunks(self, duration: Duration) -> BufferChunks<Self>
where
Self: Sized,
{
BufferChunks::new(self, duration)
}
}
impl<T: ?Sized> StreamExt for T where T: Stream {}
pin_project! {
pub struct CollectText<S> {
#[pin]
stream: S,
buffer: String,
}
}
impl<S> CollectText<S> {
fn new(stream: S) -> Self {
Self {
stream,
buffer: String::new(),
}
}
}
impl<S> std::future::Future for CollectText<S>
where
S: Stream<Item = AixResult<StreamChunk>>,
{
type Output = AixResult<String>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(chunk)) => {
this.buffer.push_str(&chunk.delta);
}
Some(Err(error)) => {
return Poll::Ready(Err(error));
}
None => {
return Poll::Ready(Ok(this.buffer.clone()));
}
}
}
}
}
pin_project! {
pub struct FilterEmpty<S> {
#[pin]
stream: S,
}
}
impl<S> FilterEmpty<S> {
fn new(stream: S) -> Self {
Self { stream }
}
}
impl<S> Stream for FilterEmpty<S>
where
S: Stream<Item = AixResult<StreamChunk>>,
{
type Item = AixResult<StreamChunk>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(chunk)) => {
if chunk.delta.is_empty() && chunk.finish_reason.is_none() {
continue;
}
return Poll::Ready(Some(Ok(chunk)));
}
other => return Poll::Ready(other),
}
}
}
}
pin_project! {
pub struct BufferChunks<S> {
#[pin]
stream: S,
buffer: Vec<StreamChunk>,
last_flush: Option<tokio::time::Instant>,
duration: Duration,
#[pin]
delay: Option<tokio::time::Sleep>,
}
}
impl<S> BufferChunks<S> {
fn new(stream: S, duration: Duration) -> Self {
Self {
stream,
buffer: Vec::new(),
last_flush: None,
duration,
delay: None,
}
}
}
impl<S> Stream for BufferChunks<S>
where
S: Stream<Item = AixResult<StreamChunk>>,
{
type Item = AixResult<StreamChunk>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let now = tokio::time::Instant::now();
if !this.buffer.is_empty() {
let should_flush = if let Some(last_flush) = this.last_flush {
now.duration_since(*last_flush) >= *this.duration
} else {
true };
if should_flush {
let combined_id = this.buffer
.first()
.map(|c| c.id.clone())
.unwrap_or_else(|| "buffered".to_string());
let combined_delta: String = this.buffer
.iter()
.map(|c| c.delta.as_str())
.collect();
let finish_reason = this.buffer
.iter()
.find_map(|c| c.finish_reason.clone());
let combined_chunk = StreamChunk {
id: combined_id,
delta: combined_delta,
finish_reason,
};
this.buffer.clear();
*this.last_flush = Some(now);
return Poll::Ready(Some(Ok(combined_chunk)));
}
}
match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(chunk)) => {
this.buffer.push(chunk);
if this.delay.is_none() {
this.delay.set(Some(tokio::time::sleep(*this.duration)));
}
if let Some(delay) = this.delay.as_mut().as_pin_mut() {
match delay.poll(cx) {
std::task::Poll::Ready(_) => {
this.delay.set(None);
}
std::task::Poll::Pending => {
}
}
}
Poll::Pending
}
Some(Err(error)) => {
Poll::Ready(Some(Err(error)))
}
None => {
if !this.buffer.is_empty() {
let combined_id = this.buffer
.first()
.map(|c| c.id.clone())
.unwrap_or_else(|| "buffered".to_string());
let combined_delta: String = this.buffer
.iter()
.map(|c| c.delta.as_str())
.collect();
let finish_reason = this.buffer
.iter()
.find_map(|c| c.finish_reason.clone());
let combined_chunk = StreamChunk {
id: combined_id,
delta: combined_delta,
finish_reason,
};
this.buffer.clear();
Poll::Ready(Some(Ok(combined_chunk)))
} else {
Poll::Ready(None)
}
}
}
}
}
pub fn from_iter<I>(iter: I) -> TokenStream
where
I: IntoIterator<Item = AixResult<StreamChunk>>,
I::IntoIter: Send + 'static,
{
let stream = futures_util::stream::iter(iter);
Box::pin(stream)
}
pub fn error_stream(error: AixError) -> TokenStream {
let stream = futures_util::stream::once(async move { Err(error) });
Box::pin(stream)
}
pub fn single_chunk(chunk: StreamChunk) -> TokenStream {
let stream = futures_util::stream::once(async move { Ok(chunk) });
Box::pin(stream)
}
pub fn chunks<I>(chunks: I) -> TokenStream
where
I: IntoIterator<Item = StreamChunk>,
I::IntoIter: Send + 'static,
{
let results = chunks.into_iter().map(Ok);
from_iter(results)
}
pub fn from_string<S>(id: S, text: S) -> TokenStream
where
S: Into<String> + Clone,
{
let id = id.into();
let text = text.into();
let chars: Vec<char> = text.chars().collect();
let stream = futures_util::stream::iter(chars.into_iter().map(move |c| {
let id = id.clone();
Ok(StreamChunk::new(id, c.to_string()))
}));
Box::pin(stream)
}
pub fn from_string_words<S>(id: S, text: S) -> TokenStream
where
S: Into<String> + Clone,
{
let id = id.into();
let text = text.into();
let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
let stream = futures_util::stream::iter(words.into_iter().map(move |word| {
let id = id.clone();
Ok(StreamChunk::new(id, format!("{} ", word)))
}));
Box::pin(stream)
}
pub struct SseParser {
buffer: String,
}
impl SseParser {
pub fn new() -> Self {
Self {
buffer: String::new(),
}
}
pub fn parse_chunk(&mut self, chunk: &[u8]) -> AixResult<Vec<String>> {
let chunk_str = std::str::from_utf8(chunk)
.map_err(|e| AixError::serialization(e.to_string(), "SSE chunk parsing"))?;
self.buffer.push_str(chunk_str);
self.extract_events()
}
fn extract_events(&mut self) -> AixResult<Vec<String>> {
let mut events = Vec::new();
let mut lines = self.buffer.lines().peekable();
while let Some(line) = lines.next() {
if line.starts_with("data:") {
let mut event_data = line[5..].trim().to_string();
while let Some(&next_line) = lines.peek() {
if next_line.starts_with("data:") {
event_data.push_str(&next_line[5..].trim());
lines.next(); } else {
break;
}
}
if event_data == "[DONE]" {
events.push("[DONE]".to_string());
} else if !event_data.is_empty() {
events.push(event_data);
}
}
}
let last_complete_pos = self.buffer.rfind("\n\n").unwrap_or(0);
if last_complete_pos > 0 {
self.buffer.drain(0..=last_complete_pos + 1);
}
Ok(events)
}
pub fn remaining_data(&self) -> &str {
&self.buffer
}
pub fn clear(&mut self) {
self.buffer.clear();
}
}
impl Default for SseParser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt as FuturesStreamExt;
#[tokio::test]
async fn test_collect_text() {
let chunks = vec![
Ok(StreamChunk::new("1", "Hello")),
Ok(StreamChunk::new("2", ", ")),
Ok(StreamChunk::new("3", "world")),
Ok(StreamChunk::new("4", "!")),
];
let stream = from_iter(chunks);
let text = stream.collect_text().await.unwrap();
assert_eq!(text, "Hello, world!");
}
#[tokio::test]
async fn test_filter_empty() {
let chunks = vec![
Ok(StreamChunk::new("1", "Hello")),
Ok(StreamChunk::new("2", "")), Ok(StreamChunk::new("3", "world")),
Ok(StreamChunk::new("4", "")), ];
let stream = from_iter(chunks).filter_empty();
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 2);
assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
assert_eq!(collected[1].as_ref().unwrap().delta, "world");
}
#[tokio::test]
async fn test_from_string() {
let stream = from_string("test", "Hello world");
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 11); assert_eq!(collected[0].as_ref().unwrap().delta, "H");
assert_eq!(collected[1].as_ref().unwrap().delta, "e");
}
#[tokio::test]
async fn test_from_string_words() {
let stream = from_string_words("test", "Hello world from Rust");
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 4);
assert_eq!(collected[0].as_ref().unwrap().delta, "Hello ");
assert_eq!(collected[1].as_ref().unwrap().delta, "world ");
assert_eq!(collected[2].as_ref().unwrap().delta, "from ");
assert_eq!(collected[3].as_ref().unwrap().delta, "Rust");
}
#[test]
fn test_sse_parser() {
let mut parser = SseParser::new();
let chunk = b"data: {\"content\": \"Hello\"}\n\n";
let events = parser.parse_chunk(chunk).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], "{\"content\": \"Hello\"}");
let chunk = b"data: [DONE]\n\n";
let events = parser.parse_chunk(chunk).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], "[DONE]");
}
#[test]
fn test_sse_parser_incomplete_event() {
let mut parser = SseParser::new();
let chunk = b"data: {\"content\":";
let events = parser.parse_chunk(chunk).unwrap();
assert_eq!(events.len(), 0);
let chunk = b" \"Hello\"}\n\n";
let events = parser.parse_chunk(chunk).unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], "{\"content\": \"Hello\"}");
}
#[test]
fn test_sse_parser_multiple_events() {
let mut parser = SseParser::new();
let chunk = b"data: {\"content\": \"Hello\"}\n\ndata: {\"content\": \"world\"}\n\ndata: [DONE]\n\n";
let events = parser.parse_chunk(chunk).unwrap();
assert_eq!(events.len(), 3);
assert_eq!(events[0], "{\"content\": \"Hello\"}");
assert_eq!(events[1], "{\"content\": \"world\"}");
assert_eq!(events[2], "[DONE]");
}
#[tokio::test]
async fn test_error_stream() {
let error = AixError::other("test error");
let stream = error_stream(error);
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 1);
assert!(collected[0].is_err());
assert_eq!(collected[0].as_ref().unwrap_err().to_string(), "Error: test error");
}
#[tokio::test]
async fn test_single_chunk() {
let chunk = StreamChunk::new("test", "Hello");
let stream = single_chunk(chunk);
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected.len(), 1);
assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
}
}