mentra 0.6.0

An agent runtime for tool-using LLM applications
Documentation
use std::{
    collections::VecDeque,
    fs,
    sync::atomic::{AtomicU64, Ordering},
    time::{SystemTime, UNIX_EPOCH},
};

use async_trait::async_trait;
use mentra::runtime::SqliteRuntimeStore;
use mentra::{
    AgentConfig, BuiltinProvider, ContentBlock, Message, Role, Runtime,
    agent::{AgentEvent, AgentStatus},
    error::RuntimeError,
    provider::{
        ContentBlockDelta, ContentBlockStart, ModelInfo, Provider, ProviderDescriptor,
        ProviderError, ProviderEvent, ProviderEventStream, ProviderId, Request,
    },
};
use reqwest::StatusCode;
use tokio::sync::{Mutex, broadcast, mpsc};

static NEXT_TEMP_ID: AtomicU64 = AtomicU64::new(1);

enum StreamScript {
    Buffered(Vec<Result<ProviderEvent, ProviderError>>),
    Fail(ProviderError),
}

#[derive(Clone)]
struct ScriptedProvider {
    kind: ProviderId,
    models: Vec<ModelInfo>,
    scripts: std::sync::Arc<Mutex<VecDeque<StreamScript>>>,
    requests: std::sync::Arc<Mutex<Vec<Request<'static>>>>,
}

impl ScriptedProvider {
    fn new(
        kind: impl Into<ProviderId>,
        models: Vec<ModelInfo>,
        scripts: Vec<StreamScript>,
    ) -> Self {
        Self {
            kind: kind.into(),
            models,
            scripts: std::sync::Arc::new(Mutex::new(VecDeque::from(scripts))),
            requests: std::sync::Arc::new(Mutex::new(Vec::new())),
        }
    }

    async fn recorded_requests(&self) -> Vec<Request<'static>> {
        self.requests.lock().await.clone()
    }
}

#[async_trait]
impl Provider for ScriptedProvider {
    fn descriptor(&self) -> ProviderDescriptor {
        ProviderDescriptor::new(self.kind.clone())
    }

    async fn list_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
        Ok(self.models.clone())
    }

    async fn stream(&self, request: Request<'_>) -> Result<ProviderEventStream, ProviderError> {
        self.requests.lock().await.push(request.into_owned());
        match self.scripts.lock().await.pop_front() {
            Some(StreamScript::Buffered(items)) => {
                let (tx, rx) = mpsc::unbounded_channel();
                for item in items {
                    tx.send(item)
                        .expect("test stream receiver dropped unexpectedly");
                }
                Ok(rx)
            }
            Some(StreamScript::Fail(error)) => Err(error),
            None => panic!("no scripted stream available"),
        }
    }
}

#[tokio::test]
async fn send_streamed_text_turn_emits_events_and_commits_history() {
    let model = model_info("model", BuiltinProvider::Anthropic);
    let provider = ScriptedProvider::new(
        BuiltinProvider::Anthropic,
        vec![model.clone()],
        vec![text_stream(&model.id, "Hello")],
    );

    let runtime = test_runtime(provider);
    let mut agent = runtime
        .spawn_with_config(
            "agent",
            model,
            AgentConfig {
                system: Some("system prompt".to_string()),
                ..AgentConfig::default()
            },
        )
        .unwrap();
    let mut events = agent.subscribe_events();

    let message = agent.send(vec![ContentBlock::text("hi")]).await.unwrap();

    assert_eq!(message, Message::assistant(ContentBlock::text("Hello")));
    assert_eq!(agent.name(), "agent");
    assert_eq!(agent.model(), "model");
    assert_eq!(agent.history().len(), 2);
    assert_eq!(agent.config().system.as_deref(), Some("system prompt"));
    assert_eq!(
        agent.last_message(),
        Some(&Message::assistant(ContentBlock::text("Hello")))
    );

    let events = collect_events(&mut events);
    assert!(events.contains(&AgentEvent::RunStarted));
    assert!(events.contains(&AgentEvent::TextDelta {
        delta: "Hello".to_string(),
        full_text: "Hello".to_string(),
    }));
    assert!(matches!(events.last(), Some(AgentEvent::RunFinished)));

    let snapshot = agent.watch_snapshot();
    assert_eq!(snapshot.borrow().status, AgentStatus::Finished);
    assert_eq!(snapshot.borrow().history_len, 2);
    assert!(snapshot.borrow().current_text.is_empty());
    assert!(snapshot.borrow().pending_tool_uses.is_empty());
}

#[tokio::test]
async fn send_failure_rolls_history_back_and_emits_run_failed() {
    let model = model_info("model", BuiltinProvider::Anthropic);
    let provider = ScriptedProvider::new(
        BuiltinProvider::Anthropic,
        vec![model.clone()],
        vec![
            text_stream(&model.id, "ok"),
            erroring_stream(
                vec![ProviderEvent::MessageStarted {
                    id: "msg-2".to_string(),
                    model: model.id.clone(),
                    role: Role::Assistant,
                }],
                ProviderError::MalformedStream("boom".to_string()),
            ),
        ],
    );

    let runtime = test_runtime(provider);
    let mut agent = runtime.spawn("agent", model).unwrap();
    agent.send(vec![ContentBlock::text("first")]).await.unwrap();
    let baseline = agent.history().to_vec();
    let mut events = agent.subscribe_events();

    let result = agent.send(vec![ContentBlock::text("second")]).await;
    assert!(result.is_err());
    assert_eq!(agent.history(), baseline.as_slice());

    let events = collect_events(&mut events);
    assert!(matches!(events.last(), Some(AgentEvent::RunFailed { .. })));
}

