Skip to main content

leann_core/embedding/
mod.rs

1pub mod manager;
2pub mod onnx;
3
4#[cfg(feature = "embedding-remote")]
5pub mod gemini;
6#[cfg(feature = "embedding-remote")]
7pub mod ollama;
8#[cfg(feature = "embedding-remote")]
9pub mod openai;
10
11use anyhow::Result;
12use ndarray::Array2;
13use std::collections::HashMap;
14
15/// Trait for embedding computation backends.
16pub trait EmbeddingProvider: Send + Sync {
17    /// Compute embeddings for a batch of text chunks.
18    fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>>;
19
20    /// Get the dimensionality of the embeddings.
21    fn dimensions(&self) -> usize;
22
23    /// Get the provider name.
24    fn name(&self) -> &str;
25}
26
27/// Embedding mode enum matching Python's embedding modes.
28#[derive(Debug, Clone, PartialEq)]
29pub enum EmbeddingMode {
30    SentenceTransformers,
31    OpenAI,
32    Ollama,
33    Gemini,
34    Mlx,
35}
36
37impl EmbeddingMode {
38    pub fn from_str_lossy(s: &str) -> Self {
39        match s.to_lowercase().as_str() {
40            "openai" => EmbeddingMode::OpenAI,
41            "ollama" => EmbeddingMode::Ollama,
42            "gemini" => EmbeddingMode::Gemini,
43            "mlx" => EmbeddingMode::Mlx,
44            _ => EmbeddingMode::SentenceTransformers,
45        }
46    }
47
48    pub fn as_str(&self) -> &str {
49        match self {
50            EmbeddingMode::SentenceTransformers => "sentence-transformers",
51            EmbeddingMode::OpenAI => "openai",
52            EmbeddingMode::Ollama => "ollama",
53            EmbeddingMode::Gemini => "gemini",
54            EmbeddingMode::Mlx => "mlx",
55        }
56    }
57}
58
59/// Create an embedding provider from mode, model name, and options map.
60///
61/// The `options` map may contain provider-specific keys:
62/// - `"host"` — Ollama host override
63/// - `"api_key"` — API key for OpenAI/Gemini
64/// - `"base_url"` — Base URL for OpenAI-compatible services
65#[cfg(feature = "embedding-remote")]
66pub fn create_embedding_provider(
67    mode: &EmbeddingMode,
68    model: &str,
69    options: &HashMap<String, serde_json::Value>,
70) -> Result<Box<dyn EmbeddingProvider>> {
71    match mode {
72        EmbeddingMode::OpenAI => {
73            let api_key = options.get("api_key").and_then(|v| v.as_str());
74            let base_url = options.get("base_url").and_then(|v| v.as_str());
75            let provider = openai::OpenAiEmbedding::new(model, api_key, base_url, None)?;
76            Ok(Box::new(provider))
77        }
78        EmbeddingMode::Ollama => {
79            let host = options.get("host").and_then(|v| v.as_str());
80            let provider = ollama::OllamaEmbedding::new(model, host);
81            Ok(Box::new(provider))
82        }
83        EmbeddingMode::Gemini => {
84            let api_key = options.get("api_key").and_then(|v| v.as_str());
85            let provider = gemini::GeminiEmbedding::new(model, api_key)?;
86            Ok(Box::new(provider))
87        }
88        _ => {
89            // sentence-transformers / mlx: try OpenAI, fall back to Ollama
90            if let Ok(provider) =
91                openai::OpenAiEmbedding::new("text-embedding-3-small", None, None, None)
92            {
93                Ok(Box::new(provider))
94            } else {
95                let provider = ollama::OllamaEmbedding::new("nomic-embed-text", None);
96                Ok(Box::new(provider))
97            }
98        }
99    }
100}