openai_openapi_http/
event_stream.rs

1use crate::{ApiError, Error};
2use futures::Stream;
3use http_body_util::BodyExt;
4use serde::Deserialize;
5use std::convert::Infallible;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll, ready};
9
10#[pin_project::pin_project]
11pub struct EventStream<B, T>(
12    #[pin]
13    #[allow(clippy::type_complexity)]
14    http_body_server_sent_events::Decode<
15        http_body_util::combinators::MapErr<B, fn(B::Error) -> Error<Infallible, B::Error>>,
16    >,
17    PhantomData<fn() -> T>,
18)
19where
20    B: http_body::Body,
21    T: for<'de> Deserialize<'de>;
22
23impl<B, T> EventStream<B, T>
24where
25    B: http_body::Body,
26    T: for<'de> Deserialize<'de>,
27{
28    pub(crate) fn new(response: http::Response<B>) -> Self {
29        let (parts, body) = response.into_parts();
30        tracing::debug!(response.parts = ?parts);
31        Self(
32            http_body_server_sent_events::decode(body.map_err(Error::Body)),
33            PhantomData,
34        )
35    }
36}
37
38impl<B, T> Stream for EventStream<B, T>
39where
40    B: http_body::Body,
41    T: for<'de> Deserialize<'de>,
42{
43    type Item = Result<T, Error<Infallible, B::Error>>;
44
45    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46        let mut this = self.project();
47        loop {
48            match ready!(this.0.as_mut().poll_frame(cx)?) {
49                Some(frame) => {
50                    let Ok(event) = frame.into_data() else {
51                        continue;
52                    };
53                    tracing::debug!(?event);
54                    let Some(data) = event.data else { continue };
55                    if data == "[DONE]" {
56                        break Poll::Ready(None);
57                    } else if event.event.is_some_and(|event| event == "error") {
58                        let openai_openapi_types::ErrorResponse {
59                            error: openai_openapi_types::Error { code, message, .. },
60                        } = serde_json::from_str(&data)?;
61                        break Poll::Ready(Some(Err(Error::Api(ApiError {
62                            code,
63                            message: Some(message),
64                            response: None,
65                        }))));
66                    } else {
67                        break Poll::Ready(Some(Ok(serde_json::from_str(&data)?)));
68                    };
69                }
70                None => break Poll::Ready(None),
71            }
72        }
73    }
74}