llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use std::{collections::HashMap, pin::Pin};

use bytes::Bytes;
use futures::stream::{Stream, StreamExt};

use crate::{
    chat::{StreamResponse, Usage},
    error::LLMError,
};

use super::events::{extract_payload, parse_event, ResponsesEvent as Event, ToolState};
use super::response_helpers::{
    stream_response_text, stream_response_tool_call, stream_response_usage, tool_call_from_state,
    tool_call_with_arguments,
};
use super::sse::SseEventBuffer;

pub(crate) fn create_responses_stream_responses(
    response: reqwest::Response,
    normalize_response: bool,
) -> Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>> {
    let stream = response
        .bytes_stream()
        .scan(
            ResponsesStreamResponseParser::new(normalize_response),
            |parser, chunk| futures::future::ready(Some(parser.handle_chunk(chunk))),
        )
        .flat_map(futures::stream::iter);
    Box::pin(stream)
}

struct ResponsesStreamResponseParser {
    sse_buffer: SseEventBuffer,
    results: Vec<Result<StreamResponse, LLMError>>,
    tool_states: HashMap<String, ToolState>,
    normalize_response: bool,
}

impl ResponsesStreamResponseParser {
    fn new(normalize_response: bool) -> Self {
        Self {
            sse_buffer: SseEventBuffer::new(),
            results: Vec::new(),
            tool_states: HashMap::new(),
            normalize_response,
        }
    }

    fn handle_chunk(
        &mut self,
        chunk: Result<Bytes, reqwest::Error>,
    ) -> Vec<Result<StreamResponse, LLMError>> {
        match chunk {
            Ok(bytes) => self.handle_bytes(&bytes),
            Err(err) => vec![Err(LLMError::HttpError(err.to_string()))],
        }
    }

    fn handle_bytes(&mut self, bytes: &[u8]) -> Vec<Result<StreamResponse, LLMError>> {
        self.sse_buffer.push_bytes(bytes);
        for event in self.sse_buffer.drain_events() {
            self.parse_event(&event);
        }
        self.results.drain(..).collect()
    }

    fn parse_event(&mut self, event: &str) {
        let payload = match extract_payload(event) {
            Some(payload) => payload,
            None => return,
        };
        match parse_event(&payload) {
            Ok(Some(event)) => self.handle_event(event),
            Ok(None) => {}
            Err(err) => self.results.push(Err(err)),
        }
    }

    fn handle_event(&mut self, event: Event) {
        match event {
            Event::OutputTextDelta { delta } => self.handle_output_text(delta),
            Event::FunctionCallAdded {
                item_id,
                call_id,
                name,
                output_index,
            } => self.handle_function_call_added(item_id, call_id, name, output_index),
            Event::FunctionCallDelta {
                item_id,
                delta,
                output_index,
            } => self.handle_function_call_delta(item_id, delta, output_index),
            Event::FunctionCallDone {
                item_id,
                arguments,
                output_index,
            } => self.handle_function_call_done(item_id, arguments, output_index),
            Event::OutputItemDone { item_id, .. } => self.handle_output_item_done(&item_id),
            Event::ResponseCompleted { usage } => self.handle_response_completed(usage),
        }
    }

    fn handle_output_text(&mut self, delta: String) {
        self.results.push(Ok(stream_response_text(delta)));
    }

    fn handle_function_call_added(
        &mut self,
        item_id: String,
        call_id: String,
        name: String,
        output_index: usize,
    ) {
        let state = ToolState {
            call_id,
            name,
            arguments: String::new(),
            output_index,
        };
        if !self.normalize_response {
            self.results
                .push(Ok(stream_response_tool_call(tool_call_from_state(&state))));
        }
        self.tool_states.insert(item_id, state);
    }

    fn handle_function_call_delta(&mut self, item_id: String, delta: String, _output_index: usize) {
        if let Some(state) = self.tool_states.get_mut(&item_id) {
            state.arguments.push_str(&delta);
            if !self.normalize_response {
                self.results
                    .push(Ok(stream_response_tool_call(tool_call_with_arguments(
                        state, &delta,
                    ))));
            }
        }
    }

    fn handle_function_call_done(
        &mut self,
        item_id: String,
        arguments: String,
        _output_index: usize,
    ) {
        if let Some(mut state) = self.tool_states.remove(&item_id) {
            if !arguments.is_empty() {
                state.arguments = arguments;
            }
            if self.normalize_response {
                self.results
                    .push(Ok(stream_response_tool_call(tool_call_with_arguments(
                        &state,
                        &state.arguments,
                    ))));
            }
        }
    }

    fn handle_output_item_done(&mut self, item_id: &str) {
        if self.normalize_response {
            self.finish_tool_call(item_id);
        } else {
            self.tool_states.remove(item_id);
        }
    }

    fn handle_response_completed(&mut self, usage: Option<Usage>) {
        if self.normalize_response {
            self.finish_all_tool_calls();
        } else {
            self.tool_states.clear();
        }
        if let Some(usage) = usage {
            self.results.push(Ok(stream_response_usage(usage)));
        }
    }

    fn finish_tool_call(&mut self, item_id: &str) {
        if let Some(state) = self.tool_states.remove(item_id) {
            self.results
                .push(Ok(stream_response_tool_call(tool_call_with_arguments(
                    &state,
                    &state.arguments,
                ))));
        }
    }

    fn finish_all_tool_calls(&mut self) {
        for (_, state) in self.tool_states.drain() {
            self.results
                .push(Ok(stream_response_tool_call(tool_call_with_arguments(
                    &state,
                    &state.arguments,
                ))));
        }
    }
}

#[cfg(test)]
mod tests;