rig/providers/anthropic/decoders/
jsonl.rs1use crate::providers::anthropic::decoders::line::LineDecoder;
3use futures::{Stream, StreamExt};
4use serde::de::DeserializeOwned;
5use serde::de::Error;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12pub enum JSONLDecoderError {
13 #[error("Failed to parse JSON: {0}")]
14 ParseError(#[from] serde_json::Error),
15
16 #[error("Response has no body")]
17 NoBodyError,
18}
19
20pub struct JSONLDecoder<T, S>
25where
26 T: DeserializeOwned + Unpin,
27 S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
28{
29 stream: S,
30 line_decoder: LineDecoder,
31 buffer: Vec<T>,
32 _phantom: PhantomData<T>,
33}
34
35impl<T, S> JSONLDecoder<T, S>
36where
37 T: DeserializeOwned + Unpin,
38 S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
39{
40 pub fn new(stream: S) -> Self {
42 Self {
43 stream,
44 line_decoder: LineDecoder::new(),
45 buffer: Vec::new(),
46 _phantom: PhantomData,
47 }
48 }
49
50 fn process_chunk(&mut self, chunk: &[u8]) -> Result<Vec<T>, JSONLDecoderError> {
52 let lines = self.line_decoder.decode(chunk);
53 let mut results = Vec::with_capacity(lines.len());
54
55 for line in lines {
56 if line.trim().is_empty() {
58 continue;
59 }
60
61 let value: T = serde_json::from_str(&line)?;
62 results.push(value);
63 }
64
65 Ok(results)
66 }
67
68 fn flush(&mut self) -> Result<Vec<T>, JSONLDecoderError> {
70 let lines = self.line_decoder.flush();
71 let mut results = Vec::with_capacity(lines.len());
72
73 for line in lines {
74 if line.trim().is_empty() {
76 continue;
77 }
78
79 let value: T = serde_json::from_str(&line)?;
80 results.push(value);
81 }
82
83 Ok(results)
84 }
85}
86
87impl<T, S> Stream for JSONLDecoder<T, S>
88where
89 T: DeserializeOwned + Unpin,
90 S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
91{
92 type Item = Result<T, JSONLDecoderError>;
93
94 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 let this = self.get_mut();
97
98 if !this.buffer.is_empty() {
100 return Poll::Ready(Some(Ok(this.buffer.remove(0))));
101 }
102
103 match this.stream.poll_next_unpin(cx) {
105 Poll::Ready(Some(Ok(chunk))) => {
106 match this.process_chunk(&chunk) {
108 Ok(mut parsed) => {
109 if !parsed.is_empty() {
111 let item = parsed.remove(0);
112 this.buffer.append(&mut parsed);
113 Poll::Ready(Some(Ok(item)))
114 } else {
115 Pin::new(this).poll_next(cx)
117 }
118 }
119 Err(e) => Poll::Ready(Some(Err(e))),
120 }
121 }
122 Poll::Ready(Some(Err(e))) => {
123 Poll::Ready(Some(Err(JSONLDecoderError::ParseError(
125 serde_json::Error::custom(format!("Stream error: {e}")),
126 ))))
127 }
128 Poll::Ready(None) => {
129 match this.flush() {
131 Ok(mut parsed) => {
132 if !parsed.is_empty() {
133 let item = parsed.remove(0);
134 this.buffer.append(&mut parsed);
135 Poll::Ready(Some(Ok(item)))
136 } else {
137 Poll::Ready(None)
139 }
140 }
141 Err(e) => Poll::Ready(Some(Err(e))),
142 }
143 }
144 Poll::Pending => Poll::Pending,
145 }
146 }
147}