async_llm/http/
stream.rs

1use std::pin::Pin;
2
3use futures::{Stream, StreamExt};
4use reqwest_eventsource::{Event, EventSource};
5use serde::de::DeserializeOwned;
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::UnboundedReceiverStream;
8
9use crate::error::Error;
10
11pub async fn stream<O: DeserializeOwned + Send + 'static>(
12    mut event_source: EventSource,
13    stream_done_message: &'static str,
14) -> Result<Pin<Box<dyn Stream<Item = Result<O, Error>> + Send>>, Error> {
15    let (tx, rx) = mpsc::unbounded_channel();
16
17    tokio::spawn(async move {
18        while let Some(event) = event_source.next().await {
19            match event {
20                Err(e) => {
21                    if let Err(_) = tx.send(Err(Error::Stream(e.to_string()))) {
22                        break;
23                    }
24                }
25                Ok(event) => match event {
26                    Event::Open => continue,
27                    Event::Message(event) => {
28                        if event.data == stream_done_message {
29                            break;
30                        }
31
32                        let output: Result<O, Error> =
33                            serde_json::from_str::<O>(&event.data).map_err(|e| e.into());
34                        if let Err(_) = tx.send(output) {
35                            break;
36                        }
37                    }
38                },
39            }
40        }
41        event_source.close();
42    });
43
44    Ok(Box::pin(UnboundedReceiverStream::new(rx)))
45}