use serde::Deserialize;
fn default_ids_name() -> String {
"input_ids".to_string()
}
fn default_mask_name() -> String {
"attention_mask".to_string()
}
fn default_output_name() -> String {
"last_hidden_state".to_string()
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct InputNames {
#[serde(default = "default_ids_name")]
pub ids: String,
#[serde(default = "default_mask_name")]
pub mask: String,
#[serde(default)]
pub token_types: Option<String>,
}
impl Default for InputNames {
fn default() -> Self {
Self::bert()
}
}
impl InputNames {
pub fn bert() -> Self {
Self {
ids: default_ids_name(),
mask: default_mask_name(),
token_types: Some("token_type_ids".to_string()),
}
}
pub fn bert_no_token_types() -> Self {
Self {
ids: default_ids_name(),
mask: default_mask_name(),
token_types: None,
}
}
}
#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum PoolingStrategy {
#[default]
Mean,
Cls,
LastToken,
}
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub name: String,
pub repo: String,
pub onnx_path: String,
pub tokenizer_path: String,
pub dim: usize,
pub max_seq_length: usize,
pub query_prefix: String,
pub doc_prefix: String,
pub input_names: InputNames,
pub output_name: String,
pub pooling: PoolingStrategy,
}
pub const DEFAULT_MODEL_REPO: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_DIM: usize = 1024;
impl ModelConfig {
pub fn default_model() -> Self {
Self::bge_large()
}
pub fn e5_base() -> Self {
Self {
name: "e5-base".to_string(),
repo: "intfloat/e5-base-v2".to_string(),
onnx_path: "onnx/model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 768,
max_seq_length: 512,
query_prefix: "query: ".to_string(),
doc_prefix: "passage: ".to_string(),
input_names: InputNames::bert(),
output_name: default_output_name(),
pooling: PoolingStrategy::Mean,
}
}
pub fn v9_200k() -> Self {
Self {
name: "v9-200k".to_string(),
repo: "jamie8johnson/e5-base-v2-code-search".to_string(),
onnx_path: "model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 768,
max_seq_length: 512,
query_prefix: "query: ".to_string(),
doc_prefix: "passage: ".to_string(),
input_names: InputNames::bert(),
output_name: default_output_name(),
pooling: PoolingStrategy::Mean,
}
}
pub fn bge_large() -> Self {
Self {
name: "bge-large".to_string(),
repo: "BAAI/bge-large-en-v1.5".to_string(),
onnx_path: "onnx/model.onnx".to_string(),
tokenizer_path: "tokenizer.json".to_string(),
dim: 1024,
max_seq_length: 512,
query_prefix: "Represent this sentence for searching relevant passages: ".to_string(),
doc_prefix: String::new(),
input_names: InputNames::bert(),
output_name: default_output_name(),
pooling: PoolingStrategy::Mean,
}
}
pub const PRESET_NAMES: &'static [&'static str] = &["e5-base", "v9-200k", "bge-large"];
pub fn from_preset(name: &str) -> Option<Self> {
match name {
"e5-base" | "intfloat/e5-base-v2" => Some(Self::e5_base()),
"v9-200k" | "jamie8johnson/e5-base-v2-code-search" => Some(Self::v9_200k()),
"bge-large" | "BAAI/bge-large-en-v1.5" => Some(Self::bge_large()),
_ => None,
}
}
pub fn resolve(cli_model: Option<&str>, config_embedding: Option<&EmbeddingConfig>) -> Self {
let _span = tracing::info_span!("resolve_model_config").entered();
if let Some(name) = cli_model {
if let Some(cfg) = Self::from_preset(name) {
tracing::info!(model = %cfg.name, source = "cli", "Resolved model config");
return cfg;
}
tracing::warn!(
model = name,
"Unknown model from CLI flag, falling back to default"
);
return Self::default_model();
}
if let Ok(env_val) = std::env::var("CQS_EMBEDDING_MODEL") {
if !env_val.is_empty() {
if let Some(cfg) = Self::from_preset(&env_val) {
tracing::info!(model = %cfg.name, source = "env", "Resolved model config");
return cfg;
}
tracing::warn!(
model = %env_val,
"Unknown CQS_EMBEDDING_MODEL env var value, falling back to default"
);
return Self::default_model();
}
}
if let Some(embedding_cfg) = config_embedding {
if let Some(cfg) = Self::from_preset(&embedding_cfg.model) {
tracing::info!(model = %cfg.name, source = "config", "Resolved model config");
return cfg;
}
let has_repo = embedding_cfg.repo.is_some();
let has_dim = embedding_cfg.dim.is_some();
if has_repo && has_dim {
let dim = embedding_cfg.dim.expect("guarded by has_dim");
if dim == 0 {
tracing::warn!(model = %embedding_cfg.model, "Custom model has dim=0, falling back to default");
return Self::default_model();
}
let repo = embedding_cfg.repo.as_ref().expect("guarded by has_repo");
if !repo.contains('/')
|| repo.contains('"')
|| repo.contains('\n')
|| repo.contains('\\')
|| repo.contains(' ')
|| repo.starts_with('/')
|| repo.contains("..")
{
tracing::warn!(
%repo,
"Custom model repo contains invalid characters, falling back to default"
);
return Self::default_model();
}
let onnx_path = embedding_cfg
.onnx_path
.clone()
.unwrap_or_else(|| "onnx/model.onnx".to_string());
let tokenizer_path = embedding_cfg
.tokenizer_path
.clone()
.unwrap_or_else(|| "tokenizer.json".to_string());
for (label, path) in [
("onnx_path", &onnx_path),
("tokenizer_path", &tokenizer_path),
] {
if path.contains("..") || std::path::Path::new(path).is_absolute() {
tracing::warn!(%label, %path, "Custom model path contains traversal or is absolute, falling back to default");
return Self::default_model();
}
}
let input_names = embedding_cfg
.input_names
.clone()
.unwrap_or_else(InputNames::bert);
let output_name = embedding_cfg
.output_name
.clone()
.unwrap_or_else(default_output_name);
let pooling = embedding_cfg.pooling.unwrap_or(PoolingStrategy::Mean);
let cfg = Self {
name: embedding_cfg.model.clone(),
repo: embedding_cfg.repo.clone().expect("guarded by has_repo"),
onnx_path,
tokenizer_path,
dim,
max_seq_length: embedding_cfg.max_seq_length.unwrap_or(512),
query_prefix: embedding_cfg.query_prefix.clone().unwrap_or_default(),
doc_prefix: embedding_cfg.doc_prefix.clone().unwrap_or_default(),
input_names,
output_name,
pooling,
};
tracing::info!(model = %cfg.name, source = "config-custom", "Resolved custom model config");
return cfg;
}
tracing::warn!(
model = %embedding_cfg.model,
has_repo,
has_dim,
"Unknown model in config and missing required custom fields (repo, dim), falling back to default"
);
}
tracing::info!(
model = "bge-large",
source = "default",
"Resolved model config"
);
Self::default_model()
}
pub fn apply_env_overrides(mut self) -> Self {
if let Ok(val) = std::env::var("CQS_MAX_SEQ_LENGTH") {
if let Ok(seq) = val.parse::<usize>() {
tracing::info!(max_seq_length = seq, "CQS_MAX_SEQ_LENGTH override active");
self.max_seq_length = seq;
}
}
if let Ok(val) = std::env::var("CQS_EMBEDDING_DIM") {
if let Ok(dim) = val.parse::<usize>() {
if dim > 0 {
tracing::info!(dim, "CQS_EMBEDDING_DIM override active");
self.dim = dim;
}
}
}
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model_name")]
pub model: String,
pub repo: Option<String>,
pub onnx_path: Option<String>,
pub tokenizer_path: Option<String>,
pub dim: Option<usize>,
pub max_seq_length: Option<usize>,
pub query_prefix: Option<String>,
pub doc_prefix: Option<String>,
#[serde(default)]
pub input_names: Option<InputNames>,
#[serde(default)]
pub output_name: Option<String>,
#[serde(default)]
pub pooling: Option<PoolingStrategy>,
}
fn default_model_name() -> String {
ModelConfig::default_model().name
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: default_model_name(),
repo: None,
onnx_path: None,
tokenizer_path: None,
dim: None,
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModelInfo {
pub name: String,
pub dimensions: usize,
pub version: String,
}
impl ModelInfo {
pub fn new(name: impl Into<String>, dim: usize) -> Self {
ModelInfo {
name: name.into(),
dimensions: dim,
version: "2".to_string(),
}
}
pub fn with_dim(dim: usize) -> Self {
Self::new(DEFAULT_MODEL_REPO, dim)
}
}
impl Default for ModelInfo {
fn default() -> Self {
ModelInfo {
name: DEFAULT_MODEL_REPO.to_string(),
dimensions: DEFAULT_DIM,
version: "2".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_e5_base_preset() {
let cfg = ModelConfig::e5_base();
assert_eq!(cfg.name, "e5-base");
assert_eq!(cfg.repo, "intfloat/e5-base-v2");
assert_eq!(cfg.dim, 768);
assert_eq!(cfg.max_seq_length, 512);
assert_eq!(cfg.query_prefix, "query: ");
assert_eq!(cfg.doc_prefix, "passage: ");
assert_eq!(cfg.onnx_path, "onnx/model.onnx");
assert_eq!(cfg.tokenizer_path, "tokenizer.json");
assert_eq!(cfg.input_names, InputNames::bert());
assert_eq!(cfg.output_name, "last_hidden_state");
assert_eq!(cfg.pooling, PoolingStrategy::Mean);
}
#[test]
fn test_bge_large_preset() {
let cfg = ModelConfig::bge_large();
assert_eq!(cfg.name, "bge-large");
assert_eq!(cfg.repo, "BAAI/bge-large-en-v1.5");
assert_eq!(cfg.dim, 1024);
assert_eq!(cfg.max_seq_length, 512);
assert_eq!(
cfg.query_prefix,
"Represent this sentence for searching relevant passages: "
);
assert_eq!(cfg.doc_prefix, "");
assert_eq!(cfg.input_names, InputNames::bert());
assert_eq!(cfg.output_name, "last_hidden_state");
assert_eq!(cfg.pooling, PoolingStrategy::Mean);
}
#[test]
fn test_v9_200k_preset() {
let cfg = ModelConfig::v9_200k();
assert_eq!(cfg.name, "v9-200k");
assert_eq!(cfg.repo, "jamie8johnson/e5-base-v2-code-search");
assert_eq!(cfg.dim, 768);
assert_eq!(cfg.onnx_path, "model.onnx");
assert_eq!(cfg.query_prefix, "query: ");
assert_eq!(cfg.doc_prefix, "passage: ");
assert_eq!(cfg.input_names, InputNames::bert());
assert_eq!(cfg.output_name, "last_hidden_state");
assert_eq!(cfg.pooling, PoolingStrategy::Mean);
}
#[test]
fn input_names_bert_defaults() {
let n = InputNames::bert();
assert_eq!(n.ids, "input_ids");
assert_eq!(n.mask, "attention_mask");
assert_eq!(n.token_types.as_deref(), Some("token_type_ids"));
}
#[test]
fn input_names_no_token_types() {
let n = InputNames::bert_no_token_types();
assert_eq!(n.ids, "input_ids");
assert_eq!(n.mask, "attention_mask");
assert!(
n.token_types.is_none(),
"bert_no_token_types should drop segment embeddings"
);
}
#[test]
fn input_names_default_matches_bert() {
assert_eq!(InputNames::default(), InputNames::bert());
}
#[test]
fn input_names_serde_empty_fills_defaults() {
let parsed: InputNames = serde_json::from_str("{}").unwrap();
assert_eq!(parsed.ids, "input_ids");
assert_eq!(parsed.mask, "attention_mask");
assert!(parsed.token_types.is_none());
}
#[test]
fn input_names_serde_custom() {
let j = r#"{ "ids": "tokens", "mask": "mask", "token_types": null }"#;
let parsed: InputNames = serde_json::from_str(j).unwrap();
assert_eq!(parsed.ids, "tokens");
assert_eq!(parsed.mask, "mask");
assert!(parsed.token_types.is_none());
}
#[test]
fn pooling_strategy_serde_roundtrip() {
let mean: PoolingStrategy = serde_json::from_str("\"mean\"").unwrap();
assert_eq!(mean, PoolingStrategy::Mean);
let cls: PoolingStrategy = serde_json::from_str("\"cls\"").unwrap();
assert_eq!(cls, PoolingStrategy::Cls);
let last: PoolingStrategy = serde_json::from_str("\"lasttoken\"").unwrap();
assert_eq!(last, PoolingStrategy::LastToken);
}
#[test]
fn pooling_strategy_default_is_mean() {
assert_eq!(PoolingStrategy::default(), PoolingStrategy::Mean);
}
#[test]
fn resolve_custom_non_bert_architecture() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "synthetic-distilbert".to_string(),
repo: Some("org/distil".to_string()),
onnx_path: Some("model.onnx".to_string()),
tokenizer_path: None,
dim: Some(384),
max_seq_length: Some(128),
query_prefix: None,
doc_prefix: None,
input_names: Some(InputNames::bert_no_token_types()),
output_name: Some("sentence_embedding".to_string()),
pooling: Some(PoolingStrategy::Cls),
};
let resolved = ModelConfig::resolve(None, Some(&embedding_cfg));
assert_eq!(resolved.name, "synthetic-distilbert");
assert_eq!(resolved.dim, 384);
assert_eq!(resolved.pooling, PoolingStrategy::Cls);
assert_eq!(resolved.output_name, "sentence_embedding");
assert!(
resolved.input_names.token_types.is_none(),
"Custom config must not re-introduce token_type_ids"
);
}
#[test]
fn resolve_custom_without_architecture_uses_bert_defaults() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let mut cfg = EmbeddingConfig::default();
cfg.model = "legacy-custom".to_string();
cfg.repo = Some("org/legacy".to_string());
cfg.dim = Some(768);
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(resolved.name, "legacy-custom");
assert_eq!(resolved.input_names, InputNames::bert());
assert_eq!(resolved.output_name, "last_hidden_state");
assert_eq!(resolved.pooling, PoolingStrategy::Mean);
}
#[test]
fn embedding_config_serde_with_architecture() {
let json = r#"{
"model": "custom",
"repo": "org/model",
"dim": 768,
"pooling": "cls",
"output_name": "pooled",
"input_names": { "ids": "tok", "mask": "m" }
}"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.pooling, Some(PoolingStrategy::Cls));
assert_eq!(cfg.output_name.as_deref(), Some("pooled"));
let names = cfg.input_names.as_ref().unwrap();
assert_eq!(names.ids, "tok");
assert_eq!(names.mask, "m");
assert!(
names.token_types.is_none(),
"Absent token_types deserializes to None"
);
}
#[test]
fn embedding_config_serde_without_architecture_keeps_all_none() {
let json = r#"{ "model": "bge-large" }"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert!(cfg.pooling.is_none());
assert!(cfg.output_name.is_none());
assert!(cfg.input_names.is_none());
}
#[test]
fn test_from_preset_short_name() {
assert!(ModelConfig::from_preset("e5-base").is_some());
assert!(ModelConfig::from_preset("v9-200k").is_some());
assert!(ModelConfig::from_preset("bge-large").is_some());
}
#[test]
fn test_from_preset_repo_id() {
let cfg = ModelConfig::from_preset("intfloat/e5-base-v2").unwrap();
assert_eq!(cfg.name, "e5-base");
let cfg = ModelConfig::from_preset("jamie8johnson/e5-base-v2-code-search").unwrap();
assert_eq!(cfg.name, "v9-200k");
let cfg = ModelConfig::from_preset("BAAI/bge-large-en-v1.5").unwrap();
assert_eq!(cfg.name, "bge-large");
}
#[test]
fn test_from_preset_unknown() {
assert!(ModelConfig::from_preset("unknown-model").is_none());
assert!(ModelConfig::from_preset("").is_none());
}
#[test]
fn test_resolve_default() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = ModelConfig::resolve(None, None);
assert_eq!(cfg.name, "bge-large");
}
#[test]
fn test_resolve_env_by_name() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_EMBEDDING_MODEL", "bge-large");
let cfg = ModelConfig::resolve(None, None);
assert_eq!(cfg.name, "bge-large");
std::env::remove_var("CQS_EMBEDDING_MODEL");
}
#[test]
fn test_resolve_env_by_repo_id() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_EMBEDDING_MODEL", "BAAI/bge-large-en-v1.5");
let cfg = ModelConfig::resolve(None, None);
assert_eq!(cfg.name, "bge-large");
std::env::remove_var("CQS_EMBEDDING_MODEL");
}
#[test]
fn test_resolve_cli_overrides_env() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_EMBEDDING_MODEL", "bge-large");
let cfg = ModelConfig::resolve(Some("e5-base"), None);
assert_eq!(cfg.name, "e5-base");
std::env::remove_var("CQS_EMBEDDING_MODEL");
}
#[test]
fn test_resolve_unknown_env_warns_and_defaults() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_EMBEDDING_MODEL", "nonexistent-model");
let cfg = ModelConfig::resolve(None, None);
assert_eq!(cfg.name, "bge-large"); std::env::remove_var("CQS_EMBEDDING_MODEL");
}
#[test]
fn test_resolve_unknown_cli_warns_and_defaults() {
let _lock = ENV_MUTEX.lock().unwrap();
let cfg = ModelConfig::resolve(Some("nonexistent"), None);
assert_eq!(cfg.name, "bge-large");
}
#[test]
fn test_resolve_config_preset() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "bge-large".to_string(),
repo: None,
onnx_path: None,
tokenizer_path: None,
dim: None,
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let cfg = ModelConfig::resolve(None, Some(&embedding_cfg));
assert_eq!(cfg.name, "bge-large");
}
#[test]
fn test_resolve_config_custom_model() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "my-custom".to_string(),
repo: Some("my-org/my-model".to_string()),
onnx_path: Some("model.onnx".to_string()),
tokenizer_path: None,
dim: Some(384),
max_seq_length: Some(256),
query_prefix: Some("search: ".to_string()),
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let cfg = ModelConfig::resolve(None, Some(&embedding_cfg));
assert_eq!(cfg.name, "my-custom");
assert_eq!(cfg.repo, "my-org/my-model");
assert_eq!(cfg.dim, 384);
assert_eq!(cfg.max_seq_length, 256);
assert_eq!(cfg.onnx_path, "model.onnx");
assert_eq!(cfg.tokenizer_path, "tokenizer.json"); assert_eq!(cfg.query_prefix, "search: ");
assert_eq!(cfg.doc_prefix, ""); }
#[test]
fn test_resolve_config_unknown_missing_fields_defaults() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "unknown".to_string(),
repo: None, onnx_path: None,
tokenizer_path: None,
dim: None, max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let cfg = ModelConfig::resolve(None, Some(&embedding_cfg));
assert_eq!(cfg.name, "bge-large"); }
#[test]
fn test_embedding_config_default_model() {
let json = r#"{}"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.model, "bge-large");
}
#[test]
fn test_embedding_config_explicit_model() {
let json = r#"{"model": "bge-large"}"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.model, "bge-large");
}
#[test]
fn test_embedding_config_custom_fields() {
let json = r#"{
"model": "custom",
"repo": "org/model",
"dim": 384,
"query_prefix": "q: "
}"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.model, "custom");
assert_eq!(cfg.repo.unwrap(), "org/model");
assert_eq!(cfg.dim.unwrap(), 384);
assert_eq!(cfg.query_prefix.unwrap(), "q: ");
assert!(cfg.doc_prefix.is_none());
}
#[test]
fn test_resolve_empty_env_ignored() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::set_var("CQS_EMBEDDING_MODEL", "");
let cfg = ModelConfig::resolve(None, None);
assert_eq!(cfg.name, "bge-large");
std::env::remove_var("CQS_EMBEDDING_MODEL");
}
#[test]
fn test_resolve_cli_overrides_config() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "bge-large".to_string(),
repo: None,
onnx_path: None,
tokenizer_path: None,
dim: None,
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let cfg = ModelConfig::resolve(Some("e5-base"), Some(&embedding_cfg));
assert_eq!(cfg.name, "e5-base");
}
#[test]
fn tc31_resolve_config_dim_zero_falls_back_to_default() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let embedding_cfg = EmbeddingConfig {
model: "zero-dim-model".to_string(),
repo: Some("org/zero-dim".to_string()),
onnx_path: None,
tokenizer_path: None,
dim: Some(0),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let cfg = ModelConfig::resolve(None, Some(&embedding_cfg));
assert_eq!(
cfg.name, "bge-large",
"dim=0 should cause fallback to default bge-large"
);
assert_eq!(cfg.dim, 1024, "Fallback should have BGE-large dim=1024");
}
#[test]
fn test_sec20_onnx_path_traversal_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "evil-model".to_string(),
repo: Some("evil/model".to_string()),
onnx_path: Some("../../../etc/passwd".to_string()),
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Traversal in onnx_path should fall back to default"
);
}
#[test]
fn test_sec20_tokenizer_path_traversal_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "evil-model".to_string(),
repo: Some("evil/model".to_string()),
onnx_path: Some("model.onnx".to_string()),
tokenizer_path: Some("../../secret/tokenizer.json".to_string()),
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Traversal in tokenizer_path should fall back to default"
);
}
#[test]
fn test_sec20_absolute_onnx_path_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "evil-model".to_string(),
repo: Some("evil/model".to_string()),
onnx_path: Some("/etc/passwd".to_string()),
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Absolute onnx_path should fall back to default"
);
}
#[test]
fn test_sec20_valid_custom_paths_accepted() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "safe-model".to_string(),
repo: Some("org/safe-model".to_string()),
onnx_path: Some("onnx/model.onnx".to_string()),
tokenizer_path: Some("tokenizer.json".to_string()),
dim: Some(384),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "safe-model",
"Valid paths should be accepted"
);
assert_eq!(resolved.onnx_path, "onnx/model.onnx");
assert_eq!(resolved.tokenizer_path, "tokenizer.json");
}
#[test]
fn test_sec20_dotdot_in_middle_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "tricky".to_string(),
repo: Some("org/tricky".to_string()),
onnx_path: Some("models/../../../etc/shadow".to_string()),
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
".. anywhere in path should fall back"
);
}
#[test]
fn test_sec28_repo_no_slash_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "bad-repo".to_string(),
repo: Some("no-slash-repo".to_string()),
onnx_path: None,
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Repo without slash should fall back to default"
);
}
#[test]
fn test_sec28_repo_traversal_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "traversal-repo".to_string(),
repo: Some("../../other-repo/model".to_string()),
onnx_path: None,
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Repo with .. should fall back to default"
);
}
#[test]
fn test_sec28_repo_absolute_path_rejected() {
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBEDDING_MODEL");
let cfg = EmbeddingConfig {
model: "abs-repo".to_string(),
repo: Some("/etc/passwd/model".to_string()),
onnx_path: None,
tokenizer_path: None,
dim: Some(768),
max_seq_length: None,
query_prefix: None,
doc_prefix: None,
input_names: None,
output_name: None,
pooling: None,
};
let resolved = ModelConfig::resolve(None, Some(&cfg));
assert_eq!(
resolved.name, "bge-large",
"Repo starting with / should fall back to default"
);
}
#[test]
fn test_default_model_consts_consistent() {
let dm = ModelConfig::default_model();
assert_eq!(
dm.repo,
super::DEFAULT_MODEL_REPO,
"DEFAULT_MODEL_REPO must match default_model().repo"
);
assert_eq!(
dm.dim,
super::DEFAULT_DIM,
"DEFAULT_DIM must match default_model().dim"
);
assert_eq!(
dm.dim,
crate::EMBEDDING_DIM,
"EMBEDDING_DIM must match default_model().dim"
);
}
}