mentra 0.6.0

An agent runtime for tool-using LLM applications
Documentation
use std::collections::HashSet;

use futures_util::StreamExt;
use tokio::sync::mpsc;

use crate::provider::model::{ProviderError, ProviderEvent, ProviderEventStream, TokenUsage};

use super::stream_model::AnthropicStreamEvent;

pub(crate) fn spawn_event_stream(response: reqwest::Response) -> ProviderEventStream {
    let (tx, rx) = mpsc::unbounded_channel();

    tokio::spawn(async move {
        if let Err(error) = forward_events(response, tx.clone()).await {
            let _ = tx.send(Err(error));
        }
    });

    rx
}

async fn forward_events(
    response: reqwest::Response,
    tx: mpsc::UnboundedSender<Result<ProviderEvent, ProviderError>>,
) -> Result<(), ProviderError> {
    let mut bytes_stream = response.bytes_stream();
    let mut buffer = Vec::new();
    let mut state = StreamState::default();

    while let Some(chunk) = bytes_stream.next().await {
        let chunk = chunk.map_err(ProviderError::Transport)?;
        buffer.extend_from_slice(&chunk);

        while let Some((frame_end, delimiter_len)) = find_frame_boundary(&buffer) {
            let frame = buffer.drain(..frame_end).collect::<Vec<_>>();
            buffer.drain(..delimiter_len);

            for event in parse_frame(&frame, &mut state)? {
                if tx.send(Ok(event)).is_err() {
                    return Ok(());
                }
            }
        }
    }

    if !buffer.is_empty() {
        for event in parse_frame(&buffer, &mut state)? {
            let _ = tx.send(Ok(event));
        }
    }

    Ok(())
}

#[derive(Default)]
struct StreamState {
    ignored_blocks: HashSet<usize>,
    latest_usage: Option<TokenUsage>,
}

fn parse_frame(frame: &[u8], state: &mut StreamState) -> Result<Vec<ProviderEvent>, ProviderError> {
    let frame = std::str::from_utf8(frame)
        .map_err(|error| ProviderError::MalformedStream(error.to_string()))?;
    let mut data_lines = Vec::new();

    for raw_line in frame.lines() {
        let line = raw_line.strip_suffix('\r').unwrap_or(raw_line);
        if line.is_empty() || line.starts_with(':') {
            continue;
        }

        if let Some(rest) = line.strip_prefix("data:") {
            data_lines.push(rest.trim_start().to_string());
        }
    }

    if data_lines.is_empty() {
        return Ok(Vec::new());
    }

    let data = data_lines.join("\n");
    let event: AnthropicStreamEvent =
        serde_json::from_str(&data).map_err(ProviderError::Deserialize)?;

    match &event {
        AnthropicStreamEvent::ContentBlockStart {
            index,
            content_block,
        } if !content_block.is_supported() => {
            state.ignored_blocks.insert(*index);
            return Ok(Vec::new());
        }
        AnthropicStreamEvent::ContentBlockDelta { index, .. }
        | AnthropicStreamEvent::ContentBlockStop { index }
            if state.ignored_blocks.contains(index) =>
        {
            if matches!(event, AnthropicStreamEvent::ContentBlockStop { .. }) {
                state.ignored_blocks.remove(index);
            }
            return Ok(Vec::new());
        }
        _ => {}
    }

    let events = event.into_provider_events().map_err(|error| {
        ProviderError::MalformedStream(format!(
            "anthropic stream error ({}): {}",
            error.kind, error.message
        ))
    })?;

    Ok(events
        .into_iter()
        .map(|event| match event {
            ProviderEvent::MessageDelta { stop_reason, usage } => {
                let usage = merge_usage(state.latest_usage.clone(), usage);
                state.latest_usage = usage.clone();
                ProviderEvent::MessageDelta { stop_reason, usage }
            }
            other => other,
        })
        .collect())
}

