Skip to main content

ripvec_core/backend/
mod.rs

1//! Backend abstraction layer.
2//!
3//! Post-v3.0.0 the only backend is the CPU cross-encoder reranker
4//! ([`cpu::CpuRerankBackend`], backed by [`cpu::CpuBertModel`]). The
5//! [`RerankBackend`] trait, [`Encoding`] input type, and [`BackendKind`]
6//! discriminant survive; the [`EmbedBackend`] trait and bi-encoder
7//! `load_backend` / `detect_backends` entry points were removed when
8//! the transformer engines came out.
9
10pub mod blas_info;
11// `cpu` covers CpuBertModel + CpuRerankBackend (both keep-anchors per
12// the surgery's backend_split.md ยง3). The CpuBackend wrapper struct
13// was removed with the bi-encoder backends; the trunk + reranker survived.
14// Gate widened from `cfg(feature = "cpu")` so the macOS default build
15// (which uses `cpu-accelerate`) gets the reranker.
16#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
17pub mod cpu;
18
19/// Pre-tokenized encoding ready for inference.
20///
21/// Token IDs, attention mask, and token type IDs must all have the same length.
22/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
23/// reaching the backend.
24#[derive(Debug, Clone)]
25pub struct Encoding {
26    /// Token IDs produced by the tokenizer.
27    pub input_ids: Vec<i64>,
28    /// Attention mask (1 for real tokens, 0 for padding).
29    pub attention_mask: Vec<i64>,
30    /// Token type IDs (0 for single-sequence models).
31    pub token_type_ids: Vec<i64>,
32}
33
34/// Trait for cross-encoder rerank backends.
35///
36/// Parallel to [`EmbedBackend`], but the forward pass terminates in a
37/// scalar relevance score per pair instead of a pooled vector. Used by
38/// the retrieve-then-rerank pipeline: a bi-encoder ([`EmbedBackend`])
39/// retrieves top-K cheaply, then [`RerankBackend`] re-scores those K
40/// candidates with the cross-encoder's higher-quality cross-attention
41/// over the concatenated `[CLS] query [SEP] doc [SEP]` sequence.
42///
43/// # Why a separate trait
44///
45/// Cross-encoders share BERT's trunk with bi-encoders, but the head and
46/// pooling differ: bi-encoder = CLS pool + L2-normalize, cross-encoder
47/// = CLS pool + linear(hidden -> 1) + sigmoid. The two return shapes are
48/// incompatible (`Vec<Vec<f32>>` vs `Vec<f32>`), so unifying them under
49/// a single trait would force every caller to handle an awkward sum
50/// type. Sibling traits keep both call sites direct.
51pub trait RerankBackend: Send + Sync {
52    /// Score a batch of pre-tokenized pairs and return one score per
53    /// encoding. Scores are sigmoid-activated and lie in `[0, 1]`.
54    ///
55    /// The encoding's `token_type_ids` should mark the query side as
56    /// 0 and the doc side as 1 (standard BERT pair convention); this
57    /// is what `tokenizers::Tokenizer::encode((query, doc), ..)`
58    /// produces.
59    ///
60    /// # Errors
61    ///
62    /// Returns an error if tensor construction or the forward pass fails.
63    fn score_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<f32>>;
64
65    /// Maximum token count this model supports.
66    fn max_tokens(&self) -> usize {
67        512
68    }
69
70    /// Whether this backend runs on a GPU.
71    fn is_gpu(&self) -> bool;
72
73    /// Short human-readable label for this backend.
74    fn name(&self) -> &'static str {
75        if self.is_gpu() { "GPU" } else { "CPU" }
76    }
77}
78
79/// Detect available backends and load them.
80///
81/// The `CpuBackend` wrapper was removed with the bi-encoder backends; the embedding path is
82/// excised (B6 will prune `embed.rs` and `cache/reindex.rs`). This function
83/// now always returns an error. Retained here until B6 removes the `server.rs`
84/// caller at line 463.
85///
86/// # Errors
87///
88/// Load a cross-encoder rerank model for CPU inference.
89///
90/// MS-MARCO family rerankers (the default
91/// `cross-encoder/ms-marco-MiniLM-L-6-v2`) are ClassicBert-shaped, so
92/// they route through [`cpu::CpuRerankBackend`] - same trunk as the
93/// bi-encoder, plus a `Linear(hidden -> 1)` classifier head.
94///
95/// Not feature-gated like the (now-deleted) embedding backends: the rerank
96/// path is load-bearing for the document-search use case (cacheless prose
97/// queries) and must work in the default build. The underlying
98/// `CpuRerankBackend` uses the same ndarray BLAS setup as the former
99/// `CpuBackend`, so it works wherever the CPU embedding backend did -
100/// `feature = "cpu"` or `feature = "cpu-accelerate"`.
101///
102/// # Errors
103///
104/// Returns an error if the model cannot be downloaded, if it lacks a
105/// classifier head (i.e., the caller pointed at a bi-encoder by
106/// mistake), or if the weights fail to parse.
107#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
108pub fn load_reranker_cpu(model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
109    let backend = cpu::CpuRerankBackend::load(model_repo)?;
110    Ok(Box::new(backend))
111}
112
113#[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
114pub fn load_reranker_cpu(_model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
115    Err(crate::Error::Other(anyhow::anyhow!(
116        "cross-encoder rerank requires building with --features cpu \
117         or --features cpu-accelerate"
118    )))
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    // EmbedBackend trait object-safety + Send/Sync tests removed in v3.0.0:
126    // the trait itself was deleted (zero impls post-surgery). The surviving
127    // RerankBackend trait has one impl (CpuRerankBackend), used as a concrete
128    // type at every call site; no trait-object assertions needed.
129
130    #[test]
131    fn encoding_construction() {
132        let enc = Encoding {
133            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
134            attention_mask: vec![1, 1, 1, 1, 1, 1],
135            token_type_ids: vec![0, 0, 0, 0, 0, 0],
136        };
137        assert_eq!(enc.input_ids.len(), 6);
138        assert_eq!(enc.attention_mask.len(), 6);
139        assert_eq!(enc.token_type_ids.len(), 6);
140    }
141
142    #[test]
143    fn encoding_clone() {
144        let enc = Encoding {
145            input_ids: vec![101, 102],
146            attention_mask: vec![1, 1],
147            token_type_ids: vec![0, 0],
148        };
149        let cloned = enc.clone();
150        assert_eq!(enc.input_ids, cloned.input_ids);
151    }
152}