use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::LlmError;
use crate::traits::{EmbeddingCapability, EmbeddingExtensions, OpenAiEmbeddingCapability};
use crate::types::{
EmbeddingFormat, EmbeddingModelInfo, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage,
};
use super::config::OpenAiConfig;
#[derive(Debug, Clone, Serialize)]
struct OpenAiEmbeddingRequest {
input: Vec<String>,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbeddingObject>,
model: String,
usage: OpenAiEmbeddingUsage,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiEmbeddingObject {
embedding: Vec<f32>,
index: usize,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiEmbeddingUsage {
prompt_tokens: u32,
total_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct OpenAiEmbeddings {
config: OpenAiConfig,
http_client: reqwest::Client,
}
impl OpenAiEmbeddings {
pub const fn new(config: OpenAiConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
fn default_model(&self) -> String {
"text-embedding-3-small".to_string()
}
fn build_request(&self, request: &EmbeddingRequest) -> OpenAiEmbeddingRequest {
let model = request
.model
.clone()
.or_else(|| {
if !self.config.common_params.model.is_empty() {
Some(self.config.common_params.model.clone())
} else {
None
}
})
.unwrap_or_else(|| self.default_model());
let encoding_format = request.encoding_format.as_ref().map(|f| match f {
EmbeddingFormat::Float => "float".to_string(),
EmbeddingFormat::Base64 => "base64".to_string(),
});
OpenAiEmbeddingRequest {
input: request.input.clone(),
model,
encoding_format,
dimensions: request.dimensions,
user: request
.user
.clone()
.or_else(|| self.config.openai_params.user.clone()),
}
}
async fn make_request(
&self,
request: OpenAiEmbeddingRequest,
) -> Result<OpenAiEmbeddingResponse, LlmError> {
let url = format!("{}/embeddings", self.config.base_url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let response = self
.http_client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("OpenAI API error {status}: {error_text}"),
details: None,
});
}
let openai_response: OpenAiEmbeddingResponse = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse response: {e}")))?;
Ok(openai_response)
}
fn convert_response(&self, openai_response: OpenAiEmbeddingResponse) -> EmbeddingResponse {
let mut embeddings_with_index: Vec<_> = openai_response.data.into_iter().collect();
embeddings_with_index.sort_by_key(|obj| obj.index);
let embeddings: Vec<Vec<f32>> = embeddings_with_index
.into_iter()
.map(|obj| obj.embedding)
.collect();
EmbeddingResponse::new(embeddings, openai_response.model).with_usage(EmbeddingUsage::new(
openai_response.usage.prompt_tokens,
openai_response.usage.total_tokens,
))
}
fn get_model_info(&self, model_id: &str) -> EmbeddingModelInfo {
match model_id {
"text-embedding-3-small" => EmbeddingModelInfo::new(
model_id.to_string(),
"Text Embedding 3 Small".to_string(),
1536,
8192,
)
.with_custom_dimensions(),
"text-embedding-3-large" => EmbeddingModelInfo::new(
model_id.to_string(),
"Text Embedding 3 Large".to_string(),
3072,
8192,
)
.with_custom_dimensions(),
"text-embedding-ada-002" => EmbeddingModelInfo::new(
model_id.to_string(),
"Text Embedding Ada 002 (Legacy)".to_string(),
1536,
8192,
),
_ => EmbeddingModelInfo::new(model_id.to_string(), model_id.to_string(), 1536, 8192),
}
}
}
#[async_trait]
impl EmbeddingCapability for OpenAiEmbeddings {
async fn embed(&self, input: Vec<String>) -> Result<EmbeddingResponse, LlmError> {
if input.is_empty() {
return Err(LlmError::InvalidInput("Input cannot be empty".to_string()));
}
let request = EmbeddingRequest::new(input);
self.embed_with_config(request).await
}
fn embedding_dimension(&self) -> usize {
let model = if !self.config.common_params.model.is_empty() {
&self.config.common_params.model
} else {
"text-embedding-3-small"
};
match model {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
_ => 1536, }
}
fn max_tokens_per_embedding(&self) -> usize {
8192 }
fn supported_embedding_models(&self) -> Vec<String> {
vec![
"text-embedding-3-small".to_string(),
"text-embedding-3-large".to_string(),
"text-embedding-ada-002".to_string(),
]
}
}
#[async_trait]
impl EmbeddingExtensions for OpenAiEmbeddings {
async fn embed_with_config(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, LlmError> {
if request.input.is_empty() {
return Err(LlmError::InvalidInput("Input cannot be empty".to_string()));
}
let openai_request = self.build_request(&request);
let openai_response = self.make_request(openai_request).await?;
Ok(self.convert_response(openai_response))
}
async fn list_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>, LlmError> {
let models = self.supported_embedding_models();
let model_infos = models
.into_iter()
.map(|id| self.get_model_info(&id))
.collect();
Ok(model_infos)
}
}
#[async_trait]
impl OpenAiEmbeddingCapability for OpenAiEmbeddings {
async fn embed_with_dimensions(
&self,
input: Vec<String>,
dimensions: u32,
) -> Result<EmbeddingResponse, LlmError> {
let request = EmbeddingRequest::new(input).with_dimensions(dimensions);
self.embed_with_config(request).await
}
async fn embed_with_format(
&self,
input: Vec<String>,
format: EmbeddingFormat,
) -> Result<EmbeddingResponse, LlmError> {
let request = EmbeddingRequest::new(input).with_encoding_format(format);
self.embed_with_config(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::openai::OpenAiConfig;
#[test]
fn test_embedding_dimensions() {
let config = OpenAiConfig::new("test-key");
let client = reqwest::Client::new();
let embeddings = OpenAiEmbeddings::new(config, client);
assert_eq!(embeddings.embedding_dimension(), 1536);
assert_eq!(embeddings.max_tokens_per_embedding(), 8192);
}
#[test]
fn test_supported_models() {
let config = OpenAiConfig::new("test-key");
let client = reqwest::Client::new();
let embeddings = OpenAiEmbeddings::new(config, client);
let models = embeddings.supported_embedding_models();
assert!(models.contains(&"text-embedding-3-small".to_string()));
assert!(models.contains(&"text-embedding-3-large".to_string()));
assert!(models.contains(&"text-embedding-ada-002".to_string()));
}
#[test]
fn test_model_info() {
let config = OpenAiConfig::new("test-key");
let client = reqwest::Client::new();
let embeddings = OpenAiEmbeddings::new(config, client);
let info = embeddings.get_model_info("text-embedding-3-small");
assert_eq!(info.id, "text-embedding-3-small");
assert_eq!(info.dimension, 1536);
assert!(info.supports_custom_dimensions);
}
#[test]
fn test_build_request() {
let config = OpenAiConfig::new("test-key");
let client = reqwest::Client::new();
let embeddings = OpenAiEmbeddings::new(config, client);
let request = EmbeddingRequest::new(vec!["test".to_string()])
.with_model("text-embedding-3-large")
.with_dimensions(2048)
.with_encoding_format(EmbeddingFormat::Float);
let openai_request = embeddings.build_request(&request);
assert_eq!(openai_request.model, "text-embedding-3-large");
assert_eq!(openai_request.dimensions, Some(2048));
assert_eq!(openai_request.encoding_format, Some("float".to_string()));
}
}