mentra 0.6.0

An agent runtime for tool-using LLM applications
Documentation
use std::collections::BTreeMap;

use super::{
    ContentBlock, ContentBlockDelta, ContentBlockStart, ProviderError, ProviderEvent,
    ProviderEventStream, Response, Role, TokenUsage,
};

pub async fn collect_response_from_stream(
    mut stream: ProviderEventStream,
) -> Result<Response, ProviderError> {
    let mut builder = StreamingResponseBuilder::default();

    while let Some(event) = stream.recv().await {
        builder.apply(event?)?;
    }

    builder.build()
}

#[derive(Default)]
struct StreamingResponseBuilder {
    id: Option<String>,
    model: Option<String>,
    role: Option<Role>,
    blocks: BTreeMap<usize, StreamingContentBlock>,
    stop_reason: Option<String>,
    usage: Option<TokenUsage>,
    stopped: bool,
}

impl StreamingResponseBuilder {
    fn apply(&mut self, event: ProviderEvent) -> Result<(), ProviderError> {
        match event {
            ProviderEvent::MessageStarted { id, model, role } => {
                self.id = Some(id);
                self.model = Some(model);
                self.role = Some(role);
            }
            ProviderEvent::ContentBlockStarted { index, kind } => {
                self.blocks.insert(index, StreamingContentBlock::from(kind));
            }
            ProviderEvent::ContentBlockDelta { index, delta } => {
                let block = self.blocks.get_mut(&index).ok_or_else(|| {
                    ProviderError::MalformedStream(format!(
                        "content block delta received before start for index {index}"
                    ))
                })?;
                block.apply_delta(delta)?;
            }
            ProviderEvent::ContentBlockStopped { index } => {
                let block = self.blocks.get_mut(&index).ok_or_else(|| {
                    ProviderError::MalformedStream(format!(
                        "content block stop received before start for index {index}"
                    ))
                })?;
                block.mark_complete();
            }
            ProviderEvent::MessageDelta { stop_reason, usage } => {
                self.stop_reason = stop_reason;
                self.usage = usage;
            }
            ProviderEvent::MessageStopped => {
                self.stopped = true;
            }
        }

        Ok(())
    }

    fn build(self) -> Result<Response, ProviderError> {
        if !self.stopped {
            return Err(ProviderError::MalformedStream(
                "message stream ended before MessageStopped".to_string(),
            ));
        }

        let id = self
            .id
            .ok_or_else(|| ProviderError::MalformedStream("missing message id".to_string()))?;
        let model = self
            .model
            .ok_or_else(|| ProviderError::MalformedStream("missing model id".to_string()))?;
        let role = self
            .role
            .ok_or_else(|| ProviderError::MalformedStream("missing message role".to_string()))?;
        let mut content = Vec::with_capacity(self.blocks.len());

        for (index, block) in self.blocks {
            if !block.is_complete() {
                return Err(ProviderError::MalformedStream(format!(
                    "content block {index} did not complete"
                )));
            }
            content.push(block.try_into_content_block()?);
        }

        Ok(Response {
            id,
            model,
            role,
            content,
            stop_reason: self.stop_reason,
            usage: self.usage,
        })
    }
}

enum StreamingContentBlock {
    Text {
        text: String,
        complete: bool,
    },
    Image {
        source: super::ImageSource,
        complete: bool,
    },
    ToolUse {
        id: String,
        name: String,
        input_json: String,
        complete: bool,
    },
    ToolResult {
        tool_use_id: String,
        content: String,
        is_error: bool,
        complete: bool,
    },
}

impl StreamingContentBlock {
    fn apply_delta(&mut self, delta: ContentBlockDelta) -> Result<(), ProviderError> {
        match (self, delta) {
            (StreamingContentBlock::Text { text, .. }, ContentBlockDelta::Text(delta)) => {
                text.push_str(&delta);
                Ok(())
            }
            (
                StreamingContentBlock::ToolUse { input_json, .. },
                ContentBlockDelta::ToolUseInputJson(delta),
            ) => {
                input_json.push_str(&delta);
                Ok(())
            }
            (
                StreamingContentBlock::ToolResult { content, .. },
                ContentBlockDelta::ToolResultContent(delta),
            ) => {
                content.push_str(&delta);
                Ok(())
            }
            (block, delta) => Err(ProviderError::MalformedStream(format!(
                "delta {delta:?} is not valid for block {}",
                block.kind_name()
            ))),
        }
    }

