1#![allow(dead_code)]
44
45use bytes::Bytes;
46use futures::stream::{Stream, StreamExt};
47use std::pin::Pin;
48use std::task::{Context, Poll};
49
50use super::error::BackendError;
51use crate::cancel_token::CancellationFlag;
52
53pub fn cancel_aware<S, T>(
93 stream: S,
94 cancel: CancellationFlag,
95) -> Pin<Box<dyn Stream<Item = T> + Send>>
96where
97 S: Stream<Item = T> + Send + Unpin + 'static,
98 T: Send + 'static,
99{
100 Box::pin(futures::stream::unfold(
101 (stream, cancel),
102 |(mut s, cancel)| async move {
103 if cancel.is_cancelled() {
108 return None;
109 }
110 tokio::select! {
115 biased;
116 _ = cancel.cancelled() => None,
117 item = s.next() => item.map(|x| (x, (s, cancel))),
118 }
119 },
120 ))
121}
122
123#[derive(Debug, Default)]
138pub struct LineBuffer {
139 tail: Vec<u8>,
144}
145
146impl LineBuffer {
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn push(&mut self, chunk: &[u8]) -> Vec<String> {
155 let mut out = Vec::new();
156 for &byte in chunk {
157 if byte == b'\n' {
158 if self.tail.last() == Some(&b'\r') {
160 self.tail.pop();
161 }
162 out.push(String::from_utf8_lossy(&self.tail).into_owned());
163 self.tail.clear();
164 } else {
165 self.tail.push(byte);
166 }
167 }
168 out
169 }
170
171 pub fn flush(&mut self) -> Option<String> {
175 if self.tail.is_empty() {
176 return None;
177 }
178 if self.tail.last() == Some(&b'\r') {
180 self.tail.pop();
181 }
182 let line = String::from_utf8_lossy(&self.tail).into_owned();
183 self.tail.clear();
184 Some(line)
185 }
186
187 pub fn is_empty(&self) -> bool {
189 self.tail.is_empty()
190 }
191}
192
193#[derive(Debug, Clone, Default, PartialEq, Eq)]
204pub struct SseEvent {
205 pub event: Option<String>,
206 pub id: Option<String>,
207 pub data: Option<String>,
208 pub retry_ms: Option<u64>,
209}
210
211impl SseEvent {
212 pub fn is_empty(&self) -> bool {
215 self.event.is_none()
216 && self.id.is_none()
217 && self.data.is_none()
218 && self.retry_ms.is_none()
219 }
220}
221
222#[derive(Debug, Default)]
226pub struct SseEventParser {
227 current: SseEvent,
228 data_acc: Vec<String>,
230}
231
232impl SseEventParser {
233 pub fn new() -> Self {
234 Self::default()
235 }
236
237 pub fn push_line(&mut self, line: &str) -> Option<SseEvent> {
250 if line.is_empty() {
253 if !self.data_acc.is_empty() {
255 self.current.data = Some(self.data_acc.join("\n"));
256 self.data_acc.clear();
257 }
258 let event = std::mem::take(&mut self.current);
259 return if event.is_empty() { None } else { Some(event) };
260 }
261
262 if line.starts_with(':') {
264 return None;
265 }
266
267 let (field, raw_value) = match line.find(':') {
271 Some(idx) => (&line[..idx], &line[idx + 1..]),
272 None => (line, ""),
273 };
274 let value = raw_value.strip_prefix(' ').unwrap_or(raw_value);
276
277 match field {
278 "event" => self.current.event = Some(value.to_string()),
279 "id" => self.current.id = Some(value.to_string()),
280 "data" => self.data_acc.push(value.to_string()),
281 "retry" => {
282 if let Ok(ms) = value.parse::<u64>() {
283 self.current.retry_ms = Some(ms);
284 }
285 }
286 _ => {
287 }
289 }
290 None
291 }
292
293 pub fn flush(&mut self) -> Option<SseEvent> {
298 if !self.data_acc.is_empty() {
299 self.current.data = Some(self.data_acc.join("\n"));
300 self.data_acc.clear();
301 }
302 let event = std::mem::take(&mut self.current);
303 if event.is_empty() {
304 None
305 } else {
306 Some(event)
307 }
308 }
309}
310
311pub struct LineStream<S> {
319 inner: S,
320 buffer: LineBuffer,
321 pending: std::collections::VecDeque<String>,
325 done: bool,
328 provider: String,
332 model: String,
333}
334
335impl<S> LineStream<S> {
336 pub fn new(inner: S, provider: impl Into<String>, model: impl Into<String>) -> Self {
337 Self {
338 inner,
339 buffer: LineBuffer::new(),
340 pending: std::collections::VecDeque::new(),
341 done: false,
342 provider: provider.into(),
343 model: model.into(),
344 }
345 }
346}
347
348impl<S> Stream for LineStream<S>
349where
350 S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
351{
352 type Item = Result<String, BackendError>;
353
354 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
355 loop {
356 if let Some(line) = self.pending.pop_front() {
358 return Poll::Ready(Some(Ok(line)));
359 }
360 if self.done {
361 return Poll::Ready(None);
362 }
363 match self.inner.poll_next_unpin(cx) {
364 Poll::Ready(Some(Ok(chunk))) => {
365 let lines = self.buffer.push(&chunk);
366 self.pending.extend(lines);
367 }
369 Poll::Ready(Some(Err(e))) => {
370 self.done = true;
371 return Poll::Ready(Some(Err(BackendError::Generic {
372 provider: self.provider.clone(),
373 model: self.model.clone(),
374 status: None,
375 message: format!("stream transport error: {e}"),
376 })));
377 }
378 Poll::Ready(None) => {
379 if let Some(tail) = self.buffer.flush() {
381 self.pending.push_back(tail);
382 }
383 self.done = true;
384 }
386 Poll::Pending => return Poll::Pending,
387 }
388 }
389 }
390}
391
392pub struct SseEventStream<S> {
399 line_stream: LineStream<S>,
400 parser: SseEventParser,
401 done: bool,
404 flushed: bool,
407}
408
409impl<S> SseEventStream<S> {
410 pub fn new(
411 inner: S,
412 provider: impl Into<String>,
413 model: impl Into<String>,
414 ) -> Self {
415 Self {
416 line_stream: LineStream::new(inner, provider, model),
417 parser: SseEventParser::new(),
418 done: false,
419 flushed: false,
420 }
421 }
422}
423
424impl<S> Stream for SseEventStream<S>
425where
426 S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
427{
428 type Item = Result<SseEvent, BackendError>;
429
430 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
431 loop {
432 if self.done {
433 return Poll::Ready(None);
434 }
435 match self.line_stream.poll_next_unpin(cx) {
436 Poll::Ready(Some(Ok(line))) => {
437 if let Some(event) = self.parser.push_line(&line) {
438 return Poll::Ready(Some(Ok(event)));
439 }
440 }
442 Poll::Ready(Some(Err(e))) => {
443 self.done = true;
444 return Poll::Ready(Some(Err(e)));
445 }
446 Poll::Ready(None) => {
447 if !self.flushed {
448 self.flushed = true;
449 if let Some(event) = self.parser.flush() {
450 return Poll::Ready(Some(Ok(event)));
451 }
452 }
453 self.done = true;
454 return Poll::Ready(None);
455 }
456 Poll::Pending => return Poll::Pending,
457 }
458 }
459 }
460}
461
462pub fn line_stream(
469 response: reqwest::Response,
470 provider: impl Into<String>,
471 model: impl Into<String>,
472) -> LineStream<impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin> {
473 LineStream::new(Box::pin(response.bytes_stream()), provider, model)
474}
475
476pub fn sse_event_stream(
479 response: reqwest::Response,
480 provider: impl Into<String>,
481 model: impl Into<String>,
482) -> SseEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin> {
483 SseEventStream::new(Box::pin(response.bytes_stream()), provider, model)
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
497 fn line_buffer_yields_complete_lf_lines() {
498 let mut buf = LineBuffer::new();
499 let lines = buf.push(b"hello\nworld\n");
500 assert_eq!(lines, vec!["hello", "world"]);
501 assert!(buf.is_empty());
502 }
503
504 #[test]
505 fn line_buffer_holds_partial_line_until_lf() {
506 let mut buf = LineBuffer::new();
507 let lines = buf.push(b"hello");
508 assert!(lines.is_empty());
509 assert!(!buf.is_empty());
510 let lines = buf.push(b" world\n");
511 assert_eq!(lines, vec!["hello world"]);
512 }
513
514 #[test]
515 fn line_buffer_normalizes_crlf() {
516 let mut buf = LineBuffer::new();
517 let lines = buf.push(b"hello\r\nworld\r\n");
518 assert_eq!(lines, vec!["hello", "world"]);
519 }
520
521 #[test]
522 fn line_buffer_splits_chunk_across_pushes() {
523 let mut buf = LineBuffer::new();
524 let lines = buf.push(b"hel");
525 assert!(lines.is_empty());
526 let lines = buf.push(b"lo\nwor");
527 assert_eq!(lines, vec!["hello"]);
528 let lines = buf.push(b"ld\n");
529 assert_eq!(lines, vec!["world"]);
530 }
531
532 #[test]
533 fn line_buffer_flush_returns_trailing_fragment() {
534 let mut buf = LineBuffer::new();
535 let _ = buf.push(b"complete\nincomplete");
536 let tail = buf.flush();
537 assert_eq!(tail, Some("incomplete".to_string()));
538 assert!(buf.is_empty());
539 }
540
541 #[test]
542 fn line_buffer_flush_on_empty_returns_none() {
543 let mut buf = LineBuffer::new();
544 assert_eq!(buf.flush(), None);
545 }
546
547 #[test]
548 fn line_buffer_empty_chunk_is_noop() {
549 let mut buf = LineBuffer::new();
550 let lines = buf.push(b"");
551 assert!(lines.is_empty());
552 assert!(buf.is_empty());
553 }
554
555 #[test]
556 fn line_buffer_handles_consecutive_lf() {
557 let mut buf = LineBuffer::new();
558 let lines = buf.push(b"a\n\nb\n");
559 assert_eq!(lines, vec!["a", "", "b"]);
560 }
561
562 #[test]
565 fn sse_parser_data_only_event() {
566 let mut p = SseEventParser::new();
567 assert!(p.push_line("data: hello").is_none());
568 let ev = p.push_line("").expect("event dispatched on blank");
569 assert_eq!(ev.data, Some("hello".to_string()));
570 assert!(ev.event.is_none());
571 }
572
573 #[test]
574 fn sse_parser_full_event_shape() {
575 let mut p = SseEventParser::new();
576 assert!(p.push_line("event: axon.token").is_none());
577 assert!(p.push_line("id: 42").is_none());
578 assert!(p.push_line("data: hello").is_none());
579 let ev = p.push_line("").expect("dispatched");
580 assert_eq!(ev.event, Some("axon.token".to_string()));
581 assert_eq!(ev.id, Some("42".to_string()));
582 assert_eq!(ev.data, Some("hello".to_string()));
583 }
584
585 #[test]
586 fn sse_parser_multi_line_data_joins_with_lf() {
587 let mut p = SseEventParser::new();
588 p.push_line("data: line1");
589 p.push_line("data: line2");
590 p.push_line("data: line3");
591 let ev = p.push_line("").expect("dispatched");
592 assert_eq!(ev.data, Some("line1\nline2\nline3".to_string()));
593 }
594
595 #[test]
596 fn sse_parser_retry_directive_parsed_to_u64() {
597 let mut p = SseEventParser::new();
598 p.push_line("retry: 5000");
599 let ev = p.push_line("").expect("dispatched");
600 assert_eq!(ev.retry_ms, Some(5000));
601 }
602
603 #[test]
604 fn sse_parser_retry_invalid_value_silently_ignored() {
605 let mut p = SseEventParser::new();
606 p.push_line("retry: not-a-number");
607 p.push_line("data: x");
608 let ev = p.push_line("").expect("dispatched");
609 assert_eq!(ev.retry_ms, None);
610 assert_eq!(ev.data, Some("x".to_string()));
611 }
612
613 #[test]
614 fn sse_parser_comment_lines_ignored() {
615 let mut p = SseEventParser::new();
616 p.push_line(": this is a comment");
617 p.push_line("data: visible");
618 let ev = p.push_line("").expect("dispatched");
619 assert_eq!(ev.data, Some("visible".to_string()));
620 }
621
622 #[test]
623 fn sse_parser_unknown_field_ignored() {
624 let mut p = SseEventParser::new();
625 p.push_line("bogus: ignored");
626 p.push_line("data: visible");
627 let ev = p.push_line("").expect("dispatched");
628 assert_eq!(ev.data, Some("visible".to_string()));
629 }
630
631 #[test]
632 fn sse_parser_consecutive_blank_lines_dont_dispatch_empty() {
633 let mut p = SseEventParser::new();
634 assert!(p.push_line("").is_none());
635 assert!(p.push_line("").is_none());
636 p.push_line("data: x");
637 let ev = p.push_line("").expect("dispatched");
638 assert_eq!(ev.data, Some("x".to_string()));
639 }
640
641 #[test]
642 fn sse_parser_field_without_space_after_colon() {
643 let mut p = SseEventParser::new();
645 p.push_line("data:nospace");
646 let ev = p.push_line("").expect("dispatched");
647 assert_eq!(ev.data, Some("nospace".to_string()));
648 }
649
650 #[test]
651 fn sse_parser_field_without_colon_still_parsed_as_empty_value() {
652 let mut p = SseEventParser::new();
654 p.push_line("data");
655 let ev = p.push_line("").expect("dispatched");
656 assert_eq!(ev.data, Some(String::new()));
657 }
658
659 #[test]
660 fn sse_parser_flush_yields_pending_event_on_eof() {
661 let mut p = SseEventParser::new();
662 p.push_line("data: trailing");
663 let ev = p.flush().expect("flush yields pending");
665 assert_eq!(ev.data, Some("trailing".to_string()));
666 }
667
668 #[test]
669 fn sse_parser_flush_on_clean_state_returns_none() {
670 let mut p = SseEventParser::new();
671 assert!(p.flush().is_none());
672 }
673
674 #[test]
675 fn sse_event_is_empty_predicate_total() {
676 let empty = SseEvent::default();
677 assert!(empty.is_empty());
678 let non_empty = SseEvent {
679 data: Some("x".into()),
680 ..Default::default()
681 };
682 assert!(!non_empty.is_empty());
683 }
684
685 use futures::stream;
688
689 fn fake_chunk_stream(
690 chunks: Vec<&'static [u8]>,
691 ) -> impl Stream<Item = Result<Bytes, reqwest::Error>> + Unpin {
692 Box::pin(stream::iter(
693 chunks.into_iter().map(|c| Ok(Bytes::from_static(c))),
694 ))
695 }
696
697 #[tokio::test]
698 async fn line_stream_yields_complete_lines_across_chunk_boundaries() {
699 let inner = fake_chunk_stream(vec![b"hel", b"lo\nwor", b"ld\n"]);
700 let stream = LineStream::new(inner, "test", "test-model");
701 let lines: Vec<String> = stream
702 .map(|r| r.unwrap())
703 .collect()
704 .await;
705 assert_eq!(lines, vec!["hello".to_string(), "world".to_string()]);
706 }
707
708 #[tokio::test]
709 async fn line_stream_flushes_trailing_fragment_on_eof() {
710 let inner = fake_chunk_stream(vec![b"a\nb"]);
711 let stream = LineStream::new(inner, "test", "test-model");
712 let lines: Vec<String> = stream
713 .map(|r| r.unwrap())
714 .collect()
715 .await;
716 assert_eq!(lines, vec!["a".to_string(), "b".to_string()]);
717 }
718
719 #[tokio::test]
720 async fn sse_event_stream_parses_canonical_openai_data_format() {
721 let inner = fake_chunk_stream(vec![
722 b"data: {\"chunk\":1}\n",
723 b"\n",
724 b"data: {\"chunk\":2}\n",
725 b"\n",
726 ]);
727 let stream = SseEventStream::new(inner, "openai", "gpt-4o-mini");
728 let events: Vec<SseEvent> = stream
729 .map(|r| r.unwrap())
730 .collect()
731 .await;
732 assert_eq!(events.len(), 2);
733 assert_eq!(events[0].data, Some(r#"{"chunk":1}"#.to_string()));
734 assert_eq!(events[1].data, Some(r#"{"chunk":2}"#.to_string()));
735 }
736
737 #[tokio::test]
738 async fn sse_event_stream_parses_anthropic_event_data_pairs() {
739 let inner = fake_chunk_stream(vec![
740 b"event: message_start\n",
741 b"data: {\"type\":\"message_start\"}\n",
742 b"\n",
743 b"event: content_block_delta\n",
744 b"data: {\"delta\":{\"text\":\"hi\"}}\n",
745 b"\n",
746 ]);
747 let stream = SseEventStream::new(inner, "anthropic", "claude-x");
748 let events: Vec<SseEvent> = stream
749 .map(|r| r.unwrap())
750 .collect()
751 .await;
752 assert_eq!(events.len(), 2);
753 assert_eq!(events[0].event.as_deref(), Some("message_start"));
754 assert_eq!(events[1].event.as_deref(), Some("content_block_delta"));
755 assert!(events[1].data.as_ref().unwrap().contains("hi"));
756 }
757
758 #[tokio::test]
759 async fn sse_event_stream_yields_final_event_without_trailing_blank() {
760 let inner = fake_chunk_stream(vec![
763 b"data: one\n\n",
764 b"data: two\n",
765 ]);
766 let stream = SseEventStream::new(inner, "test", "test-model");
767 let events: Vec<SseEvent> = stream
768 .map(|r| r.unwrap())
769 .collect()
770 .await;
771 assert_eq!(events.len(), 2);
772 assert_eq!(events[1].data, Some("two".to_string()));
773 }
774}