rig/providers/anthropic/decoders/
sse.rs1use super::line::{self, LineDecoder};
2use crate::{if_not_wasm, if_wasm};
3use bytes::Bytes;
4use futures::{Stream, StreamExt};
5use std::fmt::Debug;
6use thiserror::Error;
7if_not_wasm! {
8 use futures::stream::BoxStream;
9}
10if_wasm! {
11 use std::pin::Pin;
12}
13
14#[derive(Debug, Error)]
15pub enum SSEDecoderError {
16 #[error("Failed to parse SSE: {0}")]
17 ParseError(String),
18
19 #[error("Failed to decode UTF-8: {0}")]
20 Utf8Error(#[from] std::string::FromUtf8Error),
21
22 #[error("IO error: {0}")]
23 IoError(#[from] std::io::Error),
24}
25
26#[derive(Debug, Clone)]
28pub struct ServerSentEvent {
29 pub event: Option<String>,
30 pub data: String,
31 pub raw: Vec<String>,
32}
33
34pub struct SSEDecoder {
36 data: Vec<String>,
37 event: Option<String>,
38 chunks: Vec<String>,
39}
40
41impl Default for SSEDecoder {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl SSEDecoder {
48 pub fn new() -> Self {
50 Self {
51 data: Vec::new(),
52 event: None,
53 chunks: Vec::new(),
54 }
55 }
56
57 pub fn decode(&mut self, line: &str) -> Option<ServerSentEvent> {
59 let mut line = line.to_string();
60
61 if line.ends_with('\r') {
63 line = line[0..line.len() - 1].to_string();
64 }
65
66 if line.is_empty() {
68 if self.event.is_none() && self.data.is_empty() {
70 return None;
71 }
72
73 let sse = ServerSentEvent {
75 event: self.event.clone(),
76 data: self.data.join("\n"),
77 raw: self.chunks.clone(),
78 };
79
80 self.event = None;
82 self.data.clear();
83 self.chunks.clear();
84
85 return Some(sse);
86 }
87
88 self.chunks.push(line.clone());
90
91 if line.starts_with(':') {
93 return None;
94 }
95
96 let parts: Vec<&str> = line.splitn(2, ':').collect();
98 let (field_name, value) = match parts.as_slice() {
99 [field] => (*field, ""),
100 [field, value] => (*field, *value),
101 _ => unreachable!(),
102 };
103
104 let value = if let Some(stripped) = value.strip_prefix(' ') {
106 stripped
107 } else {
108 value
109 };
110
111 match field_name {
113 "event" => self.event = Some(value.to_string()),
114 "data" => self.data.push(value.to_string()),
115 _ => {} }
117
118 None
119 }
120}
121
122pub fn iter_sse_messages<S>(
124 mut stream: S,
125) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
126where
127 S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
128{
129 let mut sse_decoder = SSEDecoder::new();
130 let mut line_decoder = LineDecoder::new();
131 let mut buffer = Vec::new();
132
133 async_stream::stream! {
134 while let Some(chunk_result) = stream.next().await {
135 let chunk = match chunk_result {
136 Ok(c) => c,
137 Err(e) => {
138 yield Err(SSEDecoderError::IoError(e));
139 continue;
140 }
141 };
142
143 buffer.extend_from_slice(&chunk);
145
146 while let Some((chunk_data, remaining)) = extract_sse_chunk(&buffer) {
148 buffer = remaining;
149
150 for line in line_decoder.decode(&chunk_data) {
152 if let Some(sse) = sse_decoder.decode(&line) {
153 yield Ok(sse);
154 }
155 }
156 }
157 }
158
159 for line in line_decoder.flush() {
161 if let Some(sse) = sse_decoder.decode(&line) {
162 yield Ok(sse);
163 }
164 }
165
166 #[allow(clippy::collapsible_if)]
169 if !sse_decoder.data.is_empty() || sse_decoder.event.is_some() {
170 if let Some(sse) = sse_decoder.decode("") {
171 yield Ok(sse);
172 }
173 }
174 }
175}
176
177fn extract_sse_chunk(buffer: &[u8]) -> Option<(Vec<u8>, Vec<u8>)> {
179 let pattern_index = line::find_double_newline_index(buffer);
180
181 if pattern_index <= 0 {
182 return None;
183 }
184
185 let pattern_index = pattern_index as usize;
186 let chunk = buffer[0..pattern_index].to_vec();
187 let remaining = buffer[pattern_index..].to_vec();
188
189 Some((chunk, remaining))
190}
191
192if_wasm! {
193 pub fn from_response<'a, E>(
194 stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + 'a>>,
195 ) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
196 where
197 E: std::fmt::Display + 'static
198 {
199 iter_sse_messages(stream.map(|result| match result {
200 Ok(bytes) => Ok(bytes.to_vec()),
201 Err(e) => Err(std::io::Error::other(e.to_string())),
202 }))
203 }
204}
205
206if_not_wasm! {
207 pub fn from_response<'a, E>(
208 stream: BoxStream<'a, Result<Bytes, E>>,
209 ) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
210 where
211 E: Into<Box<dyn std::error::Error + Send + Sync>>
212 {
213 iter_sse_messages(stream.map(|result| match result {
214 Ok(bytes) => Ok(bytes.to_vec()),
215 Err(e) => Err(std::io::Error::other(e)),
216 }))
217 }
218}