rig/providers/anthropic/decoders/
sse.rs

1use super::line::{self, LineDecoder};
2use crate::{if_not_wasm, if_wasm};
3use bytes::Bytes;
4use futures::{Stream, StreamExt};
5use std::fmt::Debug;
6use thiserror::Error;
7if_not_wasm! {
8    use futures::stream::BoxStream;
9}
10if_wasm! {
11    use std::pin::Pin;
12}
13
14#[derive(Debug, Error)]
15pub enum SSEDecoderError {
16    #[error("Failed to parse SSE: {0}")]
17    ParseError(String),
18
19    #[error("Failed to decode UTF-8: {0}")]
20    Utf8Error(#[from] std::string::FromUtf8Error),
21
22    #[error("IO error: {0}")]
23    IoError(#[from] std::io::Error),
24}
25
26/// Server-Sent Event with event name, data, and raw lines
27#[derive(Debug, Clone)]
28pub struct ServerSentEvent {
29    pub event: Option<String>,
30    pub data: String,
31    pub raw: Vec<String>,
32}
33
34/// SSE Decoder for parsing Server-Sent Events (SSE) format
35pub struct SSEDecoder {
36    data: Vec<String>,
37    event: Option<String>,
38    chunks: Vec<String>,
39}
40
41impl Default for SSEDecoder {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl SSEDecoder {
48    /// Create a new SSE decoder
49    pub fn new() -> Self {
50        Self {
51            data: Vec::new(),
52            event: None,
53            chunks: Vec::new(),
54        }
55    }
56
57    /// Decode a line of SSE text, returning an event if complete
58    pub fn decode(&mut self, line: &str) -> Option<ServerSentEvent> {
59        let mut line = line.to_string();
60
61        // Handle carriage returns as per TypeScript impl
62        if line.ends_with('\r') {
63            line = line[0..line.len() - 1].to_string();
64        }
65
66        // Empty line signals the end of an event
67        if line.is_empty() {
68            // If we don't have any data or event, just return None
69            if self.event.is_none() && self.data.is_empty() {
70                return None;
71            }
72
73            // Create the SSE event
74            let sse = ServerSentEvent {
75                event: self.event.clone(),
76                data: self.data.join("\n"),
77                raw: self.chunks.clone(),
78            };
79
80            // Reset state
81            self.event = None;
82            self.data.clear();
83            self.chunks.clear();
84
85            return Some(sse);
86        }
87
88        // Add to raw chunks
89        self.chunks.push(line.clone());
90
91        // Ignore comments
92        if line.starts_with(':') {
93            return None;
94        }
95
96        // Parse field:value format
97        let parts: Vec<&str> = line.splitn(2, ':').collect();
98        let (field_name, value) = match parts.as_slice() {
99            [field] => (*field, ""),
100            [field, value] => (*field, *value),
101            _ => unreachable!(),
102        };
103
104        // Trim leading space from value as per SSE spec
105        let value = if let Some(stripped) = value.strip_prefix(' ') {
106            stripped
107        } else {
108            value
109        };
110
111        // Process fields
112        match field_name {
113            "event" => self.event = Some(value.to_string()),
114            "data" => self.data.push(value.to_string()),
115            _ => {} // Ignore other fields
116        }
117
118        None
119    }
120}
121
122/// Process a byte stream to extract SSE messages
123pub fn iter_sse_messages<S>(
124    mut stream: S,
125) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
126where
127    S: Stream<Item = Result<Vec<u8>, std::io::Error>> + Unpin,
128{
129    let mut sse_decoder = SSEDecoder::new();
130    let mut line_decoder = LineDecoder::new();
131    let mut buffer = Vec::new();
132
133    async_stream::stream! {
134        while let Some(chunk_result) = stream.next().await {
135            let chunk = match chunk_result {
136                Ok(c) => c,
137                Err(e) => {
138                    yield Err(SSEDecoderError::IoError(e));
139                    continue;
140                }
141            };
142
143            // Process bytes through SSE chunking
144            buffer.extend_from_slice(&chunk);
145
146            // Extract chunks at double newlines
147            while let Some((chunk_data, remaining)) = extract_sse_chunk(&buffer) {
148                buffer = remaining;
149
150                // Process the chunk lines
151                for line in line_decoder.decode(&chunk_data) {
152                    if let Some(sse) = sse_decoder.decode(&line) {
153                        yield Ok(sse);
154                    }
155                }
156            }
157        }
158
159        // Process any remaining data
160        for line in line_decoder.flush() {
161            if let Some(sse) = sse_decoder.decode(&line) {
162                yield Ok(sse);
163            }
164        }
165
166        // Force final event if we have pending data
167        // TODO: Collapse if statement (when `||` operator is supported in if-let chains)
168        #[allow(clippy::collapsible_if)]
169        if !sse_decoder.data.is_empty() || sse_decoder.event.is_some() {
170            if let Some(sse) = sse_decoder.decode("") {
171                yield Ok(sse);
172            }
173        }
174    }
175}
176
177/// Extract an SSE chunk up to a double newline
178fn extract_sse_chunk(buffer: &[u8]) -> Option<(Vec<u8>, Vec<u8>)> {
179    let pattern_index = line::find_double_newline_index(buffer);
180
181    if pattern_index <= 0 {
182        return None;
183    }
184
185    let pattern_index = pattern_index as usize;
186    let chunk = buffer[0..pattern_index].to_vec();
187    let remaining = buffer[pattern_index..].to_vec();
188
189    Some((chunk, remaining))
190}
191
192if_wasm! {
193    pub fn from_response<'a, E>(
194        stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + 'a>>,
195    ) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
196    where
197        E: std::fmt::Display + 'static
198    {
199        iter_sse_messages(stream.map(|result| match result {
200            Ok(bytes) => Ok(bytes.to_vec()),
201            Err(e) => Err(std::io::Error::other(e.to_string())),
202        }))
203    }
204}
205
206if_not_wasm! {
207    pub fn from_response<'a, E>(
208        stream: BoxStream<'a, Result<Bytes, E>>,
209    ) -> impl Stream<Item = Result<ServerSentEvent, SSEDecoderError>>
210    where
211        E: Into<Box<dyn std::error::Error + Send + Sync>>
212    {
213        iter_sse_messages(stream.map(|result| match result {
214            Ok(bytes) => Ok(bytes.to_vec()),
215            Err(e) => Err(std::io::Error::other(e)),
216        }))
217    }
218}