graphrag_core/embeddings/
mod.rs1use crate::core::error::Result;
9
10#[cfg(feature = "huggingface-hub")]
12pub mod huggingface;
13
14#[cfg(feature = "neural-embeddings")]
16pub mod neural;
17
18#[cfg(feature = "ureq")]
20pub mod api_providers;
21
22pub mod config;
24
25#[async_trait::async_trait]
27pub trait EmbeddingProvider: Send + Sync {
28 async fn initialize(&mut self) -> Result<()>;
30
31 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
33
34 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
36
37 fn dimensions(&self) -> usize;
39
40 fn is_available(&self) -> bool;
42
43 fn provider_name(&self) -> &str;
45}
46
47#[derive(Debug, Clone)]
49pub struct EmbeddingConfig {
50 pub provider: EmbeddingProviderType,
52
53 pub model: String,
55
56 pub api_key: Option<String>,
58
59 pub cache_dir: Option<String>,
61
62 pub batch_size: usize,
64}
65
66impl Default for EmbeddingConfig {
67 fn default() -> Self {
68 Self {
69 provider: EmbeddingProviderType::HuggingFace,
70 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
71 api_key: None,
72 cache_dir: None,
73 batch_size: 32,
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq)]
80pub enum EmbeddingProviderType {
81 HuggingFace,
83
84 OpenAI,
86
87 VoyageAI,
89
90 Cohere,
92
93 JinaAI,
95
96 Mistral,
98
99 TogetherAI,
101
102 Onnx,
104
105 Candle,
107
108 Custom(String),
110}
111
112impl std::fmt::Display for EmbeddingProviderType {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 Self::HuggingFace => write!(f, "HuggingFace"),
116 Self::OpenAI => write!(f, "OpenAI"),
117 Self::VoyageAI => write!(f, "VoyageAI"),
118 Self::Cohere => write!(f, "Cohere"),
119 Self::JinaAI => write!(f, "JinaAI"),
120 Self::Mistral => write!(f, "Mistral"),
121 Self::TogetherAI => write!(f, "TogetherAI"),
122 Self::Onnx => write!(f, "ONNX"),
123 Self::Candle => write!(f, "Candle"),
124 Self::Custom(name) => write!(f, "Custom({})", name),
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_default_config() {
135 let config = EmbeddingConfig::default();
136 assert_eq!(config.provider, EmbeddingProviderType::HuggingFace);
137 assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
138 assert_eq!(config.batch_size, 32);
139 }
140
141 #[test]
142 fn test_provider_display() {
143 assert_eq!(
144 EmbeddingProviderType::HuggingFace.to_string(),
145 "HuggingFace"
146 );
147 assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "OpenAI");
148 assert_eq!(EmbeddingProviderType::VoyageAI.to_string(), "VoyageAI");
149 }
150}