use std::pin::Pin;
use crate::core::{FinishReason, TokenBatch, TokenEmissionStats};
use futures_util::Stream;
use crate::providers::{ProviderError, ProviderErrorKind};
use crate::providers::{ProviderKind, ProviderResult, TokenUsage};
#[cfg(test)]
#[path = "../tests/providers/stream_tests.rs"]
mod stream_tests;
pub type ProviderStream<T> = Pin<Box<dyn Stream<Item = ProviderResult<T>> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub enum ProviderStreamEvent {
TokenBatch(TokenBatch),
Usage { usage: TokenUsage },
Finished { finish_reason: FinishReason },
}
pub(crate) struct TokenBatchBuilder {
request_id: Option<String>,
stream_id: u32,
sequence: u32,
stats: TokenEmissionStats,
}
impl TokenBatchBuilder {
pub(crate) fn new(request_id: Option<String>) -> Self {
Self {
request_id,
stream_id: 0,
sequence: 0,
stats: TokenEmissionStats::default(),
}
}
pub(crate) fn push_text(&mut self, text: &str) -> TokenBatch {
let byte_count = text.len() as u32;
self.stats.frames_sent += 1;
self.stats.bytes_sent += u64::from(byte_count);
self.stats.batches_sent += 1;
let batch = TokenBatch {
request_id: self.request_id.clone().unwrap_or_default(),
stream_id: self.stream_id,
sequence_start: self.sequence,
text: text.to_string(),
frame_count: 1,
byte_count,
stats: self.stats,
};
self.sequence = self.sequence.wrapping_add(1);
batch
}
}
const MAX_SSE_BUFFER: usize = 1 << 20;
const MAX_SSE_BUFFER_WITH_DELIMITER: usize = MAX_SSE_BUFFER + 4;
pub(crate) struct SseParser {
buffer: Vec<u8>,
provider: ProviderKind,
}
impl SseParser {
pub(crate) fn new(provider: ProviderKind) -> Self {
Self {
buffer: Vec::new(),
provider,
}
}
pub(crate) fn push(&mut self, mut bytes: &[u8]) -> ProviderResult<Vec<String>> {
let mut payloads = Vec::new();
while !bytes.is_empty() {
let available = MAX_SSE_BUFFER_WITH_DELIMITER.saturating_sub(self.buffer.len());
if available == 0 {
return Err(self.buffer_limit_error());
}
let take = bytes.len().min(available);
self.buffer.extend_from_slice(&bytes[..take]);
bytes = &bytes[take..];
while let Some((index, length)) = event_boundary(&self.buffer) {
let payload = self.decode_event(index)?;
self.buffer.drain(..index + length);
if let Some(payload) = payload {
payloads.push(payload);
}
}
if self.buffer.len() > MAX_SSE_BUFFER {
return Err(self.buffer_limit_error());
}
}
Ok(payloads)
}
pub(crate) fn finish(&mut self) -> ProviderResult<Vec<String>> {
if self.buffer.is_empty() {
return Ok(Vec::new());
}
let payload = self.decode_event(self.buffer.len())?;
self.buffer.clear();
Ok(payload.into_iter().collect())
}
fn decode_event(&self, end: usize) -> ProviderResult<Option<String>> {
let event = std::str::from_utf8(&self.buffer[..end]).map_err(|err| {
ProviderError::new(
ProviderErrorKind::Provider,
self.provider,
format!("invalid UTF-8 SSE event: {err}"),
)
})?;
Ok(sse_data_payload(event))
}
fn buffer_limit_error(&self) -> ProviderError {
ProviderError::new(
ProviderErrorKind::Provider,
self.provider,
"SSE event exceeded buffer limit without a boundary",
)
}
}
fn event_boundary(buffer: &[u8]) -> Option<(usize, usize)> {
match (
find_subslice(buffer, b"\r\n\r\n"),
find_subslice(buffer, b"\n\n"),
) {
(Some(crlf), Some(lf)) if crlf < lf => Some((crlf, 4)),
(Some(_), Some(lf)) => Some((lf, 2)),
(Some(crlf), None) => Some((crlf, 4)),
(None, Some(lf)) => Some((lf, 2)),
(None, None) => None,
}
}
fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
fn sse_data_payload(raw_event: &str) -> Option<String> {
let lines = raw_event
.lines()
.filter_map(|line| line.trim_end_matches('\r').strip_prefix("data:"))
.map(|data| data.strip_prefix(' ').unwrap_or(data))
.collect::<Vec<_>>();
if lines.is_empty() {
None
} else {
Some(lines.join("\n"))
}
}