Skip to main content

j_agent/llm/
stream.rs

1use futures::Stream;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use super::error::LlmError;
6use super::types::ChatStreamChunk;
7
8/// SSE stream: wraps a `reqwest::Response` byte stream, yields `ChatStreamChunk`.
9///
10/// Uses a `Vec<u8>` byte buffer to correctly handle multi-byte UTF-8 characters
11/// split across TCP chunks — a common occurrence when the server streams CJK text
12/// and the network fragments packets at arbitrary byte boundaries.
13pub struct SseStream {
14    body: Pin<Box<dyn Stream<Item = Result<Vec<u8>, reqwest::Error>> + Send>>,
15    byte_buf: Vec<u8>,
16    str_buf: String,
17}
18
19impl SseStream {
20    pub fn new(response: reqwest::Response) -> Self {
21        use futures::StreamExt;
22        Self {
23            body: Box::pin(response.bytes_stream().map(|r| r.map(|b| b.to_vec()))),
24            byte_buf: Vec::new(),
25            str_buf: String::new(),
26        }
27    }
28}
29
30impl Stream for SseStream {
31    type Item = Result<ChatStreamChunk, LlmError>;
32
33    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        let this = self.get_mut();
35
36        loop {
37            // Try to extract a complete SSE event from the string buffer
38            if let Some(chunk) = try_parse_event(&mut this.str_buf)? {
39                return Poll::Ready(Some(Ok(chunk)));
40            }
41
42            // Need more data from the body stream
43            match Pin::new(&mut this.body).poll_next(cx) {
44                Poll::Ready(Some(Ok(bytes))) => {
45                    this.byte_buf.extend_from_slice(&bytes);
46                    // Convert as many valid UTF-8 bytes as possible from byte_buf → str_buf.
47                    // Any trailing incomplete multi-byte sequence stays in byte_buf for the next chunk.
48                    flush_utf8(&mut this.byte_buf, &mut this.str_buf)?;
49                }
50                Poll::Ready(Some(Err(e))) => {
51                    return Poll::Ready(Some(Err(LlmError::StreamInterrupted(e.to_string()))));
52                }
53                Poll::Ready(None) => {
54                    // Stream ended. Flush any remaining bytes (truly invalid UTF-8 at this point).
55                    if !this.byte_buf.is_empty() {
56                        match std::str::from_utf8(&this.byte_buf) {
57                            Ok(s) => {
58                                this.str_buf.push_str(s);
59                                this.byte_buf.clear();
60                            }
61                            Err(e) => {
62                                return Poll::Ready(Some(Err(LlmError::StreamInterrupted(
63                                    format!("Invalid UTF-8 in SSE stream: {e}"),
64                                ))));
65                            }
66                        }
67                    }
68                    // Check if there's a final partial event in str_buf.
69                    if this.str_buf.trim().is_empty() {
70                        return Poll::Ready(None);
71                    }
72                    // Try to parse whatever remains
73                    match try_parse_remaining(&mut this.str_buf) {
74                        Ok(Some(chunk)) => return Poll::Ready(Some(Ok(chunk))),
75                        Ok(None) => return Poll::Ready(None),
76                        Err(e) => return Poll::Ready(Some(Err(e))),
77                    }
78                }
79                Poll::Pending => return Poll::Pending,
80            }
81        }
82    }
83}
84
85/// Convert complete UTF-8 sequences from `byte_buf` → `str_buf`.
86/// Any trailing incomplete multi-byte sequence stays in `byte_buf`.
87fn flush_utf8(byte_buf: &mut Vec<u8>, str_buf: &mut String) -> Result<(), LlmError> {
88    if byte_buf.is_empty() {
89        return Ok(());
90    }
91    match std::str::from_utf8(byte_buf) {
92        // Entire buffer is valid UTF-8 — flush everything.
93        Ok(s) => {
94            str_buf.push_str(s);
95            byte_buf.clear();
96            Ok(())
97        }
98        // Partial UTF-8 — flush the valid prefix, keep the incomplete tail.
99        Err(e) => {
100            let valid_up_to = e.valid_up_to();
101            if valid_up_to == 0 && e.error_len().is_some() {
102                // No valid prefix at all, and it's a real error (not just incomplete).
103                return Err(LlmError::StreamInterrupted(format!(
104                    "Invalid UTF-8 in SSE stream: {e}"
105                )));
106            }
107            // SAFETY: `valid_up_to` from `Utf8Error` is guaranteed to be a valid UTF-8 boundary
108            // index into the original slice, so slicing `[..valid_up_to]` always produces
109            // valid UTF-8 and `from_utf8` will succeed.
110            let valid = std::str::from_utf8(&byte_buf[..valid_up_to])
111                .expect("valid_up_to is guaranteed to be a UTF-8 boundary");
112            str_buf.push_str(valid);
113            byte_buf.drain(..valid_up_to);
114            // The remaining bytes are an incomplete multi-byte sequence —
115            // they'll be completed when the next TCP chunk arrives.
116            Ok(())
117        }
118    }
119}
120
121/// SSE event delimiter: double newline.
122const SSE_EVENT_DELIMITER: &str = "\n\n";
123
124/// SSE data line prefix.
125const SSE_DATA_PREFIX: &str = "data:";
126
127/// SSE stream termination sentinel value.
128const SSE_DONE_MARKER: &str = "[DONE]";
129
130/// Try to extract one complete SSE event from the buffer.
131fn try_parse_event(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
132    loop {
133        let Some(boundary) = buf.find(SSE_EVENT_DELIMITER) else {
134            return Ok(None);
135        };
136
137        let result = parse_sse_event(&buf[..boundary])?;
138        buf.drain(..boundary + SSE_EVENT_DELIMITER.len());
139
140        if let Some(chunk) = result {
141            return Ok(Some(chunk));
142        }
143    }
144}
145
146/// Try to parse any remaining data in the buffer when the stream ends.
147fn try_parse_remaining(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
148    let text = std::mem::take(buf);
149    let trimmed = text.trim();
150    if trimmed.is_empty() {
151        return Ok(None);
152    }
153    parse_sse_event(trimmed)
154}
155
156/// Parse a single SSE event text block.
157/// Returns `Ok(None)` for [DONE], comments, or empty data lines.
158fn parse_sse_event(event_text: &str) -> Result<Option<ChatStreamChunk>, LlmError> {
159    let mut data_parts = Vec::new();
160
161    for line in event_text.lines() {
162        if line.starts_with(':') {
163            continue;
164        }
165        if let Some(rest) = line.strip_prefix(SSE_DATA_PREFIX) {
166            let data = rest.strip_prefix(' ').unwrap_or(rest);
167            data_parts.push(data);
168        }
169    }
170
171    if data_parts.is_empty() {
172        return Ok(None);
173    }
174
175    let data = data_parts.join("\n");
176    let trimmed = data.trim();
177
178    if trimmed == SSE_DONE_MARKER || trimmed.is_empty() {
179        return Ok(None);
180    }
181
182    match serde_json::from_str::<ChatStreamChunk>(trimmed) {
183        Ok(chunk) => Ok(Some(chunk)),
184        Err(e) => Err(LlmError::Deserialize(format!(
185            "Failed to parse SSE data: {e} | raw: {}",
186            truncate_str(trimmed, 200)
187        ))),
188    }
189}
190
191fn truncate_str(s: &str, max_len: usize) -> &str {
192    if s.len() <= max_len {
193        s
194    } else {
195        let end = (0..=max_len)
196            .rev()
197            .find(|&i| s.is_char_boundary(i))
198            .unwrap_or(0);
199        &s[..end]
200    }
201}