frigg 0.4.2

Local-first MCP server for code understanding.
Documentation
use std::sync::Arc;

use async_trait::async_trait;
use reqwest::{Client, Method};
use serde::{Deserialize, Serialize};

use super::*;

fn google_task_type(purpose: EmbeddingPurpose) -> &'static str {
    match purpose {
        EmbeddingPurpose::Document => "RETRIEVAL_DOCUMENT",
        EmbeddingPurpose::Query => "RETRIEVAL_QUERY",
    }
}

fn google_model_path(model: &str) -> String {
    if model.starts_with("models/") {
        model.to_string()
    } else {
        format!("models/{model}")
    }
}

#[derive(Serialize)]
struct GoogleBatchEmbeddingRequestPayload {
    requests: Vec<GoogleBatchEmbeddingRequestItemPayload>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleBatchEmbeddingRequestItemPayload {
    model: String,
    content: GoogleContentPayload,
    task_type: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    output_dimensionality: Option<usize>,
}

#[derive(Serialize)]
struct GoogleContentPayload {
    parts: Vec<GooglePartPayload>,
}

#[derive(Serialize)]
struct GooglePartPayload {
    text: String,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GoogleBatchEmbeddingResponsePayload {
    #[serde(default)]
    embeddings: Vec<GoogleEmbeddingPayload>,
    #[serde(rename = "usageMetadata")]
    usage_metadata: Option<GoogleUsagePayload>,
}

#[derive(Deserialize)]
struct GoogleEmbeddingPayload {
    values: Option<Vec<f32>>,
    embedding: Option<GoogleEmbeddingValuesPayload>,
}

#[derive(Deserialize)]
struct GoogleEmbeddingValuesPayload {
    values: Vec<f32>,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GoogleUsagePayload {
    prompt_token_count: Option<u64>,
    total_token_count: Option<u64>,
}

#[derive(Deserialize)]
struct GoogleErrorEnvelope {
    error: GoogleErrorPayload,
}

#[derive(Deserialize)]
struct GoogleErrorPayload {
    code: Option<u16>,
    message: Option<String>,
    status: Option<String>,
}

pub struct GoogleEmbeddingProvider {
    http: Arc<dyn HttpExecutor>,
    sleeper: Arc<dyn BackoffSleeper>,
    api_key: String,
    config: GoogleEmbeddingProviderConfig,
}

impl GoogleEmbeddingProvider {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self::with_config(api_key, GoogleEmbeddingProviderConfig::default())
    }

    pub fn with_config(api_key: impl Into<String>, config: GoogleEmbeddingProviderConfig) -> Self {
        Self::with_runtime(
            api_key.into(),
            config,
            Arc::new(ReqwestHttpExecutor::new(Client::new())),
            Arc::new(TokioSleeper),
        )
    }

    pub(super) fn with_runtime(
        api_key: String,
        config: GoogleEmbeddingProviderConfig,
        http: Arc<dyn HttpExecutor>,
        sleeper: Arc<dyn BackoffSleeper>,
    ) -> Self {
        Self {
            http,
            sleeper,
            api_key,
            config,
        }
    }

    fn build_http_request(&self, request: &EmbeddingRequest) -> EmbeddingResult<HttpRequest> {
        let model_path = google_model_path(&request.model);
        let task_type = google_task_type(request.purpose);

        let requests = request
            .input
            .iter()
            .map(|text| GoogleBatchEmbeddingRequestItemPayload {
                model: model_path.clone(),
                content: GoogleContentPayload {
                    parts: vec![GooglePartPayload { text: text.clone() }],
                },
                task_type,
                output_dimensionality: request.dimensions,
            })
            .collect();

        let payload = GoogleBatchEmbeddingRequestPayload { requests };
        let body = serde_json::to_value(payload).map_err(|error| {
            EmbeddingError::Provider(ProviderFailure::non_retryable(
                self.kind(),
                format!("failed to serialize Google request payload: {error}"),
                Some("request_serialization_failed".to_string()),
                None,
                request.trace_id.clone(),
            ))
        })?;
        let diagnostics = HttpRequestDiagnostics::from_request(self.kind(), request, &body)?;

        let endpoint = self.config.endpoint.trim_end_matches('/');
        let url = format!(
            "{endpoint}/v1beta/{model_path}:batchEmbedContents?key={}",
            self.api_key
        );

        Ok(HttpRequest {
            method: Method::POST,
            url,
            headers: Vec::new(),
            body,
            timeout: self.config.timeout,
            diagnostics,
        })
    }

    fn map_transport_error(
        &self,
        operation: &str,
        trace_id: Option<String>,
        error: HttpTransportError,
        diagnostics: &HttpRequestDiagnostics,
    ) -> EmbeddingError {
        let message = append_request_diagnostics(error.message, diagnostics);
        let failure = match error.retryability {
            Retryability::Retryable => {
                TransportFailure::retryable(self.kind(), operation, message, trace_id)
            }
            Retryability::NonRetryable => {
                TransportFailure::non_retryable(self.kind(), operation, message, trace_id)
            }
        };

        EmbeddingError::Transport(failure)
    }

