rsllm/
streaming.rs

1//! # RSLLM Streaming Support
2//! 
3//! Streaming response handling with proper async Stream traits.
4//! Supports real-time token streaming with backpressure and error handling.
5
6use crate::{RsllmError, RsllmResult, StreamChunk, ChatResponse, CompletionResponse};
7use futures_util::Stream;
8use pin_project_lite::pin_project;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use futures_util::Future;
12
13/// Type alias for chat streaming responses
14pub type ChatStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
15
16/// Type alias for completion streaming responses  
17pub type CompletionStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
18
19/// Stream collector for assembling complete responses from chunks
20pin_project! {
21    pub struct StreamCollector<S> {
22        #[pin]
23        stream: S,
24        accumulated_content: String,
25        model: Option<String>,
26        finish_reason: Option<String>,
27        usage: Option<crate::Usage>,
28        metadata: std::collections::HashMap<String, serde_json::Value>,
29        tool_calls: Vec<crate::ToolCall>,
30        is_done: bool,
31    }
32}
33
34impl<S> StreamCollector<S>
35where
36    S: Stream<Item = RsllmResult<StreamChunk>>,
37{
38    /// Create a new stream collector
39    pub fn new(stream: S) -> Self {
40        Self {
41            stream,
42            accumulated_content: String::new(),
43            model: None,
44            finish_reason: None,
45            usage: None,
46            metadata: std::collections::HashMap::new(),
47            tool_calls: Vec::new(),
48            is_done: false,
49        }
50    }
51    
52    /// Collect all chunks into a complete chat response
53    pub async fn collect_chat_response(mut self) -> RsllmResult<ChatResponse>
54    where
55        S: Unpin,
56    {
57        use futures_util::StreamExt;
58        while let Some(chunk_result) = self.next().await {
59            let _chunk = chunk_result?;
60            // Process chunk - this updates internal state
61        }
62        
63        let model = self.model.unwrap_or_else(|| "unknown".to_string());
64        
65        let mut response = ChatResponse::new(self.accumulated_content, model);
66        
67        if let Some(reason) = self.finish_reason {
68            response = response.with_finish_reason(reason);
69        }
70        
71        if let Some(usage) = self.usage {
72            response = response.with_usage(usage);
73        }
74        
75        if !self.tool_calls.is_empty() {
76            response = response.with_tool_calls(self.tool_calls);
77        }
78        
79        for (key, value) in self.metadata {
80            response = response.with_metadata(key, value);
81        }
82        
83        Ok(response)
84    }
85    
86    /// Collect all chunks into a complete completion response
87    pub async fn collect_completion_response(mut self) -> RsllmResult<CompletionResponse>
88    where
89        S: Unpin,
90    {
91        use futures_util::StreamExt;
92        while let Some(chunk_result) = self.next().await {
93            let _chunk = chunk_result?;
94            // Process chunk - this updates internal state
95        }
96        
97        let model = self.model.unwrap_or_else(|| "unknown".to_string());
98        
99        let mut response = CompletionResponse::new(self.accumulated_content, model);
100        
101        if let Some(reason) = self.finish_reason {
102            response = response.with_finish_reason(reason);
103        }
104        
105        if let Some(usage) = self.usage {
106            response = response.with_usage(usage);
107        }
108        
109        for (key, value) in self.metadata {
110            response = response.with_metadata(key, value);
111        }
112        
113        Ok(response)
114    }
115}
116
117impl<S> Stream for StreamCollector<S>
118where
119    S: Stream<Item = RsllmResult<StreamChunk>>,
120{
121    type Item = RsllmResult<StreamChunk>;
122    
123    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124        let mut this = self.project();
125        
126        if *this.is_done {
127            return Poll::Ready(None);
128        }
129        
130        match this.stream.as_mut().poll_next(cx) {
131            Poll::Ready(Some(Ok(chunk))) => {
132                // Update accumulated state
133                if chunk.has_content() {
134                    this.accumulated_content.push_str(&chunk.content);
135                }
136                
137                if this.model.is_none() && !chunk.model.is_empty() {
138                    *this.model = Some(chunk.model.clone());
139                }
140                
141                if let Some(reason) = &chunk.finish_reason {
142                    *this.finish_reason = Some(reason.clone());
143                }
144                
145                if let Some(usage) = &chunk.usage {
146                    *this.usage = Some(usage.clone());
147                }
148                
149                // Merge metadata
150                for (key, value) in &chunk.metadata {
151                    this.metadata.insert(key.clone(), value.clone());
152                }
153                
154                // Handle tool calls delta (simplified - would need proper delta merging)
155                if let Some(_tool_calls_delta) = &chunk.tool_calls_delta {
156                    // TODO: Implement proper tool call delta merging
157                }
158                
159                if chunk.is_done {
160                    *this.is_done = true;
161                }
162                
163                Poll::Ready(Some(Ok(chunk)))
164            }
165            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
166            Poll::Ready(None) => {
167                *this.is_done = true;
168                Poll::Ready(None)
169            }
170            Poll::Pending => Poll::Pending,
171        }
172    }
173}
174
175/// Stream adapter for rate limiting
176pin_project! {
177    pub struct RateLimitedStream<S> {
178        #[pin]
179        stream: S,
180        delay: std::time::Duration,
181        last_emit: Option<std::time::Instant>,
182    }
183}
184
185impl<S> RateLimitedStream<S> {
186    /// Create a new rate-limited stream
187    pub fn new(stream: S, max_chunks_per_second: f64) -> Self {
188        let delay = std::time::Duration::from_secs_f64(1.0 / max_chunks_per_second);
189        Self {
190            stream,
191            delay,
192            last_emit: None,
193        }
194    }
195}
196
197impl<S> Stream for RateLimitedStream<S>
198where
199    S: Stream<Item = RsllmResult<StreamChunk>>,
200{
201    type Item = S::Item;
202    
203    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        let mut this = self.project();
205        
206        // Check if we need to delay
207        if let Some(last) = this.last_emit {
208            let elapsed = last.elapsed();
209            if elapsed < *this.delay {
210                let remaining = *this.delay - elapsed;
211                
212                // Set up a timer for the remaining delay
213                let sleep = tokio::time::sleep(remaining);
214                tokio::pin!(sleep);
215                
216                if sleep.as_mut().poll(cx).is_pending() {
217                    return Poll::Pending;
218                }
219            }
220        }
221        
222        match this.stream.as_mut().poll_next(cx) {
223            Poll::Ready(Some(item)) => {
224                *this.last_emit = Some(std::time::Instant::now());
225                Poll::Ready(Some(item))
226            }
227            other => other,
228        }
229    }
230}
231
232/// Stream adapter for filtering chunks
233pin_project! {
234    pub struct FilteredStream<S, F> {
235        #[pin]
236        stream: S,
237        filter: F,
238    }
239}
240
241impl<S, F> FilteredStream<S, F>
242where
243    F: Fn(&StreamChunk) -> bool,
244{
245    /// Create a new filtered stream
246    pub fn new(stream: S, filter: F) -> Self {
247        Self { stream, filter }
248    }
249}
250
251impl<S, F> Stream for FilteredStream<S, F>
252where
253    S: Stream<Item = RsllmResult<StreamChunk>>,
254    F: Fn(&StreamChunk) -> bool,
255{
256    type Item = S::Item;
257    
258    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
259        let mut this = self.project();
260        
261        loop {
262            match this.stream.as_mut().poll_next(cx) {
263                Poll::Ready(Some(Ok(chunk))) => {
264                    if (this.filter)(&chunk) {
265                        return Poll::Ready(Some(Ok(chunk)));
266                    }
267                    // Continue polling if chunk was filtered out
268                }
269                other => return other,
270            }
271        }
272    }
273}
274
275/// Stream adapter for mapping chunks
276pin_project! {
277    pub struct MappedStream<S, F> {
278        #[pin]
279        stream: S,
280        mapper: F,
281    }
282}
283
284impl<S, F> MappedStream<S, F>
285where
286    F: Fn(StreamChunk) -> StreamChunk,
287{
288    /// Create a new mapped stream
289    pub fn new(stream: S, mapper: F) -> Self {
290        Self { stream, mapper }
291    }
292}
293
294impl<S, F> Stream for MappedStream<S, F>
295where
296    S: Stream<Item = RsllmResult<StreamChunk>>,
297    F: Fn(StreamChunk) -> StreamChunk,
298{
299    type Item = S::Item;
300    
301    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302        let mut this = self.project();
303        
304        match this.stream.as_mut().poll_next(cx) {
305            Poll::Ready(Some(Ok(chunk))) => {
306                let mapped = (this.mapper)(chunk);
307                Poll::Ready(Some(Ok(mapped)))
308            }
309            other => other,
310        }
311    }
312}
313
314/// Stream utilities
315pub struct StreamUtils;
316
317impl StreamUtils {
318    /// Convert a vector of chunks into a stream
319    pub fn from_chunks(chunks: Vec<StreamChunk>) -> ChatStream {
320        let stream = tokio_stream::iter(chunks.into_iter().map(Ok));
321        Box::pin(stream)
322    }
323    
324    /// Create an empty stream
325    pub fn empty() -> ChatStream {
326        let stream = tokio_stream::empty();
327        Box::pin(stream)
328    }
329    
330    /// Create a stream that immediately returns an error
331    pub fn error(error: RsllmError) -> ChatStream {
332        use futures_util::stream;
333        let stream = stream::once(async move { Err(error) });
334        Box::pin(stream)
335    }
336    
337    /// Collect stream into a vector of chunks
338    pub async fn collect_chunks<S>(stream: S) -> RsllmResult<Vec<StreamChunk>>
339    where
340        S: Stream<Item = RsllmResult<StreamChunk>>,
341    {
342        tokio_stream::StreamExt::collect::<Vec<_>>(stream)
343            .await
344            .into_iter()
345            .collect::<RsllmResult<Vec<_>>>()
346    }
347    
348    /// Take only the first N chunks from a stream
349    pub fn take<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
350    where
351        S: Stream<Item = RsllmResult<StreamChunk>>,
352    {
353        tokio_stream::StreamExt::take(stream, n)
354    }
355    
356    /// Skip the first N chunks from a stream
357    pub fn skip<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
358    where
359        S: Stream<Item = RsllmResult<StreamChunk>>,
360    {
361        tokio_stream::StreamExt::skip(stream, n)
362    }
363    
364    /// Filter chunks based on a predicate
365    pub fn filter<S, F>(stream: S, filter: F) -> FilteredStream<S, F>
366    where
367        S: Stream<Item = RsllmResult<StreamChunk>>,
368        F: Fn(&StreamChunk) -> bool,
369    {
370        FilteredStream::new(stream, filter)
371    }
372    
373    /// Map chunks with a function
374    pub fn map<S, F>(stream: S, mapper: F) -> MappedStream<S, F>
375    where
376        S: Stream<Item = RsllmResult<StreamChunk>>,
377        F: Fn(StreamChunk) -> StreamChunk,
378    {
379        MappedStream::new(stream, mapper)
380    }
381    
382    /// Rate limit a stream
383    pub fn rate_limit<S>(stream: S, max_chunks_per_second: f64) -> RateLimitedStream<S>
384    where
385        S: Stream<Item = RsllmResult<StreamChunk>>,
386    {
387        RateLimitedStream::new(stream, max_chunks_per_second)
388    }
389    
390    /// Buffer chunks to reduce API calls (simplified implementation)
391    pub async fn buffer<S>(
392        mut stream: S,
393        max_size: usize,
394    ) -> RsllmResult<Vec<StreamChunk>>
395    where
396        S: Stream<Item = RsllmResult<StreamChunk>> + Unpin,
397    {
398        let mut chunks = Vec::new();
399        let mut count = 0;
400        
401        use futures_util::StreamExt;
402        while let Some(chunk) = stream.next().await {
403            chunks.push(chunk?);
404            count += 1;
405            
406            if count >= max_size {
407                break;
408            }
409        }
410        
411        Ok(chunks)
412    }
413}
414
415/// Stream extension traits for additional functionality
416pub trait RsllmStreamExt: Stream<Item = RsllmResult<StreamChunk>> + Sized {
417    /// Collect stream into a complete chat response
418    fn collect_chat_response(self) -> impl std::future::Future<Output = RsllmResult<ChatResponse>> + Send
419    where
420        Self: Send + Unpin,
421    {
422        StreamCollector::new(self).collect_chat_response()
423    }
424    
425    /// Collect stream into a complete completion response
426    fn collect_completion_response(self) -> impl std::future::Future<Output = RsllmResult<CompletionResponse>> + Send
427    where
428        Self: Send + Unpin,
429    {
430        StreamCollector::new(self).collect_completion_response()
431    }
432    
433    /// Filter chunks that have content
434    fn content_only(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
435        FilteredStream::new(self, |chunk| chunk.has_content())
436    }
437    
438    /// Filter out done chunks
439    fn exclude_done(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
440        FilteredStream::new(self, |chunk| !chunk.is_done)
441    }
442    
443    /// Rate limit the stream
444    fn rate_limit(self, max_chunks_per_second: f64) -> RateLimitedStream<Self> {
445        RateLimitedStream::new(self, max_chunks_per_second)
446    }
447}
448
449impl<S> RsllmStreamExt for S where S: Stream<Item = RsllmResult<StreamChunk>> {}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::{MessageRole, StreamChunk};
455    
456    #[tokio::test]
457    async fn test_stream_collector() {
458        let chunks = vec![
459            StreamChunk::delta("Hello", "gpt-4").with_role(MessageRole::Assistant),
460            StreamChunk::delta(" world", "gpt-4"),
461            StreamChunk::done("gpt-4").with_finish_reason("stop"),
462        ];
463        
464        let stream = StreamUtils::from_chunks(chunks);
465        let response = stream.collect_chat_response().await.unwrap();
466        
467        assert_eq!(response.content, "Hello world");
468        assert_eq!(response.model, "gpt-4");
469        assert_eq!(response.finish_reason, Some("stop".to_string()));
470    }
471    
472    #[tokio::test]
473    async fn test_filter_stream() {
474        let chunks = vec![
475            StreamChunk::delta("Hello", "gpt-4"),
476            StreamChunk::new("", "gpt-4", false, false), // Empty chunk
477            StreamChunk::delta(" world", "gpt-4"),
478        ];
479        
480        let stream = StreamUtils::from_chunks(chunks);
481        use futures_util::StreamExt;
482        let mut filtered_stream = stream.content_only();
483        let mut filtered_chunks = Vec::new();
484        while let Some(chunk) = filtered_stream.next().await {
485            filtered_chunks.push(chunk.unwrap());
486        }
487        
488        assert_eq!(filtered_chunks.len(), 2);
489        assert_eq!(filtered_chunks[0].content, "Hello");
490        assert_eq!(filtered_chunks[1].content, " world");
491    }
492    
493    #[tokio::test]
494    async fn test_map_stream() {
495        let chunks = vec![
496            StreamChunk::delta("hello", "gpt-4"),
497            StreamChunk::delta(" world", "gpt-4"),
498        ];
499        
500        let stream = StreamUtils::from_chunks(chunks);
501        let mapped_stream = StreamUtils::map(stream, |mut chunk| {
502            chunk.content = chunk.content.to_uppercase();
503            chunk
504        });
505        
506        let collected = StreamUtils::collect_chunks(mapped_stream).await.unwrap();
507        
508        assert_eq!(collected[0].content, "HELLO");
509        assert_eq!(collected[1].content, " WORLD");
510    }
511}