#![cfg(feature = "embeddings")]
use anyhow::{Context, Result};
use std::fs;
use std::path::{Path, PathBuf};
use super::Embedder;
#[derive(Debug, Clone, Copy)]
pub enum PretrainedModel {
AllMiniLML6V2,
AllMpnetBaseV2,
MultilingualE5Small,
AllMiniLML12V2,
}
impl PretrainedModel {
pub fn model_id(&self) -> &'static str {
match self {
Self::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMpnetBaseV2 => "sentence-transformers/all-mpnet-base-v2",
Self::MultilingualE5Small => "intfloat/multilingual-e5-small",
Self::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
}
}
pub fn dimension(&self) -> usize {
match self {
Self::AllMiniLML6V2 => 384,
Self::AllMpnetBaseV2 => 768,
Self::MultilingualE5Small => 384,
Self::AllMiniLML12V2 => 384,
}
}
pub fn size_mb(&self) -> usize {
match self {
Self::AllMiniLML6V2 => 80,
Self::AllMpnetBaseV2 => 420,
Self::MultilingualE5Small => 118,
Self::AllMiniLML12V2 => 120,
}
}
pub fn from_name(name: &str) -> Result<Self> {
match name.to_lowercase().as_str() {
"all-minilm-l6-v2" | "minilm" => Ok(Self::AllMiniLML6V2),
"all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
"multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
"all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLML12V2),
_ => anyhow::bail!("Unknown model: {}", name),
}
}
}
pub struct AutoEmbedder {
embedder: Embedder,
model_name: String,
cache_dir: PathBuf,
}
impl AutoEmbedder {
pub fn from_pretrained(model_name: &str) -> Result<Self> {
let cache_dir = Self::get_cache_dir()?;
Self::from_pretrained_with_cache(model_name, cache_dir)
}
pub fn from_pretrained_with_cache(
model_name: &str,
cache_dir: impl AsRef<Path>,
) -> Result<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
let model = PretrainedModel::from_name(model_name)?;
let model_dir = Self::ensure_model_cached(&cache_dir, model)?;
let model_path = model_dir.join("model.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
let embedder =
Embedder::new(model_path, tokenizer_path).context("Failed to load embedding model")?;
Ok(Self {
embedder,
model_name: model_name.to_string(),
cache_dir,
})
}
pub fn encode(&self, text: &str) -> Result<Vec<f32>> {
self.embedder.embed(text)
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embedder.embed_batch(texts)
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
fn get_cache_dir() -> Result<PathBuf> {
let home = directories::UserDirs::new().context("Failed to get user home directory")?;
let cache = home.home_dir().join(".vecstore").join("models");
fs::create_dir_all(&cache).context("Failed to create cache directory")?;
Ok(cache)
}
fn ensure_model_cached(cache_dir: &Path, model: PretrainedModel) -> Result<PathBuf> {
let model_dir = cache_dir.join(model.model_id().replace('/', "_"));
if model_dir.exists()
&& model_dir.join("model.onnx").exists()
&& model_dir.join("tokenizer.json").exists()
{
println!("Using cached model: {}", model.model_id());
return Ok(model_dir);
}
println!(
"Downloading model: {} (~{}MB)...",
model.model_id(),
model.size_mb()
);
fs::create_dir_all(&model_dir).context("Failed to create model directory")?;
Self::download_model_files(&model_dir, model)?;
println!("Model downloaded and cached successfully!");
Ok(model_dir)
}
fn download_model_files(model_dir: &Path, model: PretrainedModel) -> Result<()> {
let base_url = format!("https://huggingface.co/{}/resolve/main", model.model_id());
let files = vec![
("model.onnx", "model.onnx"),
("tokenizer.json", "tokenizer.json"),
];
for (remote_name, local_name) in files {
let url = format!("{}/{}", base_url, remote_name);
let dest = model_dir.join(local_name);
println!(" Downloading {}...", remote_name);
let response = ureq::get(&url)
.call()
.with_context(|| format!("Failed to download {}", url))?;
let mut reader = response.into_reader();
let mut file = fs::File::create(&dest)
.with_context(|| format!("Failed to create file: {:?}", dest))?;
std::io::copy(&mut reader, &mut file).context("Failed to write downloaded file")?;
println!(" ✓ {} downloaded", remote_name);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_parsing() {
assert!(PretrainedModel::from_name("all-MiniLM-L6-v2").is_ok());
assert!(PretrainedModel::from_name("minilm").is_ok());
assert!(PretrainedModel::from_name("mpnet").is_ok());
assert!(PretrainedModel::from_name("unknown").is_err());
}
#[test]
fn test_model_metadata() {
let model = PretrainedModel::AllMiniLML6V2;
assert_eq!(model.dimension(), 384);
assert_eq!(model.size_mb(), 80);
assert!(model.model_id().contains("MiniLM"));
}
#[test]
fn test_cache_dir_creation() {
let cache = AutoEmbedder::get_cache_dir().unwrap();
assert!(cache.exists());
assert!(cache.ends_with(".vecstore/models"));
}
#[test]
#[ignore]
fn test_auto_download_and_encode() {
let embedder = AutoEmbedder::from_pretrained("all-MiniLM-L6-v2").unwrap();
let embedding = embedder.encode("Hello world").unwrap();
assert_eq!(embedding.len(), 384);
let batch = embedder.encode_batch(&["test1", "test2"]).unwrap();
assert_eq!(batch.len(), 2);
assert_eq!(batch[0].len(), 384);
}
}