fn merge_usage(base: Option<TokenUsage>, update: Option<TokenUsage>) -> Option<TokenUsage> {
    match (base, update) {
        (Some(base), Some(update)) => {
            let merged = TokenUsage {
                input_tokens: update.input_tokens.or(base.input_tokens),
                output_tokens: update.output_tokens.or(base.output_tokens),
                total_tokens: update.total_tokens.or(base.total_tokens),
                cache_read_input_tokens: update
                    .cache_read_input_tokens
                    .or(base.cache_read_input_tokens),
                cache_creation_input_tokens: update
                    .cache_creation_input_tokens
                    .or(base.cache_creation_input_tokens),
                reasoning_tokens: update.reasoning_tokens.or(base.reasoning_tokens),
                thoughts_tokens: update.thoughts_tokens.or(base.thoughts_tokens),
                tool_input_tokens: update.tool_input_tokens.or(base.tool_input_tokens),
            };
            Some(merged)
        }
        (Some(base), None) => Some(base),
        (None, Some(update)) => Some(update),
        (None, None) => None,
    }
}

fn find_frame_boundary(buffer: &[u8]) -> Option<(usize, usize)> {
    for (index, window) in buffer.windows(2).enumerate() {
        if window == b"\n\n" {
            return Some((index, 2));
        }
    }

    for (index, window) in buffer.windows(4).enumerate() {
        if window == b"\r\n\r\n" {
            return Some((index, 4));
        }
    }

    None
}

#[cfg(test)]
mod tests {
    use super::{StreamState, parse_frame};
    use crate::provider::{ProviderEvent, Role, TokenUsage};

    #[test]
    fn merges_anthropic_usage_updates_into_cumulative_totals() {
        let mut state = StreamState::default();

        let started = parse_frame(
            br#"data: {"type":"message_start","message":{"id":"msg_1","model":"claude-sonnet","role":"assistant","content":[],"usage":{"input_tokens":10,"cache_read_input_tokens":2}}}"#,
            &mut state,
        )
        .expect("message start should parse");
        assert_eq!(
            started,
            vec![
                ProviderEvent::MessageStarted {
                    id: "msg_1".to_string(),
                    model: "claude-sonnet".to_string(),
                    role: Role::Assistant,
                },
                ProviderEvent::MessageDelta {
                    stop_reason: None,
                    usage: Some(TokenUsage {
                        input_tokens: Some(10),
                        output_tokens: None,
                        total_tokens: None,
                        cache_read_input_tokens: Some(2),
                        cache_creation_input_tokens: None,
                        reasoning_tokens: None,
                        thoughts_tokens: None,
                        tool_input_tokens: None,
                    }),
                },
            ]
        );

        let delta = parse_frame(
            br#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn","usage":{"output_tokens":3}}}"#,
            &mut state,
        )
        .expect("message delta should parse");
        assert_eq!(
            delta,
            vec![ProviderEvent::MessageDelta {
                stop_reason: Some("end_turn".to_string()),
                usage: Some(TokenUsage {
                    input_tokens: Some(10),
                    output_tokens: Some(3),
                    total_tokens: None,
                    cache_read_input_tokens: Some(2),
                    cache_creation_input_tokens: None,
                    reasoning_tokens: None,
                    thoughts_tokens: None,
                    tool_input_tokens: None,
                }),
            }]
        );
    }

    #[test]
    fn ignores_hosted_tool_search_bookkeeping_blocks() {
        let mut state = StreamState::default();

        let started = parse_frame(
            br#"data: {"type":"content_block_start","index":1,"content_block":{"type":"server_tool_use","id":"srvtoolu_1","name":"tool_search_tool_bm25"}}"#,
            &mut state,
        )
        .expect("server tool use should parse");
        assert!(started.is_empty());

        let delta = parse_frame(
            br#"data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"query\":\"weather\"}"}}"#,
            &mut state,
        )
        .expect("ignored delta should parse");
        assert!(delta.is_empty());

        let stopped = parse_frame(
            br#"data: {"type":"content_block_stop","index":1}"#,
            &mut state,
        )
        .expect("ignored stop should parse");
        assert!(stopped.is_empty());
    }
}