frigg 0.3.2

Local-first MCP server for code understanding.
Documentation
use std::sync::Arc;

use async_trait::async_trait;
use reqwest::{Client, Method};
use serde::{Deserialize, Serialize};

use super::*;

#[derive(Serialize)]
struct OpenAiEmbeddingRequestPayload<'a> {
    model: &'a str,
    input: &'a [String],
    #[serde(skip_serializing_if = "Option::is_none")]
    dimensions: Option<usize>,
    encoding_format: &'static str,
}

#[derive(Deserialize)]
struct OpenAiEmbeddingResponsePayload {
    #[serde(default)]
    data: Vec<OpenAiEmbeddingVectorPayload>,
    model: Option<String>,
    usage: Option<OpenAiUsagePayload>,
}

#[derive(Deserialize)]
struct OpenAiEmbeddingVectorPayload {
    index: usize,
    embedding: Vec<f32>,
}

#[derive(Deserialize)]
struct OpenAiUsagePayload {
    prompt_tokens: Option<u64>,
    total_tokens: Option<u64>,
}

#[derive(Deserialize)]
struct OpenAiErrorEnvelope {
    error: OpenAiErrorPayload,
}

#[derive(Deserialize)]
struct OpenAiErrorPayload {
    message: String,
    code: Option<String>,
    #[serde(rename = "type")]
    error_type: Option<String>,
}

pub struct OpenAiEmbeddingProvider {
    http: Arc<dyn HttpExecutor>,
    sleeper: Arc<dyn BackoffSleeper>,
    api_key: String,
    config: OpenAiEmbeddingProviderConfig,
}

