use async_trait::async_trait;
use serde::Deserialize;
use super::client::Ollama;
use crate::embedding::{
Embedding, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage,
};
use crate::error::{LlmError, Result};
const DEFAULT_EMBEDDING_MODEL: &str = "nomic-embed-text";
#[derive(Debug, Clone, Deserialize)]
struct OllamaEmbeddingResponse {
embeddings: Vec<Vec<f32>>,
#[serde(default)]
prompt_eval_count: Option<u32>,
}
#[async_trait]
impl EmbeddingProvider for Ollama {
async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
let url = self.embeddings_url();
let response = self.client().post(&url).json(request).send().await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(Self::parse_error(status.as_u16(), &error_text).into());
}
let response_text = response.text().await?;
let parsed: OllamaEmbeddingResponse =
serde_json::from_str(&response_text).map_err(|e| {
LlmError::response_format(
"valid Ollama embedding response",
format!("parse error: {e}, response: {response_text}"),
)
})?;
let embeddings = parsed
.embeddings
.into_iter()
.enumerate()
.map(|(i, vector)| Embedding::new(vector, i))
.collect();
let usage = parsed.prompt_eval_count.map(|tokens| EmbeddingUsage {
prompt_tokens: tokens,
total_tokens: tokens,
});
Ok(EmbeddingResponse {
embeddings,
model: Some(request.model.clone()),
usage,
})
}
fn default_embedding_model(&self) -> &str {
DEFAULT_EMBEDDING_MODEL
}
fn embedding_dimension(&self) -> Option<usize> {
None
}
}