Skip to main content

graphrag_core/embeddings/
mod.rs

1//! Embedding generation for GraphRAG
2//!
3//! This module provides embedding generation capabilities using various backends:
4//! - Hugging Face Hub models (via hf-hub crate)
5//! - Local models (ONNX, Candle)
6//! - API providers (OpenAI, Voyage AI, Cohere, etc.)
7
8use crate::core::error::Result;
9
10/// Hugging Face Hub integration for downloading and using embedding models
11#[cfg(feature = "huggingface-hub")]
12pub mod huggingface;
13
14/// API-based embedding providers (OpenAI, Voyage AI, Cohere, etc.)
15#[cfg(feature = "ureq")]
16pub mod api_providers;
17
18/// Ollama embedding provider
19#[cfg(feature = "ollama")]
20pub mod ollama;
21
22/// TOML configuration for embedding providers
23pub mod config;
24
25/// Trait for embedding providers
26#[async_trait::async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28    /// Initialize the embedding provider
29    async fn initialize(&mut self) -> Result<()>;
30
31    /// Generate embedding for a single text
32    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
33
34    /// Generate embeddings for multiple texts (batch processing)
35    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
36
37    /// Get the embedding dimension
38    fn dimensions(&self) -> usize;
39
40    /// Check if the provider is available and ready
41    fn is_available(&self) -> bool;
42
43    /// Get the provider name
44    fn provider_name(&self) -> &str;
45}
46
47/// Configuration for embedding providers
48#[derive(Debug, Clone)]
49pub struct EmbeddingConfig {
50    /// Provider type (huggingface, openai, voyage, cohere, etc.)
51    pub provider: EmbeddingProviderType,
52
53    /// Model name/identifier
54    pub model: String,
55
56    /// API key (if required)
57    pub api_key: Option<String>,
58
59    /// Cache directory for downloaded models
60    pub cache_dir: Option<String>,
61
62    /// Batch size for processing multiple texts
63    pub batch_size: usize,
64}
65
66impl Default for EmbeddingConfig {
67    fn default() -> Self {
68        Self {
69            provider: EmbeddingProviderType::HuggingFace,
70            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
71            api_key: None,
72            cache_dir: None,
73            batch_size: 32,
74        }
75    }
76}
77
78/// Supported embedding provider types
79#[derive(Debug, Clone, PartialEq)]
80pub enum EmbeddingProviderType {
81    /// Hugging Face Hub models (free, downloadable)
82    HuggingFace,
83
84    /// OpenAI embeddings API
85    OpenAI,
86
87    /// Voyage AI embeddings API (recommended by Anthropic)
88    VoyageAI,
89
90    /// Cohere embeddings API
91    Cohere,
92
93    /// Jina AI embeddings API
94    JinaAI,
95
96    /// Mistral AI embeddings API
97    Mistral,
98
99    /// Together AI embeddings API
100    TogetherAI,
101
102    /// Local ONNX model
103    Onnx,
104
105    /// Local Candle model
106    Candle,
107
108    /// Local Ollama model
109    Ollama,
110
111    /// Custom provider
112    Custom(String),
113}
114
115impl std::fmt::Display for EmbeddingProviderType {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Self::HuggingFace => write!(f, "HuggingFace"),
119            Self::OpenAI => write!(f, "OpenAI"),
120            Self::VoyageAI => write!(f, "VoyageAI"),
121            Self::Cohere => write!(f, "Cohere"),
122            Self::JinaAI => write!(f, "JinaAI"),
123            Self::Mistral => write!(f, "Mistral"),
124            Self::TogetherAI => write!(f, "TogetherAI"),
125            Self::Onnx => write!(f, "ONNX"),
126            Self::Candle => write!(f, "Candle"),
127            Self::Ollama => write!(f, "Ollama"),
128            Self::Custom(name) => write!(f, "Custom({})", name),
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_default_config() {
139        let config = EmbeddingConfig::default();
140        assert_eq!(config.provider, EmbeddingProviderType::HuggingFace);
141        assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
142        assert_eq!(config.batch_size, 32);
143    }
144
145    #[test]
146    fn test_provider_display() {
147        assert_eq!(
148            EmbeddingProviderType::HuggingFace.to_string(),
149            "HuggingFace"
150        );
151        assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "OpenAI");
152        assert_eq!(EmbeddingProviderType::VoyageAI.to_string(), "VoyageAI");
153    }
154}