openai_openapi_http/
event_stream.rs1use crate::{ApiError, Error};
2use futures::Stream;
3use http_body_util::BodyExt;
4use serde::Deserialize;
5use std::convert::Infallible;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll, ready};
9
10#[pin_project::pin_project]
11pub struct EventStream<B, T>(
12 #[pin]
13 #[allow(clippy::type_complexity)]
14 http_body_server_sent_events::Decode<
15 http_body_util::combinators::MapErr<B, fn(B::Error) -> Error<Infallible, B::Error>>,
16 >,
17 PhantomData<fn() -> T>,
18)
19where
20 B: http_body::Body,
21 T: for<'de> Deserialize<'de>;
22
23impl<B, T> EventStream<B, T>
24where
25 B: http_body::Body,
26 T: for<'de> Deserialize<'de>,
27{
28 pub(crate) fn new(response: http::Response<B>) -> Self {
29 let (parts, body) = response.into_parts();
30 tracing::debug!(response.parts = ?parts);
31 Self(
32 http_body_server_sent_events::decode(body.map_err(Error::Body)),
33 PhantomData,
34 )
35 }
36}
37
38impl<B, T> Stream for EventStream<B, T>
39where
40 B: http_body::Body,
41 T: for<'de> Deserialize<'de>,
42{
43 type Item = Result<T, Error<Infallible, B::Error>>;
44
45 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46 let mut this = self.project();
47 loop {
48 match ready!(this.0.as_mut().poll_frame(cx)?) {
49 Some(frame) => {
50 let Ok(event) = frame.into_data() else {
51 continue;
52 };
53 tracing::debug!(?event);
54 let Some(data) = event.data else { continue };
55 if data == "[DONE]" {
56 break Poll::Ready(None);
57 } else if event.event.is_some_and(|event| event == "error") {
58 let openai_openapi_types::ErrorResponse {
59 error: openai_openapi_types::Error { code, message, .. },
60 } = serde_json::from_str(&data)?;
61 break Poll::Ready(Some(Err(Error::Api(ApiError {
62 code,
63 message: Some(message),
64 response: None,
65 }))));
66 } else {
67 break Poll::Ready(Some(Ok(serde_json::from_str(&data)?)));
68 };
69 }
70 None => break Poll::Ready(None),
71 }
72 }
73 }
74}