use base64::Engine;
use bytes::Bytes;
use serde::Serialize;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ControlPayload {
pub stream_next_offset: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_cursor: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub up_to_date: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_closed: Option<bool>,
}
#[must_use]
pub fn format_control_frame(payload: &ControlPayload) -> String {
let json =
serde_json::to_string(payload).expect("ControlPayload serialization should not fail");
format!("event: control\ndata:{json}\n\n")
}
#[must_use]
pub fn format_data_frame(data: &Bytes, is_binary: bool, is_json: bool) -> String {
let text = if is_binary {
base64::engine::general_purpose::STANDARD.encode(data)
} else if is_json {
let raw = String::from_utf8_lossy(data);
format!("[{raw}]")
} else {
String::from_utf8_lossy(data).into_owned()
};
format_raw_data_frame(&text)
}
#[must_use]
pub fn format_data_frames(messages: &[Bytes], is_binary: bool, is_json: bool) -> String {
if messages.is_empty() {
return String::new();
}
if is_json {
let mut array_content = String::new();
for (i, msg) in messages.iter().enumerate() {
if i > 0 {
array_content.push(',');
}
let raw = String::from_utf8_lossy(msg);
array_content.push_str(&raw);
}
let text = format!("[{array_content}]");
format_raw_data_frame(&text)
} else {
let mut result = String::new();
for msg in messages {
result.push_str(&format_data_frame(msg, is_binary, false));
}
result
}
}
fn format_raw_data_frame(text: &str) -> String {
let mut frame = String::from("event: data\n");
for line in split_lines(text) {
frame.push_str("data:");
frame.push_str(line);
frame.push('\n');
}
frame.push('\n');
frame
}
fn split_lines(text: &str) -> Vec<&str> {
let mut lines = Vec::new();
let mut start = 0;
let bytes = text.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'\r' {
lines.push(&text[start..i]);
if i + 1 < bytes.len() && bytes[i + 1] == b'\n' {
i += 2;
} else {
i += 1;
}
start = i;
} else if bytes[i] == b'\n' {
lines.push(&text[start..i]);
i += 1;
start = i;
} else {
i += 1;
}
}
lines.push(&text[start..]);
lines
}
#[must_use]
pub fn format_keepalive_frame() -> &'static str {
":\n\n"
}
#[must_use]
pub fn is_binary_content_type(ct: &str) -> bool {
if ct.starts_with("text/") {
return false;
}
if ct == "application/json" {
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_binary_text_types() {
assert!(!is_binary_content_type("text/plain"));
assert!(!is_binary_content_type("text/html"));
assert!(!is_binary_content_type("text/csv"));
}
#[test]
fn test_is_binary_json() {
assert!(!is_binary_content_type("application/json"));
}
#[test]
fn test_is_binary_octet_stream() {
assert!(is_binary_content_type("application/octet-stream"));
}
#[test]
fn test_is_binary_protobuf() {
assert!(is_binary_content_type("application/x-protobuf"));
}
#[test]
fn test_is_binary_ndjson() {
assert!(is_binary_content_type("application/ndjson"));
}
#[test]
fn test_control_payload_serializes_camel_case() {
let payload = ControlPayload {
stream_next_offset: "abc_123".to_string(),
stream_cursor: Some("cursor1".to_string()),
up_to_date: Some(true),
stream_closed: None,
};
let json = serde_json::to_string(&payload).unwrap();
assert!(json.contains("\"streamNextOffset\""));
assert!(json.contains("\"streamCursor\""));
assert!(json.contains("\"upToDate\""));
assert!(!json.contains("\"streamClosed\""));
}
#[test]
fn test_control_payload_skips_none_fields() {
let payload = ControlPayload {
stream_next_offset: "offset1".to_string(),
stream_cursor: None,
up_to_date: None,
stream_closed: Some(true),
};
let json = serde_json::to_string(&payload).unwrap();
assert!(json.contains("\"streamNextOffset\""));
assert!(!json.contains("\"streamCursor\""));
assert!(!json.contains("\"upToDate\""));
assert!(json.contains("\"streamClosed\":true"));
}
#[test]
fn test_format_data_frame_text() {
let data = Bytes::from("hello world");
let frame = format_data_frame(&data, false, false);
assert_eq!(frame, "event: data\ndata:hello world\n\n");
}
#[test]
fn test_format_data_frame_multiline_lf() {
let data = Bytes::from("line1\nline2\nline3");
let frame = format_data_frame(&data, false, false);
assert!(frame.contains("data:line1\n"));
assert!(frame.contains("data:line2\n"));
assert!(frame.contains("data:line3\n"));
}
#[test]
fn test_format_data_frame_multiline_crlf() {
let data = Bytes::from("line1\r\nline2\r\nline3");
let frame = format_data_frame(&data, false, false);
assert!(frame.contains("data:line1\n"));
assert!(frame.contains("data:line2\n"));
assert!(frame.contains("data:line3\n"));
assert!(!frame.contains('\r'));
}
#[test]
fn test_format_data_frame_multiline_cr() {
let data = Bytes::from("line1\rline2\rline3");
let frame = format_data_frame(&data, false, false);
assert!(frame.contains("data:line1\n"));
assert!(frame.contains("data:line2\n"));
assert!(frame.contains("data:line3\n"));
assert!(!frame.contains('\r'));
}
#[test]
fn test_format_data_frame_json() {
let data = Bytes::from(r#"{"id":1}"#);
let frame = format_data_frame(&data, false, true);
assert!(frame.contains(r#"data:[{"id":1}]"#));
}
#[test]
fn test_format_data_frame_binary() {
let data = Bytes::from(vec![0x01, 0x02, 0x03]);
let frame = format_data_frame(&data, true, false);
assert!(frame.contains("data:AQID"));
}
#[test]
fn test_format_control_frame() {
let payload = ControlPayload {
stream_next_offset: "test".to_string(),
stream_cursor: None,
up_to_date: Some(true),
stream_closed: None,
};
let frame = format_control_frame(&payload);
assert!(frame.starts_with("event: control\n"));
assert!(frame.contains("data:"));
assert!(frame.ends_with("\n\n"));
}
#[test]
fn test_split_lines_lf() {
assert_eq!(split_lines("a\nb\nc"), vec!["a", "b", "c"]);
}
#[test]
fn test_split_lines_crlf() {
assert_eq!(split_lines("a\r\nb\r\nc"), vec!["a", "b", "c"]);
}
#[test]
fn test_split_lines_cr() {
assert_eq!(split_lines("a\rb\rc"), vec!["a", "b", "c"]);
}
#[test]
fn test_split_lines_mixed() {
assert_eq!(split_lines("a\nb\rc\r\nd"), vec!["a", "b", "c", "d"]);
}
#[test]
fn test_split_lines_empty_segments() {
assert_eq!(split_lines("a\n\nb"), vec!["a", "", "b"]);
assert_eq!(split_lines("a\r\rb"), vec!["a", "", "b"]);
}
#[test]
fn test_format_data_frames_json_batches_into_single_event() {
let messages = vec![
Bytes::from(r#"{"id":1}"#),
Bytes::from(r#"{"id":2}"#),
Bytes::from(r#"{"id":3}"#),
];
let frame = format_data_frames(&messages, false, true);
assert_eq!(
frame.matches("event: data").count(),
1,
"JSON batch should produce exactly one SSE data event"
);
assert!(
frame.contains(r#"data:[{"id":1},{"id":2},{"id":3}]"#),
"JSON batch should contain all messages in one array"
);
}
#[test]
fn test_format_data_frames_json_single_message() {
let messages = vec![Bytes::from(r#"{"id":1}"#)];
let frame = format_data_frames(&messages, false, true);
assert_eq!(frame.matches("event: data").count(), 1);
assert!(frame.contains(r#"data:[{"id":1}]"#));
}
#[test]
fn test_format_data_frames_text_emits_per_message() {
let messages = vec![Bytes::from("hello"), Bytes::from("world")];
let frame = format_data_frames(&messages, false, false);
assert_eq!(
frame.matches("event: data").count(),
2,
"Text batch should produce one SSE data event per message"
);
}
#[test]
fn test_format_data_frames_empty() {
let result = format_data_frames(&[], false, true);
assert!(result.is_empty());
}
}