use anyhow::{Context, Result};
use serde_json::{json, Value};
use super::super::types::InputType;
use super::{EmbeddingProvider, HTTP_CLIENT};
pub struct OpenAIProviderImpl {
model_name: String,
dimension: usize,
}
impl OpenAIProviderImpl {
pub fn new(model: &str) -> Result<Self> {
let supported_models = [
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
];
if !supported_models.contains(&model) {
return Err(anyhow::anyhow!(
"Unsupported OpenAI model: '{}'. Supported models: {:?}",
model,
supported_models
));
}
let dimension = Self::get_model_dimension(model);
Ok(Self {
model_name: model.to_string(),
dimension,
})
}
fn get_model_dimension(model: &str) -> usize {
match model {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
_ => {
panic!(
"Invalid OpenAI model '{}' passed to get_model_dimension",
model
);
}
}
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for OpenAIProviderImpl {
async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
OpenAIProvider::generate_embeddings(text, &self.model_name).await
}
async fn generate_embeddings_batch(
&self,
texts: Vec<String>,
input_type: InputType,
) -> Result<Vec<Vec<f32>>> {
OpenAIProvider::generate_embeddings_batch(texts, &self.model_name, input_type).await
}
fn get_dimension(&self) -> usize {
self.dimension
}
fn is_model_supported(&self) -> bool {
matches!(
self.model_name.as_str(),
"text-embedding-3-small" | "text-embedding-3-large" | "text-embedding-ada-002"
)
}
}
pub struct OpenAIProvider;
impl OpenAIProvider {
pub async fn generate_embeddings(contents: &str, model: &str) -> Result<Vec<f32>> {
let result =
Self::generate_embeddings_batch(vec![contents.to_string()], model, InputType::None)
.await?;
result
.first()
.cloned()
.ok_or_else(|| anyhow::anyhow!("No embeddings found"))
}
pub async fn generate_embeddings_batch(
texts: Vec<String>,
model: &str,
input_type: InputType,
) -> Result<Vec<Vec<f32>>> {
let openai_api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY environment variable not set")?;
let processed_texts: Vec<String> = texts
.into_iter()
.map(|text| input_type.apply_prefix(&text))
.collect();
let request_body = json!({
"input": processed_texts,
"model": model,
"encoding_format": "float"
});
let response = HTTP_CLIENT
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", openai_api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow::anyhow!("OpenAI API error: {}", error_text));
}
let response_json: Value = response.json().await?;
let embeddings = response_json["data"]
.as_array()
.context("Failed to get embeddings array")?
.iter()
.map(|data| {
data["embedding"]
.as_array()
.unwrap_or(&Vec::new())
.iter()
.map(|v| v.as_f64().unwrap_or_default() as f32)
.collect()
})
.collect();
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_provider_creation() {
assert!(OpenAIProviderImpl::new("text-embedding-3-small").is_ok());
assert!(OpenAIProviderImpl::new("text-embedding-3-large").is_ok());
assert!(OpenAIProviderImpl::new("text-embedding-ada-002").is_ok());
assert!(OpenAIProviderImpl::new("invalid-model").is_err());
}
#[test]
fn test_model_dimensions() {
let provider_small = OpenAIProviderImpl::new("text-embedding-3-small").unwrap();
assert_eq!(provider_small.get_dimension(), 1536);
let provider_large = OpenAIProviderImpl::new("text-embedding-3-large").unwrap();
assert_eq!(provider_large.get_dimension(), 3072);
let provider_ada = OpenAIProviderImpl::new("text-embedding-ada-002").unwrap();
assert_eq!(provider_ada.get_dimension(), 1536);
}
#[test]
fn test_model_validation() {
let provider_valid = OpenAIProviderImpl::new("text-embedding-3-small").unwrap();
assert!(provider_valid.is_model_supported());
let supported_models = [
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
];
for model in supported_models {
let provider = OpenAIProviderImpl::new(model).unwrap();
assert!(provider.is_model_supported());
}
}
}