llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use std::pin::Pin;

use bytes::Bytes;
use futures::stream::{Stream, StreamExt};

use crate::error::LLMError;

const SSE_DELIMITER: &str = "\n\n";

pub(crate) fn create_sse_stream<F>(
    response: reqwest::Response,
    parser: F,
) -> Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
where
    F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
{
    let stream = response
        .bytes_stream()
        .scan(SseState::default(), move |state, chunk| {
            let results = handle_chunk(state, chunk, &parser);
            async move { Some(results) }
        })
        .flat_map(futures::stream::iter);

    Box::pin(stream)
}

#[derive(Default)]
struct SseState {
    buffer: String,
    utf8_buffer: Vec<u8>,
}

fn handle_chunk<F>(
    state: &mut SseState,
    chunk: Result<Bytes, reqwest::Error>,
    parser: &F,
) -> Vec<Result<String, LLMError>>
where
    F: Fn(&str) -> Result<Option<String>, LLMError>,
{
    let bytes = match chunk {
        Ok(bytes) => bytes,
        Err(err) => return vec![Err(LLMError::HttpError(err.to_string()))],
    };

    state.push_bytes(&bytes);
    state.drain_events(parser)
}

impl SseState {
    fn push_bytes(&mut self, bytes: &[u8]) {
        self.utf8_buffer.extend_from_slice(bytes);
        match std::str::from_utf8(&self.utf8_buffer) {
            Ok(text) => {
                self.buffer.push_str(text);
                self.utf8_buffer.clear();
            }
            Err(err) => self.consume_valid_prefix(err.valid_up_to()),
        }
    }

    fn consume_valid_prefix(&mut self, valid_up_to: usize) {
        if valid_up_to == 0 {
            return;
        }

        let valid = String::from_utf8_lossy(&self.utf8_buffer[..valid_up_to]);
        self.buffer.push_str(&valid);
        self.utf8_buffer.drain(..valid_up_to);
    }

    fn drain_events<F>(&mut self, parser: &F) -> Vec<Result<String, LLMError>>
    where
        F: Fn(&str) -> Result<Option<String>, LLMError>,
    {
        let mut results = Vec::new();
        while let Some(event) = self.next_event() {
            match parser(&event) {
                Ok(Some(content)) => results.push(Ok(content)),
                Ok(None) => {}
                Err(err) => results.push(Err(err)),
            }
        }
        results
    }

    fn next_event(&mut self) -> Option<String> {
        let pos = self.buffer.find(SSE_DELIMITER)?;
        let end = pos + SSE_DELIMITER.len();
        let event = self.buffer[..end].to_string();
        self.buffer.drain(..end);
        Some(event)
    }
}

#[cfg(test)]
#[path = "sse_tests.rs"]
mod tests;