use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use fastembed::{
EmbeddingModel, InitOptions, InitOptionsUserDefined, Pooling, QuantizationMode, TextEmbedding,
TokenizerFiles, UserDefinedEmbeddingModel,
};
use lunaris_core::{Embedder, LunarisError, StorageError};
use parking_lot::Mutex;
pub const FASTEMBED_GEMMA_DIM: usize = 768;
pub const FASTEMBED_GEMMA_MAX_TOKENS: usize = 2048;
pub const FASTEMBED_CACHE_DIR_ENV: &str = "LUNARIS_FASTEMBED_CACHE_DIR";
pub use crate::fastembed_exec::{
ExecutionPreference, FASTEMBED_EXECUTION_ENV, execution_from_env, parse_execution,
};
use crate::fastembed_exec::{build_execution_providers, requests_accelerator};
#[derive(Clone, Debug)]
pub struct FastembedEmbedderOpts {
pub cache_dir: Option<PathBuf>,
pub show_download_progress: bool,
pub execution: ExecutionPreference,
}
impl Default for FastembedEmbedderOpts {
fn default() -> Self {
Self {
cache_dir: Some(resolve_default_cache_dir()),
show_download_progress: false,
execution: execution_from_env(),
}
}
}
fn resolve_default_cache_dir() -> PathBuf {
if let Ok(env_dir) = std::env::var(FASTEMBED_CACHE_DIR_ENV)
&& !env_dir.is_empty()
{
return PathBuf::from(env_dir);
}
let cache_root = dirs::cache_dir().unwrap_or_else(|| PathBuf::from("."));
cache_root.join("lunaris").join("models").join("fastembed")
}
#[derive(Clone)]
pub struct FastembedEmbedder {
inner: Arc<Inner>,
}
impl std::fmt::Debug for FastembedEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FastembedEmbedder")
.field("dim", &self.inner.dim)
.field("cache_dir", &self.inner.cache_dir)
.finish()
}
}
struct Inner {
model: Mutex<TextEmbedding>,
cache_dir: PathBuf,
dim: usize,
}
impl FastembedEmbedder {
pub fn new(opts: FastembedEmbedderOpts) -> Result<Self, LunarisError> {
let cache_dir = opts.cache_dir.unwrap_or_else(resolve_default_cache_dir);
let execution = opts.execution.clone();
tracing::info!(
backend = "fastembed",
model = "EmbeddingGemma300M",
cache_dir = %cache_dir.display(),
execution = ?execution,
"fastembed embedder constructing"
);
let build = |providers_enabled: bool| -> Result<TextEmbedding, anyhow::Error> {
let mut init = InitOptions::new(EmbeddingModel::EmbeddingGemma300M)
.with_cache_dir(cache_dir.clone())
.with_show_download_progress(opts.show_download_progress);
if providers_enabled {
init = init.with_execution_providers(build_execution_providers(&execution));
}
TextEmbedding::try_new(init)
};
let model = try_with_fallback(&execution, build)?;
let resolved = execution.clone();
tracing::info!(
backend = "fastembed",
model = "EmbeddingGemma300M",
execution = ?resolved,
"fastembed embedder initialized"
);
Ok(Self {
inner: Arc::new(Inner {
model: Mutex::new(model),
cache_dir,
dim: FASTEMBED_GEMMA_DIM,
}),
})
}
pub fn from_user_defined(opts: FastembedUserDefinedOpts) -> Result<Self, LunarisError> {
if opts.onnx_file.is_empty() {
return Err(LunarisError::Storage(StorageError::Backend(
"fastembed: from_user_defined called with empty onnx_file bytes".to_string(),
)));
}
if opts.tokenizer_file.is_empty() {
return Err(LunarisError::Storage(StorageError::Backend(
"fastembed: from_user_defined called with empty tokenizer_file bytes".to_string(),
)));
}
if opts.dim == 0 {
return Err(LunarisError::Storage(StorageError::Backend(
"fastembed: from_user_defined called with dim = 0".to_string(),
)));
}
let execution = opts.execution.clone();
let dim = opts.dim;
let max_length = opts.max_length;
tracing::info!(
backend = "fastembed",
model = "user-defined",
dim,
execution = ?execution,
"fastembed user-defined embedder constructing"
);
let user_model = UserDefinedEmbeddingModel {
onnx_file: opts.onnx_file,
external_initializers: Vec::new(),
tokenizer_files: TokenizerFiles {
tokenizer_file: opts.tokenizer_file,
config_file: opts.config_file.unwrap_or_default(),
special_tokens_map_file: opts.special_tokens_map_file.unwrap_or_default(),
tokenizer_config_file: opts.tokenizer_config_file.unwrap_or_default(),
},
pooling: Some(opts.pooling.into()),
quantization: QuantizationMode::None,
output_key: None,
};
let model = try_user_defined_with_fallback(&execution, user_model, max_length)?;
let resolved = execution.clone();
tracing::info!(
backend = "fastembed",
model = "user-defined",
dim,
execution = ?resolved,
"fastembed user-defined embedder initialized"
);
Ok(Self {
inner: Arc::new(Inner { model: Mutex::new(model), cache_dir: PathBuf::new(), dim }),
})
}
}
#[derive(Clone, Debug)]
pub struct FastembedUserDefinedOpts {
pub onnx_file: Vec<u8>,
pub tokenizer_file: Vec<u8>,
pub tokenizer_config_file: Option<Vec<u8>>,
pub special_tokens_map_file: Option<Vec<u8>>,
pub config_file: Option<Vec<u8>>,
pub dim: usize,
pub pooling: PoolingMode,
pub execution: ExecutionPreference,
pub max_length: usize,
}
impl Default for FastembedUserDefinedOpts {
fn default() -> Self {
Self {
onnx_file: Vec::new(),
tokenizer_file: Vec::new(),
tokenizer_config_file: None,
special_tokens_map_file: None,
config_file: None,
dim: 0,
pooling: PoolingMode::Mean,
execution: execution_from_env(),
max_length: FASTEMBED_GEMMA_MAX_TOKENS,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum PoolingMode {
Cls,
#[default]
Mean,
}
impl From<PoolingMode> for Pooling {
fn from(m: PoolingMode) -> Self {
match m {
PoolingMode::Cls => Pooling::Cls,
PoolingMode::Mean => Pooling::Mean,
}
}
}
fn try_with_fallback<F>(
pref: &ExecutionPreference,
mut build: F,
) -> Result<TextEmbedding, LunarisError>
where
F: FnMut(bool) -> Result<TextEmbedding, anyhow::Error>,
{
let want_accelerator = requests_accelerator(pref);
match build(want_accelerator) {
Ok(m) => Ok(m),
Err(e) if want_accelerator => {
tracing::warn!(
error = %e,
requested = ?pref,
"fastembed execution provider init failed, falling back to CPU"
);
build(false).map_err(anyhow_to_lunaris)
}
Err(e) => Err(anyhow_to_lunaris(e)),
}
}
fn try_user_defined_with_fallback(
pref: &ExecutionPreference,
user_model: UserDefinedEmbeddingModel,
max_length: usize,
) -> Result<TextEmbedding, LunarisError> {
let want_accelerator = requests_accelerator(pref);
let build = |providers_enabled: bool, m: UserDefinedEmbeddingModel| {
let mut init = InitOptionsUserDefined::new().with_max_length(max_length);
if providers_enabled {
init = init.with_execution_providers(build_execution_providers(pref));
}
TextEmbedding::try_new_from_user_defined(m, init)
};
if want_accelerator {
let retry_model = user_model.clone();
match build(true, user_model) {
Ok(m) => Ok(m),
Err(e) => {
tracing::warn!(
error = %e,
requested = ?pref,
"fastembed (user-defined) execution provider init failed, falling back to CPU"
);
build(false, retry_model).map_err(anyhow_to_lunaris)
}
}
} else {
build(false, user_model).map_err(anyhow_to_lunaris)
}
}
#[async_trait]
impl Embedder for FastembedEmbedder {
fn dim(&self) -> usize {
self.inner.dim
}
async fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, LunarisError> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let owned: Vec<String> = inputs.iter().map(|s| (*s).to_string()).collect();
let inner = self.inner.clone();
let expected_dim = inner.dim;
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, LunarisError> {
let raw: Vec<Vec<f32>> = {
let mut guard = inner.model.lock();
guard.embed(owned, None).map_err(anyhow_to_lunaris)?
};
let mut out: Vec<Vec<f32>> = Vec::with_capacity(raw.len());
for row in raw.into_iter() {
if row.len() != expected_dim {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"fastembed: dim mismatch — model returned {} dims, expected {expected_dim}",
row.len()
))));
}
out.push(l2_normalize_row(row, expected_dim));
}
Ok(out)
})
.await
.map_err(|e| LunarisError::Storage(StorageError::Backend(format!("fastembed join: {e}"))))?
}
}
#[inline]
fn l2_normalize_row(row: Vec<f32>, expected_dim: usize) -> Vec<f32> {
let l2 = row.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
if l2 > f64::EPSILON {
let mut out: Vec<f32> = row;
for v in out.iter_mut() {
*v = (*v as f64 / l2) as f32;
}
debug_assert_eq!(out.len(), expected_dim);
out
} else {
row
}
}
#[inline]
fn anyhow_to_lunaris(e: anyhow::Error) -> LunarisError {
LunarisError::Storage(StorageError::Backend(format!("fastembed: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn opts_default_resolves_to_cache_subdir() {
let env_override = std::env::var(FASTEMBED_CACHE_DIR_ENV).ok();
if env_override.is_some() {
return;
}
let opts = FastembedEmbedderOpts::default();
let path = opts.cache_dir.expect("default sets a cache_dir");
let s = path.to_string_lossy().to_string();
assert!(
s.contains("lunaris") && s.contains("models") && s.contains("fastembed"),
"default cache_dir should include the v0 cache layout, got: {s}"
);
}
#[test]
fn dim_constant_is_768() {
assert_eq!(FASTEMBED_GEMMA_DIM, 768);
}
#[test]
fn l2_normalize_unit_vector() {
let mut row = vec![0.0_f32; FASTEMBED_GEMMA_DIM];
row[0] = 3.0;
row[1] = 4.0; let out = l2_normalize_row(row, FASTEMBED_GEMMA_DIM);
let l2 = out.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
assert!((l2 - 1.0).abs() < 1e-6, "expected unit norm, got {l2}");
assert!((out[0] - 0.6).abs() < 1e-6);
assert!((out[1] - 0.8).abs() < 1e-6);
}
#[test]
fn l2_normalize_degenerate_row_returned_as_is() {
let row = vec![0.0_f32; FASTEMBED_GEMMA_DIM];
let out = l2_normalize_row(row, FASTEMBED_GEMMA_DIM);
assert_eq!(out.len(), FASTEMBED_GEMMA_DIM);
assert!(out.iter().all(|&x| x == 0.0));
}
#[test]
fn from_user_defined_empty_onnx_returns_actionable_error() {
let opts = FastembedUserDefinedOpts {
onnx_file: Vec::new(),
tokenizer_file: vec![0u8; 4],
dim: 768,
..Default::default()
};
let err = FastembedEmbedder::from_user_defined(opts).expect_err("empty onnx");
let msg = format!("{err}");
assert!(
msg.contains("fastembed") && msg.contains("onnx_file"),
"unexpected error message: {msg}"
);
}
#[test]
fn from_user_defined_empty_tokenizer_returns_actionable_error() {
let opts = FastembedUserDefinedOpts {
onnx_file: vec![0u8; 4],
tokenizer_file: Vec::new(),
dim: 768,
..Default::default()
};
let err = FastembedEmbedder::from_user_defined(opts).expect_err("empty tokenizer");
let msg = format!("{err}");
assert!(
msg.contains("fastembed") && msg.contains("tokenizer_file"),
"unexpected error message: {msg}"
);
}
#[test]
fn from_user_defined_zero_dim_returns_actionable_error() {
let opts = FastembedUserDefinedOpts {
onnx_file: vec![0u8; 4],
tokenizer_file: vec![0u8; 4],
dim: 0,
..Default::default()
};
let err = FastembedEmbedder::from_user_defined(opts).expect_err("zero dim");
let msg = format!("{err}");
assert!(msg.contains("fastembed") && msg.contains("dim"), "unexpected: {msg}");
}
#[test]
fn from_user_defined_bad_onnx_bytes_surfaces_fastembed_error() {
let opts = FastembedUserDefinedOpts {
onnx_file: b"not-a-real-onnx-graph".to_vec(),
tokenizer_file: b"not-a-real-tokenizer".to_vec(),
dim: 768,
..Default::default()
};
let err = FastembedEmbedder::from_user_defined(opts).expect_err("bad bytes");
let msg = format!("{err}");
assert!(msg.contains("fastembed"), "expected fastembed-prefixed error, got: {msg}");
}
#[test]
fn pooling_mode_maps_to_fastembed_pooling() {
let cls: Pooling = PoolingMode::Cls.into();
assert!(matches!(cls, Pooling::Cls));
let mean: Pooling = PoolingMode::Mean.into();
assert!(matches!(mean, Pooling::Mean));
}
}
#[cfg(all(test, feature = "embedder-it"))]
mod live_tests {
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fastembed_loads_real_model_and_embeds_one_batch() {
let embedder = FastembedEmbedder::new(FastembedEmbedderOpts::default())
.expect("real model load — auto-download to ~/.cache/lunaris/models/fastembed/");
assert_eq!(embedder.dim(), FASTEMBED_GEMMA_DIM);
let inputs: [&str; 2] = ["hello world", "lunaris memory engine"];
let vecs = embedder.embed_batch(&inputs).await.expect("embed_batch");
assert_eq!(vecs.len(), 2);
for v in &vecs {
assert_eq!(v.len(), FASTEMBED_GEMMA_DIM);
let l2 = v.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
assert!((l2 - 1.0).abs() < 1e-3, "L2 norm = {l2}, expected ~ 1.0");
}
}
}