use super::provider::{EmbeddingProvider, EmbeddingResult};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct OpenAIEmbedder {
api_key: String,
model: String,
base_url: String,
dimensions: usize,
client: reqwest::Client,
}
impl OpenAIEmbedder {
pub fn new(api_key: &str, model: &str, base_url: &str, dimensions: usize) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
dimensions,
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbedder {
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_name(&self) -> &str {
&self.model
}
async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
let results = self.embed_batch(&[text.to_string()]).await?;
results
.into_iter()
.next()
.context("Empty response from embedding API")
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
if texts.is_empty() {
return Ok(vec![]);
}
let request = EmbeddingRequest {
model: &self.model,
input: texts,
dimensions: Some(self.dimensions),
};
let url = format!("{}/embeddings", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send embedding request")?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Embedding API error ({}): {}", status, error_text);
}
let response: EmbeddingResponse = response
.json()
.await
.context("Failed to parse embedding response")?;
let mut data = response.data;
data.sort_by_key(|d| d.index);
let results = data
.into_iter()
.map(|d| EmbeddingResult {
embedding: d.embedding,
token_count: response
.usage
.as_ref()
.map(|u| u.total_tokens / texts.len()),
})
.collect();
Ok(results)
}
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: usize,
}
#[derive(Debug, Deserialize)]
struct Usage {
total_tokens: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedder_creation() {
let embedder = OpenAIEmbedder::new(
"test-key",
"text-embedding-3-small",
"https://api.openai.com/v1",
1536,
);
assert_eq!(embedder.dimensions(), 1536);
assert_eq!(embedder.model_name(), "text-embedding-3-small");
}
}