pub mod index;
use std::path::Path;
use std::sync::Mutex;
use ndarray::{Array2, ArrayView2, Axis};
use ort::session::Session;
use ort::value::Tensor;
use thiserror::Error;
use crate::embedder::{create_session, select_provider};
fn ort_err(e: ort::Error) -> SpladeError {
SpladeError::InferenceFailed(e.to_string())
}
pub type SparseVector = Vec<(u32, f32)>;
#[derive(Error, Debug)]
pub enum SpladeError {
#[error("SPLADE model not found: {0}")]
ModelNotFound(String),
#[error("SPLADE inference failed: {0}")]
InferenceFailed(String),
#[error("SPLADE tokenization failed: {0}")]
TokenizationFailed(String),
#[error(
"SPLADE config mismatch: tokenizer vocab is {tokenizer_vocab}, model vocab is \
{model_vocab}. The tokenizer.json and model.onnx in {dir:?} are from different \
models — replace tokenizer.json with the one matching the model architecture."
)]
ConfigMismatch {
dir: std::path::PathBuf,
tokenizer_vocab: usize,
model_vocab: usize,
},
}
pub struct SpladeEncoder {
session: Mutex<Option<Session>>,
model_path: std::path::PathBuf,
tokenizer: tokenizers::Tokenizer,
threshold: f32,
vocab_size: usize,
}
fn probe_model_vocab(
mut session: Session,
tokenizer: &tokenizers::Tokenizer,
onnx_path: &Path,
) -> Result<usize, SpladeError> {
let _span = tracing::debug_span!("probe_model_vocab", path = %onnx_path.display()).entered();
let encoding = tokenizer
.encode("test", true)
.map_err(|e| SpladeError::TokenizationFailed(format!("probe tokenization: {e}")))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let seq_len = input_ids.len();
let ids_array = Array2::from_shape_vec((1, seq_len), input_ids)
.map_err(|e| SpladeError::InferenceFailed(format!("probe ids tensor: {e}")))?;
let mask_array = Array2::from_shape_vec((1, seq_len), attention_mask)
.map_err(|e| SpladeError::InferenceFailed(format!("probe mask tensor: {e}")))?;
let ids_tensor = Tensor::from_array(ids_array)
.map_err(|e| SpladeError::InferenceFailed(format!("probe ids: {e}")))?;
let mask_tensor = Tensor::from_array(mask_array)
.map_err(|e| SpladeError::InferenceFailed(format!("probe mask: {e}")))?;
let outputs = session
.run(ort::inputs![
"input_ids" => ids_tensor,
"attention_mask" => mask_tensor,
])
.map_err(ort_err)?;
let vocab = if let Some(sv_output) = outputs.get("sparse_vector") {
let (shape, _data) = sv_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 2 {
return Err(SpladeError::InferenceFailed(format!(
"probe: pre-pooled sparse_vector expected 2D [batch, vocab], got {}D",
shape.len()
)));
}
shape[1] as usize
} else if let Some(logits_output) = outputs.get("logits") {
let (shape, _data) = logits_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 3 {
return Err(SpladeError::InferenceFailed(format!(
"probe: expected 3D logits [batch, seq, vocab], got {}D",
shape.len()
)));
}
shape[2] as usize
} else {
let names: Vec<&str> = outputs.keys().collect();
return Err(SpladeError::InferenceFailed(format!(
"probe: no recognized SPLADE output. Expected 'sparse_vector' or 'logits'. \
Available: {names:?}"
)));
};
tracing::debug!(model_vocab = vocab, "Probed SPLADE model vocab");
Ok(vocab)
}
pub fn resolve_splade_model_dir() -> Option<std::path::PathBuf> {
let _span = tracing::debug_span!("resolve_splade_model_dir").entered();
let dir = match std::env::var("CQS_SPLADE_MODEL") {
Ok(p) if !p.is_empty() => {
let expanded = if let Some(stripped) = p.strip_prefix("~/") {
dirs::home_dir()
.map(|h| h.join(stripped))
.unwrap_or_else(|| p.into())
} else {
p.into()
};
tracing::info!(
source = "CQS_SPLADE_MODEL",
path = %expanded.display(),
"SPLADE model dir resolved from env var"
);
expanded
}
_ => {
let default = dirs::home_dir()
.map(|h| h.join(".cache/huggingface/splade-onnx"))
.unwrap_or_default();
tracing::debug!(path = %default.display(), "Using default SPLADE model dir");
default
}
};
let model = dir.join("model.onnx");
let tokenizer = dir.join("tokenizer.json");
if !model.exists() {
tracing::warn!(
path = %model.display(),
"SPLADE model.onnx not found — hybrid search will be disabled"
);
return None;
}
if !tokenizer.exists() {
tracing::warn!(
path = %tokenizer.display(),
"SPLADE tokenizer.json not found — hybrid search will be disabled"
);
return None;
}
Some(dir)
}
fn splade_max_chars() -> usize {
std::env::var("CQS_SPLADE_MAX_CHARS")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(4000)
}
impl SpladeEncoder {
pub fn default_threshold() -> f32 {
std::env::var("CQS_SPLADE_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0.01)
}
pub fn new(model_dir: &Path, threshold: f32) -> Result<Self, SpladeError> {
let _span = tracing::info_span!("splade_encoder_new", dir = %model_dir.display()).entered();
let onnx_path = model_dir.join("model.onnx");
if !onnx_path.exists() {
return Err(SpladeError::ModelNotFound(format!(
"No model.onnx at {}",
model_dir.display()
)));
}
let tokenizer_path = model_dir.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(SpladeError::ModelNotFound(format!(
"No tokenizer.json at {}",
model_dir.display()
)));
}
let provider = select_provider();
let session = create_session(&onnx_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session: {e}")))?;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| SpladeError::TokenizationFailed(e.to_string()))?;
let tokenizer_vocab = tokenizer.get_vocab_size(true);
let model_vocab = probe_model_vocab(session, &tokenizer, &onnx_path)?;
if model_vocab < tokenizer_vocab {
tracing::error!(
tokenizer_vocab,
model_vocab,
dir = %model_dir.display(),
"SPLADE model output dim is smaller than tokenizer vocab — refusing to load"
);
return Err(SpladeError::ConfigMismatch {
dir: model_dir.to_path_buf(),
tokenizer_vocab,
model_vocab,
});
}
let padding_pct = if tokenizer_vocab > 0 {
(model_vocab - tokenizer_vocab) as f32 * 100.0 / tokenizer_vocab as f32
} else {
0.0
};
if padding_pct > 1.5 {
tracing::error!(
tokenizer_vocab,
model_vocab,
padding_pct,
dir = %model_dir.display(),
"SPLADE model vocab is suspiciously larger than tokenizer (> 1.5%) — refusing to load"
);
return Err(SpladeError::ConfigMismatch {
dir: model_dir.to_path_buf(),
tokenizer_vocab,
model_vocab,
});
}
if model_vocab > tokenizer_vocab {
tracing::warn!(
tokenizer_vocab,
model_vocab,
padding_pct,
"SPLADE model vocab is padded above tokenizer vocab — \
extra slots are zero-trained and ignored at encode time"
);
}
let session = create_session(&onnx_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session re-init: {e}")))?;
tracing::info!(
threshold,
vocab_size = tokenizer_vocab,
"SPLADE encoder loaded (vocab consistency verified)"
);
Ok(Self {
session: Mutex::new(Some(session)),
model_path: onnx_path,
tokenizer,
threshold,
vocab_size: tokenizer_vocab,
})
}
pub fn encode(&self, text: &str) -> Result<SparseVector, SpladeError> {
let _span = tracing::debug_span!("splade_encode", text_len = text.len()).entered();
if text.is_empty() {
return Ok(Vec::new());
}
let max_chars = splade_max_chars();
let text = if text.len() > max_chars {
let truncated = &text[..text
.char_indices()
.nth(max_chars)
.map(|(i, _)| i)
.unwrap_or(text.len())];
tracing::debug!(
original_len = text.len(),
truncated_len = truncated.len(),
max_chars,
"Truncated SPLADE input"
);
truncated
} else {
text
};
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| SpladeError::TokenizationFailed(e.to_string()))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let seq_len = input_ids.len();
let ids_array = Array2::from_shape_vec((1, seq_len), input_ids).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build input tensor: {e}"))
})?;
let mask_array = Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build mask tensor: {e}"))
})?;
let ids_tensor = Tensor::from_array(ids_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Tensor: {e}")))?;
let mask_tensor = Tensor::from_array(mask_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Tensor: {e}")))?;
let mut session_guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if session_guard.is_none() {
let provider = select_provider();
let new_session = create_session(&self.model_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session re-init: {e}")))?;
*session_guard = Some(new_session);
tracing::debug!("SPLADE session re-created after clear");
}
let session = session_guard.as_mut().expect("session just initialized");
let outputs = session
.run(ort::inputs![
"input_ids" => ids_tensor,
"attention_mask" => mask_tensor,
])
.map_err(ort_err)?;
let sparse = if let Some(sv_output) = outputs.get("sparse_vector") {
let (shape, data) = sv_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 2 {
return Err(SpladeError::InferenceFailed(format!(
"Pre-pooled sparse_vector expected 2D [batch, vocab], got {}D",
shape.len()
)));
}
let vocab = shape[1] as usize;
tracing::debug!(vocab, format = "pre_pooled", "SPLADE output detected");
let sv: SparseVector = data
.iter()
.enumerate()
.filter_map(|(id, &val)| {
if val > self.threshold {
Some((id as u32, val))
} else {
None
}
})
.collect();
sv
} else if let Some(logits_output) = outputs.get("logits") {
let (shape, data) = logits_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 3 {
return Err(SpladeError::InferenceFailed(format!(
"Expected 3D logits [batch, seq, vocab], got {}D",
shape.len()
)));
}
let vocab = shape[2] as usize;
tracing::debug!(vocab, format = "raw_logits", "SPLADE output detected");
let logits = ArrayView2::from_shape((seq_len, vocab), data).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to reshape logits: {e}"))
})?;
let pooled = logits.fold_axis(Axis(0), f32::NEG_INFINITY, |&a, &b| a.max(b));
let sv: SparseVector = pooled
.iter()
.enumerate()
.filter_map(|(id, &val)| {
let activated = (1.0 + val.max(0.0)).ln();
if activated > self.threshold {
Some((id as u32, activated))
} else {
None
}
})
.collect();
sv
} else {
return Err(SpladeError::InferenceFailed(format!(
"No recognized SPLADE output. Expected 'sparse_vector' or 'logits'. Available: {:?}",
outputs.keys().collect::<Vec<_>>()
)));
};
tracing::debug!(non_zero = sparse.len(), "SPLADE encoding complete");
Ok(sparse)
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<SparseVector>, SpladeError> {
let _span = tracing::debug_span!("splade_encode_batch", count = texts.len()).entered();
if texts.is_empty() {
return Ok(Vec::new());
}
let max_chars = splade_max_chars();
let truncated: Vec<&str> = texts
.iter()
.map(|t| {
if t.len() > max_chars {
let end = t
.char_indices()
.nth(max_chars)
.map(|(i, _)| i)
.unwrap_or(t.len());
&t[..end]
} else {
*t
}
})
.collect();
let non_empty_indices: Vec<usize> = truncated
.iter()
.enumerate()
.filter_map(|(i, t)| if t.is_empty() { None } else { Some(i) })
.collect();
if non_empty_indices.is_empty() {
return Ok(vec![Vec::new(); texts.len()]);
}
let non_empty_texts: Vec<&str> = non_empty_indices.iter().map(|&i| truncated[i]).collect();
let encodings: Vec<_> = non_empty_texts
.iter()
.map(|t| {
self.tokenizer
.encode(*t, true)
.map_err(|e| SpladeError::TokenizationFailed(e.to_string()))
})
.collect::<Result<_, _>>()?;
let batch_size = encodings.len();
let max_seq_len: usize = std::env::var("CQS_SPLADE_MAX_SEQ")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n >= 8)
.unwrap_or(256);
let mut input_ids: Vec<i64> = Vec::with_capacity(batch_size * max_seq_len);
let mut attention_mask: Vec<i64> = Vec::with_capacity(batch_size * max_seq_len);
let mut truncations = 0usize;
for enc in &encodings {
let ids = enc.get_ids();
let mask = enc.get_attention_mask();
let n = ids.len();
if n > max_seq_len {
truncations += 1;
}
for i in 0..max_seq_len {
if i < n {
input_ids.push(ids[i] as i64);
attention_mask.push(mask[i] as i64);
} else {
input_ids.push(0);
attention_mask.push(0);
}
}
}
if truncations > 0 {
tracing::debug!(
truncations,
batch_size,
max_seq_len,
"SPLADE batch had truncated inputs"
);
}
let ids_array =
Array2::from_shape_vec((batch_size, max_seq_len), input_ids).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build batch input tensor: {e}"))
})?;
let mask_array = Array2::from_shape_vec((batch_size, max_seq_len), attention_mask)
.map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build batch mask tensor: {e}"))
})?;
let ids_tensor = Tensor::from_array(ids_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Batch ids tensor: {e}")))?;
let mask_tensor = Tensor::from_array(mask_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Batch mask tensor: {e}")))?;
let mut session_guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if session_guard.is_none() {
let provider = select_provider();
let new_session = create_session(&self.model_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session re-init: {e}")))?;
*session_guard = Some(new_session);
tracing::debug!("SPLADE session re-created after clear");
}
let session = session_guard.as_mut().expect("session just initialized");
let outputs = session
.run(ort::inputs![
"input_ids" => ids_tensor,
"attention_mask" => mask_tensor,
])
.map_err(ort_err)?;
let per_example: Vec<SparseVector> = if let Some(sv_output) = outputs.get("sparse_vector") {
let (shape, data) = sv_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 2 {
return Err(SpladeError::InferenceFailed(format!(
"Pre-pooled sparse_vector expected 2D [batch, vocab], got {}D",
shape.len()
)));
}
if shape[0] as usize != batch_size {
return Err(SpladeError::InferenceFailed(format!(
"sparse_vector batch dim {} != input batch {}",
shape[0], batch_size
)));
}
let vocab = shape[1] as usize;
tracing::debug!(
vocab,
batch = batch_size,
format = "pre_pooled",
"SPLADE batch output"
);
let threshold = self.threshold;
(0..batch_size)
.map(|b| {
let row = &data[b * vocab..(b + 1) * vocab];
row.iter()
.enumerate()
.filter_map(|(id, &val)| {
if val > threshold {
Some((id as u32, val))
} else {
None
}
})
.collect()
})
.collect()
} else if let Some(logits_output) = outputs.get("logits") {
let (shape, data) = logits_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 3 {
return Err(SpladeError::InferenceFailed(format!(
"Expected 3D logits [batch, seq, vocab], got {}D",
shape.len()
)));
}
if shape[0] as usize != batch_size {
return Err(SpladeError::InferenceFailed(format!(
"logits batch dim {} != input batch {}",
shape[0], batch_size
)));
}
if shape[1] as usize != max_seq_len {
return Err(SpladeError::InferenceFailed(format!(
"logits seq dim {} != padded max_seq_len {}",
shape[1], max_seq_len
)));
}
let vocab = shape[2] as usize;
tracing::debug!(
vocab,
batch = batch_size,
format = "raw_logits",
"SPLADE batch output"
);
let example_stride = max_seq_len * vocab;
let threshold = self.threshold;
(0..batch_size)
.map(|b| {
let example = &data[b * example_stride..(b + 1) * example_stride];
let logits = ArrayView2::from_shape((max_seq_len, vocab), example)
.expect("shape derived from data length");
let real_seq_len = encodings[b].get_ids().len().min(max_seq_len);
let pooled: Vec<f32> = (0..vocab)
.map(|v| {
let mut max_val = f32::NEG_INFINITY;
for s in 0..real_seq_len {
let val = logits[[s, v]];
if val > max_val {
max_val = val;
}
}
max_val
})
.collect();
pooled
.iter()
.enumerate()
.filter_map(|(id, &val)| {
let activated = (1.0 + val.max(0.0)).ln();
if activated > threshold {
Some((id as u32, activated))
} else {
None
}
})
.collect()
})
.collect()
} else {
let names: Vec<&str> = outputs.keys().collect();
return Err(SpladeError::InferenceFailed(format!(
"No recognized SPLADE output. Expected 'sparse_vector' or 'logits'. \
Available: {names:?}"
)));
};
let mut results: Vec<SparseVector> = vec![Vec::new(); texts.len()];
for (out_pos, &orig_idx) in non_empty_indices.iter().enumerate() {
results[orig_idx] = per_example[out_pos].clone();
}
Ok(results)
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn decode_token(&self, token_id: u32) -> Option<String> {
self.tokenizer.decode(&[token_id], false).ok()
}
pub fn clear_session(&self) {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if guard.is_some() {
*guard = None;
tracing::debug!("SPLADE session cleared");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn splade_model_dir() -> Option<PathBuf> {
let dir = dirs::home_dir()?.join(".cache/huggingface/splade-onnx");
if dir.join("model.onnx").exists() {
Some(dir)
} else {
None
}
}
#[test]
#[ignore] fn test_encode_produces_sparse_vector() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let sparse = encoder.encode("parse configuration file").unwrap();
assert!(!sparse.is_empty(), "Sparse vector should not be empty");
assert!(
sparse.len() < encoder.vocab_size(),
"Sparse vector should be sparse (< vocab size)"
);
}
#[test]
#[ignore]
fn test_encode_respects_threshold() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.5).unwrap();
let sparse = encoder.encode("search filtered results").unwrap();
for &(_, weight) in &sparse {
assert!(
weight > 0.5,
"All weights should exceed threshold, got {}",
weight
);
}
}
#[test]
#[ignore]
fn test_encode_empty_string() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let sparse = encoder.encode("").unwrap();
assert!(
sparse.is_empty(),
"Empty string should produce empty vector"
);
}
#[test]
#[ignore]
fn test_encode_batch_matches_single() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let text = "find dead code functions";
let single = encoder.encode(text).unwrap();
let batch = encoder.encode_batch(&[text]).unwrap();
assert_eq!(single.len(), batch[0].len());
for (s, b) in single.iter().zip(batch[0].iter()) {
assert_eq!(s.0, b.0, "Token IDs should match");
assert!(
(s.1 - b.1).abs() < 1e-5,
"Weights should match: {} vs {}",
s.1,
b.1
);
}
}
#[test]
#[ignore]
fn test_encode_batch_multiple_matches_serial() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let texts = vec![
"find a function that parses configuration files and validates the result",
"search for dead code",
"Vec::new",
];
let serial: Vec<_> = texts.iter().map(|t| encoder.encode(t).unwrap()).collect();
let batched = encoder.encode_batch(&texts).unwrap();
assert_eq!(serial.len(), batched.len());
for (i, (s, b)) in serial.iter().zip(batched.iter()).enumerate() {
assert_eq!(
s.len(),
b.len(),
"example {i}: token count mismatch (serial {} vs batched {})",
s.len(),
b.len()
);
for (j, ((s_id, s_w), (b_id, b_w))) in s.iter().zip(b.iter()).enumerate() {
assert_eq!(s_id, b_id, "example {i} token {j}: id mismatch");
assert!(
(s_w - b_w).abs() < 1e-4,
"example {i} token {j}: weight mismatch ({s_w} vs {b_w})"
);
}
}
}
#[test]
fn test_encode_batch_empty_input_list() {
}
#[test]
#[ignore]
fn test_encode_batch_empty_input_real_model() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let result = encoder.encode_batch(&[]).unwrap();
assert!(result.is_empty(), "empty input list → empty result");
}
#[test]
#[ignore]
fn test_encode_batch_all_empty_strings() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let result = encoder.encode_batch(&["", "", ""]).unwrap();
assert_eq!(result.len(), 3);
for (i, sv) in result.iter().enumerate() {
assert!(
sv.is_empty(),
"position {i}: empty input should produce empty vector"
);
}
}
#[test]
#[ignore]
fn test_encode_batch_mixed_empty_and_nonempty() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let result = encoder
.encode_batch(&["", "find dead code", "", "search for parser bugs", ""])
.unwrap();
assert_eq!(result.len(), 5);
assert!(result[0].is_empty(), "position 0 (empty) → empty");
assert!(!result[1].is_empty(), "position 1 (non-empty) → non-empty");
assert!(result[2].is_empty(), "position 2 (empty) → empty");
assert!(!result[3].is_empty(), "position 3 (non-empty) → non-empty");
assert!(result[4].is_empty(), "position 4 (empty) → empty");
let serial_1 = encoder.encode("find dead code").unwrap();
let serial_3 = encoder.encode("search for parser bugs").unwrap();
assert_eq!(result[1].len(), serial_1.len());
assert_eq!(result[3].len(), serial_3.len());
}
#[test]
fn test_model_not_found() {
let result = SpladeEncoder::new(Path::new("/nonexistent"), 0.01);
assert!(result.is_err(), "Should fail for nonexistent path");
match result {
Err(e) => assert!(
e.to_string().contains("not found"),
"Error should mention not found: {e}"
),
Ok(_) => unreachable!(),
}
}
use std::sync::Mutex;
static SPLADE_ENV_LOCK: Mutex<()> = Mutex::new(());
fn write_stub_splade_dir(dir: &Path) {
std::fs::write(dir.join("model.onnx"), b"stub").unwrap();
std::fs::write(dir.join("tokenizer.json"), b"stub").unwrap();
}
#[test]
fn test_resolve_env_var_override() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
let tmp = tempfile::TempDir::new().unwrap();
write_stub_splade_dir(tmp.path());
std::env::set_var("CQS_SPLADE_MODEL", tmp.path());
let resolved = resolve_splade_model_dir();
std::env::remove_var("CQS_SPLADE_MODEL");
assert_eq!(resolved.as_deref(), Some(tmp.path()));
}
#[test]
fn test_resolve_env_var_tilde_expansion() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
let home = dirs::home_dir().expect("HOME must be set in test env");
let stub_subdir = format!(".cqs-test-splade-{}", std::process::id());
let stub_dir = home.join(&stub_subdir);
std::fs::create_dir_all(&stub_dir).unwrap();
write_stub_splade_dir(&stub_dir);
std::env::set_var("CQS_SPLADE_MODEL", format!("~/{stub_subdir}"));
let resolved = resolve_splade_model_dir();
std::env::remove_var("CQS_SPLADE_MODEL");
let _ = std::fs::remove_dir_all(&stub_dir);
assert_eq!(
resolved.as_deref(),
Some(stub_dir.as_path()),
"tilde-prefixed CQS_SPLADE_MODEL should expand against $HOME"
);
}
#[test]
fn test_resolve_env_var_missing_model_returns_none() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
let tmp = tempfile::TempDir::new().unwrap();
std::fs::write(tmp.path().join("tokenizer.json"), b"stub").unwrap();
std::env::set_var("CQS_SPLADE_MODEL", tmp.path());
let resolved = resolve_splade_model_dir();
std::env::remove_var("CQS_SPLADE_MODEL");
assert!(
resolved.is_none(),
"should return None when model.onnx is missing"
);
}
#[test]
fn test_resolve_env_var_missing_tokenizer_returns_none() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
let tmp = tempfile::TempDir::new().unwrap();
std::fs::write(tmp.path().join("model.onnx"), b"stub").unwrap();
std::env::set_var("CQS_SPLADE_MODEL", tmp.path());
let resolved = resolve_splade_model_dir();
std::env::remove_var("CQS_SPLADE_MODEL");
assert!(
resolved.is_none(),
"should return None when tokenizer.json is missing — \
a model+wrong-tokenizer dir must not silently fall through"
);
}
#[test]
fn test_resolve_env_var_empty_falls_back_to_default() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
std::env::set_var("CQS_SPLADE_MODEL", "");
let resolved = resolve_splade_model_dir();
std::env::remove_var("CQS_SPLADE_MODEL");
let expected_default = dirs::home_dir()
.map(|h| h.join(".cache/huggingface/splade-onnx"))
.unwrap_or_default();
if expected_default.join("model.onnx").exists()
&& expected_default.join("tokenizer.json").exists()
{
assert_eq!(
resolved.as_deref(),
Some(expected_default.as_path()),
"empty env var should fall back to default cache dir"
);
} else {
assert!(
resolved.is_none(),
"empty env var with no default model installed → None"
);
}
}
#[test]
fn test_resolve_no_env_var() {
let _guard = SPLADE_ENV_LOCK.lock().unwrap();
std::env::remove_var("CQS_SPLADE_MODEL");
let resolved = resolve_splade_model_dir();
let expected_default = dirs::home_dir()
.map(|h| h.join(".cache/huggingface/splade-onnx"))
.unwrap_or_default();
if expected_default.join("model.onnx").exists()
&& expected_default.join("tokenizer.json").exists()
{
assert_eq!(resolved.as_deref(), Some(expected_default.as_path()));
} else {
assert!(resolved.is_none());
}
}
#[test]
fn test_config_mismatch_error_message_is_actionable() {
let err = SpladeError::ConfigMismatch {
dir: PathBuf::from("/some/where/splade-onnx"),
tokenizer_vocab: 30522,
model_vocab: 151936,
};
let msg = err.to_string();
assert!(
msg.contains("30522"),
"should include tokenizer vocab: {msg}"
);
assert!(msg.contains("151936"), "should include model vocab: {msg}");
assert!(
msg.contains("/some/where/splade-onnx"),
"should include the directory: {msg}"
);
assert!(
msg.to_lowercase().contains("tokenizer"),
"should mention tokenizer.json as the fix-point: {msg}"
);
}
fn check_vocab_compatibility(
tokenizer_vocab: usize,
model_vocab: usize,
) -> Result<bool, &'static str> {
if model_vocab < tokenizer_vocab {
return Err("model_vocab < tokenizer_vocab");
}
let padding_pct = if tokenizer_vocab > 0 {
(model_vocab - tokenizer_vocab) as f32 * 100.0 / tokenizer_vocab as f32
} else {
0.0
};
if padding_pct > 1.5 {
return Err("padding > 1.5%");
}
Ok(model_vocab > tokenizer_vocab)
}
#[test]
fn test_vocab_compat_exact_match_accepted() {
assert_eq!(check_vocab_compatibility(30522, 30522), Ok(false));
assert_eq!(check_vocab_compatibility(151669, 151669), Ok(false));
}
#[test]
fn test_vocab_compat_benign_padding_accepted() {
assert_eq!(
check_vocab_compatibility(151669, 151936),
Ok(true),
"SPLADE-Code 0.6B's 0.18% lm_head padding must be accepted"
);
assert_eq!(
check_vocab_compatibility(30000, 30300),
Ok(true),
"1% padding should be accepted"
);
assert_eq!(
check_vocab_compatibility(30000, 30449),
Ok(true),
"1.49% padding should be accepted"
);
}
#[test]
fn test_vocab_compat_large_padding_rejected() {
assert_eq!(
check_vocab_compatibility(30000, 30460),
Err("padding > 1.5%"),
"1.53% padding should be rejected"
);
assert_eq!(
check_vocab_compatibility(30522, 121936),
Err("padding > 1.5%"),
);
}
#[test]
fn test_vocab_compat_tokenizer_larger_rejected() {
assert_eq!(
check_vocab_compatibility(151669, 30522),
Err("model_vocab < tokenizer_vocab"),
"tokenizer larger than model must hard-fail"
);
assert_eq!(
check_vocab_compatibility(151936, 151935),
Err("model_vocab < tokenizer_vocab"),
"even by 1 must hard-fail"
);
}
#[test]
fn test_vocab_compat_zero_tokenizer_vocab() {
assert_eq!(check_vocab_compatibility(0, 0), Ok(false));
assert_eq!(check_vocab_compatibility(0, 100), Ok(true));
}
}