rexis_rag/
streaming.rs

1//! # RRAG Streaming System
2//!
3//! Real-time streaming responses using Rust's async ecosystem.
4//! Leverages tokio-stream and futures for efficient token streaming.
5
6use crate::{RragError, RragResult};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::UnboundedReceiverStream;
13
14/// Streaming response token
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StreamToken {
17    /// Token content
18    pub content: String,
19
20    /// Token type (text, tool_call, metadata, etc.)
21    pub token_type: TokenType,
22
23    /// Position in the stream
24    pub position: usize,
25
26    /// Whether this is the final token
27    pub is_final: bool,
28
29    /// Token metadata
30    pub metadata: Option<serde_json::Value>,
31}
32
33impl StreamToken {
34    pub fn text(content: impl Into<String>, position: usize) -> Self {
35        Self {
36            content: content.into(),
37            token_type: TokenType::Text,
38            position,
39            is_final: false,
40            metadata: None,
41        }
42    }
43
44    pub fn tool_call(content: impl Into<String>, position: usize) -> Self {
45        Self {
46            content: content.into(),
47            token_type: TokenType::ToolCall,
48            position,
49            is_final: false,
50            metadata: None,
51        }
52    }
53
54    pub fn final_token(position: usize) -> Self {
55        Self {
56            content: String::new(),
57            token_type: TokenType::End,
58            position,
59            is_final: true,
60            metadata: None,
61        }
62    }
63
64    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
65        self.metadata = Some(metadata);
66        self
67    }
68}
69
70/// Token types for different streaming content
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum TokenType {
73    /// Regular text content
74    Text,
75
76    /// Tool call information
77    ToolCall,
78
79    /// Tool result
80    ToolResult,
81
82    /// Metadata/system information
83    Metadata,
84
85    /// Stream end marker
86    End,
87
88    /// Error token
89    Error,
90}
91
92/// Streaming response wrapper
93pub struct StreamingResponse {
94    stream: Pin<Box<dyn Stream<Item = RragResult<StreamToken>> + Send>>,
95}
96
97impl StreamingResponse {
98    /// Create from a text string by splitting into tokens
99    pub fn from_text(text: impl Into<String>) -> Self {
100        let text = text.into();
101        let tokens: Vec<_> = text
102            .split_whitespace()
103            .enumerate()
104            .map(|(i, word)| Ok(StreamToken::text(format!("{} ", word), i)))
105            .collect();
106
107        // Add final token
108        let mut tokens = tokens;
109        let final_pos = tokens.len();
110        tokens.push(Ok(StreamToken::final_token(final_pos)));
111
112        let stream = futures::stream::iter(tokens);
113
114        Self {
115            stream: Box::pin(stream),
116        }
117    }
118
119    /// Create from a token stream
120    pub fn from_stream<S>(stream: S) -> Self
121    where
122        S: Stream<Item = RragResult<StreamToken>> + Send + 'static,
123    {
124        Self {
125            stream: Box::pin(stream),
126        }
127    }
128
129    /// Create from an async channel
130    pub fn from_channel(receiver: mpsc::UnboundedReceiver<RragResult<StreamToken>>) -> Self {
131        let stream = UnboundedReceiverStream::new(receiver);
132        Self::from_stream(stream)
133    }
134
135    /// Collect all tokens into a single string
136    pub async fn collect_text(mut self) -> RragResult<String> {
137        let mut result = String::new();
138
139        while let Some(token_result) = self.stream.next().await {
140            match token_result? {
141                token if token.token_type == TokenType::Text => {
142                    result.push_str(&token.content);
143                }
144                token if token.is_final => break,
145                _ => {} // Skip non-text tokens
146            }
147        }
148
149        Ok(result.trim().to_string())
150    }
151
152    /// Filter tokens by type
153    pub fn filter_by_type(self, token_type: TokenType) -> FilteredStream {
154        FilteredStream {
155            stream: self.stream,
156            filter_type: token_type,
157        }
158    }
159
160    /// Map tokens to a different type
161    pub fn map_tokens<F, T>(self, f: F) -> MappedStream<T>
162    where
163        F: Fn(StreamToken) -> T + Send + 'static,
164        T: Send + 'static,
165    {
166        let mapped_stream = self.stream.map(move |result| result.map(&f));
167
168        MappedStream {
169            stream: Box::pin(mapped_stream),
170        }
171    }
172}
173
174impl Stream for StreamingResponse {
175    type Item = RragResult<StreamToken>;
176
177    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        self.stream.as_mut().poll_next(cx)
179    }
180}
181
182/// Filtered stream that only yields specific token types
183pub struct FilteredStream {
184    stream: Pin<Box<dyn Stream<Item = RragResult<StreamToken>> + Send>>,
185    filter_type: TokenType,
186}
187
188impl Stream for FilteredStream {
189    type Item = RragResult<StreamToken>;
190
191    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        loop {
193            match self.stream.as_mut().poll_next(cx) {
194                Poll::Ready(Some(Ok(token))) => {
195                    if token.token_type == self.filter_type || token.is_final {
196                        return Poll::Ready(Some(Ok(token)));
197                    }
198                    // Continue polling for matching tokens
199                }
200                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
201                Poll::Ready(None) => return Poll::Ready(None),
202                Poll::Pending => return Poll::Pending,
203            }
204        }
205    }
206}
207
208/// Mapped stream that transforms tokens
209pub struct MappedStream<T> {
210    stream: Pin<Box<dyn Stream<Item = RragResult<T>> + Send>>,
211}
212
213impl<T> Stream for MappedStream<T> {
214    type Item = RragResult<T>;
215
216    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
217        self.stream.as_mut().poll_next(cx)
218    }
219}
220
221/// Token stream builder for creating custom streams
222pub struct TokenStreamBuilder {
223    sender: mpsc::UnboundedSender<RragResult<StreamToken>>,
224    position: usize,
225}
226
227impl TokenStreamBuilder {
228    /// Create a new token stream builder
229    pub fn new() -> (Self, mpsc::UnboundedReceiver<RragResult<StreamToken>>) {
230        let (sender, receiver) = mpsc::unbounded_channel();
231
232        let builder = Self {
233            sender,
234            position: 0,
235        };
236
237        (builder, receiver)
238    }
239
240    /// Send a text token
241    pub fn send_text(&mut self, content: impl Into<String>) -> RragResult<()> {
242        let token = StreamToken::text(content, self.position);
243        self.position += 1;
244
245        self.sender
246            .send(Ok(token))
247            .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
248
249        Ok(())
250    }
251
252    /// Send a tool call token
253    pub fn send_tool_call(&mut self, content: impl Into<String>) -> RragResult<()> {
254        let token = StreamToken::tool_call(content, self.position);
255        self.position += 1;
256
257        self.sender
258            .send(Ok(token))
259            .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
260
261        Ok(())
262    }
263
264    /// Send an error token
265    pub fn send_error(&mut self, error: RragError) -> RragResult<()> {
266        self.sender
267            .send(Err(error))
268            .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
269
270        Ok(())
271    }
272
273    /// Finalize the stream
274    pub fn finish(self) -> RragResult<()> {
275        let final_token = StreamToken::final_token(self.position);
276
277        self.sender
278            .send(Ok(final_token))
279            .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
280
281        // Close the channel
282        drop(self.sender);
283
284        Ok(())
285    }
286}
287
288impl Default for TokenStreamBuilder {
289    fn default() -> Self {
290        let (builder, _) = Self::new();
291        builder
292    }
293}
294
295/// Convenience type alias for token streams
296pub type TokenStream = StreamingResponse;
297
298/// Utility functions for working with streams
299pub mod stream_utils {
300    use super::*;
301    use std::time::Duration;
302
303    /// Create a stream that emits tokens with a delay (for demo purposes)
304    pub fn create_delayed_stream(text: impl Into<String>, delay: Duration) -> StreamingResponse {
305        let text = text.into();
306        let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
307
308        let stream = async_stream::stream! {
309            for (i, word) in words.iter().enumerate() {
310                tokio::time::sleep(delay).await;
311                yield Ok(StreamToken::text(format!("{} ", word), i));
312            }
313            yield Ok(StreamToken::final_token(words.len()));
314        };
315
316        StreamingResponse::from_stream(stream)
317    }
318
319    /// Create a stream from multiple text chunks
320    pub fn create_chunked_stream(chunks: Vec<String>) -> StreamingResponse {
321        let stream = async_stream::stream! {
322            for (i, chunk) in chunks.iter().enumerate() {
323                yield Ok(StreamToken::text(chunk.clone(), i));
324            }
325            yield Ok(StreamToken::final_token(chunks.len()));
326        };
327
328        StreamingResponse::from_stream(stream)
329    }
330
331    /// Merge multiple streams into one
332    pub async fn merge_streams(streams: Vec<StreamingResponse>) -> RragResult<StreamingResponse> {
333        let (mut builder, receiver) = TokenStreamBuilder::new();
334
335        tokio::spawn(async move {
336            let mut position = 0;
337
338            for mut stream in streams {
339                while let Some(token_result) = stream.next().await {
340                    match token_result {
341                        Ok(mut token) => {
342                            if !token.is_final {
343                                token.position = position;
344                                position += 1;
345
346                                if let Err(_) = builder.sender.send(Ok(token)) {
347                                    break;
348                                }
349                            }
350                        }
351                        Err(e) => {
352                            let _ = builder.send_error(e);
353                            break;
354                        }
355                    }
356                }
357            }
358
359            let _ = builder.finish();
360        });
361
362        Ok(StreamingResponse::from_channel(receiver))
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use futures::StreamExt;
370    use tokio_test;
371
372    #[tokio::test]
373    async fn test_streaming_response_from_text() {
374        let response = StreamingResponse::from_text("Hello world test");
375        let text = response.collect_text().await.unwrap();
376
377        assert_eq!(text, "Hello world test");
378    }
379
380    #[tokio::test]
381    async fn test_token_stream_builder() {
382        let (mut builder, receiver) = TokenStreamBuilder::new();
383
384        tokio::spawn(async move {
385            builder.send_text("Hello").unwrap();
386            builder.send_text("world").unwrap();
387            builder.finish().unwrap();
388        });
389
390        let response = StreamingResponse::from_channel(receiver);
391        let text = response.collect_text().await.unwrap();
392
393        assert_eq!(text, "Hello world");
394    }
395
396    #[tokio::test]
397    async fn test_filtered_stream() {
398        let (mut builder, receiver) = TokenStreamBuilder::new();
399
400        tokio::spawn(async move {
401            builder.send_text("Hello").unwrap();
402            builder.send_tool_call("tool_call").unwrap();
403            builder.send_text("world").unwrap();
404            builder.finish().unwrap();
405        });
406
407        let response = StreamingResponse::from_channel(receiver);
408        let mut text_stream = response.filter_by_type(TokenType::Text);
409
410        let mut text_tokens = Vec::new();
411        while let Some(token_result) = text_stream.next().await {
412            match token_result.unwrap() {
413                token if token.token_type == TokenType::Text => {
414                    text_tokens.push(token.content);
415                }
416                token if token.is_final => break,
417                _ => {}
418            }
419        }
420
421        assert_eq!(text_tokens, vec!["Hello ", "world "]);
422    }
423
424    #[tokio::test]
425    async fn test_stream_utils_delayed() {
426        use std::time::Duration;
427
428        let start = std::time::Instant::now();
429        let response = stream_utils::create_delayed_stream("one two", Duration::from_millis(10));
430        let text = response.collect_text().await.unwrap();
431        let elapsed = start.elapsed();
432
433        assert_eq!(text, "one two");
434        assert!(elapsed >= Duration::from_millis(20)); // At least 2 delays
435    }
436
437    #[test]
438    fn test_stream_token_creation() {
439        let token = StreamToken::text("hello", 0);
440        assert_eq!(token.content, "hello");
441        assert_eq!(token.token_type, TokenType::Text);
442        assert_eq!(token.position, 0);
443        assert!(!token.is_final);
444
445        let final_token = StreamToken::final_token(10);
446        assert!(final_token.is_final);
447        assert_eq!(final_token.token_type, TokenType::End);
448    }
449}