mentra 0.6.0

An agent runtime for tool-using LLM applications
Documentation
use tokio::sync::mpsc;

use super::{ContentBlock, ImageSource, ProviderError, Response, Role, TokenUsage};

pub type ProviderEventStream = mpsc::UnboundedReceiver<Result<ProviderEvent, ProviderError>>;

#[derive(Debug, Clone, PartialEq)]
pub enum ProviderEvent {
    MessageStarted {
        id: String,
        model: String,
        role: Role,
    },
    ContentBlockStarted {
        index: usize,
        kind: ContentBlockStart,
    },
    ContentBlockDelta {
        index: usize,
        delta: ContentBlockDelta,
    },
    ContentBlockStopped {
        index: usize,
    },
    MessageDelta {
        stop_reason: Option<String>,
        usage: Option<TokenUsage>,
    },
    MessageStopped,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContentBlockStart {
    Text,
    Image { source: ImageSource },
    ToolUse { id: String, name: String },
    ToolResult { tool_use_id: String, is_error: bool },
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContentBlockDelta {
    Text(String),
    ToolUseInputJson(String),
    ToolResultContent(String),
}

pub fn provider_event_stream_from_response(response: Response) -> ProviderEventStream {
    let events = response.into_provider_events();
    let (tx, rx) = mpsc::unbounded_channel();

    for event in events {
        if tx.send(Ok(event)).is_err() {
            break;
        }
    }

    rx
}

impl Response {
    pub fn into_provider_events(self) -> Vec<ProviderEvent> {
        let mut events = vec![ProviderEvent::MessageStarted {
            id: self.id,
            model: self.model,
            role: self.role,
        }];

        for (index, block) in self.content.into_iter().enumerate() {
            events.extend(block.into_provider_events(index));
        }

        events.push(ProviderEvent::MessageDelta {
            stop_reason: self.stop_reason,
            usage: self.usage,
        });
        events.push(ProviderEvent::MessageStopped);
        events
    }
}

impl ContentBlock {
    fn into_provider_events(self, index: usize) -> Vec<ProviderEvent> {
        match self {
            ContentBlock::Text { text } => {
                let mut events = vec![ProviderEvent::ContentBlockStarted {
                    index,
                    kind: ContentBlockStart::Text,
                }];
                if !text.is_empty() {
                    events.push(ProviderEvent::ContentBlockDelta {
                        index,
                        delta: ContentBlockDelta::Text(text),
                    });
                }
                events.push(ProviderEvent::ContentBlockStopped { index });
                events
            }
            ContentBlock::Image { source } => vec![
                ProviderEvent::ContentBlockStarted {
                    index,
                    kind: ContentBlockStart::Image { source },
                },
                ProviderEvent::ContentBlockStopped { index },
            ],
            ContentBlock::ToolUse { id, name, input } => {
                let mut events = vec![ProviderEvent::ContentBlockStarted {
                    index,
                    kind: ContentBlockStart::ToolUse { id, name },
                }];
                let input_json = input.to_string();
                if !input_json.is_empty() {
                    events.push(ProviderEvent::ContentBlockDelta {
                        index,
                        delta: ContentBlockDelta::ToolUseInputJson(input_json),
                    });
                }
                events.push(ProviderEvent::ContentBlockStopped { index });
                events
            }
            ContentBlock::ToolResult {
                tool_use_id,
                content,
                is_error,
            } => {
                let mut events = vec![ProviderEvent::ContentBlockStarted {
                    index,
                    kind: ContentBlockStart::ToolResult {
                        tool_use_id,
                        is_error,
                    },
                }];
                if !content.is_empty() {
                    events.push(ProviderEvent::ContentBlockDelta {
                        index,
                        delta: ContentBlockDelta::ToolResultContent(content),
                    });
                }
                events.push(ProviderEvent::ContentBlockStopped { index });
                events
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn response_round_trip_preserves_usage() {
        let response = Response {
            id: "resp-1".to_string(),
            model: "model".to_string(),
            role: Role::Assistant,
            content: vec![ContentBlock::text("hello")],
            stop_reason: Some("stop".to_string()),
            usage: Some(TokenUsage {
                input_tokens: Some(10),
                output_tokens: Some(3),
                total_tokens: Some(13),
                cache_read_input_tokens: Some(2),
                cache_creation_input_tokens: None,
                reasoning_tokens: Some(1),
                thoughts_tokens: None,
                tool_input_tokens: None,
            }),
        };

        let rebuilt = crate::provider::collect_response_from_stream(
            provider_event_stream_from_response(response.clone()),
        )
        .await
        .expect("response should rebuild");

        assert_eq!(rebuilt, response);
    }
}