ai_provider_sdk/
streaming.rs1use async_stream::try_stream;
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use futures_util::StreamExt;
7use serde_json::Value;
8
9use crate::error::{Error, Result};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ServerSentEvent {
13 pub event: Option<String>,
15 pub data: String,
17 pub id: Option<String>,
19 pub retry: Option<u64>,
21}
22
23pub struct SseStream {
24 response: reqwest::Response,
25}
26
27impl SseStream {
28 pub(crate) fn new(response: reqwest::Response) -> Self {
29 Self { response }
30 }
31
32 pub fn events(self) -> impl Stream<Item = Result<ServerSentEvent>> {
38 let mut chunks = self.response.bytes_stream();
39
40 try_stream! {
41 let mut decoder = SseDecoder::new();
42 while let Some(chunk) = chunks.next().await {
43 let chunk = chunk.map_err(|err| Error::Stream(err.to_string()))?;
44 for event in decoder.push(chunk)? {
45 if event.data.starts_with("[DONE]") {
46 return;
47 }
48 if let Ok(data) = serde_json::from_str::<Value>(&event.data) {
49 if let Some(error) = data.get("error") {
50 Err(Error::Stream(
51 error
52 .get("message")
53 .and_then(Value::as_str)
54 .unwrap_or("An error occurred during streaming")
55 .to_string(),
56 ))?;
57 }
58 }
59 yield event;
60 }
61 }
62
63 for event in decoder.finish()? {
64 if event.data.starts_with("[DONE]") {
65 return;
66 }
67 yield event;
68 }
69 }
70 }
71}
72
73#[derive(Debug, Default)]
74pub struct SseDecoder {
78 bytes: BytesMut,
79 event: Option<String>,
80 data: Vec<String>,
81 last_event_id: Option<String>,
82 retry: Option<u64>,
83}
84
85impl SseDecoder {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn push(&mut self, chunk: Bytes) -> Result<Vec<ServerSentEvent>> {
95 self.bytes.extend_from_slice(&chunk);
96 let mut events = Vec::new();
97
98 while let Some(line) = self.next_line()? {
99 if let Some(event) = self.decode_line(&line) {
100 events.push(event);
101 }
102 }
103
104 Ok(events)
105 }
106
107 pub fn finish(&mut self) -> Result<Vec<ServerSentEvent>> {
109 let mut events = Vec::new();
110 if !self.bytes.is_empty() {
111 let line = std::str::from_utf8(&self.bytes)
112 .map_err(|err| Error::Stream(err.to_string()))?
113 .to_string();
114 self.bytes.clear();
115 if let Some(event) = self.decode_line(&line) {
116 events.push(event);
117 }
118 }
119
120 if let Some(event) = self.flush_event() {
121 events.push(event);
122 }
123
124 Ok(events)
125 }
126
127 fn next_line(&mut self) -> Result<Option<String>> {
129 let Some(pos) = self
130 .bytes
131 .iter()
132 .position(|byte| *byte == b'\n' || *byte == b'\r')
133 else {
134 return Ok(None);
135 };
136
137 let line = self.bytes.split_to(pos);
138 let newline = self.bytes.get_u8();
139 if newline == b'\r' && self.bytes.first() == Some(&b'\n') {
140 self.bytes.advance(1);
141 }
142
143 let line = std::str::from_utf8(&line)
144 .map_err(|err| Error::Stream(err.to_string()))?
145 .to_string();
146 Ok(Some(line))
147 }
148
149 fn decode_line(&mut self, line: &str) -> Option<ServerSentEvent> {
151 if line.is_empty() {
152 return self.flush_event();
153 }
154
155 if line.starts_with(':') {
156 return None;
157 }
158
159 let (field, value) = line.split_once(':').unwrap_or((line, ""));
160 let value = value.strip_prefix(' ').unwrap_or(value);
161
162 match field {
163 "event" => self.event = Some(value.to_string()),
164 "data" => self.data.push(value.to_string()),
165 "id" if !value.contains('\0') => self.last_event_id = Some(value.to_string()),
166 "retry" => self.retry = value.parse().ok(),
167 _ => {}
168 }
169
170 None
171 }
172
173 fn flush_event(&mut self) -> Option<ServerSentEvent> {
175 if self.event.is_none()
176 && self.data.is_empty()
177 && self.last_event_id.is_none()
178 && self.retry.is_none()
179 {
180 return None;
181 }
182
183 let event = ServerSentEvent {
184 event: self.event.take(),
185 data: self.data.join("\n"),
186 id: self.last_event_id.clone(),
187 retry: self.retry.take(),
188 };
189 self.data.clear();
190 Some(event)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn decodes_complete_event() {
200 let mut decoder = SseDecoder::new();
201 let events = decoder
202 .push(Bytes::from_static(
203 b"event: ping\ndata: {\"x\":1}\nid: abc\n\n",
204 ))
205 .unwrap();
206
207 assert_eq!(
208 events,
209 vec![ServerSentEvent {
210 event: Some("ping".to_string()),
211 data: "{\"x\":1}".to_string(),
212 id: Some("abc".to_string()),
213 retry: None,
214 }]
215 );
216 }
217
218 #[test]
219 fn decodes_split_event_and_multi_data_lines() {
220 let mut decoder = SseDecoder::new();
221 assert!(decoder
222 .push(Bytes::from_static(b"data: a\n"))
223 .unwrap()
224 .is_empty());
225 let events = decoder.push(Bytes::from_static(b"data: b\n\n")).unwrap();
226
227 assert_eq!(events[0].data, "a\nb");
228 }
229
230 #[test]
231 fn keeps_last_event_id_across_events() {
232 let mut decoder = SseDecoder::new();
233 let events = decoder
234 .push(Bytes::from_static(b"id: one\ndata: a\n\ndata: b\n\n"))
235 .unwrap();
236
237 assert_eq!(events[0].id.as_deref(), Some("one"));
238 assert_eq!(events[1].id.as_deref(), Some("one"));
239 }
240}