use crate::ApiError;
use futures::Stream;
use pin_project::pin_project;
use reqwest::{header::CONTENT_TYPE, Response};
use reqwest_sse::{error::EventError, Event, EventSource};
use serde::de::DeserializeOwned;
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug, Clone)]
pub struct SseMetadata {
pub event: String,
pub id: Option<String>,
pub retry: Option<u64>,
}
#[derive(Debug)]
pub struct SseEvent<T> {
pub data: T,
pub metadata: SseMetadata,
}
#[pin_project]
pub struct SseStream<T> {
#[pin]
inner: Pin<Box<dyn Stream<Item = Result<Event, EventError>> + Send>>,
terminator: Option<String>,
_phantom: PhantomData<T>,
}
impl<T> SseStream<T>
where
T: DeserializeOwned,
{
pub(crate) async fn new(
response: Response,
terminator: Option<String>,
) -> Result<Self, ApiError> {
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let content_type_main = content_type.split(';').next().unwrap_or("").trim();
if !content_type_main.eq_ignore_ascii_case("text/event-stream") {
return Err(ApiError::SseParseError(format!(
"Expected Content-Type to be 'text/event-stream', got '{}'",
content_type
)));
}
let events = response
.events()
.await
.map_err(|e| ApiError::SseParseError(e.to_string()))?;
Ok(Self {
inner: Box::pin(events),
terminator,
_phantom: PhantomData,
})
}
}
impl<T> SseStream<T>
where
T: DeserializeOwned,
{
pub fn with_metadata(self) -> SseStreamWithMetadata<T> {
SseStreamWithMetadata { inner: self }
}
}
impl<T> Stream for SseStream<T>
where
T: DeserializeOwned,
{
type Item = Result<T, ApiError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.inner.poll_next(cx) {
Poll::Ready(Some(Ok(event))) => {
if let Some(ref terminator) = this.terminator {
if event.data == *terminator {
return Poll::Ready(None);
}
}
match serde_json::from_str(&event.data) {
Ok(value) => Poll::Ready(Some(Ok(value))),
Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
}
}
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[pin_project]
pub struct SseStreamWithMetadata<T> {
#[pin]
inner: SseStream<T>,
}
impl<T> Stream for SseStreamWithMetadata<T>
where
T: DeserializeOwned,
{
type Item = Result<SseEvent<T>, ApiError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let inner_pin = this.inner.project();
match inner_pin.inner.poll_next(cx) {
Poll::Ready(Some(Ok(event))) => {
if let Some(ref terminator) = inner_pin.terminator {
if event.data == *terminator {
return Poll::Ready(None);
}
}
let metadata = SseMetadata {
event: if event.event_type.is_empty() {
"message".to_string()
} else {
event.event_type.clone()
},
id: event.last_event_id.clone(),
retry: event.retry.map(|d| d.as_millis() as u64),
};
match serde_json::from_str(&event.data) {
Ok(data) => Poll::Ready(Some(Ok(SseEvent { data, metadata }))),
Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
}
}
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}