use crate::core::service::ServiceError;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbeddingData {
embedding: Vec<f32>,
}
#[async_trait]
pub trait EmbeddingService: Send + Sync {
async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError>;
async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError>;
}
pub struct OpenAIEmbeddingService {
client: Client,
base_url: String,
model: String,
api_key: String,
}
impl OpenAIEmbeddingService {
pub fn new(base_url: String, model: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url,
model,
api_key,
}
}
pub fn from_config(config: &crate::core::service::EmbeddingConfig, api_key: String) -> Self {
Self::new(
config.openai_base_url.clone(),
config.embedding_model.clone(),
api_key,
)
}
async fn call_openai_api(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
#[derive(Serialize)]
struct OpenAIRequest {
input: String,
model: String,
}
let request = OpenAIRequest {
input: text.to_string(),
model: self.model.clone(),
};
let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| ServiceError::Custom(format!("OpenAI API request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ServiceError::Custom(format!(
"OpenAI API error {}: {}",
status, body
)));
}
let embedding_response: OpenAIEmbeddingResponse = response
.json()
.await
.map_err(|e| ServiceError::Custom(format!("Failed to parse OpenAI response: {}", e)))?;
if embedding_response.data.is_empty() {
return Err(ServiceError::Custom(
"No embeddings returned from OpenAI".to_string(),
));
}
Ok(embedding_response.data[0].embedding.clone())
}
}
#[async_trait]
impl EmbeddingService for OpenAIEmbeddingService {
async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
self.call_openai_api(text).await
}
async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError> {
self.call_openai_api(query).await
}
}