mentra-provider 0.2.0

Shared provider core for Mentra
Documentation
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;

pub(crate) mod model;
pub(crate) mod sse;
pub(crate) mod stream_model;

use crate::AuthScheme;
use crate::BuiltinProvider;
use crate::CompactionRequest;
use crate::CompactionResponse;
use crate::CredentialSource;
use crate::ModelCatalog;
use crate::ModelInfo;
use crate::ProviderCapabilities;
use crate::ProviderDefinition;
use crate::ProviderError;
use crate::ProviderEventStream;
use crate::ProviderSession;
use crate::ProviderSessionFactory;
use crate::RegisteredProvider;
use crate::Request;
use crate::StaticCredentialSource;
use crate::WireApi;

const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";

pub struct AnthropicProvider<C = StaticCredentialSource> {
    client: reqwest::Client,
    credential_source: Arc<C>,
    definition: ProviderDefinition,
}

impl<C> Clone for AnthropicProvider<C> {
    fn clone(&self) -> Self {
        Self {
            client: self.client.clone(),
            credential_source: Arc::clone(&self.credential_source),
            definition: self.definition.clone(),
        }
    }
}

impl AnthropicProvider<StaticCredentialSource> {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self::with_credential_source(StaticCredentialSource::new(api_key))
    }
}

impl<C> AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    pub fn with_credential_source(credential_source: C) -> Self {
        Self::with_shared_credential_source(Arc::new(credential_source))
    }

    pub fn with_shared_credential_source(credential_source: Arc<C>) -> Self {
        Self::with_definition_and_shared_credential_source(Self::definition(), credential_source)
    }

    pub fn with_definition_and_credential_source(
        definition: ProviderDefinition,
        credential_source: C,
    ) -> Self {
        Self::with_definition_and_shared_credential_source(definition, Arc::new(credential_source))
    }

    pub fn with_definition_and_shared_credential_source(
        definition: ProviderDefinition,
        credential_source: Arc<C>,
    ) -> Self {
        let client = reqwest::Client::builder()
            .build()
            .expect("Failed to build client");

        Self {
            client,
            credential_source,
            definition,
        }
    }

    fn definition() -> ProviderDefinition {
        let mut definition = ProviderDefinition::new(BuiltinProvider::Anthropic);
        definition.descriptor.display_name = Some("Anthropic".to_string());
        definition.descriptor.description = Some("Anthropic Messages API provider".to_string());
        definition.wire_api = WireApi::AnthropicMessages;
        definition.auth_scheme = AuthScheme::Header {
            name: "x-api-key".to_string(),
        };
        definition.capabilities = ProviderCapabilities {
            supports_model_listing: true,
            supports_streaming: true,
            supports_websockets: false,
            supports_tool_calls: true,
            supports_images: true,
            supports_history_compaction: true,
            supports_memory_summarization: true,
            supports_deferred_tools: true,
            supports_hosted_tool_search: true,
            supports_hosted_web_search: false,
            supports_image_generation: false,
            supports_reasoning_effort: true,
            reports_reasoning_tokens: false,
            reports_thoughts_tokens: false,
            supports_structured_tool_results: false,
        };
        definition.base_url = Some(DEFAULT_BASE_URL.to_string());
        definition.headers = Some(HashMap::from([(
            "anthropic-version".to_string(),
            ANTHROPIC_VERSION.to_string(),
        )]));
        definition
    }
}

#[async_trait]
impl<C> ModelCatalog for AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn list_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
        let mut models = Vec::new();
        let mut after_id = None;

        loop {
            let credentials = self.credential_source.credentials().await?;
            let request = self
                .client
                .get(
                    self.definition
                        .request_url_with_auth_for_path("v1/models", &credentials)?,
                )
                .headers(self.definition.build_headers(&credentials)?)
                .query(&[
                    ("limit", "1000"),
                    ("after_id", after_id.as_deref().unwrap_or("")),
                ]);

            let response = request.send().await.map_err(ProviderError::Transport)?;

            if !response.status().is_success() {
                return Err(ProviderError::Http {
                    status: response.status(),
                    body: response.text().await.unwrap_or_default(),
                });
            }

            let page = response
                .json::<model::AnthropicModelsPage>()
                .await
                .map_err(ProviderError::Decode)?;

            after_id = page.last_id.clone();
            models.extend(page.data.into_iter().map(|model| model.into()));

            if !page.has_more {
                break;
            }
        }

        Ok(models)
    }
}

#[async_trait]
impl<C> ProviderSessionFactory for AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn create_session(&self) -> Result<Box<dyn ProviderSession>, ProviderError> {
        Ok(Box::new((*self).clone()))
    }
}

#[async_trait]
impl<C> ProviderSession for AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn stream(&self, request: Request<'_>) -> Result<ProviderEventStream, ProviderError> {
        let response = self.send_message(request, true).await?;
        Ok(sse::spawn_event_stream(response))
    }

    async fn compact(
        &self,
        request: CompactionRequest<'_>,
    ) -> Result<CompactionResponse, ProviderError> {
        let request = request.into_model_request()?;
        let response = ProviderSession::send(self, request).await?;
        Ok(response.into_compaction_response())
    }

    async fn summarize_memories(
        &self,
        request: crate::MemorySummarizeRequest<'_>,
    ) -> Result<crate::MemorySummarizeResponse, ProviderError> {
        let request = request.into_model_request()?;
        let response = ProviderSession::send(self, request).await?;
        response.into_memory_summarize_response()
    }
}

#[async_trait]
impl<C> RegisteredProvider for AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    fn definition(&self) -> ProviderDefinition {
        self.definition.clone()
    }
}

impl<C> AnthropicProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn send_message(
        &self,
        request: Request<'_>,
        stream: bool,
    ) -> Result<reqwest::Response, ProviderError> {
        let session = request.provider_request_options.session.clone();
        let request = model::AnthropicRequest::try_from(request)?;
        let mut body = serde_json::to_value(request).map_err(ProviderError::Serialize)?;
        if stream {
            body["stream"] = Value::Bool(true);
        }
        let credentials = self.credential_source.credentials().await?;
        let response = self
            .client
            .post(
                self.definition
                    .request_url_with_auth_for_path("v1/messages", &credentials)?,
            )
            .headers(self.definition.build_headers_for_session(
                &credentials,
                Some(&session),
                None,
            )?)
            .json(&body)
            .send()
            .await
            .map_err(ProviderError::Transport)?;

        if !response.status().is_success() {
            return Err(ProviderError::Http {
                status: response.status(),
                body: response.text().await.unwrap_or_default(),
            });
        }

        Ok(response)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::RegisteredProvider;

    #[test]
    fn definition_advertises_history_compaction_support() {
        let provider = AnthropicProvider::new("test-key");

        assert!(
            provider
                .definition()
                .capabilities
                .supports_history_compaction
        );
    }
}