Skip to main content

aix_core/
streaming.rs

1//! Streaming abstractions and utilities.
2//!
3//! This module provides types and utilities for handling streaming responses
4//! from AI providers, including stream adapters and convenience functions.
5
6use crate::error::{AixError, AixResult};
7use crate::types::StreamChunk;
8use futures_core::Stream;
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::Duration;
14
15/// Type alias for a stream of stream chunks.
16pub type TokenStream = Pin<Box<dyn Stream<Item = AixResult<StreamChunk>> + Send>>;
17
18/// Extension trait for streams with convenience methods.
19pub trait StreamExt: Stream {
20    /// Collect all chunks into a single string.
21    ///
22    /// This method consumes the stream and collects all the delta content
23    /// from each chunk into a single string.
24    ///
25    /// # Returns
26    /// A future that resolves to the complete text or an error
27    fn collect_text(self) -> CollectText<Self>
28    where
29        Self: Sized,
30    {
31        CollectText::new(self)
32    }
33
34    /// Filter out chunks with empty delta content.
35    ///
36    /// # Returns
37    /// A stream that only yields chunks with non-empty delta content
38    fn filter_empty(self) -> FilterEmpty<Self>
39    where
40        Self: Sized,
41    {
42        FilterEmpty::new(self)
43    }
44
45    /// Buffer chunks for a given duration before yielding them.
46    ///
47    /// This can be useful to reduce the frequency of updates in UI applications.
48    ///
49    /// # Arguments
50    /// * `duration` - The buffer duration
51    ///
52    /// # Returns
53    /// A stream that yields buffered chunks
54    fn buffer_chunks(self, duration: Duration) -> BufferChunks<Self>
55    where
56        Self: Sized,
57    {
58        BufferChunks::new(self, duration)
59    }
60}
61
62// Blanket implementation for all streams
63impl<T: ?Sized> StreamExt for T where T: Stream {}
64
65/// Stream adapter that collects all text from chunks.
66pin_project! {
67    pub struct CollectText<S> {
68        #[pin]
69        stream: S,
70        buffer: String,
71    }
72}
73
74impl<S> CollectText<S> {
75    fn new(stream: S) -> Self {
76        Self {
77            stream,
78            buffer: String::new(),
79        }
80    }
81}
82
83impl<S> std::future::Future for CollectText<S>
84where
85    S: Stream<Item = AixResult<StreamChunk>>,
86{
87    type Output = AixResult<String>;
88
89    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90        let mut this = self.project();
91        
92        loop {
93            match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
94                Some(Ok(chunk)) => {
95                    this.buffer.push_str(&chunk.delta);
96                }
97                Some(Err(error)) => {
98                    return Poll::Ready(Err(error));
99                }
100                None => {
101                    return Poll::Ready(Ok(this.buffer.clone()));
102                }
103            }
104        }
105    }
106}
107
108/// Stream adapter that filters out empty chunks.
109pin_project! {
110    pub struct FilterEmpty<S> {
111        #[pin]
112        stream: S,
113    }
114}
115
116impl<S> FilterEmpty<S> {
117    fn new(stream: S) -> Self {
118        Self { stream }
119    }
120}
121
122impl<S> Stream for FilterEmpty<S>
123where
124    S: Stream<Item = AixResult<StreamChunk>>,
125{
126    type Item = AixResult<StreamChunk>;
127
128    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        let mut this = self.project();
130        
131        loop {
132            match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
133                Some(Ok(chunk)) => {
134                    if chunk.delta.is_empty() && chunk.finish_reason.is_none() {
135                        // Skip empty chunks without finish reason
136                        continue;
137                    }
138                    return Poll::Ready(Some(Ok(chunk)));
139                }
140                other => return Poll::Ready(other),
141            }
142        }
143    }
144}
145
146/// Stream adapter that buffers chunks for a duration.
147pin_project! {
148    pub struct BufferChunks<S> {
149        #[pin]
150        stream: S,
151        buffer: Vec<StreamChunk>,
152        last_flush: Option<tokio::time::Instant>,
153        duration: Duration,
154        #[pin]
155        delay: Option<tokio::time::Sleep>,
156    }
157}
158
159impl<S> BufferChunks<S> {
160    fn new(stream: S, duration: Duration) -> Self {
161        Self {
162            stream,
163            buffer: Vec::new(),
164            last_flush: None,
165            duration,
166            delay: None,
167        }
168    }
169}
170
171impl<S> Stream for BufferChunks<S>
172where
173    S: Stream<Item = AixResult<StreamChunk>>,
174{
175    type Item = AixResult<StreamChunk>;
176
177    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        let mut this = self.project();
179        let now = tokio::time::Instant::now();
180
181        // Check if we have buffered chunks and it's time to flush
182        if !this.buffer.is_empty() {
183            let should_flush = if let Some(last_flush) = this.last_flush {
184                now.duration_since(*last_flush) >= *this.duration
185            } else {
186                true // Flush immediately on first chunk
187            };
188
189            if should_flush {
190                // Combine all buffered chunks into one
191                let combined_id = this.buffer
192                    .first()
193                    .map(|c| c.id.clone())
194                    .unwrap_or_else(|| "buffered".to_string());
195                
196                let combined_delta: String = this.buffer
197                    .iter()
198                    .map(|c| c.delta.as_str())
199                    .collect();
200                
201                let finish_reason = this.buffer
202                    .iter()
203                    .find_map(|c| c.finish_reason.clone());
204
205                let combined_chunk = StreamChunk {
206                    id: combined_id,
207                    delta: combined_delta,
208                    finish_reason,
209                };
210
211                this.buffer.clear();
212                *this.last_flush = Some(now);
213
214                return Poll::Ready(Some(Ok(combined_chunk)));
215            }
216        }
217
218        // Poll for new chunks
219        match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
220            Some(Ok(chunk)) => {
221                // Buffer the chunk
222                this.buffer.push(chunk);
223
224                // If this is the first chunk, set up the delay
225                if this.delay.is_none() {
226                    this.delay.set(Some(tokio::time::sleep(*this.duration)));
227                }
228
229                // Try to poll the delay
230                if let Some(delay) = this.delay.as_mut().as_pin_mut() {
231                    match delay.poll(cx) {
232                        std::task::Poll::Ready(_) => {
233                            this.delay.set(None);
234                            // The next poll will flush the buffer
235                        }
236                        std::task::Poll::Pending => {
237                            // Still waiting
238                        }
239                    }
240                }
241
242                Poll::Pending
243            }
244            Some(Err(error)) => {
245                Poll::Ready(Some(Err(error)))
246            }
247            None => {
248                // Stream ended, flush any remaining chunks
249                if !this.buffer.is_empty() {
250                    let combined_id = this.buffer
251                        .first()
252                        .map(|c| c.id.clone())
253                        .unwrap_or_else(|| "buffered".to_string());
254                    
255                    let combined_delta: String = this.buffer
256                        .iter()
257                        .map(|c| c.delta.as_str())
258                        .collect();
259                    
260                    let finish_reason = this.buffer
261                        .iter()
262                        .find_map(|c| c.finish_reason.clone());
263
264                    let combined_chunk = StreamChunk {
265                        id: combined_id,
266                        delta: combined_delta,
267                        finish_reason,
268                    };
269
270                    this.buffer.clear();
271
272                    Poll::Ready(Some(Ok(combined_chunk)))
273                } else {
274                    Poll::Ready(None)
275                }
276            }
277        }
278    }
279}
280
281/// Create a stream from an iterator of results.
282pub fn from_iter<I>(iter: I) -> TokenStream
283where
284    I: IntoIterator<Item = AixResult<StreamChunk>>,
285    I::IntoIter: Send + 'static,
286{
287    let stream = futures_util::stream::iter(iter);
288    Box::pin(stream)
289}
290
291/// Create a stream that immediately yields an error.
292pub fn error_stream(error: AixError) -> TokenStream {
293    let stream = futures_util::stream::once(async move { Err(error) });
294    Box::pin(stream)
295}
296
297/// Create a stream that yields a single chunk.
298pub fn single_chunk(chunk: StreamChunk) -> TokenStream {
299    let stream = futures_util::stream::once(async move { Ok(chunk) });
300    Box::pin(stream)
301}
302
303/// Create a stream that yields multiple chunks.
304pub fn chunks<I>(chunks: I) -> TokenStream
305where
306    I: IntoIterator<Item = StreamChunk>,
307    I::IntoIter: Send + 'static,
308{
309    let results = chunks.into_iter().map(Ok);
310    from_iter(results)
311}
312
313/// Create a stream from a string that yields character by character.
314pub fn from_string<S>(id: S, text: S) -> TokenStream
315where
316    S: Into<String> + Clone,
317{
318    let id = id.into();
319    let text = text.into();
320    let chars: Vec<char> = text.chars().collect();
321    let stream = futures_util::stream::iter(chars.into_iter().map(move |c| {
322        let id = id.clone();
323        Ok(StreamChunk::new(id, c.to_string()))
324    }));
325    Box::pin(stream)
326}
327
328/// Create a stream from a string that yields word by word.
329pub fn from_string_words<S>(id: S, text: S) -> TokenStream
330where
331    S: Into<String> + Clone,
332{
333    let id = id.into();
334    let text = text.into();
335    let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
336    let stream = futures_util::stream::iter(words.into_iter().map(move |word| {
337        let id = id.clone();
338        Ok(StreamChunk::new(id, format!("{} ", word)))
339    }));
340    Box::pin(stream)
341}
342
343/// Utility for parsing Server-Sent Events (SSE).
344pub struct SseParser {
345    buffer: String,
346}
347
348impl SseParser {
349    /// Create a new SSE parser.
350    pub fn new() -> Self {
351        Self {
352            buffer: String::new(),
353        }
354    }
355
356    /// Parse a chunk of SSE data.
357    ///
358    /// # Arguments
359    /// * `chunk` - A chunk of bytes from the SSE stream
360    ///
361    /// # Returns
362    /// A vector of parsed events, or an error if parsing fails
363    pub fn parse_chunk(&mut self, chunk: &[u8]) -> AixResult<Vec<String>> {
364        let chunk_str = std::str::from_utf8(chunk)
365            .map_err(|e| AixError::serialization(e.to_string(), "SSE chunk parsing"))?;
366
367        self.buffer.push_str(chunk_str);
368        self.extract_events()
369    }
370
371    /// Extract complete events from the buffer.
372    fn extract_events(&mut self) -> AixResult<Vec<String>> {
373        let mut events = Vec::new();
374        let mut lines = self.buffer.lines().peekable();
375
376        while let Some(line) = lines.next() {
377            if line.starts_with("data:") {
378                let mut event_data = line[5..].trim().to_string();
379                
380                // Look for additional data lines
381                while let Some(&next_line) = lines.peek() {
382                    if next_line.starts_with("data:") {
383                        event_data.push_str(&next_line[5..].trim());
384                        lines.next(); // Consume the line
385                    } else {
386                        break;
387                    }
388                }
389
390                // Check if this is the end of an event (empty line or [DONE])
391                if event_data == "[DONE]" {
392                    events.push("[DONE]".to_string());
393                } else if !event_data.is_empty() {
394                    events.push(event_data);
395                }
396            }
397        }
398
399        // Clear processed data from buffer
400        // Keep any incomplete data that might be waiting for more chunks
401        let last_complete_pos = self.buffer.rfind("\n\n").unwrap_or(0);
402        if last_complete_pos > 0 {
403            self.buffer.drain(0..=last_complete_pos + 1);
404        }
405
406        Ok(events)
407    }
408
409    /// Get any remaining data in the buffer.
410    pub fn remaining_data(&self) -> &str {
411        &self.buffer
412    }
413
414    /// Clear the buffer.
415    pub fn clear(&mut self) {
416        self.buffer.clear();
417    }
418}
419
420impl Default for SseParser {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use futures_util::StreamExt as FuturesStreamExt;
430
431    #[tokio::test]
432    async fn test_collect_text() {
433        let chunks = vec![
434            Ok(StreamChunk::new("1", "Hello")),
435            Ok(StreamChunk::new("2", ", ")),
436            Ok(StreamChunk::new("3", "world")),
437            Ok(StreamChunk::new("4", "!")),
438        ];
439
440        let stream = from_iter(chunks);
441        let text = stream.collect_text().await.unwrap();
442        assert_eq!(text, "Hello, world!");
443    }
444
445    #[tokio::test]
446    async fn test_filter_empty() {
447        let chunks = vec![
448            Ok(StreamChunk::new("1", "Hello")),
449            Ok(StreamChunk::new("2", "")), // Should be filtered out
450            Ok(StreamChunk::new("3", "world")),
451            Ok(StreamChunk::new("4", "")), // Should be filtered out
452        ];
453
454        let stream = from_iter(chunks).filter_empty();
455        let collected: Vec<_> = stream.collect().await;
456        
457        assert_eq!(collected.len(), 2);
458        assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
459        assert_eq!(collected[1].as_ref().unwrap().delta, "world");
460    }
461
462    #[tokio::test]
463    async fn test_from_string() {
464        let stream = from_string("test", "Hello world");
465        let collected: Vec<_> = stream.collect().await;
466        
467        assert_eq!(collected.len(), 11); // "Hello world" + space = 11 chars
468        assert_eq!(collected[0].as_ref().unwrap().delta, "H");
469        assert_eq!(collected[1].as_ref().unwrap().delta, "e");
470    }
471
472    #[tokio::test]
473    async fn test_from_string_words() {
474        let stream = from_string_words("test", "Hello world from Rust");
475        let collected: Vec<_> = stream.collect().await;
476        
477        assert_eq!(collected.len(), 4);
478        assert_eq!(collected[0].as_ref().unwrap().delta, "Hello ");
479        assert_eq!(collected[1].as_ref().unwrap().delta, "world ");
480        assert_eq!(collected[2].as_ref().unwrap().delta, "from ");
481        assert_eq!(collected[3].as_ref().unwrap().delta, "Rust");
482    }
483
484    #[test]
485    fn test_sse_parser() {
486        let mut parser = SseParser::new();
487        
488        // Test parsing a complete event
489        let chunk = b"data: {\"content\": \"Hello\"}\n\n";
490        let events = parser.parse_chunk(chunk).unwrap();
491        assert_eq!(events.len(), 1);
492        assert_eq!(events[0], "{\"content\": \"Hello\"}");
493
494        // Test parsing [DONE] event
495        let chunk = b"data: [DONE]\n\n";
496        let events = parser.parse_chunk(chunk).unwrap();
497        assert_eq!(events.len(), 1);
498        assert_eq!(events[0], "[DONE]");
499    }
500
501    #[test]
502    fn test_sse_parser_incomplete_event() {
503        let mut parser = SseParser::new();
504        
505        // Send incomplete event
506        let chunk = b"data: {\"content\":";
507        let events = parser.parse_chunk(chunk).unwrap();
508        assert_eq!(events.len(), 0); // Should not yield events yet
509
510        // Complete the event
511        let chunk = b" \"Hello\"}\n\n";
512        let events = parser.parse_chunk(chunk).unwrap();
513        assert_eq!(events.len(), 1);
514        assert_eq!(events[0], "{\"content\": \"Hello\"}");
515    }
516
517    #[test]
518    fn test_sse_parser_multiple_events() {
519        let mut parser = SseParser::new();
520        
521        let chunk = b"data: {\"content\": \"Hello\"}\n\ndata: {\"content\": \"world\"}\n\ndata: [DONE]\n\n";
522        let events = parser.parse_chunk(chunk).unwrap();
523        assert_eq!(events.len(), 3);
524        assert_eq!(events[0], "{\"content\": \"Hello\"}");
525        assert_eq!(events[1], "{\"content\": \"world\"}");
526        assert_eq!(events[2], "[DONE]");
527    }
528
529    #[tokio::test]
530    async fn test_error_stream() {
531        let error = AixError::other("test error");
532        let stream = error_stream(error);
533        let collected: Vec<_> = stream.collect().await;
534        
535        assert_eq!(collected.len(), 1);
536        assert!(collected[0].is_err());
537        assert_eq!(collected[0].as_ref().unwrap_err().to_string(), "Error: test error");
538    }
539
540    #[tokio::test]
541    async fn test_single_chunk() {
542        let chunk = StreamChunk::new("test", "Hello");
543        let stream = single_chunk(chunk);
544        let collected: Vec<_> = stream.collect().await;
545        
546        assert_eq!(collected.len(), 1);
547        assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
548    }
549}