vyctor 0.1.0

A fast CLI tool for semantic file search using vector embeddings
Documentation
//! OpenAI embedding provider implementation
//! Also works with Voyage AI and other OpenAI-compatible APIs

use super::provider::{EmbeddingProvider, EmbeddingResult};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

/// OpenAI embedding provider
#[derive(Debug)]
pub struct OpenAIEmbedder {
    api_key: String,
    model: String,
    base_url: String,
    dimensions: usize,
    client: reqwest::Client,
}

impl OpenAIEmbedder {
    /// Create a new OpenAI embedder
    pub fn new(api_key: &str, model: &str, base_url: &str, dimensions: usize) -> Self {
        Self {
            api_key: api_key.to_string(),
            model: model.to_string(),
            base_url: base_url.trim_end_matches('/').to_string(),
            dimensions,
            client: reqwest::Client::new(),
        }
    }
}

#[async_trait]
impl EmbeddingProvider for OpenAIEmbedder {
    fn dimensions(&self) -> usize {
        self.dimensions
    }

    fn model_name(&self) -> &str {
        &self.model
    }

    async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
        let results = self.embed_batch(&[text.to_string()]).await?;
        results
            .into_iter()
            .next()
            .context("Empty response from embedding API")
    }

    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
        if texts.is_empty() {
            return Ok(vec![]);
        }

        let request = EmbeddingRequest {
            model: &self.model,
            input: texts,
            dimensions: Some(self.dimensions),
        };

        let url = format!("{}/embeddings", self.base_url);

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await
            .context("Failed to send embedding request")?;

        let status = response.status();
        if !status.is_success() {
            let error_text = response.text().await.unwrap_or_default();
            anyhow::bail!("Embedding API error ({}): {}", status, error_text);
        }

        let response: EmbeddingResponse = response
            .json()
            .await
            .context("Failed to parse embedding response")?;

        // Sort by index to ensure correct order
        let mut data = response.data;
        data.sort_by_key(|d| d.index);

        let results = data
            .into_iter()
            .map(|d| EmbeddingResult {
                embedding: d.embedding,
                token_count: response
                    .usage
                    .as_ref()
                    .map(|u| u.total_tokens / texts.len()),
            })
            .collect();

        Ok(results)
    }
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
    model: &'a str,
    input: &'a [String],
    #[serde(skip_serializing_if = "Option::is_none")]
    dimensions: Option<usize>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
    usage: Option<Usage>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
    index: usize,
}

#[derive(Debug, Deserialize)]
struct Usage {
    total_tokens: usize,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_embedder_creation() {
        let embedder = OpenAIEmbedder::new(
            "test-key",
            "text-embedding-3-small",
            "https://api.openai.com/v1",
            1536,
        );
        assert_eq!(embedder.dimensions(), 1536);
        assert_eq!(embedder.model_name(), "text-embedding-3-small");
    }
}