Skip to main content

graphrag_core/embeddings/
mod.rs

1//! Embedding generation for GraphRAG
2//!
3//! This module provides embedding generation capabilities using various backends:
4//! - Hugging Face Hub models (via hf-hub crate)
5//! - Local models (ONNX, Candle)
6//! - API providers (OpenAI, Voyage AI, Cohere, etc.)
7
8use crate::core::error::Result;
9
10/// Hugging Face Hub integration for downloading and using embedding models
11#[cfg(feature = "huggingface-hub")]
12pub mod huggingface;
13
14/// Neural embedding models (local inference)
15#[cfg(feature = "neural-embeddings")]
16pub mod neural;
17
18/// API-based embedding providers (OpenAI, Voyage AI, Cohere, etc.)
19#[cfg(feature = "ureq")]
20pub mod api_providers;
21
22/// TOML configuration for embedding providers
23pub mod config;
24
25/// Trait for embedding providers
26#[async_trait::async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28    /// Initialize the embedding provider
29    async fn initialize(&mut self) -> Result<()>;
30
31    /// Generate embedding for a single text
32    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
33
34    /// Generate embeddings for multiple texts (batch processing)
35    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
36
37    /// Get the embedding dimension
38    fn dimensions(&self) -> usize;
39
40    /// Check if the provider is available and ready
41    fn is_available(&self) -> bool;
42
43    /// Get the provider name
44    fn provider_name(&self) -> &str;
45}
46
47/// Configuration for embedding providers
48#[derive(Debug, Clone)]
49pub struct EmbeddingConfig {
50    /// Provider type (huggingface, openai, voyage, cohere, etc.)
51    pub provider: EmbeddingProviderType,
52
53    /// Model name/identifier
54    pub model: String,
55
56    /// API key (if required)
57    pub api_key: Option<String>,
58
59    /// Cache directory for downloaded models
60    pub cache_dir: Option<String>,
61
62    /// Batch size for processing multiple texts
63    pub batch_size: usize,
64}
65
66impl Default for EmbeddingConfig {
67    fn default() -> Self {
68        Self {
69            provider: EmbeddingProviderType::HuggingFace,
70            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
71            api_key: None,
72            cache_dir: None,
73            batch_size: 32,
74        }
75    }
76}
77
78/// Supported embedding provider types
79#[derive(Debug, Clone, PartialEq)]
80pub enum EmbeddingProviderType {
81    /// Hugging Face Hub models (free, downloadable)
82    HuggingFace,
83
84    /// OpenAI embeddings API
85    OpenAI,
86
87    /// Voyage AI embeddings API (recommended by Anthropic)
88    VoyageAI,
89
90    /// Cohere embeddings API
91    Cohere,
92
93    /// Jina AI embeddings API
94    JinaAI,
95
96    /// Mistral AI embeddings API
97    Mistral,
98
99    /// Together AI embeddings API
100    TogetherAI,
101
102    /// Local ONNX model
103    Onnx,
104
105    /// Local Candle model
106    Candle,
107
108    /// Custom provider
109    Custom(String),
110}
111
112impl std::fmt::Display for EmbeddingProviderType {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            Self::HuggingFace => write!(f, "HuggingFace"),
116            Self::OpenAI => write!(f, "OpenAI"),
117            Self::VoyageAI => write!(f, "VoyageAI"),
118            Self::Cohere => write!(f, "Cohere"),
119            Self::JinaAI => write!(f, "JinaAI"),
120            Self::Mistral => write!(f, "Mistral"),
121            Self::TogetherAI => write!(f, "TogetherAI"),
122            Self::Onnx => write!(f, "ONNX"),
123            Self::Candle => write!(f, "Candle"),
124            Self::Custom(name) => write!(f, "Custom({})", name),
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_default_config() {
135        let config = EmbeddingConfig::default();
136        assert_eq!(config.provider, EmbeddingProviderType::HuggingFace);
137        assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
138        assert_eq!(config.batch_size, 32);
139    }
140
141    #[test]
142    fn test_provider_display() {
143        assert_eq!(
144            EmbeddingProviderType::HuggingFace.to_string(),
145            "HuggingFace"
146        );
147        assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "OpenAI");
148        assert_eq!(EmbeddingProviderType::VoyageAI.to_string(), "VoyageAI");
149    }
150}