1use crate::error::{AixError, AixResult};
7use crate::types::StreamChunk;
8use futures_core::Stream;
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::Duration;
14
15pub type TokenStream = Pin<Box<dyn Stream<Item = AixResult<StreamChunk>> + Send>>;
17
18pub trait StreamExt: Stream {
20 fn collect_text(self) -> CollectText<Self>
28 where
29 Self: Sized,
30 {
31 CollectText::new(self)
32 }
33
34 fn filter_empty(self) -> FilterEmpty<Self>
39 where
40 Self: Sized,
41 {
42 FilterEmpty::new(self)
43 }
44
45 fn buffer_chunks(self, duration: Duration) -> BufferChunks<Self>
55 where
56 Self: Sized,
57 {
58 BufferChunks::new(self, duration)
59 }
60}
61
62impl<T: ?Sized> StreamExt for T where T: Stream {}
64
65pin_project! {
67 pub struct CollectText<S> {
68 #[pin]
69 stream: S,
70 buffer: String,
71 }
72}
73
74impl<S> CollectText<S> {
75 fn new(stream: S) -> Self {
76 Self {
77 stream,
78 buffer: String::new(),
79 }
80 }
81}
82
83impl<S> std::future::Future for CollectText<S>
84where
85 S: Stream<Item = AixResult<StreamChunk>>,
86{
87 type Output = AixResult<String>;
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90 let mut this = self.project();
91
92 loop {
93 match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
94 Some(Ok(chunk)) => {
95 this.buffer.push_str(&chunk.delta);
96 }
97 Some(Err(error)) => {
98 return Poll::Ready(Err(error));
99 }
100 None => {
101 return Poll::Ready(Ok(this.buffer.clone()));
102 }
103 }
104 }
105 }
106}
107
108pin_project! {
110 pub struct FilterEmpty<S> {
111 #[pin]
112 stream: S,
113 }
114}
115
116impl<S> FilterEmpty<S> {
117 fn new(stream: S) -> Self {
118 Self { stream }
119 }
120}
121
122impl<S> Stream for FilterEmpty<S>
123where
124 S: Stream<Item = AixResult<StreamChunk>>,
125{
126 type Item = AixResult<StreamChunk>;
127
128 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 let mut this = self.project();
130
131 loop {
132 match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
133 Some(Ok(chunk)) => {
134 if chunk.delta.is_empty() && chunk.finish_reason.is_none() {
135 continue;
137 }
138 return Poll::Ready(Some(Ok(chunk)));
139 }
140 other => return Poll::Ready(other),
141 }
142 }
143 }
144}
145
146pin_project! {
148 pub struct BufferChunks<S> {
149 #[pin]
150 stream: S,
151 buffer: Vec<StreamChunk>,
152 last_flush: Option<tokio::time::Instant>,
153 duration: Duration,
154 #[pin]
155 delay: Option<tokio::time::Sleep>,
156 }
157}
158
159impl<S> BufferChunks<S> {
160 fn new(stream: S, duration: Duration) -> Self {
161 Self {
162 stream,
163 buffer: Vec::new(),
164 last_flush: None,
165 duration,
166 delay: None,
167 }
168 }
169}
170
171impl<S> Stream for BufferChunks<S>
172where
173 S: Stream<Item = AixResult<StreamChunk>>,
174{
175 type Item = AixResult<StreamChunk>;
176
177 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 let mut this = self.project();
179 let now = tokio::time::Instant::now();
180
181 if !this.buffer.is_empty() {
183 let should_flush = if let Some(last_flush) = this.last_flush {
184 now.duration_since(*last_flush) >= *this.duration
185 } else {
186 true };
188
189 if should_flush {
190 let combined_id = this.buffer
192 .first()
193 .map(|c| c.id.clone())
194 .unwrap_or_else(|| "buffered".to_string());
195
196 let combined_delta: String = this.buffer
197 .iter()
198 .map(|c| c.delta.as_str())
199 .collect();
200
201 let finish_reason = this.buffer
202 .iter()
203 .find_map(|c| c.finish_reason.clone());
204
205 let combined_chunk = StreamChunk {
206 id: combined_id,
207 delta: combined_delta,
208 finish_reason,
209 };
210
211 this.buffer.clear();
212 *this.last_flush = Some(now);
213
214 return Poll::Ready(Some(Ok(combined_chunk)));
215 }
216 }
217
218 match futures_core::ready!(this.stream.as_mut().poll_next(cx)) {
220 Some(Ok(chunk)) => {
221 this.buffer.push(chunk);
223
224 if this.delay.is_none() {
226 this.delay.set(Some(tokio::time::sleep(*this.duration)));
227 }
228
229 if let Some(delay) = this.delay.as_mut().as_pin_mut() {
231 match delay.poll(cx) {
232 std::task::Poll::Ready(_) => {
233 this.delay.set(None);
234 }
236 std::task::Poll::Pending => {
237 }
239 }
240 }
241
242 Poll::Pending
243 }
244 Some(Err(error)) => {
245 Poll::Ready(Some(Err(error)))
246 }
247 None => {
248 if !this.buffer.is_empty() {
250 let combined_id = this.buffer
251 .first()
252 .map(|c| c.id.clone())
253 .unwrap_or_else(|| "buffered".to_string());
254
255 let combined_delta: String = this.buffer
256 .iter()
257 .map(|c| c.delta.as_str())
258 .collect();
259
260 let finish_reason = this.buffer
261 .iter()
262 .find_map(|c| c.finish_reason.clone());
263
264 let combined_chunk = StreamChunk {
265 id: combined_id,
266 delta: combined_delta,
267 finish_reason,
268 };
269
270 this.buffer.clear();
271
272 Poll::Ready(Some(Ok(combined_chunk)))
273 } else {
274 Poll::Ready(None)
275 }
276 }
277 }
278 }
279}
280
281pub fn from_iter<I>(iter: I) -> TokenStream
283where
284 I: IntoIterator<Item = AixResult<StreamChunk>>,
285 I::IntoIter: Send + 'static,
286{
287 let stream = futures_util::stream::iter(iter);
288 Box::pin(stream)
289}
290
291pub fn error_stream(error: AixError) -> TokenStream {
293 let stream = futures_util::stream::once(async move { Err(error) });
294 Box::pin(stream)
295}
296
297pub fn single_chunk(chunk: StreamChunk) -> TokenStream {
299 let stream = futures_util::stream::once(async move { Ok(chunk) });
300 Box::pin(stream)
301}
302
303pub fn chunks<I>(chunks: I) -> TokenStream
305where
306 I: IntoIterator<Item = StreamChunk>,
307 I::IntoIter: Send + 'static,
308{
309 let results = chunks.into_iter().map(Ok);
310 from_iter(results)
311}
312
313pub fn from_string<S>(id: S, text: S) -> TokenStream
315where
316 S: Into<String> + Clone,
317{
318 let id = id.into();
319 let text = text.into();
320 let chars: Vec<char> = text.chars().collect();
321 let stream = futures_util::stream::iter(chars.into_iter().map(move |c| {
322 let id = id.clone();
323 Ok(StreamChunk::new(id, c.to_string()))
324 }));
325 Box::pin(stream)
326}
327
328pub fn from_string_words<S>(id: S, text: S) -> TokenStream
330where
331 S: Into<String> + Clone,
332{
333 let id = id.into();
334 let text = text.into();
335 let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
336 let stream = futures_util::stream::iter(words.into_iter().map(move |word| {
337 let id = id.clone();
338 Ok(StreamChunk::new(id, format!("{} ", word)))
339 }));
340 Box::pin(stream)
341}
342
343pub struct SseParser {
345 buffer: String,
346}
347
348impl SseParser {
349 pub fn new() -> Self {
351 Self {
352 buffer: String::new(),
353 }
354 }
355
356 pub fn parse_chunk(&mut self, chunk: &[u8]) -> AixResult<Vec<String>> {
364 let chunk_str = std::str::from_utf8(chunk)
365 .map_err(|e| AixError::serialization(e.to_string(), "SSE chunk parsing"))?;
366
367 self.buffer.push_str(chunk_str);
368 self.extract_events()
369 }
370
371 fn extract_events(&mut self) -> AixResult<Vec<String>> {
373 let mut events = Vec::new();
374 let mut lines = self.buffer.lines().peekable();
375
376 while let Some(line) = lines.next() {
377 if line.starts_with("data:") {
378 let mut event_data = line[5..].trim().to_string();
379
380 while let Some(&next_line) = lines.peek() {
382 if next_line.starts_with("data:") {
383 event_data.push_str(&next_line[5..].trim());
384 lines.next(); } else {
386 break;
387 }
388 }
389
390 if event_data == "[DONE]" {
392 events.push("[DONE]".to_string());
393 } else if !event_data.is_empty() {
394 events.push(event_data);
395 }
396 }
397 }
398
399 let last_complete_pos = self.buffer.rfind("\n\n").unwrap_or(0);
402 if last_complete_pos > 0 {
403 self.buffer.drain(0..=last_complete_pos + 1);
404 }
405
406 Ok(events)
407 }
408
409 pub fn remaining_data(&self) -> &str {
411 &self.buffer
412 }
413
414 pub fn clear(&mut self) {
416 self.buffer.clear();
417 }
418}
419
420impl Default for SseParser {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use futures_util::StreamExt as FuturesStreamExt;
430
431 #[tokio::test]
432 async fn test_collect_text() {
433 let chunks = vec![
434 Ok(StreamChunk::new("1", "Hello")),
435 Ok(StreamChunk::new("2", ", ")),
436 Ok(StreamChunk::new("3", "world")),
437 Ok(StreamChunk::new("4", "!")),
438 ];
439
440 let stream = from_iter(chunks);
441 let text = stream.collect_text().await.unwrap();
442 assert_eq!(text, "Hello, world!");
443 }
444
445 #[tokio::test]
446 async fn test_filter_empty() {
447 let chunks = vec![
448 Ok(StreamChunk::new("1", "Hello")),
449 Ok(StreamChunk::new("2", "")), Ok(StreamChunk::new("3", "world")),
451 Ok(StreamChunk::new("4", "")), ];
453
454 let stream = from_iter(chunks).filter_empty();
455 let collected: Vec<_> = stream.collect().await;
456
457 assert_eq!(collected.len(), 2);
458 assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
459 assert_eq!(collected[1].as_ref().unwrap().delta, "world");
460 }
461
462 #[tokio::test]
463 async fn test_from_string() {
464 let stream = from_string("test", "Hello world");
465 let collected: Vec<_> = stream.collect().await;
466
467 assert_eq!(collected.len(), 11); assert_eq!(collected[0].as_ref().unwrap().delta, "H");
469 assert_eq!(collected[1].as_ref().unwrap().delta, "e");
470 }
471
472 #[tokio::test]
473 async fn test_from_string_words() {
474 let stream = from_string_words("test", "Hello world from Rust");
475 let collected: Vec<_> = stream.collect().await;
476
477 assert_eq!(collected.len(), 4);
478 assert_eq!(collected[0].as_ref().unwrap().delta, "Hello ");
479 assert_eq!(collected[1].as_ref().unwrap().delta, "world ");
480 assert_eq!(collected[2].as_ref().unwrap().delta, "from ");
481 assert_eq!(collected[3].as_ref().unwrap().delta, "Rust");
482 }
483
484 #[test]
485 fn test_sse_parser() {
486 let mut parser = SseParser::new();
487
488 let chunk = b"data: {\"content\": \"Hello\"}\n\n";
490 let events = parser.parse_chunk(chunk).unwrap();
491 assert_eq!(events.len(), 1);
492 assert_eq!(events[0], "{\"content\": \"Hello\"}");
493
494 let chunk = b"data: [DONE]\n\n";
496 let events = parser.parse_chunk(chunk).unwrap();
497 assert_eq!(events.len(), 1);
498 assert_eq!(events[0], "[DONE]");
499 }
500
501 #[test]
502 fn test_sse_parser_incomplete_event() {
503 let mut parser = SseParser::new();
504
505 let chunk = b"data: {\"content\":";
507 let events = parser.parse_chunk(chunk).unwrap();
508 assert_eq!(events.len(), 0); let chunk = b" \"Hello\"}\n\n";
512 let events = parser.parse_chunk(chunk).unwrap();
513 assert_eq!(events.len(), 1);
514 assert_eq!(events[0], "{\"content\": \"Hello\"}");
515 }
516
517 #[test]
518 fn test_sse_parser_multiple_events() {
519 let mut parser = SseParser::new();
520
521 let chunk = b"data: {\"content\": \"Hello\"}\n\ndata: {\"content\": \"world\"}\n\ndata: [DONE]\n\n";
522 let events = parser.parse_chunk(chunk).unwrap();
523 assert_eq!(events.len(), 3);
524 assert_eq!(events[0], "{\"content\": \"Hello\"}");
525 assert_eq!(events[1], "{\"content\": \"world\"}");
526 assert_eq!(events[2], "[DONE]");
527 }
528
529 #[tokio::test]
530 async fn test_error_stream() {
531 let error = AixError::other("test error");
532 let stream = error_stream(error);
533 let collected: Vec<_> = stream.collect().await;
534
535 assert_eq!(collected.len(), 1);
536 assert!(collected[0].is_err());
537 assert_eq!(collected[0].as_ref().unwrap_err().to_string(), "Error: test error");
538 }
539
540 #[tokio::test]
541 async fn test_single_chunk() {
542 let chunk = StreamChunk::new("test", "Hello");
543 let stream = single_chunk(chunk);
544 let collected: Vec<_> = stream.collect().await;
545
546 assert_eq!(collected.len(), 1);
547 assert_eq!(collected[0].as_ref().unwrap().delta, "Hello");
548 }
549}