graphrag_core/embeddings/
mod.rs1use crate::core::error::Result;
9
10#[cfg(feature = "huggingface-hub")]
12pub mod huggingface;
13
14#[cfg(feature = "ureq")]
16pub mod api_providers;
17
18#[cfg(feature = "ollama")]
20pub mod ollama;
21
22pub mod config;
24
25#[async_trait::async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28 async fn initialize(&mut self) -> Result<()>;
30
31 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
33
34 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
36
37 fn dimensions(&self) -> usize;
39
40 fn is_available(&self) -> bool;
42
43 fn provider_name(&self) -> &str;
45}
46
47#[derive(Debug, Clone)]
49pub struct EmbeddingConfig {
50 pub provider: EmbeddingProviderType,
52
53 pub model: String,
55
56 pub api_key: Option<String>,
58
59 pub cache_dir: Option<String>,
61
62 pub batch_size: usize,
64}
65
66impl Default for EmbeddingConfig {
67 fn default() -> Self {
68 Self {
69 provider: EmbeddingProviderType::HuggingFace,
70 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
71 api_key: None,
72 cache_dir: None,
73 batch_size: 32,
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq)]
80pub enum EmbeddingProviderType {
81 HuggingFace,
83
84 OpenAI,
86
87 VoyageAI,
89
90 Cohere,
92
93 JinaAI,
95
96 Mistral,
98
99 TogetherAI,
101
102 Onnx,
104
105 Candle,
107
108 Ollama,
110
111 Custom(String),
113}
114
115impl std::fmt::Display for EmbeddingProviderType {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 Self::HuggingFace => write!(f, "HuggingFace"),
119 Self::OpenAI => write!(f, "OpenAI"),
120 Self::VoyageAI => write!(f, "VoyageAI"),
121 Self::Cohere => write!(f, "Cohere"),
122 Self::JinaAI => write!(f, "JinaAI"),
123 Self::Mistral => write!(f, "Mistral"),
124 Self::TogetherAI => write!(f, "TogetherAI"),
125 Self::Onnx => write!(f, "ONNX"),
126 Self::Candle => write!(f, "Candle"),
127 Self::Ollama => write!(f, "Ollama"),
128 Self::Custom(name) => write!(f, "Custom({})", name),
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_default_config() {
139 let config = EmbeddingConfig::default();
140 assert_eq!(config.provider, EmbeddingProviderType::HuggingFace);
141 assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
142 assert_eq!(config.batch_size, 32);
143 }
144
145 #[test]
146 fn test_provider_display() {
147 assert_eq!(
148 EmbeddingProviderType::HuggingFace.to_string(),
149 "HuggingFace"
150 );
151 assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "OpenAI");
152 assert_eq!(EmbeddingProviderType::VoyageAI.to_string(), "VoyageAI");
153 }
154}