rig/providers/anthropic/decoders/
jsonl.rs

1//! JSONL is currently not used, it might be used when Anthropic batches beta feature is used.
2use crate::providers::anthropic::decoders::line::LineDecoder;
3use futures::{Stream, StreamExt};
4use serde::de::DeserializeOwned;
5use serde::de::Error;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12pub enum JSONLDecoderError {
13    #[error("Failed to parse JSON: {0}")]
14    ParseError(#[from] serde_json::Error),
15
16    #[error("Response has no body")]
17    NoBodyError,
18}
19
20/// Decoder for JSON Lines format, where each line is a separate JSON object.
21///
22/// This struct allows processing a stream of bytes, decoding them into lines,
23/// and then parsing each line as a JSON object of type T.
24pub struct JSONLDecoder<T, S>
25where
26    T: DeserializeOwned + Unpin,
27    S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
28{
29    stream: S,
30    line_decoder: LineDecoder,
31    buffer: Vec<T>,
32    _phantom: PhantomData<T>,
33}
34
35impl<T, S> JSONLDecoder<T, S>
36where
37    T: DeserializeOwned + Unpin,
38    S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
39{
40    /// Create a new JSONLDecoder from a byte stream
41    pub fn new(stream: S) -> Self {
42        Self {
43            stream,
44            line_decoder: LineDecoder::new(),
45            buffer: Vec::new(),
46            _phantom: PhantomData,
47        }
48    }
49
50    /// Process a chunk of data, returning a Result with any JSON parsing errors
51    fn process_chunk(&mut self, chunk: &[u8]) -> Result<Vec<T>, JSONLDecoderError> {
52        let lines = self.line_decoder.decode(chunk);
53        let mut results = Vec::with_capacity(lines.len());
54
55        for line in lines {
56            // Skip empty lines
57            if line.trim().is_empty() {
58                continue;
59            }
60
61            let value: T = serde_json::from_str(&line)?;
62            results.push(value);
63        }
64
65        Ok(results)
66    }
67
68    /// Flush any remaining data in the line decoder and parse it
69    fn flush(&mut self) -> Result<Vec<T>, JSONLDecoderError> {
70        let lines = self.line_decoder.flush();
71        let mut results = Vec::with_capacity(lines.len());
72
73        for line in lines {
74            // Skip empty lines
75            if line.trim().is_empty() {
76                continue;
77            }
78
79            let value: T = serde_json::from_str(&line)?;
80            results.push(value);
81        }
82
83        Ok(results)
84    }
85}
86
87impl<T, S> Stream for JSONLDecoder<T, S>
88where
89    T: DeserializeOwned + Unpin,
90    S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
91{
92    type Item = Result<T, JSONLDecoderError>;
93
94    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        // Get a mutable reference to self
96        let this = self.get_mut();
97
98        // Return any buffered items first
99        if !this.buffer.is_empty() {
100            return Poll::Ready(Some(Ok(this.buffer.remove(0))));
101        }
102
103        // Poll the underlying stream
104        match this.stream.poll_next_unpin(cx) {
105            Poll::Ready(Some(Ok(chunk))) => {
106                // Process the chunk
107                match this.process_chunk(&chunk) {
108                    Ok(mut parsed) => {
109                        // If we got any items, buffer them and return the first one
110                        if !parsed.is_empty() {
111                            let item = parsed.remove(0);
112                            this.buffer.append(&mut parsed);
113                            Poll::Ready(Some(Ok(item)))
114                        } else {
115                            // No items yet, try again
116                            Pin::new(this).poll_next(cx)
117                        }
118                    }
119                    Err(e) => Poll::Ready(Some(Err(e))),
120                }
121            }
122            Poll::Ready(Some(Err(e))) => {
123                // Propagate stream errors
124                Poll::Ready(Some(Err(JSONLDecoderError::ParseError(
125                    serde_json::Error::custom(format!("Stream error: {e}")),
126                ))))
127            }
128            Poll::Ready(None) => {
129                // Stream is done, flush any remaining data
130                match this.flush() {
131                    Ok(mut parsed) => {
132                        if !parsed.is_empty() {
133                            let item = parsed.remove(0);
134                            this.buffer.append(&mut parsed);
135                            Poll::Ready(Some(Ok(item)))
136                        } else {
137                            // Nothing left
138                            Poll::Ready(None)
139                        }
140                    }
141                    Err(e) => Poll::Ready(Some(Err(e))),
142                }
143            }
144            Poll::Pending => Poll::Pending,
145        }
146    }
147}