use crate::types::*;
use async_sse::Decoder;
use bytes::Bytes;
use futures_util::{
stream::{IntoAsyncRead, MapErr, MapOk},
Stream, TryFutureExt, TryStreamExt,
};
use pin_project_lite::pin_project;
use reqwest::header::{self, HeaderValue};
use serde::{de::DeserializeOwned, Serialize};
use std::{
future::Future,
io,
pin::Pin,
task::{ready, Context, Poll},
time::Duration,
};
use tracing::{debug, trace, warn};
type TryIo = fn(reqwest::Error) -> io::Error;
type TryOk<T> = fn(async_sse::Event) -> serde_json::Result<EventOrRetry<T>>;
type ReqStream = Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>;
type SseDecoderStream<T> = MapOk<Decoder<IntoAsyncRead<MapErr<ReqStream, TryIo>>>, TryOk<T>>;
#[derive(Debug, Clone)]
pub struct EventClient {
client: reqwest::Client,
max_retries: Option<u64>,
}
impl EventClient {
pub fn new(client: reqwest::Client) -> Self {
Self { client, max_retries: None }
}
pub fn with_max_retries(mut self, max_retries: u64) -> Self {
self.set_max_retries(max_retries);
self
}
pub fn set_max_retries(&mut self, max_retries: u64) {
self.max_retries = Some(max_retries);
}
pub async fn subscribe<T: DeserializeOwned>(
&self,
endpoint: &str,
) -> reqwest::Result<EventStream<T>> {
let st = new_stream(&self.client, endpoint, None::<()>).await?;
let endpoint = endpoint.to_string();
let inner =
EventStreamInner { num_retries: 0, endpoint, client: self.clone(), query: None };
let st = EventStream { inner, state: Some(State::Active(Box::pin(st))) };
Ok(st)
}
pub async fn subscribe_with_query<T: DeserializeOwned, S: Serialize>(
&self,
endpoint: &str,
query: S,
) -> reqwest::Result<EventStream<T>> {
let query = Some(serde_json::to_value(query).expect("serialization failed"));
let st = new_stream(&self.client, endpoint, query.as_ref()).await?;
let endpoint = endpoint.to_string();
let inner = EventStreamInner { num_retries: 0, endpoint, client: self.clone(), query };
let st = EventStream { inner, state: Some(State::Active(Box::pin(st))) };
Ok(st)
}
pub async fn events(&self, endpoint: &str) -> reqwest::Result<EventStream<Event>> {
self.subscribe(endpoint).await
}
pub async fn event_history(
&self,
endpoint: &str,
params: EventHistoryParams,
) -> reqwest::Result<Vec<EventHistory>> {
self.client.get(endpoint).query(¶ms).send().await?.json().await
}
pub async fn event_history_info(&self, endpoint: &str) -> reqwest::Result<EventHistoryInfo> {
self.get_json(endpoint).await
}
async fn get_json<T: DeserializeOwned>(&self, endpoint: &str) -> reqwest::Result<T> {
self.client.get(endpoint).send().await?.json().await
}
}
impl Default for EventClient {
fn default() -> Self {
Self::new(Default::default())
}
}
#[must_use = "streams do nothing unless polled"]
pub struct EventStream<T> {
inner: EventStreamInner,
state: Option<State<T>>,
}
impl<T> EventStream<T> {
pub fn endpoint(&self) -> &str {
&self.inner.endpoint
}
pub fn reset_retries(&mut self) {
self.inner.num_retries = 0;
}
}
impl<T: DeserializeOwned> EventStream<T> {
pub async fn retry(&mut self) -> Result<(), SseError> {
let st = self.inner.retry().await?;
self.state = Some(State::Active(Box::pin(st)));
Ok(())
}
pub async fn retry_with(&mut self, endpoint: impl Into<String>) -> Result<(), SseError> {
self.inner.endpoint = endpoint.into();
let st = self.inner.retry().await?;
self.state = Some(State::Active(Box::pin(st)));
Ok(())
}
}
impl<T: DeserializeOwned> Stream for EventStream<T> {
type Item = Result<T, SseError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let mut res = Poll::Pending;
loop {
match this.state.take().expect("EventStream polled after completion") {
State::End => return Poll::Ready(None),
State::Retry(mut fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(st)) => {
this.state = Some(State::Active(Box::pin(st)));
continue
}
Poll::Ready(Err(err)) => {
this.state = Some(State::End);
return Poll::Ready(Some(Err(err)))
}
Poll::Pending => {
this.state = Some(State::Retry(fut));
return Poll::Pending
}
},
State::Active(mut st) => {
match st.as_mut().poll_next(cx) {
Poll::Ready(None) => {
this.state = Some(State::End);
debug!("stream finished");
return Poll::Ready(None)
}
Poll::Ready(Some(Ok(maybe_event))) => match maybe_event {
EventOrRetry::Event(event) => {
res = Poll::Ready(Some(Ok(event)));
}
EventOrRetry::Retry(duration) => {
let mut client = this.inner.clone();
let fut = Box::pin(async move {
tokio::time::sleep(duration).await;
client.retry().await
});
this.state = Some(State::Retry(fut));
continue
}
},
Poll::Ready(Some(Err(err))) => {
warn!(?err, "active stream error");
res = Poll::Ready(Some(Err(err)));
}
Poll::Pending => {}
}
this.state = Some(State::Active(st));
break
}
}
}
res
}
}
impl<T> std::fmt::Debug for EventStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventStream")
.field("endpoint", &self.inner.endpoint)
.field("num_retries", &self.inner.num_retries)
.field("client", &self.inner.client.client)
.finish_non_exhaustive()
}
}
enum State<T> {
End,
Retry(Pin<Box<dyn Future<Output = Result<ActiveEventStream<T>, SseError>> + Send>>),
Active(Pin<Box<ActiveEventStream<T>>>),
}
#[derive(Clone)]
struct EventStreamInner {
num_retries: u64,
endpoint: String,
client: EventClient,
query: Option<serde_json::Value>,
}
impl EventStreamInner {
async fn retry<T: DeserializeOwned>(&mut self) -> Result<ActiveEventStream<T>, SseError> {
self.num_retries += 1;
if let Some(max_retries) = self.client.max_retries {
if self.num_retries > max_retries {
return Err(SseError::MaxRetriesExceeded(max_retries))
}
}
debug!(retries = self.num_retries, "retrying SSE stream");
new_stream(&self.client.client, &self.endpoint, self.query.as_ref())
.map_err(SseError::RetryError)
.await
}
}
pin_project! {
struct ActiveEventStream<T> {
#[pin]
st: SseDecoderStream<T>
}
}
impl<T: DeserializeOwned> Stream for ActiveEventStream<T> {
type Item = Result<EventOrRetry<T>, SseError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.st.poll_next(cx)) {
None => {
Poll::Ready(None)
}
Some(res) => {
let item = match res {
Ok(Ok(e)) => Ok(e),
Ok(Err(e)) => Err(SseError::SerdeJsonError(e)),
Err(e) => Err(SseError::Http(e)),
};
Poll::Ready(Some(item))
}
}
}
}
async fn new_stream<T: DeserializeOwned, S: Serialize>(
client: &reqwest::Client,
endpoint: &str,
query: Option<S>,
) -> reqwest::Result<ActiveEventStream<T>> {
let mut builder = client
.get(endpoint)
.header(header::ACCEPT, HeaderValue::from_static("text/event-stream"))
.header(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
if let Some(query) = query {
builder = builder.query(&query);
}
let resp = builder.send().await?;
let map_io_err: TryIo = |e| io::Error::new(io::ErrorKind::Other, e);
let o: TryOk<_> = |e| match e {
async_sse::Event::Message(msg) => {
trace!(
message = ?String::from_utf8_lossy(msg.data()),
"received message"
);
serde_json::from_slice::<T>(msg.data()).map(EventOrRetry::Event)
}
async_sse::Event::Retry(duration) => Ok(EventOrRetry::Retry(duration)),
};
let event_stream: ReqStream = Box::pin(resp.bytes_stream());
let st = async_sse::decode(event_stream.map_err(map_io_err).into_async_read()).map_ok(o);
Ok(ActiveEventStream { st })
}
enum EventOrRetry<T> {
Retry(Duration),
Event(T),
}
#[derive(Debug, thiserror::Error)]
pub enum SseError {
#[error("Failed to deserialize event: {0}")]
SerdeJsonError(serde_json::Error),
#[error("{0}")]
Http(http_types::Error),
#[error("Failed to establish a retry connection: {0}")]
RetryError(reqwest::Error),
#[error("Exceeded all retries: {0}")]
MaxRetriesExceeded(u64),
}
#[cfg(test)]
mod tests {
use super::*;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
const HISTORY_V1: &str = "https://mev-share.flashbots.net/api/v1/history";
const HISTORY_INFO_V1: &str = "https://mev-share.flashbots.net/api/v1/history/info";
fn init_tracing() {
let _ = tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.try_init();
}
#[tokio::test]
#[ignore]
async fn get_event_history_info() {
init_tracing();
let client = EventClient::default();
let _info = client.event_history_info(HISTORY_INFO_V1).await.unwrap();
}
#[tokio::test]
#[ignore]
async fn get_event_history() {
init_tracing();
let client = EventClient::default();
let _history = client.event_history(HISTORY_V1, Default::default()).await.unwrap();
}
}