llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use std::collections::HashMap;
use std::pin::Pin;

use bytes::Bytes;
use futures::stream::{Stream, StreamExt};

use crate::chat::StreamChunk;
use crate::error::LLMError;
use crate::{FunctionCall, ToolCall};

use super::events::{extract_payload, parse_event, ResponsesEvent, ToolState};
use super::sse::SseEventBuffer;

pub(crate) fn create_responses_stream_chunks(
    response: reqwest::Response,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>> {
    let stream = response
        .bytes_stream()
        .scan(ResponsesStreamChunkParser::new(), |parser, chunk| {
            futures::future::ready(Some(parser.handle_chunk(chunk)))
        })
        .flat_map(futures::stream::iter);
    Box::pin(stream)
}

struct ResponsesStreamChunkParser {
    sse_buffer: SseEventBuffer,
    results: Vec<Result<StreamChunk, LLMError>>,
    tool_states: HashMap<String, ToolState>,
    saw_tool_call: bool,
}

impl ResponsesStreamChunkParser {
    fn new() -> Self {
        Self {
            sse_buffer: SseEventBuffer::new(),
            results: Vec::new(),
            tool_states: HashMap::new(),
            saw_tool_call: false,
        }
    }

    fn handle_chunk(
        &mut self,
        chunk: Result<Bytes, reqwest::Error>,
    ) -> Vec<Result<StreamChunk, LLMError>> {
        match chunk {
            Ok(bytes) => self.handle_bytes(&bytes),
            Err(err) => vec![Err(LLMError::HttpError(err.to_string()))],
        }
    }

    fn handle_bytes(&mut self, bytes: &[u8]) -> Vec<Result<StreamChunk, LLMError>> {
        self.sse_buffer.push_bytes(bytes);
        for event in self.sse_buffer.drain_events() {
            self.parse_event(&event);
        }
        self.results.drain(..).collect()
    }

    fn parse_event(&mut self, event: &str) {
        let payload = match extract_payload(event) {
            Some(payload) => payload,
            None => return,
        };
        match parse_event(&payload) {
            Ok(Some(event)) => self.handle_event(event),
            Ok(None) => {}
            Err(err) => self.results.push(Err(err)),
        }
    }

    fn handle_event(&mut self, event: ResponsesEvent) {
        match event {
            ResponsesEvent::OutputTextDelta { delta } => {
                self.results.push(Ok(StreamChunk::Text(delta)));
            }
            ResponsesEvent::FunctionCallAdded {
                item_id,
                call_id,
                name,
                output_index,
            } => self.handle_call_added(item_id, call_id, name, output_index),
            ResponsesEvent::FunctionCallDelta {
                item_id,
                delta,
                output_index,
            } => self.handle_call_delta(item_id, delta, output_index),
            ResponsesEvent::FunctionCallDone {
                item_id,
                arguments,
                output_index,
            } => self.handle_call_done(item_id, arguments, output_index),
            ResponsesEvent::OutputItemDone {
                item_id,
                output_index,
            } => self.handle_item_done(item_id, output_index),
            ResponsesEvent::ResponseCompleted { .. } => {
                self.handle_response_completed();
            }
        }
    }

    fn handle_call_added(
        &mut self,
        item_id: String,
        call_id: String,
        name: String,
        output_index: usize,
    ) {
        let state = ToolState {
            call_id: call_id.clone(),
            name: name.clone(),
            arguments: String::new(),
            output_index,
        };
        self.saw_tool_call = true;
        self.results.push(Ok(StreamChunk::ToolUseStart {
            index: output_index,
            id: call_id,
            name,
        }));
        self.tool_states.insert(item_id, state);
    }

    fn handle_call_delta(&mut self, item_id: String, delta: String, output_index: usize) {
        if let Some(state) = self.tool_states.get_mut(&item_id) {
            state.arguments.push_str(&delta);
            self.results.push(Ok(StreamChunk::ToolUseInputDelta {
                index: output_index,
                partial_json: delta,
            }));
        }
    }

    fn handle_call_done(&mut self, item_id: String, arguments: String, output_index: usize) {
        if let Some(mut state) = self.tool_states.remove(&item_id) {
            if !arguments.is_empty() {
                state.arguments = arguments;
            }
            self.results.push(Ok(StreamChunk::ToolUseComplete {
                index: output_index,
                tool_call: tool_call_with_arguments(&state, &state.arguments),
            }));
        }
    }

    fn handle_item_done(&mut self, item_id: String, output_index: usize) {
        if let Some(state) = self.tool_states.remove(&item_id) {
            self.results.push(Ok(StreamChunk::ToolUseComplete {
                index: output_index,
                tool_call: tool_call_with_arguments(&state, &state.arguments),
            }));
        }
    }

    fn handle_response_completed(&mut self) {
        self.flush_tool_states();
        let stop_reason = if self.saw_tool_call {
            "tool_use"
        } else {
            "end_turn"
        };
        self.results.push(Ok(StreamChunk::Done {
            stop_reason: stop_reason.to_string(),
        }));
    }

    fn flush_tool_states(&mut self) {
        for (_, state) in self.tool_states.drain() {
            self.results.push(Ok(StreamChunk::ToolUseComplete {
                index: state.output_index,
                tool_call: tool_call_with_arguments(&state, &state.arguments),
            }));
        }
    }
}

fn tool_call_with_arguments(state: &ToolState, arguments: &str) -> ToolCall {
    ToolCall {
        id: state.call_id.clone(),
        call_type: "function".to_string(),
        function: FunctionCall {
            name: state.name.clone(),
            arguments: arguments.to_string(),
        },
    }
}

#[cfg(test)]
mod tests;