use std::path::PathBuf;
struct ModelSpec {
repo_id: &'static str,
model_file: &'static str,
tokenizer_file: &'static str,
expected_dimensions: usize,
display_name: &'static str,
}
const MINILM_L6_V2: ModelSpec = ModelSpec {
repo_id: "sentence-transformers/all-MiniLM-L6-v2",
model_file: "onnx/model.onnx",
tokenizer_file: "tokenizer.json",
expected_dimensions: 384,
display_name: "all-MiniLM-L6-v2",
};
const MINILM_L12_V2: ModelSpec = ModelSpec {
repo_id: "sentence-transformers/all-MiniLM-L12-v2",
model_file: "onnx/model.onnx",
tokenizer_file: "tokenizer.json",
expected_dimensions: 384,
display_name: "all-MiniLM-L12-v2",
};
const BGE_SMALL_EN_V15: ModelSpec = ModelSpec {
repo_id: "BAAI/bge-small-en-v1.5",
model_file: "onnx/model.onnx",
tokenizer_file: "tokenizer.json",
expected_dimensions: 384,
display_name: "bge-small-en-v1.5",
};
#[derive(Debug, Clone)]
pub enum EmbeddingModelConfig {
MiniLmL6v2,
MiniLmL12v2,
BgeSmallEnV15,
HuggingFace {
repo_id: String,
model_file: String,
tokenizer_file: String,
},
Local {
model_path: PathBuf,
tokenizer_path: PathBuf,
},
}
impl EmbeddingModelConfig {
#[must_use]
pub fn expected_dimensions(&self) -> Option<usize> {
self.model_spec().map(|s| s.expected_dimensions)
}
#[must_use]
pub fn display_name(&self) -> String {
match self {
Self::MiniLmL6v2 => MINILM_L6_V2.display_name.to_string(),
Self::MiniLmL12v2 => MINILM_L12_V2.display_name.to_string(),
Self::BgeSmallEnV15 => BGE_SMALL_EN_V15.display_name.to_string(),
Self::HuggingFace { repo_id, .. } => repo_id.clone(),
Self::Local {
model_path: path, ..
} => path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("custom-model")
.to_string(),
}
}
fn model_spec(&self) -> Option<&'static ModelSpec> {
match self {
Self::MiniLmL6v2 => Some(&MINILM_L6_V2),
Self::MiniLmL12v2 => Some(&MINILM_L12_V2),
Self::BgeSmallEnV15 => Some(&BGE_SMALL_EN_V15),
_ => None,
}
}
pub(crate) fn resolve_info(&self) -> ResolveInfo<'_> {
match self {
Self::MiniLmL6v2 => {
let spec = &MINILM_L6_V2;
ResolveInfo::Hub {
repo_id: spec.repo_id,
model_file: spec.model_file,
tokenizer_file: spec.tokenizer_file,
}
}
Self::MiniLmL12v2 => {
let spec = &MINILM_L12_V2;
ResolveInfo::Hub {
repo_id: spec.repo_id,
model_file: spec.model_file,
tokenizer_file: spec.tokenizer_file,
}
}
Self::BgeSmallEnV15 => {
let spec = &BGE_SMALL_EN_V15;
ResolveInfo::Hub {
repo_id: spec.repo_id,
model_file: spec.model_file,
tokenizer_file: spec.tokenizer_file,
}
}
Self::HuggingFace {
repo_id,
model_file,
tokenizer_file,
} => ResolveInfo::Hub {
repo_id,
model_file,
tokenizer_file,
},
Self::Local {
model_path,
tokenizer_path,
} => ResolveInfo::Local {
model_path,
tokenizer_path,
},
}
}
}
pub(crate) enum ResolveInfo<'a> {
Hub {
repo_id: &'a str,
model_file: &'a str,
tokenizer_file: &'a str,
},
Local {
model_path: &'a PathBuf,
tokenizer_path: &'a PathBuf,
},
}
#[derive(Debug, Clone)]
pub struct EmbeddingOptions {
pub batch_size: usize,
pub intra_threads: usize,
pub inter_threads: usize,
}
impl Default for EmbeddingOptions {
fn default() -> Self {
Self {
batch_size: 32,
intra_threads: 1,
inter_threads: 1,
}
}
}
impl EmbeddingOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_intra_threads(mut self, threads: usize) -> Self {
self.intra_threads = threads;
self
}
#[must_use]
pub fn with_inter_threads(mut self, threads: usize) -> Self {
self.inter_threads = threads;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn preset_expected_dimensions() {
assert_eq!(
EmbeddingModelConfig::MiniLmL6v2.expected_dimensions(),
Some(384)
);
assert_eq!(
EmbeddingModelConfig::MiniLmL12v2.expected_dimensions(),
Some(384)
);
assert_eq!(
EmbeddingModelConfig::BgeSmallEnV15.expected_dimensions(),
Some(384)
);
}
#[test]
fn custom_has_no_expected_dimensions() {
let local = EmbeddingModelConfig::Local {
model_path: "model.onnx".into(),
tokenizer_path: "tokenizer.json".into(),
};
assert_eq!(local.expected_dimensions(), None);
let hf = EmbeddingModelConfig::HuggingFace {
repo_id: "org/model".into(),
model_file: "model.onnx".into(),
tokenizer_file: "tokenizer.json".into(),
};
assert_eq!(hf.expected_dimensions(), None);
}
#[test]
fn preset_display_names() {
assert_eq!(
EmbeddingModelConfig::MiniLmL6v2.display_name(),
"all-MiniLM-L6-v2"
);
assert_eq!(
EmbeddingModelConfig::MiniLmL12v2.display_name(),
"all-MiniLM-L12-v2"
);
assert_eq!(
EmbeddingModelConfig::BgeSmallEnV15.display_name(),
"bge-small-en-v1.5"
);
}
#[test]
fn huggingface_display_name_is_repo_id() {
let config = EmbeddingModelConfig::HuggingFace {
repo_id: "org/my-model".into(),
model_file: "model.onnx".into(),
tokenizer_file: "tokenizer.json".into(),
};
assert_eq!(config.display_name(), "org/my-model");
}
#[test]
fn local_display_name_from_file_stem() {
let config = EmbeddingModelConfig::Local {
model_path: "/path/to/my-model.onnx".into(),
tokenizer_path: "/path/to/tokenizer.json".into(),
};
assert_eq!(config.display_name(), "my-model");
}
#[test]
fn options_default_values() {
let opts = EmbeddingOptions::default();
assert_eq!(opts.batch_size, 32);
assert_eq!(opts.intra_threads, 1);
assert_eq!(opts.inter_threads, 1);
}
#[test]
fn options_builder_chaining() {
let opts = EmbeddingOptions::new()
.with_batch_size(64)
.with_intra_threads(4)
.with_inter_threads(2);
assert_eq!(opts.batch_size, 64);
assert_eq!(opts.intra_threads, 4);
assert_eq!(opts.inter_threads, 2);
}
}