Skip to main content

leann_core/embedding/
mod.rs

1pub mod manager;
2pub mod onnx;
3
4#[cfg(feature = "embedding-remote")]
5pub mod gemini;
6#[cfg(feature = "embedding-remote")]
7pub mod ollama;
8#[cfg(feature = "embedding-remote")]
9pub mod openai;
10
11use anyhow::Result;
12use ndarray::Array2;
13use std::collections::HashMap;
14
15/// Trait for embedding computation backends.
16pub trait EmbeddingProvider: Send + Sync {
17    /// Compute embeddings for a batch of text chunks.
18    fn compute_embeddings(
19        &self,
20        chunks: &[String],
21        progress: Option<&dyn crate::hnsw::IndexProgress>,
22    ) -> Result<Array2<f32>>;
23
24    /// Get the dimensionality of the embeddings.
25    fn dimensions(&self) -> usize;
26
27    /// Get the provider name.
28    fn name(&self) -> &str;
29}
30
31/// Embedding mode enum matching Python's embedding modes.
32#[derive(Debug, Clone, PartialEq)]
33pub enum EmbeddingMode {
34    SentenceTransformers,
35    OpenAI,
36    Ollama,
37    Gemini,
38    Mlx,
39}
40
41impl EmbeddingMode {
42    pub fn from_str_lossy(s: &str) -> Self {
43        match s.to_lowercase().as_str() {
44            "openai" => EmbeddingMode::OpenAI,
45            "ollama" => EmbeddingMode::Ollama,
46            "gemini" => EmbeddingMode::Gemini,
47            "mlx" => EmbeddingMode::Mlx,
48            _ => EmbeddingMode::SentenceTransformers,
49        }
50    }
51
52    pub fn as_str(&self) -> &str {
53        match self {
54            EmbeddingMode::SentenceTransformers => "sentence-transformers",
55            EmbeddingMode::OpenAI => "openai",
56            EmbeddingMode::Ollama => "ollama",
57            EmbeddingMode::Gemini => "gemini",
58            EmbeddingMode::Mlx => "mlx",
59        }
60    }
61}
62
63/// Create an embedding provider from mode, model name, and options map.
64///
65/// The `options` map may contain provider-specific keys:
66/// - `"host"` — Ollama host override
67/// - `"api_key"` — API key for OpenAI/Gemini
68/// - `"base_url"` — Base URL for OpenAI-compatible services
69#[cfg(feature = "embedding-remote")]
70pub fn create_embedding_provider(
71    mode: &EmbeddingMode,
72    model: &str,
73    options: &HashMap<String, serde_json::Value>,
74) -> Result<Box<dyn EmbeddingProvider>> {
75    match mode {
76        EmbeddingMode::OpenAI => {
77            let api_key = options.get("api_key").and_then(|v| v.as_str());
78            let base_url = options.get("base_url").and_then(|v| v.as_str());
79            let provider = openai::OpenAiEmbedding::new(model, api_key, base_url, None)?;
80            Ok(Box::new(provider))
81        }
82        EmbeddingMode::Ollama => {
83            let host = options.get("host").and_then(|v| v.as_str());
84            let provider = ollama::OllamaEmbedding::new(model, host);
85            Ok(Box::new(provider))
86        }
87        EmbeddingMode::Gemini => {
88            let api_key = options.get("api_key").and_then(|v| v.as_str());
89            let provider = gemini::GeminiEmbedding::new(model, api_key)?;
90            Ok(Box::new(provider))
91        }
92        _ => {
93            // sentence-transformers / mlx: try OpenAI, fall back to Ollama
94            if let Ok(provider) =
95                openai::OpenAiEmbedding::new("text-embedding-3-small", None, None, None)
96            {
97                Ok(Box::new(provider))
98            } else {
99                let provider = ollama::OllamaEmbedding::new("nomic-embed-text", None);
100                Ok(Box::new(provider))
101            }
102        }
103    }
104}