use crate::error::{NpcError, Result};
pub async fn get_ollama_embeddings(text: &str, model: &str) -> Result<Vec<f32>> {
let api_url =
std::env::var("OLLAMA_API_URL").unwrap_or_else(|_| "http://localhost:11434".into());
let url = format!("{}/api/embeddings", api_url);
let body = serde_json::json!({
"model": model,
"prompt": text,
});
let client = reqwest::Client::new();
let resp = client.post(&url).json(&body).send().await?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(NpcError::LlmRequest(format!(
"Ollama embeddings failed: {}",
text
)));
}
let json: serde_json::Value = resp.json().await?;
let embedding = json
.get("embedding")
.and_then(|e| e.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default();
Ok(embedding)
}
pub async fn get_openai_embeddings(
text: &str,
model: &str,
api_key: Option<&str>,
) -> Result<Vec<f32>> {
let key = api_key
.map(String::from)
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.ok_or_else(|| NpcError::LlmRequest("OPENAI_API_KEY not set".into()))?;
let body = serde_json::json!({
"model": model,
"input": text,
});
let client = reqwest::Client::new();
let resp = client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", key))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(NpcError::LlmRequest(format!(
"OpenAI embeddings failed: {}",
text
)));
}
let json: serde_json::Value = resp.json().await?;
let embedding = json["data"][0]["embedding"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default();
Ok(embedding)
}
pub async fn get_embeddings(
text: &str,
model: &str,
provider: &str,
api_key: Option<&str>,
) -> Result<Vec<f32>> {
match provider {
"ollama" => get_ollama_embeddings(text, model).await,
"openai" => get_openai_embeddings(text, model, api_key).await,
_ => Err(NpcError::UnsupportedProvider {
provider: provider.to_string(),
}),
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}