use async_trait::async_trait;
use serde::Deserialize;
use super::client::OpenAI;
use crate::embedding::{
Embedding, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage,
};
use crate::error::{LlmError, Result};
#[derive(Debug, Clone, Deserialize)]
struct OpenAIEmbeddingData {
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIEmbeddingResponse {
pub data: Vec<OpenAIEmbeddingData>,
pub model: String,
pub usage: Option<EmbeddingUsage>,
}
const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-3-small";
const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
#[async_trait]
impl EmbeddingProvider for OpenAI {
async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
let url = self.embeddings_url();
let response = self.build_request(&url).json(request).send().await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(Self::parse_error(status.as_u16(), &error_text).into());
}
let response_text = response.text().await?;
let parsed: OpenAIEmbeddingResponse =
serde_json::from_str(&response_text).map_err(|e| {
LlmError::response_format(
"valid OpenAI embedding response",
format!("parse error: {e}, response: {response_text}"),
)
})?;
let embeddings = parsed
.data
.into_iter()
.map(|d| Embedding::new(d.embedding, d.index))
.collect();
let usage = parsed.usage;
Ok(EmbeddingResponse {
embeddings,
model: Some(parsed.model),
usage,
})
}
fn default_embedding_model(&self) -> &str {
DEFAULT_EMBEDDING_MODEL
}
fn embedding_dimension(&self) -> Option<usize> {
Some(DEFAULT_EMBEDDING_DIMENSION)
}
}