use crate::{header::X_STREAM_ERROR, response::Error};
use bytes::{Bytes, BytesMut};
use futures::{
task::{Context, Poll},
Stream,
};
use serde::Deserialize;
use std::{cmp, fmt::Display, io, marker::PhantomData, pin::Pin};
use tokio::io::{AsyncRead, ReadBuf};
use tokio_util::codec::Decoder;
use tracing::{event, instrument, Level};
pub struct JsonLineDecoder<T> {
parse_stream_error: bool,
ty: PhantomData<T>,
}
impl<T> JsonLineDecoder<T> {
#[inline]
pub fn new(parse_stream_error: bool) -> JsonLineDecoder<T> {
JsonLineDecoder {
parse_stream_error,
ty: PhantomData,
}
}
}
impl<T> Decoder for JsonLineDecoder<T>
where
for<'de> T: Deserialize<'de>,
{
type Item = T;
type Error = Error;
#[instrument(skip(self, src), fields(stream_trailer = self.parse_stream_error))]
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let nl_index = src.iter().position(|b| *b == b'\n');
if let Some(pos) = nl_index {
event!(Level::INFO, "Found new line delimeter in buffer");
let slice = src.split_to(pos + 1);
let slice = &slice[..slice.len() - 1];
match serde_json::from_slice(slice) {
Ok(json) => Ok(json),
Err(e) => {
if self.parse_stream_error {
match slice.iter().position(|&x| x == b':') {
Some(colon) if &slice[..colon] == X_STREAM_ERROR.as_bytes() => {
let e = Error::StreamError(
String::from_utf8_lossy(&slice[colon + 2..]).into(),
);
Err(e)
}
_ => Err(e.into()),
}
} else {
Err(e.into())
}
}
}
} else {
event!(Level::INFO, "Waiting for more data to decode JSON");
Ok(None)
}
}
}
pub struct LineDecoder;
impl Decoder for LineDecoder {
type Item = String;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let nl_index = src.iter().position(|b| *b == b'\n');
if let Some(pos) = nl_index {
let slice = src.split_to(pos + 1);
Ok(Some(
String::from_utf8_lossy(&slice[..slice.len() - 1]).into_owned(),
))
} else {
Ok(None)
}
}
}
fn copy_from_chunk_to(dest: &mut ReadBuf<'_>, chunk: &mut Bytes, chunk_start: usize) -> usize {
let len = cmp::min(dest.capacity(), chunk.len() - chunk_start);
let chunk_end = chunk_start + len;
dest.put_slice(&chunk[chunk_start..chunk_end]);
len
}
enum ReadState {
Ready(Bytes, usize),
NotReady,
}
pub struct StreamReader<S> {
stream: S,
state: ReadState,
}
impl<S, E> StreamReader<S>
where
S: Stream<Item = Result<Bytes, E>>,
E: Display,
{
#[inline]
pub fn new(stream: S) -> StreamReader<S> {
StreamReader {
stream,
state: ReadState::NotReady,
}
}
}
impl<S, E> AsyncRead for StreamReader<S>
where
S: Stream<Item = Result<Bytes, E>> + Unpin,
E: Display,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.state {
ReadState::Ready(ref mut chunk, ref mut pos) => {
let bytes_read = copy_from_chunk_to(buf, chunk, *pos);
if *pos + bytes_read >= chunk.len() {
self.state = ReadState::NotReady;
} else {
*pos += bytes_read;
}
return Poll::Ready(Ok(()));
}
ReadState::NotReady => {
match Stream::poll_next(Pin::new(&mut self.stream), cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
let bytes_read = copy_from_chunk_to(buf, &mut chunk, 0);
if bytes_read >= chunk.len() {
self.state = ReadState::NotReady;
} else {
self.state = ReadState::Ready(chunk, bytes_read);
}
return Poll::Ready(Ok(()));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
e.to_string(),
)));
}
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => (),
}
}
}
Poll::Pending
}
}