rexis_llm/
streaming.rs

1//! # RSLLM Streaming Support
2//!
3//! Streaming response handling with proper async Stream traits.
4//! Supports real-time token streaming with backpressure and error handling.
5
6use crate::{ChatResponse, CompletionResponse, RsllmError, RsllmResult, StreamChunk};
7use futures_util::Future;
8use futures_util::Stream;
9use pin_project_lite::pin_project;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13/// Type alias for chat streaming responses
14pub type ChatStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
15
16/// Type alias for completion streaming responses
17pub type CompletionStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
18
19pin_project! {
20    /// Stream collector for assembling complete responses from chunks
21    pub struct StreamCollector<S> {
22        #[pin]
23        stream: S,
24        accumulated_content: String,
25        model: Option<String>,
26        finish_reason: Option<String>,
27        usage: Option<crate::Usage>,
28        metadata: std::collections::HashMap<String, serde_json::Value>,
29        tool_calls: Vec<crate::ToolCall>,
30        is_done: bool,
31    }
32}
33
34impl<S> StreamCollector<S>
35where
36    S: Stream<Item = RsllmResult<StreamChunk>>,
37{
38    /// Create a new stream collector
39    pub fn new(stream: S) -> Self {
40        Self {
41            stream,
42            accumulated_content: String::new(),
43            model: None,
44            finish_reason: None,
45            usage: None,
46            metadata: std::collections::HashMap::new(),
47            tool_calls: Vec::new(),
48            is_done: false,
49        }
50    }
51
52    /// Collect all chunks into a complete chat response
53    pub async fn collect_chat_response(mut self) -> RsllmResult<ChatResponse>
54    where
55        S: Unpin,
56    {
57        use futures_util::StreamExt;
58        while let Some(chunk_result) = self.next().await {
59            let _chunk = chunk_result?;
60            // Process chunk - this updates internal state
61        }
62
63        let model = self.model.unwrap_or_else(|| "unknown".to_string());
64
65        let mut response = ChatResponse::new(self.accumulated_content, model);
66
67        if let Some(reason) = self.finish_reason {
68            response = response.with_finish_reason(reason);
69        }
70
71        if let Some(usage) = self.usage {
72            response = response.with_usage(usage);
73        }
74
75        if !self.tool_calls.is_empty() {
76            response = response.with_tool_calls(self.tool_calls);
77        }
78
79        for (key, value) in self.metadata {
80            response = response.with_metadata(key, value);
81        }
82
83        Ok(response)
84    }
85
86    /// Collect all chunks into a complete completion response
87    pub async fn collect_completion_response(mut self) -> RsllmResult<CompletionResponse>
88    where
89        S: Unpin,
90    {
91        use futures_util::StreamExt;
92        while let Some(chunk_result) = self.next().await {
93            let _chunk = chunk_result?;
94            // Process chunk - this updates internal state
95        }
96
97        let model = self.model.unwrap_or_else(|| "unknown".to_string());
98
99        let mut response = CompletionResponse::new(self.accumulated_content, model);
100
101        if let Some(reason) = self.finish_reason {
102            response = response.with_finish_reason(reason);
103        }
104
105        if let Some(usage) = self.usage {
106            response = response.with_usage(usage);
107        }
108
109        for (key, value) in self.metadata {
110            response = response.with_metadata(key, value);
111        }
112
113        Ok(response)
114    }
115}
116
117impl<S> Stream for StreamCollector<S>
118where
119    S: Stream<Item = RsllmResult<StreamChunk>>,
120{
121    type Item = RsllmResult<StreamChunk>;
122
123    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124        let mut this = self.project();
125
126        if *this.is_done {
127            return Poll::Ready(None);
128        }
129
130        match this.stream.as_mut().poll_next(cx) {
131            Poll::Ready(Some(Ok(chunk))) => {
132                // Update accumulated state
133                if chunk.has_content() {
134                    this.accumulated_content.push_str(&chunk.content);
135                }
136
137                if this.model.is_none() && !chunk.model.is_empty() {
138                    *this.model = Some(chunk.model.clone());
139                }
140
141                if let Some(reason) = &chunk.finish_reason {
142                    *this.finish_reason = Some(reason.clone());
143                }
144
145                if let Some(usage) = &chunk.usage {
146                    *this.usage = Some(usage.clone());
147                }
148
149                // Merge metadata
150                for (key, value) in &chunk.metadata {
151                    this.metadata.insert(key.clone(), value.clone());
152                }
153
154                // Handle tool calls delta (simplified - would need proper delta merging)
155                if let Some(_tool_calls_delta) = &chunk.tool_calls_delta {
156                    // TODO: Implement proper tool call delta merging
157                }
158
159                if chunk.is_done {
160                    *this.is_done = true;
161                }
162
163                Poll::Ready(Some(Ok(chunk)))
164            }
165            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
166            Poll::Ready(None) => {
167                *this.is_done = true;
168                Poll::Ready(None)
169            }
170            Poll::Pending => Poll::Pending,
171        }
172    }
173}
174
175pin_project! {
176    /// Stream adapter for rate limiting
177    pub struct RateLimitedStream<S> {
178        #[pin]
179        stream: S,
180        delay: std::time::Duration,
181        last_emit: Option<std::time::Instant>,
182    }
183}
184
185impl<S> RateLimitedStream<S> {
186    /// Create a new rate-limited stream
187    pub fn new(stream: S, max_chunks_per_second: f64) -> Self {
188        let delay = std::time::Duration::from_secs_f64(1.0 / max_chunks_per_second);
189        Self {
190            stream,
191            delay,
192            last_emit: None,
193        }
194    }
195}
196
197impl<S> Stream for RateLimitedStream<S>
198where
199    S: Stream<Item = RsllmResult<StreamChunk>>,
200{
201    type Item = S::Item;
202
203    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        let mut this = self.project();
205
206        // Check if we need to delay
207        if let Some(last) = this.last_emit {
208            let elapsed = last.elapsed();
209            if elapsed < *this.delay {
210                let remaining = *this.delay - elapsed;
211
212                // Set up a timer for the remaining delay
213                let sleep = tokio::time::sleep(remaining);
214                tokio::pin!(sleep);
215
216                if sleep.as_mut().poll(cx).is_pending() {
217                    return Poll::Pending;
218                }
219            }
220        }
221
222        match this.stream.as_mut().poll_next(cx) {
223            Poll::Ready(Some(item)) => {
224                *this.last_emit = Some(std::time::Instant::now());
225                Poll::Ready(Some(item))
226            }
227            other => other,
228        }
229    }
230}
231
232pin_project! {
233    /// Stream adapter for filtering chunks
234    pub struct FilteredStream<S, F> {
235        #[pin]
236        stream: S,
237        filter: F,
238    }
239}
240
241impl<S, F> FilteredStream<S, F>
242where
243    F: Fn(&StreamChunk) -> bool,
244{
245    /// Create a new filtered stream
246    pub fn new(stream: S, filter: F) -> Self {
247        Self { stream, filter }
248    }
249}
250
251impl<S, F> Stream for FilteredStream<S, F>
252where
253    S: Stream<Item = RsllmResult<StreamChunk>>,
254    F: Fn(&StreamChunk) -> bool,
255{
256    type Item = S::Item;
257
258    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
259        let mut this = self.project();
260
261        loop {
262            match this.stream.as_mut().poll_next(cx) {
263                Poll::Ready(Some(Ok(chunk))) => {
264                    if (this.filter)(&chunk) {
265                        return Poll::Ready(Some(Ok(chunk)));
266                    }
267                    // Continue polling if chunk was filtered out
268                }
269                other => return other,
270            }
271        }
272    }
273}
274
275pin_project! {
276    /// Stream adapter for mapping chunks
277    pub struct MappedStream<S, F> {
278        #[pin]
279        stream: S,
280        mapper: F,
281    }
282}
283
284impl<S, F> MappedStream<S, F>
285where
286    F: Fn(StreamChunk) -> StreamChunk,
287{
288    /// Create a new mapped stream
289    pub fn new(stream: S, mapper: F) -> Self {
290        Self { stream, mapper }
291    }
292}
293
294impl<S, F> Stream for MappedStream<S, F>
295where
296    S: Stream<Item = RsllmResult<StreamChunk>>,
297    F: Fn(StreamChunk) -> StreamChunk,
298{
299    type Item = S::Item;
300
301    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302        let mut this = self.project();
303
304        match this.stream.as_mut().poll_next(cx) {
305            Poll::Ready(Some(Ok(chunk))) => {
306                let mapped = (this.mapper)(chunk);
307                Poll::Ready(Some(Ok(mapped)))
308            }
309            other => other,
310        }
311    }
312}
313
314/// Stream utilities
315pub struct StreamUtils;
316
317impl StreamUtils {
318    /// Convert a vector of chunks into a stream
319    pub fn from_chunks(chunks: Vec<StreamChunk>) -> ChatStream {
320        let stream = tokio_stream::iter(chunks.into_iter().map(Ok));
321        Box::pin(stream)
322    }
323
324    /// Create an empty stream
325    pub fn empty() -> ChatStream {
326        let stream = tokio_stream::empty();
327        Box::pin(stream)
328    }
329
330    /// Create a stream that immediately returns an error
331    pub fn error(error: RsllmError) -> ChatStream {
332        use futures_util::stream;
333        let stream = stream::once(async move { Err(error) });
334        Box::pin(stream)
335    }
336
337    /// Collect stream into a vector of chunks
338    pub async fn collect_chunks<S>(stream: S) -> RsllmResult<Vec<StreamChunk>>
339    where
340        S: Stream<Item = RsllmResult<StreamChunk>>,
341    {
342        tokio_stream::StreamExt::collect::<Vec<_>>(stream)
343            .await
344            .into_iter()
345            .collect::<RsllmResult<Vec<_>>>()
346    }
347
348    /// Take only the first N chunks from a stream
349    pub fn take<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
350    where
351        S: Stream<Item = RsllmResult<StreamChunk>>,
352    {
353        tokio_stream::StreamExt::take(stream, n)
354    }
355
356    /// Skip the first N chunks from a stream
357    pub fn skip<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
358    where
359        S: Stream<Item = RsllmResult<StreamChunk>>,
360    {
361        tokio_stream::StreamExt::skip(stream, n)
362    }
363
364    /// Filter chunks based on a predicate
365    pub fn filter<S, F>(stream: S, filter: F) -> FilteredStream<S, F>
366    where
367        S: Stream<Item = RsllmResult<StreamChunk>>,
368        F: Fn(&StreamChunk) -> bool,
369    {
370        FilteredStream::new(stream, filter)
371    }
372
373    /// Map chunks with a function
374    pub fn map<S, F>(stream: S, mapper: F) -> MappedStream<S, F>
375    where
376        S: Stream<Item = RsllmResult<StreamChunk>>,
377        F: Fn(StreamChunk) -> StreamChunk,
378    {
379        MappedStream::new(stream, mapper)
380    }
381
382    /// Rate limit a stream
383    pub fn rate_limit<S>(stream: S, max_chunks_per_second: f64) -> RateLimitedStream<S>
384    where
385        S: Stream<Item = RsllmResult<StreamChunk>>,
386    {
387        RateLimitedStream::new(stream, max_chunks_per_second)
388    }
389
390    /// Buffer chunks to reduce API calls (simplified implementation)
391    pub async fn buffer<S>(mut stream: S, max_size: usize) -> RsllmResult<Vec<StreamChunk>>
392    where
393        S: Stream<Item = RsllmResult<StreamChunk>> + Unpin,
394    {
395        let mut chunks = Vec::new();
396        let mut count = 0;
397
398        use futures_util::StreamExt;
399        while let Some(chunk) = stream.next().await {
400            chunks.push(chunk?);
401            count += 1;
402
403            if count >= max_size {
404                break;
405            }
406        }
407
408        Ok(chunks)
409    }
410}
411
412/// Stream extension traits for additional functionality
413pub trait RsllmStreamExt: Stream<Item = RsllmResult<StreamChunk>> + Sized {
414    /// Collect stream into a complete chat response
415    fn collect_chat_response(
416        self,
417    ) -> impl std::future::Future<Output = RsllmResult<ChatResponse>> + Send
418    where
419        Self: Send + Unpin,
420    {
421        StreamCollector::new(self).collect_chat_response()
422    }
423
424    /// Collect stream into a complete completion response
425    fn collect_completion_response(
426        self,
427    ) -> impl std::future::Future<Output = RsllmResult<CompletionResponse>> + Send
428    where
429        Self: Send + Unpin,
430    {
431        StreamCollector::new(self).collect_completion_response()
432    }
433
434    /// Filter chunks that have content
435    fn content_only(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
436        FilteredStream::new(self, |chunk| chunk.has_content())
437    }
438
439    /// Filter out done chunks
440    fn exclude_done(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
441        FilteredStream::new(self, |chunk| !chunk.is_done)
442    }
443
444    /// Rate limit the stream
445    fn rate_limit(self, max_chunks_per_second: f64) -> RateLimitedStream<Self> {
446        RateLimitedStream::new(self, max_chunks_per_second)
447    }
448}
449
450impl<S> RsllmStreamExt for S where S: Stream<Item = RsllmResult<StreamChunk>> {}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::{MessageRole, StreamChunk};
456
457    #[tokio::test]
458    async fn test_stream_collector() {
459        let chunks = vec![
460            StreamChunk::delta("Hello", "gpt-4").with_role(MessageRole::Assistant),
461            StreamChunk::delta(" world", "gpt-4"),
462            StreamChunk::done("gpt-4").with_finish_reason("stop"),
463        ];
464
465        let stream = StreamUtils::from_chunks(chunks);
466        let response = stream.collect_chat_response().await.unwrap();
467
468        assert_eq!(response.content, "Hello world");
469        assert_eq!(response.model, "gpt-4");
470        assert_eq!(response.finish_reason, Some("stop".to_string()));
471    }
472
473    #[tokio::test]
474    async fn test_filter_stream() {
475        let chunks = vec![
476            StreamChunk::delta("Hello", "gpt-4"),
477            StreamChunk::new("", "gpt-4", false, false), // Empty chunk
478            StreamChunk::delta(" world", "gpt-4"),
479        ];
480
481        let stream = StreamUtils::from_chunks(chunks);
482        use futures_util::StreamExt;
483        let mut filtered_stream = stream.content_only();
484        let mut filtered_chunks = Vec::new();
485        while let Some(chunk) = filtered_stream.next().await {
486            filtered_chunks.push(chunk.unwrap());
487        }
488
489        assert_eq!(filtered_chunks.len(), 2);
490        assert_eq!(filtered_chunks[0].content, "Hello");
491        assert_eq!(filtered_chunks[1].content, " world");
492    }
493
494    #[tokio::test]
495    async fn test_map_stream() {
496        let chunks = vec![
497            StreamChunk::delta("hello", "gpt-4"),
498            StreamChunk::delta(" world", "gpt-4"),
499        ];
500
501        let stream = StreamUtils::from_chunks(chunks);
502        let mapped_stream = StreamUtils::map(stream, |mut chunk| {
503            chunk.content = chunk.content.to_uppercase();
504            chunk
505        });
506
507        let collected = StreamUtils::collect_chunks(mapped_stream).await.unwrap();
508
509        assert_eq!(collected[0].content, "HELLO");
510        assert_eq!(collected[1].content, " WORLD");
511    }
512}