sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
use std::pin::Pin;

use crate::core::{FinishReason, TokenBatch, TokenEmissionStats};
use futures_util::Stream;

use crate::providers::{ProviderError, ProviderErrorKind};
use crate::providers::{ProviderKind, ProviderResult, TokenUsage};

/////////////////////////////////////////////////////////////////////////////////
/// TESTS
/////////////////////////////////////////////////////////////////////////////////
#[cfg(test)]
#[path = "../tests/providers/stream_tests.rs"]
mod stream_tests;

/////////////////////////////////////////////////////////////////////////////////
/// SRC
/////////////////////////////////////////////////////////////////////////////////
pub type ProviderStream<T> = Pin<Box<dyn Stream<Item = ProviderResult<T>> + Send>>;

#[derive(Debug, Clone, PartialEq)]
pub enum ProviderStreamEvent {
    TokenBatch(TokenBatch),
    Usage { usage: TokenUsage },
    Finished { finish_reason: FinishReason },
}

pub(crate) struct TokenBatchBuilder {
    request_id: Option<String>,
    stream_id: u32,
    sequence: u32,
    stats: TokenEmissionStats,
}

impl TokenBatchBuilder {
    pub(crate) fn new(request_id: Option<String>) -> Self {
        Self {
            request_id,
            stream_id: 0,
            sequence: 0,
            stats: TokenEmissionStats::default(),
        }
    }

    pub(crate) fn push_text(&mut self, text: &str) -> TokenBatch {
        let byte_count = text.len() as u32;
        self.stats.frames_sent += 1;
        self.stats.bytes_sent += u64::from(byte_count);
        self.stats.batches_sent += 1;

        let batch = TokenBatch {
            request_id: self.request_id.clone().unwrap_or_default(),
            stream_id: self.stream_id,
            sequence_start: self.sequence,
            text: text.to_string(),
            frame_count: 1,
            byte_count,
            stats: self.stats,
        };
        self.sequence = self.sequence.wrapping_add(1);
        batch
    }
}

/// Maximum bytes buffered while waiting for an SSE event boundary. Guards
/// against a server that streams without `\n\n` delimiters.
const MAX_SSE_BUFFER: usize = 1 << 20;
const MAX_SSE_BUFFER_WITH_DELIMITER: usize = MAX_SSE_BUFFER + 4;

pub(crate) struct SseParser {
    buffer: Vec<u8>,
    provider: ProviderKind,
}

impl SseParser {
    pub(crate) fn new(provider: ProviderKind) -> Self {
        Self {
            buffer: Vec::new(),
            provider,
        }
    }

    /// Accept a raw network chunk and return any complete SSE `data:` payloads.
    /// Bytes are buffered, so a multi-byte UTF-8 character split across chunk
    /// boundaries is only decoded once the whole event has arrived.
    pub(crate) fn push(&mut self, mut bytes: &[u8]) -> ProviderResult<Vec<String>> {
        let mut payloads = Vec::new();

        while !bytes.is_empty() {
            let available = MAX_SSE_BUFFER_WITH_DELIMITER.saturating_sub(self.buffer.len());
            if available == 0 {
                return Err(self.buffer_limit_error());
            }

            let take = bytes.len().min(available);
            self.buffer.extend_from_slice(&bytes[..take]);
            bytes = &bytes[take..];

            while let Some((index, length)) = event_boundary(&self.buffer) {
                let payload = self.decode_event(index)?;
                self.buffer.drain(..index + length);
                if let Some(payload) = payload {
                    payloads.push(payload);
                }
            }

            if self.buffer.len() > MAX_SSE_BUFFER {
                return Err(self.buffer_limit_error());
            }
        }

        Ok(payloads)
    }

    /// Flush any trailing event the stream ended without a blank-line boundary.
    /// A complete final `data:` line is still delivered; a genuinely truncated
    /// multi-byte sequence surfaces here as an error rather than silently.
    pub(crate) fn finish(&mut self) -> ProviderResult<Vec<String>> {
        if self.buffer.is_empty() {
            return Ok(Vec::new());
        }
        let payload = self.decode_event(self.buffer.len())?;
        self.buffer.clear();
        Ok(payload.into_iter().collect())
    }

    /// Decode `self.buffer[..end]` as one UTF-8 SSE event and extract its `data:`
    /// payload. `end` always lands on an event boundary (an ASCII `\n`), so the
    /// slice is a complete UTF-8 region.
    fn decode_event(&self, end: usize) -> ProviderResult<Option<String>> {
        let event = std::str::from_utf8(&self.buffer[..end]).map_err(|err| {
            ProviderError::new(
                ProviderErrorKind::Provider,
                self.provider,
                format!("invalid UTF-8 SSE event: {err}"),
            )
        })?;
        Ok(sse_data_payload(event))
    }

    fn buffer_limit_error(&self) -> ProviderError {
        ProviderError::new(
            ProviderErrorKind::Provider,
            self.provider,
            "SSE event exceeded buffer limit without a boundary",
        )
    }
}

fn event_boundary(buffer: &[u8]) -> Option<(usize, usize)> {
    match (
        find_subslice(buffer, b"\r\n\r\n"),
        find_subslice(buffer, b"\n\n"),
    ) {
        (Some(crlf), Some(lf)) if crlf < lf => Some((crlf, 4)),
        (Some(_), Some(lf)) => Some((lf, 2)),
        (Some(crlf), None) => Some((crlf, 4)),
        (None, Some(lf)) => Some((lf, 2)),
        (None, None) => None,
    }
}

fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
    haystack
        .windows(needle.len())
        .position(|window| window == needle)
}

fn sse_data_payload(raw_event: &str) -> Option<String> {
    let lines = raw_event
        .lines()
        .filter_map(|line| line.trim_end_matches('\r').strip_prefix("data:"))
        .map(|data| data.strip_prefix(' ').unwrap_or(data))
        .collect::<Vec<_>>();

    if lines.is_empty() {
        None
    } else {
        Some(lines.join("\n"))
    }
}