rank_refine/lib.rs
1//! Rerank search candidates with embeddings. SIMD-accelerated.
2//!
3//! You bring embeddings, this crate scores them. No model weights, no inference.
4//!
5//! # Quick Start
6//!
7//! ```rust
8//! use rank_refine::simd;
9//!
10//! // Dense scoring
11//! let score = simd::cosine(&[1.0, 0.0], &[0.707, 0.707]);
12//!
13//! // Late interaction (`ColBERT` `MaxSim`)
14//! let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
15//! let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
16//! let score = simd::maxsim_vecs(&query, &doc);
17//! ```
18//!
19//! # Modules
20//!
21//! | Module | Purpose |
22//! |--------|---------|
23//! | [`simd`] | SIMD vector ops (dot, cosine, maxsim) |
24//! | [`colbert`] | Late interaction (`MaxSim`), token pooling |
25//! | [`diversity`] | MMR + DPP diversity selection |
26//! | [`crossencoder`] | Cross-encoder trait |
27//! | [`matryoshka`] | MRL tail refinement |
28//!
29//! Advanced: [`scoring`] for trait-based polymorphism, [`embedding`] for type-safe wrappers.
30
31pub mod colbert;
32pub mod crossencoder;
33pub mod diversity;
34pub mod embedding;
35pub mod matryoshka;
36pub mod scoring;
37pub mod simd;
38
39/// Common imports for reranking.
40///
41/// ```rust
42/// use rank_refine::prelude::*;
43/// ```
44pub mod prelude {
45 // Core SIMD functions
46 pub use crate::simd::{cosine, cosine_truncating, dot, dot_truncating, maxsim, norm};
47
48 // Score utilities
49 pub use crate::simd::{normalize_maxsim, softmax_scores, top_k_indices};
50
51 // `ColBERT`
52 pub use crate::colbert::{pool_tokens, rank as colbert_rank};
53
54 // Diversity
55 pub use crate::diversity::{dpp, mmr, DppConfig, MmrConfig};
56
57 // Cross-encoder trait
58 pub use crate::crossencoder::CrossEncoderModel;
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// Error Types
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// Errors from refinement operations.
66#[derive(Debug, Clone, PartialEq)]
67pub enum RefineError {
68 /// `head_dims` must be less than `query.len()` for tail refinement.
69 InvalidHeadDims {
70 /// The head dimension that was provided.
71 head_dims: usize,
72 /// The query length (must be > `head_dims`).
73 query_len: usize,
74 },
75 /// Vector dimensions must match.
76 DimensionMismatch {
77 /// Expected dimension.
78 expected: usize,
79 /// Actual dimension received.
80 got: usize,
81 },
82}
83
84impl std::fmt::Display for RefineError {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 match self {
87 Self::InvalidHeadDims {
88 head_dims,
89 query_len,
90 } => write!(
91 f,
92 "invalid head_dims: {head_dims} >= query length {query_len}"
93 ),
94 Self::DimensionMismatch { expected, got } => {
95 write!(f, "expected {expected} dimensions, got {got}")
96 }
97 }
98 }
99}
100
101impl std::error::Error for RefineError {}
102
103/// Result type for refinement operations.
104pub type Result<T> = std::result::Result<T, RefineError>;
105
106/// Convert `&[Vec<f32>]` to `Vec<&[f32]>`.
107///
108/// Convenience for passing owned token vectors to slice-based APIs:
109///
110/// ```rust
111/// use rank_refine::{simd, as_slices};
112///
113/// let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
114/// let refs = as_slices(&tokens);
115/// let score = simd::maxsim(&refs, &refs);
116/// ```
117#[inline]
118#[must_use]
119pub fn as_slices(tokens: &[Vec<f32>]) -> Vec<&[f32]> {
120 tokens.iter().map(Vec::as_slice).collect()
121}
122
123// ─────────────────────────────────────────────────────────────────────────────
124// Sorting Utilities
125// ─────────────────────────────────────────────────────────────────────────────
126
127/// Sort scored results in descending order (highest score first).
128///
129/// Uses `f32::total_cmp` for deterministic ordering of NaN values.
130#[inline]
131pub(crate) fn sort_scored_desc<T>(results: &mut [(T, f32)]) {
132 results.sort_by(|a, b| b.1.total_cmp(&a.1));
133}
134
135/// Configuration for score blending and truncation.
136///
137/// ```rust
138/// use rank_refine::RefineConfig;
139///
140/// let config = RefineConfig::default()
141/// .with_alpha(0.7) // 70% original, 30% refinement
142/// .with_top_k(10);
143/// ```
144#[derive(Debug, Clone, Copy, PartialEq)]
145pub struct RefineConfig {
146 /// Blending weight: 0.0 = all refinement, 1.0 = all original. Default: 0.5.
147 pub alpha: f32,
148 /// Truncate to top k results. Default: None (return all).
149 pub top_k: Option<usize>,
150}
151
152impl Default for RefineConfig {
153 fn default() -> Self {
154 Self {
155 alpha: 0.5,
156 top_k: None,
157 }
158 }
159}
160
161impl RefineConfig {
162 /// Set blending weight. Clamped to \[0, 1\].
163 ///
164 /// - `0.0` = all refinement score
165 /// - `0.5` = equal blend (default)
166 /// - `1.0` = all original score
167 #[must_use]
168 pub fn with_alpha(mut self, alpha: f32) -> Self {
169 self.alpha = alpha.clamp(0.0, 1.0);
170 self
171 }
172
173 /// Limit output to top k.
174 #[must_use]
175 pub const fn with_top_k(mut self, top_k: usize) -> Self {
176 self.top_k = Some(top_k);
177 self
178 }
179
180 /// Only use refinement scores (alpha = 0).
181 #[must_use]
182 pub const fn refinement_only() -> Self {
183 Self {
184 alpha: 0.0,
185 top_k: None,
186 }
187 }
188
189 /// Only use original scores (alpha = 1).
190 #[must_use]
191 pub const fn original_only() -> Self {
192 Self {
193 alpha: 1.0,
194 top_k: None,
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn refine_config_clamps_alpha() {
205 assert_eq!(RefineConfig::default().with_alpha(-0.5).alpha, 0.0);
206 assert_eq!(RefineConfig::default().with_alpha(1.5).alpha, 1.0);
207 assert_eq!(RefineConfig::default().with_alpha(0.7).alpha, 0.7);
208 }
209}