#[cfg(feature = "tract")]
mod tract_backend;
#[cfg(feature = "candle")]
mod candle_backend;
use orbok_core::{OrbokError, OrbokResult};
use orbok_models::{EmbeddingModel, EmbeddingModelConfig, InferenceBackend, MockEmbeddingModel};
pub const RECOMMENDED_MODEL_NAME: &str = "multilingual-e5-small";
pub const RECOMMENDED_MODEL_VERSION: &str = "v1";
pub const RECOMMENDED_MODEL_DIMENSION: u32 = 384;
pub const RECOMMENDED_MODEL_MAX_SEQ_LEN: u32 = 512;
pub const RECOMMENDED_HF_MODEL_ID: &str = "intfloat/multilingual-e5-small";
pub const RECOMMENDED_ONNX_FILE: &str = "onnx/model.onnx";
pub fn create_embedding_model(
config: &EmbeddingModelConfig,
) -> OrbokResult<Box<dyn EmbeddingModel>> {
match &config.backend {
InferenceBackend::Mock => Ok(Box::new(MockEmbeddingModel)),
InferenceBackend::OnnxRuntime => {
#[cfg(feature = "tract")]
{
tract_backend::create(config)
}
#[cfg(not(feature = "tract"))]
{
Err(OrbokError::Cache(
"ONNX inference is not compiled in. \
Rebuild with: --features orbok-embed/tract"
.into(),
))
}
}
InferenceBackend::CandleCpu | InferenceBackend::CandleCuda => {
#[cfg(feature = "candle")]
{
candle_backend::create(config)
}
#[cfg(not(feature = "candle"))]
{
Err(OrbokError::Cache(
"Candle inference is not compiled in. \
Rebuild with: --features orbok-embed/candle"
.into(),
))
}
}
}
}
pub fn recommended_config(weights_path: impl Into<String>) -> EmbeddingModelConfig {
EmbeddingModelConfig {
weights_path: weights_path.into(),
tokenizer_path: None,
dimension: RECOMMENDED_MODEL_DIMENSION,
max_seq_len: RECOMMENDED_MODEL_MAX_SEQ_LEN,
backend: InferenceBackend::OnnxRuntime,
model_name: RECOMMENDED_MODEL_NAME.to_string(),
model_version: RECOMMENDED_MODEL_VERSION.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_backend_always_works() {
let config = EmbeddingModelConfig {
weights_path: String::new(),
tokenizer_path: None,
dimension: 8,
max_seq_len: 512,
backend: InferenceBackend::Mock,
model_name: "mock".into(),
model_version: "v1".into(),
};
let model = create_embedding_model(&config).unwrap();
let vecs = model.embed_batch(&["hello world"]).unwrap();
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].len(), model.dimension() as usize);
}
#[cfg(not(feature = "tract"))]
#[test]
fn onnx_backend_without_feature_returns_error() {
let config = EmbeddingModelConfig {
weights_path: "/nonexistent/model.onnx".into(),
tokenizer_path: None,
dimension: 384,
max_seq_len: 512,
backend: InferenceBackend::OnnxRuntime,
model_name: "test".into(),
model_version: "v1".into(),
};
match create_embedding_model(&config) {
Err(err) => {
let msg = err.to_string();
assert!(
msg.contains("tract") || msg.contains("compiled"),
"error should mention feature flag"
);
}
Ok(_) => panic!("ONNX without tract feature should fail"),
}
}
#[test]
fn recommended_config_correct_defaults() {
let cfg = recommended_config("/models/multilingual-e5-small.onnx");
assert_eq!(cfg.dimension, RECOMMENDED_MODEL_DIMENSION);
assert_eq!(cfg.model_name, RECOMMENDED_MODEL_NAME);
assert_eq!(cfg.max_seq_len, 512);
}
#[test]
fn storage_impact_per_dimension() {
let bytes_384 = 384 * 4; let bytes_768 = 768 * 4; assert_eq!(bytes_384, 1536);
assert_eq!(bytes_768, 3072);
assert!(bytes_384 < bytes_768);
}
}