use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct EmbedderConfig {
pub provider: String,
pub model: String,
pub api_key_env: Option<String>,
pub base_url: Option<String>,
pub cache_dir: Option<std::path::PathBuf>,
}
pub trait Embedder: Send + Sync {
fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
}
pub fn build_embedder(config: &EmbedderConfig) -> Result<Box<dyn Embedder + Send + Sync>> {
match config.provider.as_str() {
"openai" | "ollama" => Ok(Box::new(RestEmbedder {
config: config.clone(),
})),
#[cfg(feature = "fastembed-embed")]
"fastembed" => Ok(Box::new(FastEmbedEmbedder::new(config)?)),
other => Err(Error::Embed(format!(
"unknown embedding provider `{other}`; supported values: openai, ollama{}",
if cfg!(feature = "fastembed-embed") {
", fastembed"
} else {
""
}
))),
}
}
struct RestEmbedder {
config: EmbedderConfig,
}
impl Embedder for RestEmbedder {
fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
match self.config.provider.as_str() {
"openai" => embed_openai(&self.config, texts),
"ollama" => embed_ollama(&self.config, texts),
other => Err(Error::Embed(format!("unknown REST provider `{other}`"))),
}
}
}
#[cfg(feature = "fastembed-embed")]
struct FastEmbedEmbedder {
model: std::sync::Mutex<fastembed::TextEmbedding>,
}
#[cfg(feature = "fastembed-embed")]
impl FastEmbedEmbedder {
fn new(config: &EmbedderConfig) -> Result<Self> {
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
let model_variant = match config.model.as_str() {
"AllMiniLML6V2" => EmbeddingModel::AllMiniLML6V2,
"BGESmallENV15" => EmbeddingModel::BGESmallENV15,
"BGEBaseENV15" => EmbeddingModel::BGEBaseENV15,
"BGELargeENV15" => EmbeddingModel::BGELargeENV15,
"NomicEmbedTextV1" => EmbeddingModel::NomicEmbedTextV1,
"NomicEmbedTextV15" => EmbeddingModel::NomicEmbedTextV15,
"MultilingualE5Small" => EmbeddingModel::MultilingualE5Small,
"MultilingualE5Base" => EmbeddingModel::MultilingualE5Base,
"MultilingualE5Large" => EmbeddingModel::MultilingualE5Large,
other => {
return Err(Error::Embed(format!(
"unknown fastembed model `{other}`; \
supported: AllMiniLML6V2, BGESmallENV15, BGEBaseENV15, BGELargeENV15, \
NomicEmbedTextV1, NomicEmbedTextV15, \
MultilingualE5Small, MultilingualE5Base, MultilingualE5Large"
)));
}
};
let cache_dir = config
.cache_dir
.clone()
.unwrap_or_else(|| std::env::temp_dir().join("fastembed"));
let model = TextEmbedding::try_new(
InitOptions::new(model_variant)
.with_cache_dir(cache_dir)
.with_show_download_progress(true),
)
.map_err(|e| Error::Embed(format!("failed to load fastembed model: {e}")))?;
Ok(Self {
model: std::sync::Mutex::new(model),
})
}
}
#[cfg(feature = "fastembed-embed")]
impl Embedder for FastEmbedEmbedder {
fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
self.model
.lock()
.unwrap()
.embed(texts_owned, None)
.map_err(|e| Error::Embed(format!("fastembed embedding failed: {e}")))
}
}
fn embed_openai(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let api_key_env = config.api_key_env.as_deref().unwrap_or("OPENAI_API_KEY");
let api_key = std::env::var(api_key_env)
.map_err(|_| Error::Embed(format!("environment variable `{api_key_env}` is not set")))?;
let base_url = config
.base_url
.as_deref()
.unwrap_or("https://api.openai.com");
let url = format!("{base_url}/v1/embeddings");
let body = serde_json::json!({
"model": config.model,
"input": texts,
});
let response: serde_json::Value = ureq::post(&url)
.header("Authorization", &format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.send_json(body)
.map_err(|e| Error::Embed(e.to_string()))?
.into_body()
.read_json()
.map_err(|e| Error::Embed(e.to_string()))?;
parse_openai_response(&response, texts.len())
}
fn parse_openai_response(response: &serde_json::Value, expected: usize) -> Result<Vec<Vec<f32>>> {
let data = response["data"]
.as_array()
.ok_or_else(|| Error::Embed("unexpected OpenAI response: missing `data` array".into()))?;
let mut results = vec![Vec::new(); expected];
for item in data {
let index = item["index"]
.as_u64()
.ok_or_else(|| Error::Embed("missing `index` in embedding object".into()))?
as usize;
let vec = parse_float_array(&item["embedding"])?;
if index < results.len() {
results[index] = vec;
}
}
Ok(results)
}
fn embed_ollama(config: &EmbedderConfig, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let base_url = config
.base_url
.as_deref()
.unwrap_or("http://localhost:11434");
let url = format!("{base_url}/api/embed");
let body = serde_json::json!({
"model": config.model,
"input": texts,
});
let response: serde_json::Value = ureq::post(&url)
.header("Content-Type", "application/json")
.send_json(body)
.map_err(|e| Error::Embed(e.to_string()))?
.into_body()
.read_json()
.map_err(|e| Error::Embed(e.to_string()))?;
response["embeddings"]
.as_array()
.ok_or_else(|| {
Error::Embed("unexpected Ollama response: missing `embeddings` array".into())
})?
.iter()
.map(parse_float_array)
.collect()
}
fn parse_float_array(value: &serde_json::Value) -> Result<Vec<f32>> {
value
.as_array()
.ok_or_else(|| Error::Embed("embedding value is not a JSON array".into()))?
.iter()
.map(|v| {
v.as_f64()
.map(|f| f as f32)
.ok_or_else(|| Error::Embed("non-numeric value in embedding vector".into()))
})
.collect()
}