use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures_util::stream::{Stream, StreamExt};
use http_body_util::BodyDataStream;
use serde::de::DeserializeOwned;
use crate::{
BoxError,
body::{Body, codec::BodyDecoder},
};
#[derive(Clone, Debug, Default)]
pub struct NdjsonDecoder;
#[derive(Debug, thiserror::Error)]
pub enum NdjsonDecodeError {
#[error("body read error: {0}")]
Body(#[source] BoxError),
#[error("ndjson decode error: {0}")]
Json(#[from] serde_json::Error),
}
pub struct NdjsonStream<O> {
body: BodyDataStream<Body>,
buf: BytesMut,
eof: bool,
_marker: PhantomData<fn() -> O>,
}
impl<O> NdjsonStream<O>
where
O: DeserializeOwned + Send + 'static,
{
pub fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
Self {
body: BodyDataStream::new(Body::new(body)),
buf: BytesMut::new(),
eof: false,
_marker: PhantomData,
}
}
}
impl<O> Stream for NdjsonStream<O>
where
O: DeserializeOwned,
{
type Item = Result<O, NdjsonDecodeError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
loop {
if let Some(pos) = this.buf.iter().position(|b| *b == b'\n') {
let line = this.buf.split_to(pos + 1);
let trimmed = trim_end_newline(&line);
if trimmed.is_empty() {
continue;
}
return Poll::Ready(Some(parse_line::<O>(trimmed)));
}
if this.eof {
if this.buf.is_empty() {
return Poll::Ready(None);
}
let trailing = std::mem::take(&mut this.buf);
let trimmed = trim_end_newline(&trailing);
if trimmed.is_empty() {
return Poll::Ready(None);
}
return Poll::Ready(Some(parse_line::<O>(trimmed)));
}
match this.body.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(chunk))) => this.buf.extend_from_slice(chunk.as_ref()),
Poll::Ready(Some(Err(e))) => {
this.eof = true;
return Poll::Ready(Some(Err(NdjsonDecodeError::Body(e))));
}
Poll::Ready(None) => this.eof = true,
}
}
}
}
fn parse_line<O: DeserializeOwned>(bytes: &[u8]) -> Result<O, NdjsonDecodeError> {
serde_json::from_slice(bytes).map_err(NdjsonDecodeError::Json)
}
fn trim_end_newline(bytes: &[u8]) -> &[u8] {
let trimmed = bytes.strip_suffix(b"\n").unwrap_or(bytes);
trimmed.strip_suffix(b"\r").unwrap_or(trimmed)
}
impl<O> BodyDecoder<NdjsonStream<O>> for NdjsonDecoder
where
O: DeserializeOwned + Send + 'static,
{
type Error = std::convert::Infallible;
async fn decode<B>(&self, body: B) -> Result<NdjsonStream<O>, Self::Error>
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
Ok(NdjsonStream::new(body))
}
}