1use crate::{RragError, RragResult};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::UnboundedReceiverStream;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StreamToken {
17 pub content: String,
19
20 pub token_type: TokenType,
22
23 pub position: usize,
25
26 pub is_final: bool,
28
29 pub metadata: Option<serde_json::Value>,
31}
32
33impl StreamToken {
34 pub fn text(content: impl Into<String>, position: usize) -> Self {
35 Self {
36 content: content.into(),
37 token_type: TokenType::Text,
38 position,
39 is_final: false,
40 metadata: None,
41 }
42 }
43
44 pub fn tool_call(content: impl Into<String>, position: usize) -> Self {
45 Self {
46 content: content.into(),
47 token_type: TokenType::ToolCall,
48 position,
49 is_final: false,
50 metadata: None,
51 }
52 }
53
54 pub fn final_token(position: usize) -> Self {
55 Self {
56 content: String::new(),
57 token_type: TokenType::End,
58 position,
59 is_final: true,
60 metadata: None,
61 }
62 }
63
64 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
65 self.metadata = Some(metadata);
66 self
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum TokenType {
73 Text,
75
76 ToolCall,
78
79 ToolResult,
81
82 Metadata,
84
85 End,
87
88 Error,
90}
91
92pub struct StreamingResponse {
94 stream: Pin<Box<dyn Stream<Item = RragResult<StreamToken>> + Send>>,
95}
96
97impl StreamingResponse {
98 pub fn from_text(text: impl Into<String>) -> Self {
100 let text = text.into();
101 let tokens: Vec<_> = text
102 .split_whitespace()
103 .enumerate()
104 .map(|(i, word)| Ok(StreamToken::text(format!("{} ", word), i)))
105 .collect();
106
107 let mut tokens = tokens;
109 let final_pos = tokens.len();
110 tokens.push(Ok(StreamToken::final_token(final_pos)));
111
112 let stream = futures::stream::iter(tokens);
113
114 Self {
115 stream: Box::pin(stream),
116 }
117 }
118
119 pub fn from_stream<S>(stream: S) -> Self
121 where
122 S: Stream<Item = RragResult<StreamToken>> + Send + 'static,
123 {
124 Self {
125 stream: Box::pin(stream),
126 }
127 }
128
129 pub fn from_channel(receiver: mpsc::UnboundedReceiver<RragResult<StreamToken>>) -> Self {
131 let stream = UnboundedReceiverStream::new(receiver);
132 Self::from_stream(stream)
133 }
134
135 pub async fn collect_text(mut self) -> RragResult<String> {
137 let mut result = String::new();
138
139 while let Some(token_result) = self.stream.next().await {
140 match token_result? {
141 token if token.token_type == TokenType::Text => {
142 result.push_str(&token.content);
143 }
144 token if token.is_final => break,
145 _ => {} }
147 }
148
149 Ok(result.trim().to_string())
150 }
151
152 pub fn filter_by_type(self, token_type: TokenType) -> FilteredStream {
154 FilteredStream {
155 stream: self.stream,
156 filter_type: token_type,
157 }
158 }
159
160 pub fn map_tokens<F, T>(self, f: F) -> MappedStream<T>
162 where
163 F: Fn(StreamToken) -> T + Send + 'static,
164 T: Send + 'static,
165 {
166 let mapped_stream = self.stream.map(move |result| result.map(&f));
167
168 MappedStream {
169 stream: Box::pin(mapped_stream),
170 }
171 }
172}
173
174impl Stream for StreamingResponse {
175 type Item = RragResult<StreamToken>;
176
177 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 self.stream.as_mut().poll_next(cx)
179 }
180}
181
182pub struct FilteredStream {
184 stream: Pin<Box<dyn Stream<Item = RragResult<StreamToken>> + Send>>,
185 filter_type: TokenType,
186}
187
188impl Stream for FilteredStream {
189 type Item = RragResult<StreamToken>;
190
191 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 loop {
193 match self.stream.as_mut().poll_next(cx) {
194 Poll::Ready(Some(Ok(token))) => {
195 if token.token_type == self.filter_type || token.is_final {
196 return Poll::Ready(Some(Ok(token)));
197 }
198 }
200 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
201 Poll::Ready(None) => return Poll::Ready(None),
202 Poll::Pending => return Poll::Pending,
203 }
204 }
205 }
206}
207
208pub struct MappedStream<T> {
210 stream: Pin<Box<dyn Stream<Item = RragResult<T>> + Send>>,
211}
212
213impl<T> Stream for MappedStream<T> {
214 type Item = RragResult<T>;
215
216 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
217 self.stream.as_mut().poll_next(cx)
218 }
219}
220
221pub struct TokenStreamBuilder {
223 sender: mpsc::UnboundedSender<RragResult<StreamToken>>,
224 position: usize,
225}
226
227impl TokenStreamBuilder {
228 pub fn new() -> (Self, mpsc::UnboundedReceiver<RragResult<StreamToken>>) {
230 let (sender, receiver) = mpsc::unbounded_channel();
231
232 let builder = Self {
233 sender,
234 position: 0,
235 };
236
237 (builder, receiver)
238 }
239
240 pub fn send_text(&mut self, content: impl Into<String>) -> RragResult<()> {
242 let token = StreamToken::text(content, self.position);
243 self.position += 1;
244
245 self.sender
246 .send(Ok(token))
247 .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
248
249 Ok(())
250 }
251
252 pub fn send_tool_call(&mut self, content: impl Into<String>) -> RragResult<()> {
254 let token = StreamToken::tool_call(content, self.position);
255 self.position += 1;
256
257 self.sender
258 .send(Ok(token))
259 .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
260
261 Ok(())
262 }
263
264 pub fn send_error(&mut self, error: RragError) -> RragResult<()> {
266 self.sender
267 .send(Err(error))
268 .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
269
270 Ok(())
271 }
272
273 pub fn finish(self) -> RragResult<()> {
275 let final_token = StreamToken::final_token(self.position);
276
277 self.sender
278 .send(Ok(final_token))
279 .map_err(|_| RragError::stream("token_builder", "Channel closed"))?;
280
281 drop(self.sender);
283
284 Ok(())
285 }
286}
287
288impl Default for TokenStreamBuilder {
289 fn default() -> Self {
290 let (builder, _) = Self::new();
291 builder
292 }
293}
294
295pub type TokenStream = StreamingResponse;
297
298pub mod stream_utils {
300 use super::*;
301 use std::time::Duration;
302
303 pub fn create_delayed_stream(text: impl Into<String>, delay: Duration) -> StreamingResponse {
305 let text = text.into();
306 let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
307
308 let stream = async_stream::stream! {
309 for (i, word) in words.iter().enumerate() {
310 tokio::time::sleep(delay).await;
311 yield Ok(StreamToken::text(format!("{} ", word), i));
312 }
313 yield Ok(StreamToken::final_token(words.len()));
314 };
315
316 StreamingResponse::from_stream(stream)
317 }
318
319 pub fn create_chunked_stream(chunks: Vec<String>) -> StreamingResponse {
321 let stream = async_stream::stream! {
322 for (i, chunk) in chunks.iter().enumerate() {
323 yield Ok(StreamToken::text(chunk.clone(), i));
324 }
325 yield Ok(StreamToken::final_token(chunks.len()));
326 };
327
328 StreamingResponse::from_stream(stream)
329 }
330
331 pub async fn merge_streams(streams: Vec<StreamingResponse>) -> RragResult<StreamingResponse> {
333 let (mut builder, receiver) = TokenStreamBuilder::new();
334
335 tokio::spawn(async move {
336 let mut position = 0;
337
338 for mut stream in streams {
339 while let Some(token_result) = stream.next().await {
340 match token_result {
341 Ok(mut token) => {
342 if !token.is_final {
343 token.position = position;
344 position += 1;
345
346 if let Err(_) = builder.sender.send(Ok(token)) {
347 break;
348 }
349 }
350 }
351 Err(e) => {
352 let _ = builder.send_error(e);
353 break;
354 }
355 }
356 }
357 }
358
359 let _ = builder.finish();
360 });
361
362 Ok(StreamingResponse::from_channel(receiver))
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use futures::StreamExt;
370 use tokio_test;
371
372 #[tokio::test]
373 async fn test_streaming_response_from_text() {
374 let response = StreamingResponse::from_text("Hello world test");
375 let text = response.collect_text().await.unwrap();
376
377 assert_eq!(text, "Hello world test");
378 }
379
380 #[tokio::test]
381 async fn test_token_stream_builder() {
382 let (mut builder, receiver) = TokenStreamBuilder::new();
383
384 tokio::spawn(async move {
385 builder.send_text("Hello").unwrap();
386 builder.send_text("world").unwrap();
387 builder.finish().unwrap();
388 });
389
390 let response = StreamingResponse::from_channel(receiver);
391 let text = response.collect_text().await.unwrap();
392
393 assert_eq!(text, "Hello world");
394 }
395
396 #[tokio::test]
397 async fn test_filtered_stream() {
398 let (mut builder, receiver) = TokenStreamBuilder::new();
399
400 tokio::spawn(async move {
401 builder.send_text("Hello").unwrap();
402 builder.send_tool_call("tool_call").unwrap();
403 builder.send_text("world").unwrap();
404 builder.finish().unwrap();
405 });
406
407 let response = StreamingResponse::from_channel(receiver);
408 let mut text_stream = response.filter_by_type(TokenType::Text);
409
410 let mut text_tokens = Vec::new();
411 while let Some(token_result) = text_stream.next().await {
412 match token_result.unwrap() {
413 token if token.token_type == TokenType::Text => {
414 text_tokens.push(token.content);
415 }
416 token if token.is_final => break,
417 _ => {}
418 }
419 }
420
421 assert_eq!(text_tokens, vec!["Hello ", "world "]);
422 }
423
424 #[tokio::test]
425 async fn test_stream_utils_delayed() {
426 use std::time::Duration;
427
428 let start = std::time::Instant::now();
429 let response = stream_utils::create_delayed_stream("one two", Duration::from_millis(10));
430 let text = response.collect_text().await.unwrap();
431 let elapsed = start.elapsed();
432
433 assert_eq!(text, "one two");
434 assert!(elapsed >= Duration::from_millis(20)); }
436
437 #[test]
438 fn test_stream_token_creation() {
439 let token = StreamToken::text("hello", 0);
440 assert_eq!(token.content, "hello");
441 assert_eq!(token.token_type, TokenType::Text);
442 assert_eq!(token.position, 0);
443 assert!(!token.is_final);
444
445 let final_token = StreamToken::final_token(10);
446 assert!(final_token.is_final);
447 assert_eq!(final_token.token_type, TokenType::End);
448 }
449}