behest 0.2.1

A Rust-native cloud agent runtime with typed tools, pluggable memory, queues, and observability.
//! OpenAI embedding provider adapter implementing [`EmbeddingProvider`].

use async_trait::async_trait;
use reqwest::Client;

use crate::adapt::http::{build_client, parse_retry_after, status_to_error, with_bearer_auth};
use crate::error::ProviderError;
use crate::provider::{
    Embedding, EmbeddingInput, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, ModelName,
    ProviderCapabilities, ProviderHttpConfig, ProviderId, ProviderResult, TokenUsage,
};

use super::types::{OpenAiEmbeddingRequest, OpenAiEmbeddingResponse};

/// OpenAI-compatible embedding adapter.
pub struct OpenAiEmbeddingAdapter {
    id: ProviderId,
    client: Client,
    config: ProviderHttpConfig,
}

impl OpenAiEmbeddingAdapter {
    /// Creates an OpenAI embedding adapter with a new HTTP client.
    ///
    /// # Errors
    ///
    /// Returns [`ProviderError::Transport`] when the HTTP client cannot be built.
    pub fn new(config: ProviderHttpConfig) -> Result<Self, ProviderError> {
        let client = build_client(&config)?;
        Ok(Self {
            id: config.id.clone(),
            client,
            config,
        })
    }

    /// Creates an OpenAI embedding adapter reusing an existing HTTP client.
    #[must_use]
    pub fn with_client(config: ProviderHttpConfig, client: Client) -> Self {
        Self {
            id: config.id.clone(),
            client,
            config,
        }
    }

    fn url(&self) -> String {
        format!("{}/embeddings", self.config.base_url)
    }

    fn wrap_transport(&self, source: reqwest::Error) -> ProviderError {
        if source.is_timeout() {
            ProviderError::Timeout {
                provider: self.id.clone(),
            }
        } else {
            ProviderError::Transport {
                provider: self.id.clone(),
                source,
            }
        }
    }
}

#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddingAdapter {
    fn id(&self) -> ProviderId {
        self.id.clone()
    }

    fn capabilities(&self) -> ProviderCapabilities {
        ProviderCapabilities::embeddings()
    }

    async fn embed(&self, request: EmbeddingRequest) -> ProviderResult<EmbeddingResponse> {
        let input_texts = extract_texts(&request.input);
        let body = OpenAiEmbeddingRequest {
            model: request.model.as_str().to_owned(),
            input: input_texts,
            dimensions: request.dimensions,
        };

        let builder = self.client.post(self.url()).json(&body);
        let builder = with_bearer_auth(builder, &self.config);
        let response = builder.send().await.map_err(|e| self.wrap_transport(e))?;

        if !response.status().is_success() {
            let status = response.status();
            let retry_after = parse_retry_after(response.headers());
            let text = response
                .text()
                .await
                .unwrap_or_else(|e| format!("<failed to read error body: {e}>"));
            return Err(status_to_error(&self.id, status, &text, retry_after));
        }

        let parsed: OpenAiEmbeddingResponse =
            response.json().await.map_err(|e| ProviderError::Decode {
                provider: self.id.clone(),
                message: e.to_string(),
            })?;

        Ok(from_response(&self.id, &request.model, &parsed))
    }
}

fn extract_texts(inputs: &[EmbeddingInput]) -> Vec<String> {
    inputs
        .iter()
        .map(|input| match input {
            EmbeddingInput::Text { text } => text.clone(),
            EmbeddingInput::Tokens { tokens } => tokens
                .iter()
                .map(ToString::to_string)
                .collect::<Vec<_>>()
                .join(" "),
        })
        .collect()
}

fn from_response(
    provider: &ProviderId,
    model: &ModelName,
    response: &OpenAiEmbeddingResponse,
) -> EmbeddingResponse {
    let embeddings = response
        .data
        .iter()
        .map(|d| Embedding::new(d.index, d.embedding.clone()))
        .collect();

    // Embeddings only consume input tokens; output_tokens is always 0.
    let usage = response
        .usage
        .as_ref()
        .map(|u| TokenUsage::new(u.prompt_tokens, 0));

    EmbeddingResponse {
        provider: provider.clone(),
        model: model.clone(),
        embeddings,
        usage,
        raw: None,
    }
}