mentra-provider 0.2.0

Shared provider core for Mentra
Documentation
use async_trait::async_trait;
use std::sync::Arc;

pub(crate) mod model;
pub(crate) mod sse;

use crate::AuthScheme;
use crate::BuiltinProvider;
use crate::CompactionRequest;
use crate::CompactionResponse;
use crate::CredentialSource;
use crate::ModelCatalog;
use crate::ModelInfo;
use crate::ProviderCapabilities;
use crate::ProviderDefinition;
use crate::ProviderError;
use crate::ProviderEventStream;
use crate::ProviderSession;
use crate::ProviderSessionFactory;
use crate::RegisteredProvider;
use crate::Request;
use crate::StaticCredentialSource;
use crate::WireApi;

const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/";

pub struct GeminiProvider<C = StaticCredentialSource> {
    client: reqwest::Client,
    credential_source: Arc<C>,
    definition: ProviderDefinition,
}

impl<C> Clone for GeminiProvider<C> {
    fn clone(&self) -> Self {
        Self {
            client: self.client.clone(),
            credential_source: Arc::clone(&self.credential_source),
            definition: self.definition.clone(),
        }
    }
}

impl GeminiProvider<StaticCredentialSource> {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self::with_credential_source(StaticCredentialSource::new(api_key))
    }
}

impl<C> GeminiProvider<C>
where
    C: CredentialSource + 'static,
{
    pub fn with_credential_source(credential_source: C) -> Self {
        Self::with_shared_credential_source(Arc::new(credential_source))
    }

    pub fn with_shared_credential_source(credential_source: Arc<C>) -> Self {
        Self::with_definition_and_shared_credential_source(Self::definition(), credential_source)
    }

    pub fn with_definition_and_credential_source(
        definition: ProviderDefinition,
        credential_source: C,
    ) -> Self {
        Self::with_definition_and_shared_credential_source(definition, Arc::new(credential_source))
    }

    pub fn with_definition_and_shared_credential_source(
        definition: ProviderDefinition,
        credential_source: Arc<C>,
    ) -> Self {
        let client = reqwest::Client::builder()
            .build()
            .expect("Failed to build client");

        Self {
            client,
            credential_source,
            definition,
        }
    }

    fn definition() -> ProviderDefinition {
        let mut definition = ProviderDefinition::new(BuiltinProvider::Gemini);
        definition.descriptor.display_name = Some("Gemini".to_string());
        definition.descriptor.description =
            Some("Google Gemini Developer API provider".to_string());
        definition.wire_api = WireApi::GeminiGenerateContent;
        definition.auth_scheme = AuthScheme::Header {
            name: "x-goog-api-key".to_string(),
        };
        definition.capabilities = ProviderCapabilities {
            supports_model_listing: true,
            supports_streaming: true,
            supports_websockets: false,
            supports_tool_calls: true,
            supports_images: true,
            supports_history_compaction: true,
            supports_memory_summarization: true,
            supports_deferred_tools: false,
            supports_hosted_tool_search: false,
            supports_hosted_web_search: false,
            supports_image_generation: false,
            supports_reasoning_effort: true,
            reports_reasoning_tokens: false,
            reports_thoughts_tokens: true,
            supports_structured_tool_results: false,
        };
        definition.base_url = Some(DEFAULT_BASE_URL.to_string());
        definition
    }
}

#[async_trait]
impl<C> ModelCatalog for GeminiProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn list_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
        let mut models = Vec::new();
        let mut page_token = None::<String>;

        loop {
            let credentials = self.credential_source.credentials().await?;
            let mut request = self
                .client
                .get(
                    self.definition
                        .request_url_with_auth_for_path("v1beta/models", &credentials)?,
                )
                .headers(self.definition.build_headers(&credentials)?)
                .query(&[("pageSize", "1000")]);

            if let Some(token) = page_token.as_deref() {
                request = request.query(&[("pageToken", token)]);
            }

            let response = request.send().await.map_err(ProviderError::Transport)?;
            if !response.status().is_success() {
                return Err(ProviderError::Http {
                    status: response.status(),
                    body: response.text().await.unwrap_or_default(),
                });
            }

            let page = response
                .json::<model::GeminiModelsPage>()
                .await
                .map_err(ProviderError::Decode)?;

            models.extend(
                page.models
                    .into_iter()
                    .filter(|model| model.supports_generate_content())
                    .map(ModelInfo::from),
            );

            page_token = page.next_page_token;
            if page_token.is_none() {
                break;
            }
        }

        Ok(models)
    }
}

#[async_trait]
impl<C> ProviderSessionFactory for GeminiProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn create_session(&self) -> Result<Box<dyn ProviderSession>, ProviderError> {
        Ok(Box::new((*self).clone()))
    }
}

#[async_trait]
impl<C> ProviderSession for GeminiProvider<C>
where
    C: CredentialSource + 'static,
{
    async fn stream(&self, request: Request<'_>) -> Result<ProviderEventStream, ProviderError> {
        let session = request.provider_request_options.session.clone();
        let model_name = request.model.to_string();
        let request = model::GeminiGenerateContentRequest::try_from(request)?;
        let credentials = self.credential_source.credentials().await?;
        let response = self
            .client
            .post(self.definition.request_url_with_auth_for_path(
                &format!(
                    "v1beta/{}:streamGenerateContent",
                    normalize_model_name(&model_name)
                ),
                &credentials,
            )?)
            .headers(self.definition.build_headers_for_session(
                &credentials,
                Some(&session),
                None,
            )?)
            .query(&[("alt", "sse")])
            .json(&request)
            .send()
            .await
            .map_err(ProviderError::Transport)?;

        if !response.status().is_success() {
            return Err(ProviderError::Http {
                status: response.status(),
                body: response.text().await.unwrap_or_default(),
            });
        }

        Ok(sse::spawn_event_stream(response, model_name))
    }

    async fn compact(
        &self,
        request: CompactionRequest<'_>,
    ) -> Result<CompactionResponse, ProviderError> {
        let request = request.into_model_request()?;
        let response = ProviderSession::send(self, request).await?;
        Ok(response.into_compaction_response())
    }

    async fn summarize_memories(
        &self,
        request: crate::MemorySummarizeRequest<'_>,
    ) -> Result<crate::MemorySummarizeResponse, ProviderError> {
        let request = request.into_model_request()?;
        let response = ProviderSession::send(self, request).await?;
        response.into_memory_summarize_response()
    }
}

#[async_trait]
impl<C> RegisteredProvider for GeminiProvider<C>
where
    C: CredentialSource + 'static,
{
    fn definition(&self) -> ProviderDefinition {
        self.definition.clone()
    }
}

fn normalize_model_name(model: &str) -> String {
    if model.starts_with("models/") {
        model.to_string()
    } else {
        format!("models/{model}")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::RegisteredProvider;

    #[test]
    fn definition_advertises_history_compaction_support() {
        let provider = GeminiProvider::new("test-key");

        assert!(
            provider
                .definition()
                .capabilities
                .supports_history_compaction
        );
    }
}