1use crate::{RsllmError, RsllmResult, StreamChunk, ChatResponse, CompletionResponse};
7use futures_util::Stream;
8use pin_project_lite::pin_project;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use futures_util::Future;
12
13pub type ChatStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
15
16pub type CompletionStream = Pin<Box<dyn Stream<Item = RsllmResult<StreamChunk>> + Send>>;
18
19pin_project! {
21 pub struct StreamCollector<S> {
22 #[pin]
23 stream: S,
24 accumulated_content: String,
25 model: Option<String>,
26 finish_reason: Option<String>,
27 usage: Option<crate::Usage>,
28 metadata: std::collections::HashMap<String, serde_json::Value>,
29 tool_calls: Vec<crate::ToolCall>,
30 is_done: bool,
31 }
32}
33
34impl<S> StreamCollector<S>
35where
36 S: Stream<Item = RsllmResult<StreamChunk>>,
37{
38 pub fn new(stream: S) -> Self {
40 Self {
41 stream,
42 accumulated_content: String::new(),
43 model: None,
44 finish_reason: None,
45 usage: None,
46 metadata: std::collections::HashMap::new(),
47 tool_calls: Vec::new(),
48 is_done: false,
49 }
50 }
51
52 pub async fn collect_chat_response(mut self) -> RsllmResult<ChatResponse>
54 where
55 S: Unpin,
56 {
57 use futures_util::StreamExt;
58 while let Some(chunk_result) = self.next().await {
59 let _chunk = chunk_result?;
60 }
62
63 let model = self.model.unwrap_or_else(|| "unknown".to_string());
64
65 let mut response = ChatResponse::new(self.accumulated_content, model);
66
67 if let Some(reason) = self.finish_reason {
68 response = response.with_finish_reason(reason);
69 }
70
71 if let Some(usage) = self.usage {
72 response = response.with_usage(usage);
73 }
74
75 if !self.tool_calls.is_empty() {
76 response = response.with_tool_calls(self.tool_calls);
77 }
78
79 for (key, value) in self.metadata {
80 response = response.with_metadata(key, value);
81 }
82
83 Ok(response)
84 }
85
86 pub async fn collect_completion_response(mut self) -> RsllmResult<CompletionResponse>
88 where
89 S: Unpin,
90 {
91 use futures_util::StreamExt;
92 while let Some(chunk_result) = self.next().await {
93 let _chunk = chunk_result?;
94 }
96
97 let model = self.model.unwrap_or_else(|| "unknown".to_string());
98
99 let mut response = CompletionResponse::new(self.accumulated_content, model);
100
101 if let Some(reason) = self.finish_reason {
102 response = response.with_finish_reason(reason);
103 }
104
105 if let Some(usage) = self.usage {
106 response = response.with_usage(usage);
107 }
108
109 for (key, value) in self.metadata {
110 response = response.with_metadata(key, value);
111 }
112
113 Ok(response)
114 }
115}
116
117impl<S> Stream for StreamCollector<S>
118where
119 S: Stream<Item = RsllmResult<StreamChunk>>,
120{
121 type Item = RsllmResult<StreamChunk>;
122
123 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124 let mut this = self.project();
125
126 if *this.is_done {
127 return Poll::Ready(None);
128 }
129
130 match this.stream.as_mut().poll_next(cx) {
131 Poll::Ready(Some(Ok(chunk))) => {
132 if chunk.has_content() {
134 this.accumulated_content.push_str(&chunk.content);
135 }
136
137 if this.model.is_none() && !chunk.model.is_empty() {
138 *this.model = Some(chunk.model.clone());
139 }
140
141 if let Some(reason) = &chunk.finish_reason {
142 *this.finish_reason = Some(reason.clone());
143 }
144
145 if let Some(usage) = &chunk.usage {
146 *this.usage = Some(usage.clone());
147 }
148
149 for (key, value) in &chunk.metadata {
151 this.metadata.insert(key.clone(), value.clone());
152 }
153
154 if let Some(_tool_calls_delta) = &chunk.tool_calls_delta {
156 }
158
159 if chunk.is_done {
160 *this.is_done = true;
161 }
162
163 Poll::Ready(Some(Ok(chunk)))
164 }
165 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
166 Poll::Ready(None) => {
167 *this.is_done = true;
168 Poll::Ready(None)
169 }
170 Poll::Pending => Poll::Pending,
171 }
172 }
173}
174
175pin_project! {
177 pub struct RateLimitedStream<S> {
178 #[pin]
179 stream: S,
180 delay: std::time::Duration,
181 last_emit: Option<std::time::Instant>,
182 }
183}
184
185impl<S> RateLimitedStream<S> {
186 pub fn new(stream: S, max_chunks_per_second: f64) -> Self {
188 let delay = std::time::Duration::from_secs_f64(1.0 / max_chunks_per_second);
189 Self {
190 stream,
191 delay,
192 last_emit: None,
193 }
194 }
195}
196
197impl<S> Stream for RateLimitedStream<S>
198where
199 S: Stream<Item = RsllmResult<StreamChunk>>,
200{
201 type Item = S::Item;
202
203 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 let mut this = self.project();
205
206 if let Some(last) = this.last_emit {
208 let elapsed = last.elapsed();
209 if elapsed < *this.delay {
210 let remaining = *this.delay - elapsed;
211
212 let sleep = tokio::time::sleep(remaining);
214 tokio::pin!(sleep);
215
216 if sleep.as_mut().poll(cx).is_pending() {
217 return Poll::Pending;
218 }
219 }
220 }
221
222 match this.stream.as_mut().poll_next(cx) {
223 Poll::Ready(Some(item)) => {
224 *this.last_emit = Some(std::time::Instant::now());
225 Poll::Ready(Some(item))
226 }
227 other => other,
228 }
229 }
230}
231
232pin_project! {
234 pub struct FilteredStream<S, F> {
235 #[pin]
236 stream: S,
237 filter: F,
238 }
239}
240
241impl<S, F> FilteredStream<S, F>
242where
243 F: Fn(&StreamChunk) -> bool,
244{
245 pub fn new(stream: S, filter: F) -> Self {
247 Self { stream, filter }
248 }
249}
250
251impl<S, F> Stream for FilteredStream<S, F>
252where
253 S: Stream<Item = RsllmResult<StreamChunk>>,
254 F: Fn(&StreamChunk) -> bool,
255{
256 type Item = S::Item;
257
258 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
259 let mut this = self.project();
260
261 loop {
262 match this.stream.as_mut().poll_next(cx) {
263 Poll::Ready(Some(Ok(chunk))) => {
264 if (this.filter)(&chunk) {
265 return Poll::Ready(Some(Ok(chunk)));
266 }
267 }
269 other => return other,
270 }
271 }
272 }
273}
274
275pin_project! {
277 pub struct MappedStream<S, F> {
278 #[pin]
279 stream: S,
280 mapper: F,
281 }
282}
283
284impl<S, F> MappedStream<S, F>
285where
286 F: Fn(StreamChunk) -> StreamChunk,
287{
288 pub fn new(stream: S, mapper: F) -> Self {
290 Self { stream, mapper }
291 }
292}
293
294impl<S, F> Stream for MappedStream<S, F>
295where
296 S: Stream<Item = RsllmResult<StreamChunk>>,
297 F: Fn(StreamChunk) -> StreamChunk,
298{
299 type Item = S::Item;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 let mut this = self.project();
303
304 match this.stream.as_mut().poll_next(cx) {
305 Poll::Ready(Some(Ok(chunk))) => {
306 let mapped = (this.mapper)(chunk);
307 Poll::Ready(Some(Ok(mapped)))
308 }
309 other => other,
310 }
311 }
312}
313
314pub struct StreamUtils;
316
317impl StreamUtils {
318 pub fn from_chunks(chunks: Vec<StreamChunk>) -> ChatStream {
320 let stream = tokio_stream::iter(chunks.into_iter().map(Ok));
321 Box::pin(stream)
322 }
323
324 pub fn empty() -> ChatStream {
326 let stream = tokio_stream::empty();
327 Box::pin(stream)
328 }
329
330 pub fn error(error: RsllmError) -> ChatStream {
332 use futures_util::stream;
333 let stream = stream::once(async move { Err(error) });
334 Box::pin(stream)
335 }
336
337 pub async fn collect_chunks<S>(stream: S) -> RsllmResult<Vec<StreamChunk>>
339 where
340 S: Stream<Item = RsllmResult<StreamChunk>>,
341 {
342 tokio_stream::StreamExt::collect::<Vec<_>>(stream)
343 .await
344 .into_iter()
345 .collect::<RsllmResult<Vec<_>>>()
346 }
347
348 pub fn take<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
350 where
351 S: Stream<Item = RsllmResult<StreamChunk>>,
352 {
353 tokio_stream::StreamExt::take(stream, n)
354 }
355
356 pub fn skip<S>(stream: S, n: usize) -> impl Stream<Item = RsllmResult<StreamChunk>>
358 where
359 S: Stream<Item = RsllmResult<StreamChunk>>,
360 {
361 tokio_stream::StreamExt::skip(stream, n)
362 }
363
364 pub fn filter<S, F>(stream: S, filter: F) -> FilteredStream<S, F>
366 where
367 S: Stream<Item = RsllmResult<StreamChunk>>,
368 F: Fn(&StreamChunk) -> bool,
369 {
370 FilteredStream::new(stream, filter)
371 }
372
373 pub fn map<S, F>(stream: S, mapper: F) -> MappedStream<S, F>
375 where
376 S: Stream<Item = RsllmResult<StreamChunk>>,
377 F: Fn(StreamChunk) -> StreamChunk,
378 {
379 MappedStream::new(stream, mapper)
380 }
381
382 pub fn rate_limit<S>(stream: S, max_chunks_per_second: f64) -> RateLimitedStream<S>
384 where
385 S: Stream<Item = RsllmResult<StreamChunk>>,
386 {
387 RateLimitedStream::new(stream, max_chunks_per_second)
388 }
389
390 pub async fn buffer<S>(
392 mut stream: S,
393 max_size: usize,
394 ) -> RsllmResult<Vec<StreamChunk>>
395 where
396 S: Stream<Item = RsllmResult<StreamChunk>> + Unpin,
397 {
398 let mut chunks = Vec::new();
399 let mut count = 0;
400
401 use futures_util::StreamExt;
402 while let Some(chunk) = stream.next().await {
403 chunks.push(chunk?);
404 count += 1;
405
406 if count >= max_size {
407 break;
408 }
409 }
410
411 Ok(chunks)
412 }
413}
414
415pub trait RsllmStreamExt: Stream<Item = RsllmResult<StreamChunk>> + Sized {
417 fn collect_chat_response(self) -> impl std::future::Future<Output = RsllmResult<ChatResponse>> + Send
419 where
420 Self: Send + Unpin,
421 {
422 StreamCollector::new(self).collect_chat_response()
423 }
424
425 fn collect_completion_response(self) -> impl std::future::Future<Output = RsllmResult<CompletionResponse>> + Send
427 where
428 Self: Send + Unpin,
429 {
430 StreamCollector::new(self).collect_completion_response()
431 }
432
433 fn content_only(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
435 FilteredStream::new(self, |chunk| chunk.has_content())
436 }
437
438 fn exclude_done(self) -> FilteredStream<Self, fn(&StreamChunk) -> bool> {
440 FilteredStream::new(self, |chunk| !chunk.is_done)
441 }
442
443 fn rate_limit(self, max_chunks_per_second: f64) -> RateLimitedStream<Self> {
445 RateLimitedStream::new(self, max_chunks_per_second)
446 }
447}
448
449impl<S> RsllmStreamExt for S where S: Stream<Item = RsllmResult<StreamChunk>> {}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::{MessageRole, StreamChunk};
455
456 #[tokio::test]
457 async fn test_stream_collector() {
458 let chunks = vec![
459 StreamChunk::delta("Hello", "gpt-4").with_role(MessageRole::Assistant),
460 StreamChunk::delta(" world", "gpt-4"),
461 StreamChunk::done("gpt-4").with_finish_reason("stop"),
462 ];
463
464 let stream = StreamUtils::from_chunks(chunks);
465 let response = stream.collect_chat_response().await.unwrap();
466
467 assert_eq!(response.content, "Hello world");
468 assert_eq!(response.model, "gpt-4");
469 assert_eq!(response.finish_reason, Some("stop".to_string()));
470 }
471
472 #[tokio::test]
473 async fn test_filter_stream() {
474 let chunks = vec![
475 StreamChunk::delta("Hello", "gpt-4"),
476 StreamChunk::new("", "gpt-4", false, false), StreamChunk::delta(" world", "gpt-4"),
478 ];
479
480 let stream = StreamUtils::from_chunks(chunks);
481 use futures_util::StreamExt;
482 let mut filtered_stream = stream.content_only();
483 let mut filtered_chunks = Vec::new();
484 while let Some(chunk) = filtered_stream.next().await {
485 filtered_chunks.push(chunk.unwrap());
486 }
487
488 assert_eq!(filtered_chunks.len(), 2);
489 assert_eq!(filtered_chunks[0].content, "Hello");
490 assert_eq!(filtered_chunks[1].content, " world");
491 }
492
493 #[tokio::test]
494 async fn test_map_stream() {
495 let chunks = vec![
496 StreamChunk::delta("hello", "gpt-4"),
497 StreamChunk::delta(" world", "gpt-4"),
498 ];
499
500 let stream = StreamUtils::from_chunks(chunks);
501 let mapped_stream = StreamUtils::map(stream, |mut chunk| {
502 chunk.content = chunk.content.to_uppercase();
503 chunk
504 });
505
506 let collected = StreamUtils::collect_chunks(mapped_stream).await.unwrap();
507
508 assert_eq!(collected[0].content, "HELLO");
509 assert_eq!(collected[1].content, " WORLD");
510 }
511}