impl OpenAiEmbeddingProvider {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self::with_config(api_key, OpenAiEmbeddingProviderConfig::default())
    }

    pub fn with_config(api_key: impl Into<String>, config: OpenAiEmbeddingProviderConfig) -> Self {
        Self::with_runtime(
            api_key.into(),
            config,
            Arc::new(ReqwestHttpExecutor::new(Client::new())),
            Arc::new(TokioSleeper),
        )
    }

    pub(super) fn with_runtime(
        api_key: String,
        config: OpenAiEmbeddingProviderConfig,
        http: Arc<dyn HttpExecutor>,
        sleeper: Arc<dyn BackoffSleeper>,
    ) -> Self {
        Self {
            http,
            sleeper,
            api_key,
            config,
        }
    }

    fn build_http_request(&self, request: &EmbeddingRequest) -> EmbeddingResult<HttpRequest> {
        let payload = OpenAiEmbeddingRequestPayload {
            model: &request.model,
            input: &request.input,
            dimensions: request.dimensions,
            encoding_format: "float",
        };

        let body = serde_json::to_value(payload).map_err(|error| {
            EmbeddingError::Provider(ProviderFailure::non_retryable(
                self.kind(),
                format!("failed to serialize OpenAI request payload: {error}"),
                Some("request_serialization_failed".to_string()),
                None,
                request.trace_id.clone(),
            ))
        })?;
        let diagnostics = HttpRequestDiagnostics::from_request(self.kind(), request, &body)?;

        Ok(HttpRequest {
            method: Method::POST,
            url: self.config.endpoint.clone(),
            headers: vec![(
                "Authorization".to_string(),
                format!("Bearer {}", self.api_key),
            )],
            body,
            timeout: self.config.timeout,
            diagnostics,
        })
    }

    fn map_transport_error(
        &self,
        operation: &str,
        trace_id: Option<String>,
        error: HttpTransportError,
        diagnostics: &HttpRequestDiagnostics,
    ) -> EmbeddingError {
        let message = append_request_diagnostics(error.message, diagnostics);
        let failure = match error.retryability {
            Retryability::Retryable => {
                TransportFailure::retryable(self.kind(), operation, message, trace_id)
            }
            Retryability::NonRetryable => {
                TransportFailure::non_retryable(self.kind(), operation, message, trace_id)
            }
        };

        EmbeddingError::Transport(failure)
    }

    fn map_provider_http_error(
        &self,
        status_code: u16,
        body: &str,
        trace_id: Option<String>,
        diagnostics: &HttpRequestDiagnostics,
    ) -> EmbeddingError {
        let mut message = format!("OpenAI request failed with status {status_code}");
        let mut code = None;

        if let Ok(envelope) = serde_json::from_str::<OpenAiErrorEnvelope>(body) {
            message = envelope.error.message;
            code = envelope.error.code.or(envelope.error.error_type);
        }
        let message = append_request_diagnostics(message, diagnostics);

        let retryability = status_retryability(status_code);
        let failure = match retryability {
            Retryability::Retryable => {
                ProviderFailure::retryable(self.kind(), message, code, Some(status_code), trace_id)
            }
            Retryability::NonRetryable => ProviderFailure::non_retryable(
                self.kind(),
                message,
                code,
                Some(status_code),
                trace_id,
            ),
        };

        EmbeddingError::Provider(failure)
    }

    fn parse_success_response(
        &self,
        body: &str,
        request: &EmbeddingRequest,
    ) -> EmbeddingResult<EmbeddingResponse> {
        let parsed =
            serde_json::from_str::<OpenAiEmbeddingResponsePayload>(body).map_err(|error| {
                EmbeddingError::Provider(ProviderFailure::non_retryable(
                    self.kind(),
                    format!("failed to parse OpenAI success response: {error}"),
                    Some("invalid_response".to_string()),
                    Some(200),
                    request.trace_id.clone(),
                ))
            })?;

        if parsed.data.is_empty() {
            return Err(EmbeddingError::Provider(ProviderFailure::non_retryable(
                self.kind(),
                "OpenAI response did not contain embeddings",
                Some("invalid_response".to_string()),
                Some(200),
                request.trace_id.clone(),
            )));
        }

        let vectors = parsed
            .data
            .into_iter()
            .map(|item| EmbeddingVector {
                index: item.index,
                values: item.embedding,
            })
            .collect();

        let usage = parsed
            .usage
            .and_then(|usage| usage_from_counts(usage.prompt_tokens, usage.total_tokens));

        Ok(EmbeddingResponse {
            provider: self.kind(),
            model: parsed.model.unwrap_or_else(|| request.model.clone()),
            vectors,
            trace_id: request.trace_id.clone(),
            usage,
        })
    }

    async fn embed_once(&self, request: &EmbeddingRequest) -> EmbeddingResult<EmbeddingResponse> {
        let http_request = self.build_http_request(request)?;
        let diagnostics = http_request.diagnostics.clone();
        let http_response = self.http.execute(http_request).await.map_err(|error| {
            self.map_transport_error(
                "openai_embed",
                request.trace_id.clone(),
                error,
                &diagnostics,
            )
        })?;

        if (200..=299).contains(&http_response.status_code) {
            self.parse_success_response(&http_response.body, request)
        } else {
            Err(self.map_provider_http_error(
                http_response.status_code,
                &http_response.body,
                request.trace_id.clone(),
                &diagnostics,
            ))
        }
    }

    async fn embed_with_retry(
        &self,
        request: &EmbeddingRequest,
    ) -> EmbeddingResult<EmbeddingResponse> {
        let mut retries = 0usize;

        loop {
            match self.embed_once(request).await {
                Ok(response) => return Ok(response),
                Err(error)
                    if error.is_retryable() && retries < self.config.retry_policy.max_retries =>
                {
                    let backoff = self.config.retry_policy.backoff_for_retry(retries);
                    retries += 1;
                    self.sleeper.sleep(backoff).await;
                }
                Err(error) => return Err(error),
            }
        }
    }
}

#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddingProvider {
    fn kind(&self) -> EmbeddingProviderKind {
        EmbeddingProviderKind::OpenAi
    }

    async fn embed(&self, request: EmbeddingRequest) -> EmbeddingResult<EmbeddingResponse> {
        request.validate()?;

        if self.api_key.trim().is_empty() {
            return Err(EmbeddingError::Validation(ValidationFailure::new(
                "api_key",
                "api_key must not be empty",
            )));
        }

        self.embed_with_retry(&request).await
    }
}