use crate::providers::anthropic::decoders::line::LineDecoder;
use futures::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde::de::Error;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum JSONLDecoderError {
#[error("Failed to parse JSON: {0}")]
ParseError(#[from] serde_json::Error),
#[error("Response has no body")]
NoBodyError,
}
pub struct JSONLDecoder<T, S>
where
T: DeserializeOwned + Unpin,
S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
{
stream: S,
line_decoder: LineDecoder,
buffer: Vec<T>,
_phantom: PhantomData<T>,
}
impl<T, S> JSONLDecoder<T, S>
where
T: DeserializeOwned + Unpin,
S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
{
pub fn new(stream: S) -> Self {
Self {
stream,
line_decoder: LineDecoder::new(),
buffer: Vec::new(),
_phantom: PhantomData,
}
}
fn process_chunk(&mut self, chunk: &[u8]) -> Result<Vec<T>, JSONLDecoderError> {
let lines = self.line_decoder.decode(chunk);
let mut results = Vec::with_capacity(lines.len());
for line in lines {
if line.trim().is_empty() {
continue;
}
let value: T = serde_json::from_str(&line)?;
results.push(value);
}
Ok(results)
}
fn flush(&mut self) -> Result<Vec<T>, JSONLDecoderError> {
let lines = self.line_decoder.flush();
let mut results = Vec::with_capacity(lines.len());
for line in lines {
if line.trim().is_empty() {
continue;
}
let value: T = serde_json::from_str(&line)?;
results.push(value);
}
Ok(results)
}
}
impl<T, S> Stream for JSONLDecoder<T, S>
where
T: DeserializeOwned + Unpin,
S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
{
type Item = Result<T, JSONLDecoderError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if !this.buffer.is_empty() {
return Poll::Ready(Some(Ok(this.buffer.remove(0))));
}
match this.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
match this.process_chunk(&chunk) {
Ok(mut parsed) => {
if !parsed.is_empty() {
let item = parsed.remove(0);
this.buffer.append(&mut parsed);
Poll::Ready(Some(Ok(item)))
} else {
Pin::new(this).poll_next(cx)
}
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(JSONLDecoderError::ParseError(
serde_json::Error::custom(format!("Stream error: {}", e)),
))))
}
Poll::Ready(None) => {
match this.flush() {
Ok(mut parsed) => {
if !parsed.is_empty() {
let item = parsed.remove(0);
this.buffer.append(&mut parsed);
Poll::Ready(Some(Ok(item)))
} else {
Poll::Ready(None)
}
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
Poll::Pending => Poll::Pending,
}
}
}