pub mod index;
use std::path::Path;
use std::sync::Mutex;
use ndarray::{Array2, ArrayView2, Axis};
use ort::session::Session;
use ort::value::Tensor;
use thiserror::Error;
use crate::embedder::{create_session, select_provider};
fn ort_err(e: ort::Error) -> SpladeError {
SpladeError::InferenceFailed(e.to_string())
}
pub type SparseVector = Vec<(u32, f32)>;
#[derive(Error, Debug)]
pub enum SpladeError {
#[error("SPLADE model not found: {0}")]
ModelNotFound(String),
#[error("SPLADE inference failed: {0}")]
InferenceFailed(String),
#[error("SPLADE tokenization failed: {0}")]
TokenizationFailed(String),
}
pub struct SpladeEncoder {
session: Mutex<Option<Session>>,
model_path: std::path::PathBuf,
tokenizer: tokenizers::Tokenizer,
threshold: f32,
vocab_size: usize,
}
impl SpladeEncoder {
pub fn default_threshold() -> f32 {
std::env::var("CQS_SPLADE_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0.01)
}
pub fn new(model_dir: &Path, threshold: f32) -> Result<Self, SpladeError> {
let _span = tracing::info_span!("splade_encoder_new", dir = %model_dir.display()).entered();
let onnx_path = model_dir.join("model.onnx");
if !onnx_path.exists() {
return Err(SpladeError::ModelNotFound(format!(
"No model.onnx at {}",
model_dir.display()
)));
}
let tokenizer_path = model_dir.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(SpladeError::ModelNotFound(format!(
"No tokenizer.json at {}",
model_dir.display()
)));
}
let provider = select_provider();
let session = create_session(&onnx_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session: {e}")))?;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| SpladeError::TokenizationFailed(e.to_string()))?;
let vocab_size = tokenizer.get_vocab_size(true);
tracing::info!(threshold, vocab_size, "SPLADE encoder loaded");
Ok(Self {
session: Mutex::new(Some(session)),
model_path: onnx_path,
tokenizer,
threshold,
vocab_size,
})
}
pub fn encode(&self, text: &str) -> Result<SparseVector, SpladeError> {
let _span = tracing::debug_span!("splade_encode", text_len = text.len()).entered();
if text.is_empty() {
return Ok(Vec::new());
}
let text = if text.len() > 4000 {
let truncated = &text[..text
.char_indices()
.nth(4000)
.map(|(i, _)| i)
.unwrap_or(text.len())];
tracing::debug!(
original_len = text.len(),
truncated_len = truncated.len(),
"Truncated SPLADE input to 4000 chars"
);
truncated
} else {
text
};
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| SpladeError::TokenizationFailed(e.to_string()))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let seq_len = input_ids.len();
let ids_array = Array2::from_shape_vec((1, seq_len), input_ids).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build input tensor: {e}"))
})?;
let mask_array = Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to build mask tensor: {e}"))
})?;
let ids_tensor = Tensor::from_array(ids_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Tensor: {e}")))?;
let mask_tensor = Tensor::from_array(mask_array)
.map_err(|e| SpladeError::InferenceFailed(format!("Tensor: {e}")))?;
let mut session_guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if session_guard.is_none() {
let provider = select_provider();
let new_session = create_session(&self.model_path, provider)
.map_err(|e| SpladeError::InferenceFailed(format!("ORT session re-init: {e}")))?;
*session_guard = Some(new_session);
tracing::debug!("SPLADE session re-created after clear");
}
let session = session_guard.as_mut().expect("session just initialized");
let outputs = session
.run(ort::inputs![
"input_ids" => ids_tensor,
"attention_mask" => mask_tensor,
])
.map_err(ort_err)?;
let sparse = if let Some(sv_output) = outputs.get("sparse_vector") {
let (shape, data) = sv_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 2 {
return Err(SpladeError::InferenceFailed(format!(
"Pre-pooled sparse_vector expected 2D [batch, vocab], got {}D",
shape.len()
)));
}
let vocab = shape[1] as usize;
tracing::debug!(vocab, format = "pre_pooled", "SPLADE output detected");
let sv: SparseVector = data
.iter()
.enumerate()
.filter_map(|(id, &val)| {
if val > self.threshold {
Some((id as u32, val))
} else {
None
}
})
.collect();
sv
} else if let Some(logits_output) = outputs.get("logits") {
let (shape, data) = logits_output.try_extract_tensor::<f32>().map_err(ort_err)?;
if shape.len() != 3 {
return Err(SpladeError::InferenceFailed(format!(
"Expected 3D logits [batch, seq, vocab], got {}D",
shape.len()
)));
}
let vocab = shape[2] as usize;
tracing::debug!(vocab, format = "raw_logits", "SPLADE output detected");
let logits = ArrayView2::from_shape((seq_len, vocab), data).map_err(|e| {
SpladeError::InferenceFailed(format!("Failed to reshape logits: {e}"))
})?;
let pooled = logits.fold_axis(Axis(0), f32::NEG_INFINITY, |&a, &b| a.max(b));
let sv: SparseVector = pooled
.iter()
.enumerate()
.filter_map(|(id, &val)| {
let activated = (1.0 + val.max(0.0)).ln();
if activated > self.threshold {
Some((id as u32, activated))
} else {
None
}
})
.collect();
sv
} else {
return Err(SpladeError::InferenceFailed(format!(
"No recognized SPLADE output. Expected 'sparse_vector' or 'logits'. Available: {:?}",
outputs.keys().collect::<Vec<_>>()
)));
};
tracing::debug!(non_zero = sparse.len(), "SPLADE encoding complete");
Ok(sparse)
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<SparseVector>, SpladeError> {
let _span = tracing::debug_span!("splade_encode_batch", count = texts.len()).entered();
texts.iter().map(|t| self.encode(t)).collect()
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn decode_token(&self, token_id: u32) -> Option<String> {
self.tokenizer.decode(&[token_id], false).ok()
}
pub fn clear_session(&self) {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if guard.is_some() {
*guard = None;
tracing::debug!("SPLADE session cleared");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn splade_model_dir() -> Option<PathBuf> {
let dir = dirs::home_dir()?.join(".cache/huggingface/splade-onnx");
if dir.join("model.onnx").exists() {
Some(dir)
} else {
None
}
}
#[test]
#[ignore] fn test_encode_produces_sparse_vector() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let sparse = encoder.encode("parse configuration file").unwrap();
assert!(!sparse.is_empty(), "Sparse vector should not be empty");
assert!(
sparse.len() < encoder.vocab_size(),
"Sparse vector should be sparse (< vocab size)"
);
}
#[test]
#[ignore]
fn test_encode_respects_threshold() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.5).unwrap();
let sparse = encoder.encode("search filtered results").unwrap();
for &(_, weight) in &sparse {
assert!(
weight > 0.5,
"All weights should exceed threshold, got {}",
weight
);
}
}
#[test]
#[ignore]
fn test_encode_empty_string() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let sparse = encoder.encode("").unwrap();
assert!(
sparse.is_empty(),
"Empty string should produce empty vector"
);
}
#[test]
#[ignore]
fn test_encode_batch_matches_single() {
let dir = splade_model_dir().expect("SPLADE model not downloaded");
let encoder = SpladeEncoder::new(&dir, 0.01).unwrap();
let text = "find dead code functions";
let single = encoder.encode(text).unwrap();
let batch = encoder.encode_batch(&[text]).unwrap();
assert_eq!(single.len(), batch[0].len());
for (s, b) in single.iter().zip(batch[0].iter()) {
assert_eq!(s.0, b.0, "Token IDs should match");
assert!(
(s.1 - b.1).abs() < 1e-5,
"Weights should match: {} vs {}",
s.1,
b.1
);
}
}
#[test]
fn test_model_not_found() {
let result = SpladeEncoder::new(Path::new("/nonexistent"), 0.01);
assert!(result.is_err(), "Should fail for nonexistent path");
match result {
Err(e) => assert!(
e.to_string().contains("not found"),
"Error should mention not found: {e}"
),
Ok(_) => unreachable!(),
}
}
}