use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::{resize_vector, EmbeddingProvider};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
pub struct OpenAIProvider {
api_key: String,
base_url: String,
model: String,
dim: usize,
client: reqwest::Client,
}
impl OpenAIProvider {
pub fn new(config: &Config) -> Self {
let api_key = config
.openai_key
.clone()
.unwrap_or_default();
let model = config
.openai_model
.clone()
.unwrap_or_else(|| "text-embedding-3-small".to_string());
Self {
api_key,
base_url: config.openai_base_url.clone(),
model,
dim: config.vec_dim,
client: reqwest::Client::new(),
}
}
fn model_for_sector(&self, _sector: &Sector) -> &str {
&self.model
}
}
#[derive(Serialize)]
struct EmbeddingRequest {
input: Vec<String>,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for OpenAIProvider {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
if self.api_key.is_empty() {
return Err(Error::config("OpenAI API key is not configured"));
}
let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
let model = self.model_for_sector(sector);
let request = EmbeddingRequest {
input: vec![text.to_string()],
model: model.to_string(),
dimensions: Some(self.dim),
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(Error::RateLimit {
retry_after_secs: 5,
});
}
return Err(Error::embedding(format!(
"OpenAI API error {}: {}",
status, body
)));
}
let data: EmbeddingResponse = response.json().await?;
if data.data.is_empty() {
return Err(Error::embedding("No embedding returned from OpenAI"));
}
let vector = resize_vector(&data.data[0].embedding, self.dim);
Ok(EmbeddingResult {
sector: *sector,
vector: vector.clone(),
dim: vector.len(),
})
}
async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
if self.api_key.is_empty() {
return Err(Error::config("OpenAI API key is not configured"));
}
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
let input: Vec<String> = texts.iter().map(|(t, _)| t.to_string()).collect();
let sectors: Vec<Sector> = texts.iter().map(|(_, s)| **s).collect();
let request = EmbeddingRequest {
input,
model: self.model.clone(),
dimensions: Some(self.dim),
};
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(Error::RateLimit {
retry_after_secs: 5,
});
}
return Err(Error::embedding(format!(
"OpenAI API error {}: {}",
status, body
)));
}
let data: EmbeddingResponse = response.json().await?;
let results: Vec<EmbeddingResult> = data
.data
.into_iter()
.zip(sectors.into_iter())
.map(|(emb, sector)| {
let vector = resize_vector(&emb.embedding, self.dim);
EmbeddingResult {
sector,
vector: vector.clone(),
dim: vector.len(),
}
})
.collect();
Ok(results)
}
fn dimensions(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"openai"
}
fn supports_batch(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let config = Config::default();
let provider = OpenAIProvider::new(&config);
assert_eq!(provider.name(), "openai");
assert!(provider.supports_batch());
}
#[test]
fn test_model_name() {
let mut config = Config::default();
config.openai_model = Some("text-embedding-ada-002".to_string());
let provider = OpenAIProvider::new(&config);
assert_eq!(provider.model, "text-embedding-ada-002");
}
}