use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use lunaris_core::{Embedder, LunarisError, StorageError};
use tokenizers::Tokenizer;
pub const EMBEDDING_GEMMA_DIM: usize = 768;
pub const EMBEDDING_GEMMA_MAX_TOKENS: usize = 2048;
#[derive(Clone, Debug)]
pub struct CandleEmbeddingGemmaOpts {
pub model_path: Option<PathBuf>,
pub device: Device,
}
impl Default for CandleEmbeddingGemmaOpts {
fn default() -> Self {
let cache_root = dirs::cache_dir().unwrap_or_else(|| PathBuf::from("."));
let default_model_path =
cache_root.join("lunaris").join("models").join("embedding-gemma-300m");
Self { model_path: Some(default_model_path), device: Device::Cpu }
}
}
#[derive(Clone)]
pub struct CandleEmbeddingGemma {
inner: Arc<CandleInner>,
}
impl std::fmt::Debug for CandleEmbeddingGemma {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CandleEmbeddingGemma")
.field("dim", &EMBEDDING_GEMMA_DIM)
.field("device", &format_args!("{:?}", self.inner.device))
.field("hidden_size", &self.inner.hidden_size)
.finish()
}
}
struct CandleInner {
tokenizer: Tokenizer,
embed_weight: Tensor,
device: Device,
hidden_size: usize,
}
impl CandleEmbeddingGemma {
pub async fn new(opts: CandleEmbeddingGemmaOpts) -> Result<Self, LunarisError> {
let model_path = opts
.model_path
.clone()
.unwrap_or_else(|| CandleEmbeddingGemmaOpts::default().model_path.unwrap());
let tokenizer_path = model_path.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights missing at {} — run `huggingface-cli download google/embeddinggemma-300m --local-dir {}`",
model_path.display(),
model_path.display()
))));
}
let safetensors_path = model_path.join("model.safetensors");
if !safetensors_path.exists() {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights missing at {} (no model.safetensors) — run `huggingface-cli download google/embeddinggemma-300m --local-dir {}`",
safetensors_path.display(),
model_path.display()
))));
}
let device = opts.device.clone();
let load = tokio::task::spawn_blocking(move || -> Result<CandleInner, LunarisError> {
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma tokenizer: {e}"
)))
})?;
let bytes = std::fs::read(&safetensors_path).map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights: read {} ({e})",
safetensors_path.display()
)))
})?;
let vb =
VarBuilder::from_buffered_safetensors(bytes, DType::F32, &device).map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights: {e}"
)))
})?;
let embed_weight = vb
.pp("model")
.pp("embed_tokens")
.get_unchecked("weight")
.or_else(|_| vb.pp("embed_tokens").get_unchecked("weight"))
.map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights: embed_tokens.weight not found \
(tried model.embed_tokens.weight and embed_tokens.weight): {e}"
)))
})?;
let dims = embed_weight.dims();
if dims.len() != 2 {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights: model.embed_tokens.weight has rank {} (expected 2)",
dims.len()
))));
}
let hidden_size = dims[1];
if hidden_size != EMBEDDING_GEMMA_DIM {
return Err(LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma weights: hidden_size {hidden_size} != {EMBEDDING_GEMMA_DIM}"
))));
}
Ok(CandleInner { tokenizer, embed_weight, device, hidden_size })
})
.await
.map_err(|e| LunarisError::Storage(StorageError::Backend(format!("candle join: {e}"))))??;
Ok(Self { inner: Arc::new(load) })
}
}
#[async_trait]
impl Embedder for CandleEmbeddingGemma {
fn dim(&self) -> usize {
EMBEDDING_GEMMA_DIM
}
async fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, LunarisError> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let owned_inputs: Vec<String> = inputs.iter().map(|s| (*s).to_string()).collect();
let inner = self.inner.clone();
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, LunarisError> {
let mut out: Vec<Vec<f32>> = Vec::with_capacity(owned_inputs.len());
for text in owned_inputs.iter() {
let encoding = inner.tokenizer.encode(text.as_str(), true).map_err(|e| {
LunarisError::Storage(StorageError::Backend(format!(
"embedding-gemma tokenize: {e}"
)))
})?;
let mut ids = encoding.get_ids().to_vec();
if ids.len() > EMBEDDING_GEMMA_MAX_TOKENS {
ids.truncate(EMBEDDING_GEMMA_MAX_TOKENS);
}
if ids.is_empty() {
out.push(vec![0.0_f32; EMBEDDING_GEMMA_DIM]);
continue;
}
let id_tensor = Tensor::from_vec(
ids,
(encoding.get_ids().len().min(EMBEDDING_GEMMA_MAX_TOKENS),),
&inner.device,
)
.map_err(candle_err)?;
let token_embeds =
inner.embed_weight.index_select(&id_tensor, 0).map_err(candle_err)?;
let mean = token_embeds.mean(0).map_err(candle_err)?;
let mean_vec: Vec<f32> = mean.to_vec1::<f32>().map_err(candle_err)?;
let l2 = mean_vec.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
let normalised: Vec<f32> = if l2 > f64::EPSILON {
mean_vec.iter().map(|x| (*x as f64 / l2) as f32).collect()
} else {
mean_vec
};
debug_assert_eq!(normalised.len(), inner.hidden_size);
out.push(normalised);
}
Ok(out)
})
.await
.map_err(|e| LunarisError::Storage(StorageError::Backend(format!("candle join: {e}"))))?
}
}
#[inline]
fn candle_err(e: candle_core::Error) -> LunarisError {
LunarisError::Storage(StorageError::Backend(format!("candle: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn opts_default_resolves_to_cache_subdir() {
let opts = CandleEmbeddingGemmaOpts::default();
let path = opts.model_path.expect("default sets a path");
let s = path.to_string_lossy().to_string();
assert!(
s.contains("lunaris") && s.contains("models") && s.contains("embedding-gemma-300m"),
"default model_path should include the v0 cache layout, got: {s}"
);
}
#[test]
fn dim_constant_is_768() {
assert_eq!(EMBEDDING_GEMMA_DIM, 768);
}
}