#[cfg(feature = "embeddings-openai")]
use std::any::Any;
#[cfg(feature = "embeddings-openai")]
use async_trait::async_trait;
#[cfg(feature = "embeddings-openai")]
use reqwest::Client;
#[cfg(feature = "embeddings-openai")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "embeddings-openai")]
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
#[cfg(feature = "embeddings-openai")]
use crate::error::{LaurusError, Result};
#[cfg(feature = "embeddings-openai")]
use crate::vector::core::vector::Vector;
#[cfg(feature = "embeddings-openai")]
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[cfg(feature = "embeddings-openai")]
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[cfg(feature = "embeddings-openai")]
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[cfg(feature = "embeddings-openai")]
pub struct OpenAIEmbedder {
client: Client,
api_key: String,
model: String,
dimension: usize,
}
#[cfg(feature = "embeddings-openai")]
impl std::fmt::Debug for OpenAIEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIEmbedder")
.field("model", &self.model)
.field("dimension", &self.dimension)
.field("api_key", &"[REDACTED]")
.finish()
}
}
#[cfg(feature = "embeddings-openai")]
impl OpenAIEmbedder {
pub async fn new(api_key: String, model: String) -> Result<Self> {
let client = Client::new();
let url = format!("https://api.openai.com/v1/models/{}", model);
let response = client
.get(&url)
.header("Authorization", format!("Bearer {}", api_key))
.send()
.await
.map_err(|e| {
LaurusError::InvalidOperation(format!("Failed to connect to OpenAI API: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(LaurusError::InvalidOperation(format!(
"Failed to validate OpenAI model '{}'. Status: {}. Response: {}",
model, status, text
)));
}
let dimension = Self::default_dimension(&model);
Ok(Self {
client,
api_key,
model,
dimension,
})
}
pub async fn with_dimension(api_key: String, model: String, dimension: usize) -> Result<Self> {
let mut embedder = Self::new(api_key, model).await?;
embedder.dimension = dimension;
Ok(embedder)
}
fn default_dimension(model: &str) -> usize {
match model {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
_ => 1536, }
}
async fn embed_text(&self, text: &str) -> Result<Vector> {
let dimensions = if self.dimension == Self::default_dimension(&self.model) {
None
} else {
Some(self.dimension)
};
let request = EmbeddingRequest {
model: self.model.clone(),
input: vec![text.to_string()],
dimensions,
};
let http_response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| {
LaurusError::InvalidOperation(format!("OpenAI API request failed: {}", e))
})?;
let status = http_response.status();
let response_text = http_response.text().await.map_err(|e| {
LaurusError::InvalidOperation(format!("Failed to read response text: {}", e))
})?;
if !status.is_success() {
return Err(LaurusError::InvalidOperation(format!(
"OpenAI API error (status {}): {}",
status, response_text
)));
}
let response: EmbeddingResponse = serde_json::from_str(&response_text).map_err(|e| {
LaurusError::InvalidOperation(format!(
"Failed to parse OpenAI response: {}. Response text: {}",
e, response_text
))
})?;
let embedding = response
.data
.into_iter()
.next()
.ok_or_else(|| LaurusError::InvalidOperation("No embedding in response".to_string()))?
.embedding;
Ok(Vector::new(embedding))
}
async fn embed_text_batch(&self, texts: &[&str]) -> Result<Vec<Vector>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let dimensions = if self.dimension == Self::default_dimension(&self.model) {
None
} else {
Some(self.dimension)
};
let request = EmbeddingRequest {
model: self.model.clone(),
input: texts.iter().map(|s| s.to_string()).collect(),
dimensions,
};
let http_response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| {
LaurusError::InvalidOperation(format!("OpenAI API request failed: {}", e))
})?;
let status = http_response.status();
let response_text = http_response.text().await.map_err(|e| {
LaurusError::InvalidOperation(format!("Failed to read response text: {}", e))
})?;
if !status.is_success() {
return Err(LaurusError::InvalidOperation(format!(
"OpenAI API error (status {}): {}",
status, response_text
)));
}
let response: EmbeddingResponse = serde_json::from_str(&response_text).map_err(|e| {
LaurusError::InvalidOperation(format!(
"Failed to parse OpenAI response: {}. Response text: {}",
e, response_text
))
})?;
Ok(response
.data
.into_iter()
.map(|d| Vector::new(d.embedding))
.collect())
}
}
#[cfg(feature = "embeddings-openai")]
#[async_trait]
impl Embedder for OpenAIEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(text) => self.embed_text(text).await,
_ => Err(LaurusError::invalid_argument(
"OpenAIEmbedder only supports text input",
)),
}
}
async fn embed_batch(&self, inputs: &[EmbedInput<'_>]) -> Result<Vec<Vector>> {
let texts: Vec<&str> = inputs
.iter()
.map(|input| match input {
EmbedInput::Text(text) => Ok(*text),
_ => Err(LaurusError::invalid_argument(
"OpenAIEmbedder only supports text input",
)),
})
.collect::<Result<Vec<_>>>()?;
self.embed_text_batch(&texts).await
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
&self.model
}
fn as_any(&self) -> &dyn Any {
self
}
}