use bytes::{Bytes, BytesMut};
#[derive(Debug, Clone)]
pub struct SseEvent {
pub event_type: Option<String>,
pub data: String,
}
#[derive(Debug, Clone)]
pub enum SseParseError {
UnterminatedJson,
MissingDelimiter,
InvalidUtf8,
}
impl std::fmt::Display for SseParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SseParseError::UnterminatedJson => write!(f, "unterminated JSON object"),
SseParseError::MissingDelimiter => write!(f, "missing SSE event delimiter"),
SseParseError::InvalidUtf8 => write!(f, "invalid UTF-8"),
}
}
}
pub struct SseEventIterator<'a> {
text: &'a str,
position: usize,
}
impl<'a> SseEventIterator<'a> {
pub fn new(text: &'a str) -> Self {
Self { text, position: 0 }
}
pub fn position(&self) -> usize {
self.position
}
pub fn skip_incomplete_event(&mut self) -> usize {
if let Some(next_delim) = self.text[self.position..].find("\n\n") {
self.position += next_delim + 2;
} else {
self.position = self.text.len();
}
self.position
}
pub fn next_event(&mut self) -> Option<Result<SseEvent, SseParseError>> {
let base_pos = self.position;
let text = &self.text[base_pos..];
let data_start = text.find("data:")?;
let event_type = if data_start > 0 {
let before_data = &text[..data_start];
let mut result = None;
for line in before_data.lines().rev() {
let line = line.trim();
if let Some(stripped) = line.strip_prefix("event:") {
result = Some(stripped.trim().to_string());
break;
}
if !line.is_empty() && !line.starts_with("event:") {
break;
}
}
result
} else {
None
};
let after_prefix = data_start + 5; if after_prefix >= text.len() {
return None;
}
let mut value_start = after_prefix;
while value_start < text.len() {
let c = text.as_bytes()[value_start];
if c == b' ' || c == b'\t' || c == b'\n' || c == b'\r' {
value_start += 1;
} else {
break;
}
}
if value_start >= text.len() {
return None;
}
let rest = &text[value_start..];
if rest.starts_with("[DONE]") {
self.position = base_pos + value_start + 6;
if self.position + 2 <= self.text.len()
&& &self.text[self.position..self.position + 2] == "\n\n"
{
self.position += 2;
}
return Some(Ok(SseEvent {
event_type: None, data: "[DONE]".to_string(),
}));
}
let json_end = match find_json_end(rest) {
Some(pos) => pos,
None => return Some(Err(SseParseError::UnterminatedJson)),
};
let after_json = &rest[json_end..];
match after_json.find("\n\n") {
Some(delimiter_pos) => {
let json_content = &rest[..json_end];
self.position = base_pos + value_start + json_end + delimiter_pos + 2;
Some(Ok(SseEvent {
event_type,
data: json_content.to_string(),
}))
}
None => Some(Err(SseParseError::MissingDelimiter)),
}
}
}
pub fn parse_sse(text: &str) -> (Vec<SseEvent>, usize) {
let mut events = Vec::new();
let mut iter = SseEventIterator::new(text);
while let Some(result) = iter.next_event() {
match result {
Ok(event) => events.push(event),
Err(e) => {
tracing::debug!("SSE parse error: {}, skipping incomplete event", e);
iter.skip_incomplete_event();
break;
}
}
}
let end_pos = if !events.is_empty() && iter.position() >= text.len() {
text.len()
} else {
iter.position()
};
(events, end_pos)
}
fn find_json_end(text: &str) -> Option<usize> {
let mut brace_depth = 0;
let mut in_string = false;
let mut escaped = false;
for (i, c) in text.char_indices() {
if escaped {
escaped = false;
continue;
}
match c {
'\\' if in_string => {
escaped = true;
}
'"' => {
in_string = !in_string;
}
'{' | '[' if !in_string => {
brace_depth += 1;
}
'}' | ']' if !in_string => {
brace_depth -= 1;
if brace_depth == 0 {
return Some(i + c.len_utf8());
}
}
_ => {}
}
}
None
}
pub fn serialize_sse(event: &SseEvent) -> String {
let mut result = String::new();
if let Some(ref et) = event.event_type {
result.push_str("event: ");
result.push_str(et);
result.push('\n');
}
result.push_str("data: ");
result.push_str(&event.data);
result.push_str("\n\n");
result
}
pub fn collect_frames(frames: &[Bytes]) -> Bytes {
if frames.is_empty() {
return Bytes::new();
}
if frames.len() == 1 {
return frames[0].clone();
}
let total_len: usize = frames.iter().map(|f| f.len()).sum();
let mut result = BytesMut::with_capacity(total_len);
for frame in frames {
result.extend_from_slice(frame);
}
result.freeze()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_event() {
let text = "data: {\"type\": \"response.created\", \"id\": \"resp_123\"}\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, None);
assert_eq!(
events[0].data,
"{\"type\": \"response.created\", \"id\": \"resp_123\"}"
);
}
#[test]
fn test_parse_event_with_type() {
let text = "event: response.created\ndata: {\"id\": \"resp_123\"}\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 1);
assert_eq!(
events[0].event_type,
Some("response.created".to_string())
);
assert_eq!(events[0].data, "{\"id\": \"resp_123\"}");
}
#[test]
fn test_parse_event_without_space_after_data_colon() {
let text = "event: response.created\ndata:{\"id\":\"resp_123\"}\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, Some("response.created".to_string()));
assert_eq!(events[0].data, "{\"id\":\"resp_123\"}");
}
#[test]
fn test_parse_event_with_newline_after_data_colon() {
let text = "event: response.created\ndata:\n{\"id\":\"resp_123\"}\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, Some("response.created".to_string()));
assert_eq!(events[0].data, "{\"id\":\"resp_123\"}");
}
#[test]
fn test_parse_done_event() {
let text = "data: [DONE]\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, None);
assert_eq!(events[0].data, "[DONE]");
}
#[test]
fn test_parse_multiple_events() {
let text = "event: response.created\ndata: {\"id\": \"1\"}\n\nevent: response.output_text.delta\ndata: {\"delta\": \"hello\"}\n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 2);
assert_eq!(
events[0].event_type,
Some("response.created".to_string())
);
assert_eq!(
events[1].event_type,
Some("response.output_text.delta".to_string())
);
}
#[test]
fn test_parse_empty_data() {
let text = "event: done\ndata: \n\n";
let (events, _) = parse_sse(text);
assert_eq!(events.len(), 0);
}
#[test]
fn test_serialize_sse() {
let event = SseEvent {
event_type: Some("response.created".to_string()),
data: "{\"id\": \"resp_123\"}".to_string(),
};
let result = serialize_sse(&event);
assert!(result.contains("event: response.created\n"));
assert!(result.contains("data: {\"id\": \"resp_123\"}\n\n"));
}
#[test]
fn test_collect_frames_empty() {
let frames: [Bytes; 0] = [];
let result = collect_frames(&frames);
assert!(result.is_empty());
}
#[test]
fn test_collect_frames_single() {
let frames = vec![Bytes::from("hello")];
let result = collect_frames(&frames);
assert_eq!(&result[..], b"hello");
}
#[test]
fn test_collect_frames_multiple() {
let frames = vec![
Bytes::from("hello"),
Bytes::from(" world"),
Bytes::from("!"),
];
let result = collect_frames(&frames);
assert_eq!(&result[..], b"hello world!");
}
#[test]
fn test_find_json_end() {
assert_eq!(find_json_end(r#"{"key": "value"}"#), Some(16));
assert_eq!(
find_json_end(r#"{"outer": {"inner": "value"}}"#),
Some(29)
);
assert_eq!(find_json_end(r#"[1, 2, 3]"#), Some(9));
assert_eq!(find_json_end(""), None);
assert_eq!(find_json_end(r#"{"key": "value"#), None);
}
}