use crate::Result;
pub struct EmbeddingClient {
provider: String,
model: String,
api_key: Option<String>,
endpoint: String,
client: reqwest::Client,
}
impl EmbeddingClient {
pub fn new(
provider: impl Into<String>,
model: impl Into<String>,
api_key: Option<String>,
endpoint: Option<String>,
) -> Self {
let provider = provider.into();
let endpoint = endpoint.unwrap_or_else(|| match provider.as_str() {
"openai" => "https://api.openai.com/v1/embeddings".to_string(),
_ => "http://localhost:8080/v1/embeddings".to_string(),
});
Self {
provider,
model: model.into(),
api_key,
endpoint,
client: reqwest::Client::new(),
}
}
pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
match self.provider.as_str() {
"openai" | "openai-compatible" => self.generate_openai_embedding(text).await,
_ => Err(crate::Error::internal(format!(
"Unsupported embedding provider: {}",
self.provider
))),
}
}
pub async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::new();
for text in texts {
let embedding = self.generate_embedding(&text).await?;
embeddings.push(embedding);
}
Ok(embeddings)
}
async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
let api_key = self
.api_key
.clone()
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.ok_or_else(|| crate::Error::internal("OpenAI API key not found"))?;
let request_body = serde_json::json!({
"model": self.model,
"input": text,
});
let mut request =
self.client.post(&self.endpoint).header("Content-Type", "application/json");
if !api_key.is_empty() {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response =
request.json(&request_body).send().await.map_err(|e| {
crate::Error::internal(format!("Embedding API request failed: {}", e))
})?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("Embedding API error: {}", error_text)));
}
let response_json: serde_json::Value = response.json().await.map_err(|e| {
crate::Error::config(format!("Failed to parse embedding response: {}", e))
})?;
let embedding: Vec<f32> = response_json["data"][0]["embedding"]
.as_array()
.ok_or_else(|| crate::Error::internal("Invalid embedding response format"))?
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if embedding.is_empty() {
return Err(crate::Error::internal("Empty embedding returned"));
}
Ok(embedding)
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![1.0, 0.0];
let d = vec![0.0, 1.0];
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
}
#[test]
fn test_embedding_client_creation() {
let client = EmbeddingClient::new(
"openai",
"text-embedding-ada-002",
Some("test_key".to_string()),
None,
);
assert_eq!(client.provider, "openai");
assert_eq!(client.model, "text-embedding-ada-002");
}
}