use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tracing::{debug, error};
use crate::embedding::EmbeddingProvider;
use crate::error::{RagError, Result};
const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
const DEFAULT_MODEL: &str = "text-embedding-3-small";
const DEFAULT_DIMENSIONS: usize = 1536;
pub struct OpenAIEmbeddingProvider {
client: reqwest::Client,
api_key: String,
model: String,
dimensions: usize,
request_dimensions: Option<usize>,
}
impl OpenAIEmbeddingProvider {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(RagError::EmbeddingError {
provider: "OpenAI".into(),
message: "API key must not be empty".into(),
});
}
Ok(Self {
client: reqwest::Client::new(),
api_key,
model: DEFAULT_MODEL.into(),
dimensions: DEFAULT_DIMENSIONS,
request_dimensions: None,
})
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| RagError::EmbeddingError {
provider: "OpenAI".into(),
message: "OPENAI_API_KEY environment variable not set".into(),
})?;
Self::new(api_key)
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_dimensions(mut self, dims: usize) -> Self {
self.dimensions = dims;
self.request_dimensions = Some(dims);
self
}
}
#[derive(Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Deserialize)]
struct ErrorResponse {
error: ErrorDetail,
}
#[derive(Deserialize)]
struct ErrorDetail {
message: String,
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
debug!(provider = "OpenAI", text_len = text.len(), "embedding single text");
let results = self.embed_batch(&[text]).await?;
results.into_iter().next().ok_or_else(|| RagError::EmbeddingError {
provider: "OpenAI".into(),
message: "API returned empty response".into(),
})
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
debug!(
provider = "OpenAI",
batch_size = texts.len(),
model = %self.model,
"embedding batch"
);
let request_body = EmbeddingRequest {
model: &self.model,
input: texts.to_vec(),
dimensions: self.request_dimensions,
};
let response = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.bearer_auth(&self.api_key)
.json(&request_body)
.send()
.await
.map_err(|e| {
error!(provider = "OpenAI", error = %e, "request failed");
RagError::EmbeddingError {
provider: "OpenAI".into(),
message: format!("request failed: {e}"),
}
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let detail = serde_json::from_str::<ErrorResponse>(&body)
.map(|e| e.error.message)
.unwrap_or(body);
error!(provider = "OpenAI", %status, "API error");
return Err(RagError::EmbeddingError {
provider: "OpenAI".into(),
message: format!("API returned {status}: {detail}"),
});
}
let embedding_response: EmbeddingResponse = response.json().await.map_err(|e| {
error!(provider = "OpenAI", error = %e, "failed to parse response");
RagError::EmbeddingError {
provider: "OpenAI".into(),
message: format!("failed to parse response: {e}"),
}
})?;
Ok(embedding_response.data.into_iter().map(|d| d.embedding).collect())
}
fn dimensions(&self) -> usize {
self.dimensions
}
}