use crate::core::error::{GraphRAGError, Result};
use crate::embeddings::EmbeddingProvider;
use ollama_rs::Ollama;
pub struct OllamaEmbeddings {
model: String,
client: Ollama,
dimensions: usize,
}
impl OllamaEmbeddings {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
client: Ollama::default(),
dimensions: 1024, }
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for OllamaEmbeddings {
async fn initialize(&mut self) -> Result<()> {
match self.client.list_local_models().await {
Ok(_) => Ok(()),
Err(e) => Err(GraphRAGError::Embedding {
message: format!("Failed to connect to Ollama: {}", e),
}),
}
}
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
use ollama_rs::generation::embeddings::request::EmbeddingsInput;
let embeddings = self
.client
.generate_embeddings(
ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest::new(
self.model.clone(),
EmbeddingsInput::Single(text.to_string()),
),
)
.await
.map_err(|e| GraphRAGError::Embedding {
message: format!("Ollama embedding generation failed: {}", e),
})?;
let embedding: Vec<f32> = embeddings
.embeddings
.first()
.ok_or_else(|| GraphRAGError::Embedding {
message: "No embeddings returned from Ollama".to_string(),
})?
.to_vec();
Ok(embedding)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn is_available(&self) -> bool {
true
}
fn provider_name(&self) -> &str {
"Ollama"
}
}