#![allow(dead_code)]
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use std::pin::Pin;
use std::task::{Context, Poll};
use super::error::BackendError;
use crate::cancel_token::CancellationFlag;
pub fn cancel_aware<S, T>(
stream: S,
cancel: CancellationFlag,
) -> Pin<Box<dyn Stream<Item = T> + Send>>
where
S: Stream<Item = T> + Send + Unpin + 'static,
T: Send + 'static,
{
Box::pin(futures::stream::unfold(
(stream, cancel),
|(mut s, cancel)| async move {
if cancel.is_cancelled() {
return None;
}
tokio::select! {
biased;
_ = cancel.cancelled() => None,
item = s.next() => item.map(|x| (x, (s, cancel))),
}
},
))
}
#[derive(Debug, Default)]
pub struct LineBuffer {
tail: Vec<u8>,
}
impl LineBuffer {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, chunk: &[u8]) -> Vec<String> {
let mut out = Vec::new();
for &byte in chunk {
if byte == b'\n' {
if self.tail.last() == Some(&b'\r') {
self.tail.pop();
}
out.push(String::from_utf8_lossy(&self.tail).into_owned());
self.tail.clear();
} else {
self.tail.push(byte);
}
}
out
}
pub fn flush(&mut self) -> Option<String> {
if self.tail.is_empty() {
return None;
}
if self.tail.last() == Some(&b'\r') {
self.tail.pop();
}
let line = String::from_utf8_lossy(&self.tail).into_owned();
self.tail.clear();
Some(line)
}
pub fn is_empty(&self) -> bool {
self.tail.is_empty()
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct SseEvent {
pub event: Option<String>,
pub id: Option<String>,
pub data: Option<String>,
pub retry_ms: Option<u64>,
}
impl SseEvent {
pub fn is_empty(&self) -> bool {
self.event.is_none()
&& self.id.is_none()
&& self.data.is_none()
&& self.retry_ms.is_none()
}
}
#[derive(Debug, Default)]
pub struct SseEventParser {
current: SseEvent,
data_acc: Vec<String>,
}
impl SseEventParser {
pub fn new() -> Self {
Self::default()
}
pub fn push_line(&mut self, line: &str) -> Option<SseEvent> {
if line.is_empty() {
if !self.data_acc.is_empty() {
self.current.data = Some(self.data_acc.join("\n"));
self.data_acc.clear();
}
let event = std::mem::take(&mut self.current);
return if event.is_empty() { None } else { Some(event) };
}
if line.starts_with(':') {
return None;
}
let (field, raw_value) = match line.find(':') {
Some(idx) => (&line[..idx], &line[idx + 1..]),
None => (line, ""),
};
let value = raw_value.strip_prefix(' ').unwrap_or(raw_value);
match field {
"event" => self.current.event = Some(value.to_string()),
"id" => self.current.id = Some(value.to_string()),
"data" => self.data_acc.push(value.to_string()),
"retry" => {
if let Ok(ms) = value.parse::<u64>() {
self.current.retry_ms = Some(ms);
}
}
_ => {
}
}
None
}
pub fn flush(&mut self) -> Option<SseEvent> {
if !self.data_acc.is_empty() {
self.current.data = Some(self.data_acc.join("\n"));
self.data_acc.clear();
}
let event = std::mem::take(&mut self.current);
if event.is_empty() {
None
} else {
Some(event)
}
}
}
pub struct LineStream<S> {
inner: S,
buffer: LineBuffer,
pending: std::collections::VecDeque<String>,
done: bool,
provider: String,
model: String,
}
impl<S> LineStream<S> {
pub fn new(inner: S, provider: impl Into<String>, model: impl Into<String>) -> Self {
Self {
inner,
buffer: LineBuffer::new(),
pending: std::collections::VecDeque::new(),
done: false,
provider: provider.into(),
model: model.into(),
}
}
}
impl<S> Stream for LineStream<S>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
type Item = Result<String, BackendError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(line) = self.pending.pop_front() {
return Poll::Ready(Some(Ok(line)));
}
if self.done {
return Poll::Ready(None);
}
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let lines = self.buffer.push(&chunk);
self.pending.extend(lines);
}
Poll::Ready(Some(Err(e))) => {
self.done = true;
return Poll::Ready(Some(Err(BackendError::Generic {
provider: self.provider.clone(),
model: self.model.clone(),
status: None,
message: format!("stream transport error: {e}"),
})));
}
Poll::Ready(None) => {
if let Some(tail) = self.buffer.flush() {
self.pending.push_back(tail);
}
self.done = true;
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub struct SseEventStream<S> {
line_stream: LineStream<S>,
parser: SseEventParser,
done: bool,
flushed: bool,
}
impl<S> SseEventStream<S> {
pub fn new(
inner: S,
provider: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
line_stream: LineStream::new(inner, provider, model),
parser: SseEventParser::new(),
done: false,
flushed: false,
}
}
}
impl<S> Stream for SseEventStream<S>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
type Item = Result<SseEvent, BackendError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.done {
return Poll::Ready(None);
}
match self.line_stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(line))) => {
if let Some(event) = self.parser.push_line(&line) {
return Poll::Ready(Some(Ok(event)));
}
}
Poll::Ready(Some(Err(e))) => {
self.done = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
if !self.flushed {
self.flushed = true;
if let Some(event) = self.parser.flush() {
return Poll::Ready(Some(Ok(event)));
}
}
self.done = true;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub fn line_stream(
response: reqwest::Response,
provider: impl Into<String>,
model: impl Into<String>,
) -> LineStream<impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin> {
LineStream::new(Box::pin(response.bytes_stream()), provider, model)
}
pub fn sse_event_stream(
response: reqwest::Response,
provider: impl Into<String>,
model: impl Into<String>,
) -> SseEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin> {
SseEventStream::new(Box::pin(response.bytes_stream()), provider, model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn line_buffer_yields_complete_lf_lines() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"hello\nworld\n");
assert_eq!(lines, vec!["hello", "world"]);
assert!(buf.is_empty());
}
#[test]
fn line_buffer_holds_partial_line_until_lf() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"hello");
assert!(lines.is_empty());
assert!(!buf.is_empty());
let lines = buf.push(b" world\n");
assert_eq!(lines, vec!["hello world"]);
}
#[test]
fn line_buffer_normalizes_crlf() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"hello\r\nworld\r\n");
assert_eq!(lines, vec!["hello", "world"]);
}
#[test]
fn line_buffer_splits_chunk_across_pushes() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"hel");
assert!(lines.is_empty());
let lines = buf.push(b"lo\nwor");
assert_eq!(lines, vec!["hello"]);
let lines = buf.push(b"ld\n");
assert_eq!(lines, vec!["world"]);
}
#[test]
fn line_buffer_flush_returns_trailing_fragment() {
let mut buf = LineBuffer::new();
let _ = buf.push(b"complete\nincomplete");
let tail = buf.flush();
assert_eq!(tail, Some("incomplete".to_string()));
assert!(buf.is_empty());
}
#[test]
fn line_buffer_flush_on_empty_returns_none() {
let mut buf = LineBuffer::new();
assert_eq!(buf.flush(), None);
}
#[test]
fn line_buffer_empty_chunk_is_noop() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"");
assert!(lines.is_empty());
assert!(buf.is_empty());
}
#[test]
fn line_buffer_handles_consecutive_lf() {
let mut buf = LineBuffer::new();
let lines = buf.push(b"a\n\nb\n");
assert_eq!(lines, vec!["a", "", "b"]);
}
#[test]
fn sse_parser_data_only_event() {
let mut p = SseEventParser::new();
assert!(p.push_line("data: hello").is_none());
let ev = p.push_line("").expect("event dispatched on blank");
assert_eq!(ev.data, Some("hello".to_string()));
assert!(ev.event.is_none());
}
#[test]
fn sse_parser_full_event_shape() {
let mut p = SseEventParser::new();
assert!(p.push_line("event: axon.token").is_none());
assert!(p.push_line("id: 42").is_none());
assert!(p.push_line("data: hello").is_none());
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.event, Some("axon.token".to_string()));
assert_eq!(ev.id, Some("42".to_string()));
assert_eq!(ev.data, Some("hello".to_string()));
}
#[test]
fn sse_parser_multi_line_data_joins_with_lf() {
let mut p = SseEventParser::new();
p.push_line("data: line1");
p.push_line("data: line2");
p.push_line("data: line3");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some("line1\nline2\nline3".to_string()));
}
#[test]
fn sse_parser_retry_directive_parsed_to_u64() {
let mut p = SseEventParser::new();
p.push_line("retry: 5000");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.retry_ms, Some(5000));
}
#[test]
fn sse_parser_retry_invalid_value_silently_ignored() {
let mut p = SseEventParser::new();
p.push_line("retry: not-a-number");
p.push_line("data: x");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.retry_ms, None);
assert_eq!(ev.data, Some("x".to_string()));
}
#[test]
fn sse_parser_comment_lines_ignored() {
let mut p = SseEventParser::new();
p.push_line(": this is a comment");
p.push_line("data: visible");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some("visible".to_string()));
}
#[test]
fn sse_parser_unknown_field_ignored() {
let mut p = SseEventParser::new();
p.push_line("bogus: ignored");
p.push_line("data: visible");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some("visible".to_string()));
}
#[test]
fn sse_parser_consecutive_blank_lines_dont_dispatch_empty() {
let mut p = SseEventParser::new();
assert!(p.push_line("").is_none());
assert!(p.push_line("").is_none());
p.push_line("data: x");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some("x".to_string()));
}
#[test]
fn sse_parser_field_without_space_after_colon() {
let mut p = SseEventParser::new();
p.push_line("data:nospace");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some("nospace".to_string()));
}
#[test]
fn sse_parser_field_without_colon_still_parsed_as_empty_value() {
let mut p = SseEventParser::new();
p.push_line("data");
let ev = p.push_line("").expect("dispatched");
assert_eq!(ev.data, Some(String::new()));
}
#[test]
fn sse_parser_flush_yields_pending_event_on_eof() {
let mut p = SseEventParser::new();
p.push_line("data: trailing");
let ev = p.flush().expect("flush yields pending");
assert_eq!(ev.data, Some("trailing".to_string()));
}
#[test]
fn sse_parser_flush_on_clean_state_returns_none() {
let mut p = SseEventParser::new();
assert!(p.flush().is_none());
}
#[test]
fn sse_event_is_empty_predicate_total() {
let empty = SseEvent::default();
assert!(empty.is_empty());
let non_empty = SseEvent {
data: Some("x".into()),
..Default::default()
};
assert!(!non_empty.is_empty());
}
use futures::stream;
fn fake_chunk_stream(
chunks: Vec<&'static [u8]>,
) -> impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin {
Box::pin(stream::iter(
chunks.into_iter().map(|c| Ok(Bytes::from_static(c))),
))
}
#[tokio::test]
async fn line_stream_yields_complete_lines_across_chunk_boundaries() {
let inner = fake_chunk_stream(vec![b"hel", b"lo\nwor", b"ld\n"]);
let stream = LineStream::new(inner, "test", "test-model");
let lines: Vec<String> = stream
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(lines, vec!["hello".to_string(), "world".to_string()]);
}
#[tokio::test]
async fn line_stream_flushes_trailing_fragment_on_eof() {
let inner = fake_chunk_stream(vec![b"a\nb"]);
let stream = LineStream::new(inner, "test", "test-model");
let lines: Vec<String> = stream
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(lines, vec!["a".to_string(), "b".to_string()]);
}
#[tokio::test]
async fn sse_event_stream_parses_canonical_openai_data_format() {
let inner = fake_chunk_stream(vec![
b"data: {\"chunk\":1}\n",
b"\n",
b"data: {\"chunk\":2}\n",
b"\n",
]);
let stream = SseEventStream::new(inner, "openai", "gpt-4o-mini");
let events: Vec<SseEvent> = stream
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(events.len(), 2);
assert_eq!(events[0].data, Some(r#"{"chunk":1}"#.to_string()));
assert_eq!(events[1].data, Some(r#"{"chunk":2}"#.to_string()));
}
#[tokio::test]
async fn sse_event_stream_parses_anthropic_event_data_pairs() {
let inner = fake_chunk_stream(vec![
b"event: message_start\n",
b"data: {\"type\":\"message_start\"}\n",
b"\n",
b"event: content_block_delta\n",
b"data: {\"delta\":{\"text\":\"hi\"}}\n",
b"\n",
]);
let stream = SseEventStream::new(inner, "anthropic", "claude-x");
let events: Vec<SseEvent> = stream
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(events.len(), 2);
assert_eq!(events[0].event.as_deref(), Some("message_start"));
assert_eq!(events[1].event.as_deref(), Some("content_block_delta"));
assert!(events[1].data.as_ref().unwrap().contains("hi"));
}
#[tokio::test]
async fn sse_event_stream_yields_final_event_without_trailing_blank() {
let inner = fake_chunk_stream(vec![
b"data: one\n\n",
b"data: two\n",
]);
let stream = SseEventStream::new(inner, "test", "test-model");
let events: Vec<SseEvent> = stream
.map(|r| r.unwrap())
.collect()
.await;
assert_eq!(events.len(), 2);
assert_eq!(events[1].data, Some("two".to_string()));
}
}