use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::error::LlmError;
use super::types::ChatStreamChunk;
pub struct SseStream {
body: Pin<Box<dyn Stream<Item = Result<Vec<u8>, reqwest::Error>> + Send>>,
buf: String,
}
impl SseStream {
pub fn new(response: reqwest::Response) -> Self {
use futures::StreamExt;
Self {
body: Box::pin(response.bytes_stream().map(|r| r.map(|b| b.to_vec()))),
buf: String::new(),
}
}
}
impl Stream for SseStream {
type Item = Result<ChatStreamChunk, LlmError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(chunk) = try_parse_event(&mut this.buf)? {
return Poll::Ready(Some(Ok(chunk)));
}
match Pin::new(&mut this.body).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => match std::str::from_utf8(&bytes) {
Ok(s) => this.buf.push_str(s),
Err(e) => {
return Poll::Ready(Some(Err(LlmError::StreamInterrupted(format!(
"Invalid UTF-8 in SSE stream: {e}"
)))));
}
},
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(LlmError::StreamInterrupted(e.to_string()))));
}
Poll::Ready(None) => {
if this.buf.trim().is_empty() {
return Poll::Ready(None);
}
match try_parse_remaining(&mut this.buf) {
Ok(Some(chunk)) => return Poll::Ready(Some(Ok(chunk))),
Ok(None) => return Poll::Ready(None),
Err(e) => return Poll::Ready(Some(Err(e))),
}
}
Poll::Pending => return Poll::Pending,
}
}
}
}
const SSE_EVENT_DELIMITER: &str = "\n\n";
const SSE_DATA_PREFIX: &str = "data:";
const SSE_DONE_MARKER: &str = "[DONE]";
fn try_parse_event(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
loop {
let Some(boundary) = buf.find(SSE_EVENT_DELIMITER) else {
return Ok(None);
};
let result = parse_sse_event(&buf[..boundary])?;
buf.drain(..boundary + SSE_EVENT_DELIMITER.len());
if let Some(chunk) = result {
return Ok(Some(chunk));
}
}
}
fn try_parse_remaining(buf: &mut String) -> Result<Option<ChatStreamChunk>, LlmError> {
let text = std::mem::take(buf);
let trimmed = text.trim();
if trimmed.is_empty() {
return Ok(None);
}
parse_sse_event(trimmed)
}
fn parse_sse_event(event_text: &str) -> Result<Option<ChatStreamChunk>, LlmError> {
let mut data_parts = Vec::new();
for line in event_text.lines() {
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix(SSE_DATA_PREFIX) {
let data = rest.strip_prefix(' ').unwrap_or(rest);
data_parts.push(data);
}
}
if data_parts.is_empty() {
return Ok(None);
}
let data = data_parts.join("\n");
let trimmed = data.trim();
if trimmed == SSE_DONE_MARKER || trimmed.is_empty() {
return Ok(None);
}
match serde_json::from_str::<ChatStreamChunk>(trimmed) {
Ok(chunk) => Ok(Some(chunk)),
Err(e) => Err(LlmError::Deserialize(format!(
"Failed to parse SSE data: {e} | raw: {}",
truncate_str(trimmed, 200)
))),
}
}
fn truncate_str(s: &str, max_len: usize) -> &str {
if s.len() <= max_len { s } else { &s[..max_len] }
}