    fn mark_complete(&mut self) {
        match self {
            StreamingContentBlock::Text { complete, .. }
            | StreamingContentBlock::Image { complete, .. }
            | StreamingContentBlock::ToolUse { complete, .. }
            | StreamingContentBlock::ToolResult { complete, .. } => *complete = true,
        }
    }

    fn is_complete(&self) -> bool {
        match self {
            StreamingContentBlock::Text { complete, .. }
            | StreamingContentBlock::Image { complete, .. }
            | StreamingContentBlock::ToolUse { complete, .. }
            | StreamingContentBlock::ToolResult { complete, .. } => *complete,
        }
    }

    fn try_into_content_block(self) -> Result<ContentBlock, ProviderError> {
        match self {
            StreamingContentBlock::Text { text, .. } => Ok(ContentBlock::Text { text }),
            StreamingContentBlock::Image { source, .. } => Ok(ContentBlock::Image { source }),
            StreamingContentBlock::ToolUse {
                id,
                name,
                input_json,
                ..
            } => Ok(ContentBlock::ToolUse {
                id,
                name,
                input: serde_json::from_str(&input_json).map_err(ProviderError::Deserialize)?,
            }),
            StreamingContentBlock::ToolResult {
                tool_use_id,
                content,
                is_error,
                ..
            } => Ok(ContentBlock::ToolResult {
                tool_use_id,
                content,
                is_error,
            }),
        }
    }

    fn kind_name(&self) -> &'static str {
        match self {
            StreamingContentBlock::Text { .. } => "text",
            StreamingContentBlock::Image { .. } => "image",
            StreamingContentBlock::ToolUse { .. } => "tool_use",
            StreamingContentBlock::ToolResult { .. } => "tool_result",
        }
    }
}

impl From<ContentBlockStart> for StreamingContentBlock {
    fn from(value: ContentBlockStart) -> Self {
        match value {
            ContentBlockStart::Text => StreamingContentBlock::Text {
                text: String::new(),
                complete: false,
            },
            ContentBlockStart::Image { source } => StreamingContentBlock::Image {
                source,
                complete: false,
            },
            ContentBlockStart::ToolUse { id, name } => StreamingContentBlock::ToolUse {
                id,
                name,
                input_json: String::new(),
                complete: false,
            },
            ContentBlockStart::ToolResult {
                tool_use_id,
                is_error,
            } => StreamingContentBlock::ToolResult {
                tool_use_id,
                content: String::new(),
                is_error,
                complete: false,
            },
        }
    }
}

#[cfg(test)]
mod tests {
    use tokio::sync::mpsc;

    use crate::provider::model::{ProviderEvent, TokenUsage};

    use super::collect_response_from_stream;

    #[tokio::test]
    async fn collect_response_keeps_latest_usage_update() {
        let (tx, rx) = mpsc::unbounded_channel();
        tx.send(Ok(ProviderEvent::MessageStarted {
            id: "resp-1".to_string(),
            model: "model".to_string(),
            role: crate::provider::Role::Assistant,
        }))
        .unwrap();
        tx.send(Ok(ProviderEvent::ContentBlockStarted {
            index: 0,
            kind: crate::provider::ContentBlockStart::Text,
        }))
        .unwrap();
        tx.send(Ok(ProviderEvent::ContentBlockDelta {
            index: 0,
            delta: crate::provider::ContentBlockDelta::Text("hello".to_string()),
        }))
        .unwrap();
        tx.send(Ok(ProviderEvent::ContentBlockStopped { index: 0 }))
            .unwrap();
        tx.send(Ok(ProviderEvent::MessageDelta {
            stop_reason: None,
            usage: Some(TokenUsage {
                input_tokens: Some(4),
                ..TokenUsage::default()
            }),
        }))
        .unwrap();
        tx.send(Ok(ProviderEvent::MessageDelta {
            stop_reason: Some("stop".to_string()),
            usage: Some(TokenUsage {
                input_tokens: Some(4),
                output_tokens: Some(2),
                total_tokens: Some(6),
                ..TokenUsage::default()
            }),
        }))
        .unwrap();
        tx.send(Ok(ProviderEvent::MessageStopped)).unwrap();
        drop(tx);

        let response = collect_response_from_stream(rx).await.unwrap();
        assert_eq!(
            response.usage,
            Some(TokenUsage {
                input_tokens: Some(4),
                output_tokens: Some(2),
                total_tokens: Some(6),
                ..TokenUsage::default()
            })
        );
        assert_eq!(response.stop_reason.as_deref(), Some("stop"));
    }
}