use crate::core::error::Result;
#[cfg(feature = "huggingface-hub")]
pub mod huggingface;
#[cfg(feature = "ureq")]
pub mod api_providers;
#[cfg(feature = "ollama")]
pub mod ollama;
pub mod config;
#[async_trait::async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn initialize(&mut self) -> Result<()>;
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
fn dimensions(&self) -> usize;
fn is_available(&self) -> bool;
fn provider_name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub provider: EmbeddingProviderType,
pub model: String,
pub api_key: Option<String>,
pub cache_dir: Option<String>,
pub batch_size: usize,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
provider: EmbeddingProviderType::HuggingFace,
model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
api_key: None,
cache_dir: None,
batch_size: 32,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum EmbeddingProviderType {
HuggingFace,
OpenAI,
VoyageAI,
Cohere,
JinaAI,
Mistral,
TogetherAI,
Onnx,
Candle,
Ollama,
Custom(String),
}
impl std::fmt::Display for EmbeddingProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HuggingFace => write!(f, "HuggingFace"),
Self::OpenAI => write!(f, "OpenAI"),
Self::VoyageAI => write!(f, "VoyageAI"),
Self::Cohere => write!(f, "Cohere"),
Self::JinaAI => write!(f, "JinaAI"),
Self::Mistral => write!(f, "Mistral"),
Self::TogetherAI => write!(f, "TogetherAI"),
Self::Onnx => write!(f, "ONNX"),
Self::Candle => write!(f, "Candle"),
Self::Ollama => write!(f, "Ollama"),
Self::Custom(name) => write!(f, "Custom({})", name),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.provider, EmbeddingProviderType::HuggingFace);
assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
assert_eq!(config.batch_size, 32);
}
#[test]
fn test_provider_display() {
assert_eq!(
EmbeddingProviderType::HuggingFace.to_string(),
"HuggingFace"
);
assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "OpenAI");
assert_eq!(EmbeddingProviderType::VoyageAI.to_string(), "VoyageAI");
}
}