use std::{
error::Error,
fmt::Display,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use self::sealed::NdjsonError;
use futures::{
future,
stream::{Stream, TryStream, TryStreamExt},
StreamExt,
};
use pin_project::pin_project;
use tokio::time::{self, Sleep};
type Bytes = Vec<u8>;
static DELIM: &[u8] = b"\n";
pub fn reply<S>(event_stream: S) -> impl warp::Reply
where
S: TryStream<Ok = Bytes> + Send + 'static,
S::Error: Error + Send + Sync + 'static,
{
NdjsonReply { event_stream }
}
struct NdjsonReply<S> {
event_stream: S,
}
impl<S> warp::Reply for NdjsonReply<S>
where
S: TryStream<Ok = Bytes> + Send + 'static,
S::Error: Error + Send + Sync + 'static,
{
#[inline]
fn into_response(self) -> warp::reply::Response {
let body_stream = self
.event_stream
.map_err(|error| {
tracing::error!(?error, "Error converting to Ndjson");
NdjsonError
})
.into_stream()
.and_then(|event| future::ready(Ok(event)));
let mut res = warp::reply::Response::new(hyper::Body::wrap_stream(body_stream));
res.headers_mut().insert(
warp::http::header::CONTENT_TYPE,
warp::http::header::HeaderValue::from_static("application/x-ndjson"),
);
res
}
}
#[derive(Debug)]
pub struct KeepAlive {
max_interval: Duration,
delimiter: Bytes,
writer_capacity: usize,
}
impl KeepAlive {
#[allow(dead_code)]
pub fn interval(mut self, time: Duration) -> Self {
self.max_interval = time;
self
}
#[allow(dead_code)]
pub fn delimiter(mut self, delim: Bytes) -> Self {
self.delimiter = delim;
self
}
#[allow(dead_code)]
pub fn writer_capacity(mut self, capacity: usize) -> Self {
self.writer_capacity = capacity;
self
}
pub fn stream<S>(
self,
event_stream: impl Stream<Item = S> + Send + 'static,
) -> impl TryStream<Ok = Bytes, Error = impl Error + Send + Sync + 'static> + Send + 'static
where
S: serde::Serialize + Send + 'static,
{
let alive_timer = time::sleep(self.max_interval);
let delimiter = self.delimiter.clone();
let capacity = self.writer_capacity;
let event_stream = event_stream.map(move |e| {
let mut writer = Vec::with_capacity(capacity);
serde_json::to_writer(&mut writer, &e)?;
writer.extend(&delimiter);
Ok::<Bytes, serde_json::error::Error>(writer)
});
NdjsonKeepAlive {
event_stream,
max_interval: self.max_interval,
delimiter: self.delimiter,
alive_timer,
}
}
}
#[pin_project]
struct NdjsonKeepAlive<S> {
#[pin]
event_stream: S,
max_interval: Duration,
delimiter: Bytes,
#[pin]
alive_timer: Sleep,
}
pub fn keep_alive() -> KeepAlive {
KeepAlive {
max_interval: Duration::from_secs(15),
delimiter: DELIM.to_vec(),
writer_capacity: 128,
}
}
impl<S> Stream for NdjsonKeepAlive<S>
where
S: TryStream<Ok = Bytes> + Send + 'static,
S::Error: Error + Send + Sync + 'static,
{
type Item = Result<Bytes, NdjsonError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut pin = self.project();
match pin.event_stream.try_poll_next(cx) {
Poll::Pending => {
match pin.alive_timer.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => {
pin.alive_timer.reset(tokio::time::Instant::now() + *pin.max_interval);
Poll::Ready(Some(Ok(pin.delimiter.clone())))
}
}
}
Poll::Ready(Some(Ok(event))) => {
pin.alive_timer.reset(tokio::time::Instant::now() + *pin.max_interval);
Poll::Ready(Some(Ok(event)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(error))) => {
tracing::error!("ndjson error: {}", error);
Poll::Ready(Some(Err(NdjsonError)))
}
}
}
}
mod sealed {
use super::*;
#[derive(Debug)]
pub struct NdjsonError;
impl Display for NdjsonError {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(f, "ndjson error")
}
}
impl Error for NdjsonError {}
}