1use futures::Stream;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use super::error::LlmError;
6use super::types::ChatStreamChunk;
7
8pub struct SseStream {
14 body: Pin<Box<dyn Stream<Item = Result<Vec<u8>, reqwest::Error>> + Send>>,
15 byte_buf: Vec<u8>,
16 str_buf: String,
17}
18
19impl SseStream {
20 pub fn new(response: reqwest::Response) -> Self {
21 use futures::StreamExt;
22 Self {
23 body: Box::pin(response.bytes_stream().map(|r| r.map(|b| b.to_vec()))),
24 byte_buf: Vec::new(),
25 str_buf: String::new(),
26 }
27 }
28}
29
30impl Stream for SseStream {
31 type Item = Result<ChatStreamChunk, LlmError>;
32
33 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34 let this = self.get_mut();
35
36 loop {
37 if let Some(chunk) = try_parse_event(&mut this.str_buf)? {
39 return Poll::Ready(Some(Ok(chunk)));
40 }
41
42 match Pin::new(&mut this.body).poll_next(cx) {
44 Poll::Ready(Some(Ok(bytes))) => {
45 this.byte_buf.extend_from_slice(&bytes);
46 flush_utf8(&mut this.byte_buf, &mut this.str_buf)?;
49 }
50 Poll::Ready(Some(Err(e))) => {
51 return Poll::Ready(Some(Err(LlmError::StreamInterrupted(e.to_string()))));
52 }
53 Poll::Ready(None) => {
54 if !this.byte_buf.is_empty() {
56 match std::str::from_utf8(&this.byte_buf) {
57 Ok(s) => {
58 this.str_buf.push_str(s);
59 this.byte_buf.clear();
60 }
61 Err(e) => {
62 return Poll::Ready(Some(Err(LlmError::StreamInterrupted(
63 format!("Invalid UTF-8 in SSE stream: {e}"),
64 ))));
65 }
66 }
67 }
68 if this.str_buf.trim().is_empty() {
70 return Poll::Ready(None);
71 }
72 match try_parse_remaining(&mut this.str_buf) {
74 Ok(Some(chunk)) => return Poll::Ready(Some(Ok(chunk))),
75 Ok(None) => return Poll::Ready(None),
76 Err(e) => return Poll::Ready(Some(Err(e))),
77 }
78 }
79 Poll::Pending => return Poll::Pending,
80 }
81 }
82 }
83}
84
85fn flush_utf8(byte_buf: &mut Vec<u8>, str_buf: &mut String) -> Result<(), LlmError> {
88 if byte_buf.is_empty() {
89 return Ok(());
90 }
91 match std::str::from_utf8(byte_buf) {
92 Ok(s) => {
94 str_buf.push_str(s);
95 byte_buf.clear();
96 Ok(())
97 }
98 Err(e) => {
100 let valid_up_to = e.valid_up_to();
101 if valid_up_to == 0 && e.error_len().is_some() {
102 return Err(LlmError::StreamInterrupted(format!(
104 "Invalid UTF-8 in SSE stream: {e}"
105 )));
106 }
107 let valid = std::str::from_utf8(&byte_buf[..valid_up_to])
111 .expect("valid_up_to is guaranteed to be a UTF-8 boundary");
112 str_buf.push_str(valid);
113 byte_buf.drain(..valid_up_to);
114 Ok(())
117 }
118 }
119}
120
121const SSE_EVENT_DELIMITER: &str = "\n\n";
123
124const SSE_DATA_PREFIX: &str = "data:";
126
127const SSE_DONE_MARKER: &str = "[DONE]";
129
130fn try_parse_event(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
132 loop {
133 let Some(boundary) = buf.find(SSE_EVENT_DELIMITER) else {
134 return Ok(None);
135 };
136
137 let result = parse_sse_event(&buf[..boundary])?;
138 buf.drain(..boundary + SSE_EVENT_DELIMITER.len());
139
140 if let Some(chunk) = result {
141 return Ok(Some(chunk));
142 }
143 }
144}
145
146fn try_parse_remaining(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
148 let text = std::mem::take(buf);
149 let trimmed = text.trim();
150 if trimmed.is_empty() {
151 return Ok(None);
152 }
153 parse_sse_event(trimmed)
154}
155
156fn parse_sse_event(event_text: &str) -> Result<Option<ChatStreamChunk>, LlmError> {
159 let mut data_parts = Vec::new();
160
161 for line in event_text.lines() {
162 if line.starts_with(':') {
163 continue;
164 }
165 if let Some(rest) = line.strip_prefix(SSE_DATA_PREFIX) {
166 let data = rest.strip_prefix(' ').unwrap_or(rest);
167 data_parts.push(data);
168 }
169 }
170
171 if data_parts.is_empty() {
172 return Ok(None);
173 }
174
175 let data = data_parts.join("\n");
176 let trimmed = data.trim();
177
178 if trimmed == SSE_DONE_MARKER || trimmed.is_empty() {
179 return Ok(None);
180 }
181
182 match serde_json::from_str::<ChatStreamChunk>(trimmed) {
183 Ok(chunk) => Ok(Some(chunk)),
184 Err(e) => Err(LlmError::Deserialize(format!(
185 "Failed to parse SSE data: {e} | raw: {}",
186 truncate_str(trimmed, 200)
187 ))),
188 }
189}
190
191fn truncate_str(s: &str, max_len: usize) -> &str {
192 if s.len() <= max_len {
193 s
194 } else {
195 let end = (0..=max_len)
196 .rev()
197 .find(|&i| s.is_char_boundary(i))
198 .unwrap_or(0);
199 &s[..end]
200 }
201}