    fn map_provider_http_error(
        &self,
        status_code: u16,
        body: &str,
        trace_id: Option<String>,
        diagnostics: &HttpRequestDiagnostics,
    ) -> EmbeddingError {
        let mut message = format!("Google request failed with status {status_code}");
        let mut code = None;
        let mut retryability = status_retryability(status_code);

        if let Ok(envelope) = serde_json::from_str::<GoogleErrorEnvelope>(body) {
            if let Some(error_message) = envelope.error.message {
                message = error_message;
            }

            if let Some(error_status) = envelope.error.status {
                if matches!(
                    error_status.as_str(),
                    "RESOURCE_EXHAUSTED" | "UNAVAILABLE" | "DEADLINE_EXCEEDED" | "ABORTED"
                ) {
                    retryability = Retryability::Retryable;
                }
                code = Some(error_status);
            }

            if let Some(provider_status_code) = envelope.error.code {
                retryability = status_retryability(provider_status_code);
            }
        }
        let message = append_request_diagnostics(message, diagnostics);

        let failure = match retryability {
            Retryability::Retryable => {
                ProviderFailure::retryable(self.kind(), message, code, Some(status_code), trace_id)
            }
            Retryability::NonRetryable => ProviderFailure::non_retryable(
                self.kind(),
                message,
                code,
                Some(status_code),
                trace_id,
            ),
        };

        EmbeddingError::Provider(failure)
    }

    fn parse_success_response(
        &self,
        body: &str,
        request: &EmbeddingRequest,
    ) -> EmbeddingResult<EmbeddingResponse> {
        let parsed =
            serde_json::from_str::<GoogleBatchEmbeddingResponsePayload>(body).map_err(|error| {
                EmbeddingError::Provider(ProviderFailure::non_retryable(
                    self.kind(),
                    format!("failed to parse Google success response: {error}"),
                    Some("invalid_response".to_string()),
                    Some(200),
                    request.trace_id.clone(),
                ))
            })?;

        if parsed.embeddings.is_empty() {
            return Err(EmbeddingError::Provider(ProviderFailure::non_retryable(
                self.kind(),
                "Google response did not contain embeddings",
                Some("invalid_response".to_string()),
                Some(200),
                request.trace_id.clone(),
            )));
        }

        let mut vectors = Vec::with_capacity(parsed.embeddings.len());
        for (index, embedding) in parsed.embeddings.into_iter().enumerate() {
            let values = embedding
                .values
                .or_else(|| embedding.embedding.map(|nested| nested.values))
                .ok_or_else(|| {
                    EmbeddingError::Provider(ProviderFailure::non_retryable(
                        self.kind(),
                        "Google response contained an embedding without vector values",
                        Some("invalid_response".to_string()),
                        Some(200),
                        request.trace_id.clone(),
                    ))
                })?;

            vectors.push(EmbeddingVector { index, values });
        }

        let usage = parsed
            .usage_metadata
            .and_then(|usage| usage_from_counts(usage.prompt_token_count, usage.total_token_count));

        Ok(EmbeddingResponse {
            provider: self.kind(),
            model: request.model.clone(),
            vectors,
            trace_id: request.trace_id.clone(),
            usage,
        })
    }

    async fn embed_once(&self, request: &EmbeddingRequest) -> EmbeddingResult<EmbeddingResponse> {
        let http_request = self.build_http_request(request)?;
        let diagnostics = http_request.diagnostics.clone();
        let http_response = self.http.execute(http_request).await.map_err(|error| {
            self.map_transport_error(
                "google_batch_embed_contents",
                request.trace_id.clone(),
                error,
                &diagnostics,
            )
        })?;

        if (200..=299).contains(&http_response.status_code) {
            self.parse_success_response(&http_response.body, request)
        } else {
            Err(self.map_provider_http_error(
                http_response.status_code,
                &http_response.body,
                request.trace_id.clone(),
                &diagnostics,
            ))
        }
    }

    async fn embed_with_retry(
        &self,
        request: &EmbeddingRequest,
    ) -> EmbeddingResult<EmbeddingResponse> {
        let mut retries = 0usize;

        loop {
            match self.embed_once(request).await {
                Ok(response) => return Ok(response),
                Err(error)
                    if error.is_retryable() && retries < self.config.retry_policy.max_retries =>
                {
                    let backoff = self.config.retry_policy.backoff_for_retry(retries);
                    retries += 1;
                    self.sleeper.sleep(backoff).await;
                }
                Err(error) => return Err(error),
            }
        }
    }
}

#[async_trait]
impl EmbeddingProvider for GoogleEmbeddingProvider {
    fn kind(&self) -> EmbeddingProviderKind {
        EmbeddingProviderKind::Google
    }

    async fn embed(&self, request: EmbeddingRequest) -> EmbeddingResult<EmbeddingResponse> {
        request.validate()?;

        if self.api_key.trim().is_empty() {
            return Err(EmbeddingError::Validation(ValidationFailure::new(
                "api_key",
                "api_key must not be empty",
            )));
        }

        self.embed_with_retry(&request).await
    }
}