1mod openai;
2
3#[cfg(feature = "fastembed")]
4mod fastembed_provider;
5
6pub use openai::OpenAICompatibleEmbedding;
7
8#[cfg(feature = "fastembed")]
9pub use fastembed_provider::FastEmbedClient;
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use mem7_config::EmbeddingConfig;
15use mem7_error::{Mem7Error, Result};
16
17#[async_trait]
18pub trait EmbeddingClient: Send + Sync {
19 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
20}
21
22pub fn create_embedding(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingClient>> {
25 match config.provider.as_str() {
26 "openai" | "ollama" | "vllm" | "lmstudio" | "deepseek" => {
27 Ok(Arc::new(OpenAICompatibleEmbedding::new(config.clone())))
28 }
29 #[cfg(feature = "fastembed")]
30 "fastembed" => Ok(Arc::new(FastEmbedClient::new(
31 &config.model,
32 config.cache_dir.as_deref(),
33 )?)),
34 #[cfg(not(feature = "fastembed"))]
35 "fastembed" => Err(Mem7Error::Config(
36 "fastembed provider requires the `fastembed` feature to be enabled".into(),
37 )),
38 other => Err(Mem7Error::Config(format!(
39 "unknown embedding provider: {other}"
40 ))),
41 }
42}