#[tokio::test]
async fn send_retries_transient_provider_error_before_streaming() {
    let model = model_info("model", BuiltinProvider::Anthropic);
    let provider = ScriptedProvider::new(
        BuiltinProvider::Anthropic,
        vec![model.clone()],
        vec![
            failed_request(ProviderError::Http {
                status: StatusCode::SERVICE_UNAVAILABLE,
                body: "offline".to_string(),
            }),
            text_stream(&model.id, "recovered"),
        ],
    );
    let provider_handle = provider.clone();

    let runtime = test_runtime(provider);
    let mut agent = runtime.spawn("agent", model).expect("spawn agent");

    let message = agent
        .send(vec![ContentBlock::text("hello")])
        .await
        .expect("send should retry");

    assert_eq!(message.text(), "recovered");
    assert_eq!(provider_handle.recorded_requests().await.len(), 2);
    assert_eq!(
        agent.last_message(),
        Some(&Message::assistant(ContentBlock::text("recovered")))
    );
}

#[tokio::test]
async fn resume_replays_last_failed_turn() {
    let model = model_info("model", BuiltinProvider::Anthropic);
    let provider = ScriptedProvider::new(
        BuiltinProvider::Anthropic,
        vec![model.clone()],
        vec![
            erroring_stream(
                vec![ProviderEvent::MessageStarted {
                    id: "msg-1".to_string(),
                    model: model.id.clone(),
                    role: Role::Assistant,
                }],
                ProviderError::MalformedStream("boom".to_string()),
            ),
            text_stream(&model.id, "done"),
        ],
    );

    let runtime = test_runtime(provider);
    let mut agent = runtime.spawn("agent", model).expect("spawn agent");

    let error = agent
        .send(vec![ContentBlock::text("retry me")])
        .await
        .expect_err("first send should fail");
    assert!(matches!(error, RuntimeError::FailedToStreamResponse(_)));
    assert!(agent.history().is_empty());

    let resumed = agent
        .resume()
        .await
        .expect("resume should replay user turn");

    assert_eq!(resumed.text(), "done");
    assert_eq!(agent.history().len(), 2);
    assert_eq!(
        agent.history()[0],
        Message::user(ContentBlock::text("retry me"))
    );
    assert_eq!(
        agent.history()[1],
        Message::assistant(ContentBlock::text("done"))
    );

    let error = agent
        .resume()
        .await
        .expect_err("successful run clears resume state");
    assert!(matches!(error, RuntimeError::NoResumableTurn));
}

fn collect_events(receiver: &mut broadcast::Receiver<AgentEvent>) -> Vec<AgentEvent> {
    let mut events = Vec::new();
    while let Ok(event) = receiver.try_recv() {
        events.push(event);
    }
    events
}

fn test_runtime(provider: ScriptedProvider) -> Runtime {
    Runtime::empty_builder()
        .with_provider_instance(provider)
        .with_store(temp_store("agent-runtime"))
        .build()
        .expect("build runtime")
}

fn temp_store(label: &str) -> SqliteRuntimeStore {
    let unique = NEXT_TEMP_ID.fetch_add(1, Ordering::Relaxed);
    let timestamp = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("system time")
        .as_nanos();
    let path = std::env::temp_dir().join(format!(
        "mentra-agent-runtime-{label}-{timestamp}-{unique}.sqlite"
    ));
    if let Some(parent) = path.parent() {
        fs::create_dir_all(parent).expect("create temp dir");
    }
    SqliteRuntimeStore::new(path)
}

fn model_info(id: &str, provider: impl Into<ProviderId>) -> ModelInfo {
    ModelInfo::new(id, provider)
}

fn buffered_stream(events: Vec<ProviderEvent>) -> StreamScript {
    StreamScript::Buffered(events.into_iter().map(Ok).collect())
}

fn erroring_stream(mut events: Vec<ProviderEvent>, error: ProviderError) -> StreamScript {
    let mut items = events.drain(..).map(Ok).collect::<Vec<_>>();
    items.push(Err(error));
    StreamScript::Buffered(items)
}

fn failed_request(error: ProviderError) -> StreamScript {
    StreamScript::Fail(error)
}

fn text_stream(model: &str, text: &str) -> StreamScript {
    buffered_stream(vec![
        ProviderEvent::MessageStarted {
            id: "msg-text".to_string(),
            model: model.to_string(),
            role: Role::Assistant,
        },
        ProviderEvent::ContentBlockStarted {
            index: 0,
            kind: ContentBlockStart::Text,
        },
        ProviderEvent::ContentBlockDelta {
            index: 0,
            delta: ContentBlockDelta::Text(text.to_string()),
        },
        ProviderEvent::ContentBlockStopped { index: 0 },
        ProviderEvent::MessageStopped,
    ])
}