rank_refine/
lib.rs

1//! Rerank search candidates with embeddings. SIMD-accelerated.
2//!
3//! You bring embeddings, this crate scores them. No model weights, no inference.
4//!
5//! # Quick Start
6//!
7//! ```rust
8//! use rank_refine::simd;
9//!
10//! // Dense scoring
11//! let score = simd::cosine(&[1.0, 0.0], &[0.707, 0.707]);
12//!
13//! // Late interaction (`ColBERT` `MaxSim`)
14//! let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
15//! let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
16//! let score = simd::maxsim_vecs(&query, &doc);
17//! ```
18//!
19//! # Modules
20//!
21//! | Module | Purpose |
22//! |--------|---------|
23//! | [`simd`] | SIMD vector ops (dot, cosine, maxsim) |
24//! | [`colbert`] | Late interaction (`MaxSim`), token pooling |
25//! | [`diversity`] | MMR + DPP diversity selection |
26//! | [`crossencoder`] | Cross-encoder trait |
27//! | [`matryoshka`] | MRL tail refinement |
28//!
29//! Advanced: [`scoring`] for trait-based polymorphism, [`embedding`] for type-safe wrappers.
30
31pub mod colbert;
32pub mod crossencoder;
33pub mod diversity;
34pub mod embedding;
35pub mod matryoshka;
36pub mod scoring;
37pub mod simd;
38
39/// Common imports for reranking.
40///
41/// ```rust
42/// use rank_refine::prelude::*;
43/// ```
44pub mod prelude {
45    // Core SIMD functions
46    pub use crate::simd::{cosine, cosine_truncating, dot, dot_truncating, maxsim, norm};
47
48    // Score utilities
49    pub use crate::simd::{normalize_maxsim, softmax_scores, top_k_indices};
50
51    // `ColBERT`
52    pub use crate::colbert::{pool_tokens, rank as colbert_rank};
53
54    // Diversity
55    pub use crate::diversity::{dpp, mmr, DppConfig, MmrConfig};
56
57    // Cross-encoder trait
58    pub use crate::crossencoder::CrossEncoderModel;
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// Error Types
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// Errors from refinement operations.
66#[derive(Debug, Clone, PartialEq)]
67pub enum RefineError {
68    /// `head_dims` must be less than `query.len()` for tail refinement.
69    InvalidHeadDims {
70        /// The head dimension that was provided.
71        head_dims: usize,
72        /// The query length (must be > `head_dims`).
73        query_len: usize,
74    },
75    /// Vector dimensions must match.
76    DimensionMismatch {
77        /// Expected dimension.
78        expected: usize,
79        /// Actual dimension received.
80        got: usize,
81    },
82}
83
84impl std::fmt::Display for RefineError {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        match self {
87            Self::InvalidHeadDims {
88                head_dims,
89                query_len,
90            } => write!(
91                f,
92                "invalid head_dims: {head_dims} >= query length {query_len}"
93            ),
94            Self::DimensionMismatch { expected, got } => {
95                write!(f, "expected {expected} dimensions, got {got}")
96            }
97        }
98    }
99}
100
101impl std::error::Error for RefineError {}
102
103/// Result type for refinement operations.
104pub type Result<T> = std::result::Result<T, RefineError>;
105
106/// Convert `&[Vec<f32>]` to `Vec<&[f32]>`.
107///
108/// Convenience for passing owned token vectors to slice-based APIs:
109///
110/// ```rust
111/// use rank_refine::{simd, as_slices};
112///
113/// let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
114/// let refs = as_slices(&tokens);
115/// let score = simd::maxsim(&refs, &refs);
116/// ```
117#[inline]
118#[must_use]
119pub fn as_slices(tokens: &[Vec<f32>]) -> Vec<&[f32]> {
120    tokens.iter().map(Vec::as_slice).collect()
121}
122
123// ─────────────────────────────────────────────────────────────────────────────
124// Sorting Utilities
125// ─────────────────────────────────────────────────────────────────────────────
126
127/// Sort scored results in descending order (highest score first).
128///
129/// Uses `f32::total_cmp` for deterministic ordering of NaN values.
130#[inline]
131pub(crate) fn sort_scored_desc<T>(results: &mut [(T, f32)]) {
132    results.sort_by(|a, b| b.1.total_cmp(&a.1));
133}
134
135/// Configuration for score blending and truncation.
136///
137/// ```rust
138/// use rank_refine::RefineConfig;
139///
140/// let config = RefineConfig::default()
141///     .with_alpha(0.7)  // 70% original, 30% refinement
142///     .with_top_k(10);
143/// ```
144#[derive(Debug, Clone, Copy, PartialEq)]
145pub struct RefineConfig {
146    /// Blending weight: 0.0 = all refinement, 1.0 = all original. Default: 0.5.
147    pub alpha: f32,
148    /// Truncate to top k results. Default: None (return all).
149    pub top_k: Option<usize>,
150}
151
152impl Default for RefineConfig {
153    fn default() -> Self {
154        Self {
155            alpha: 0.5,
156            top_k: None,
157        }
158    }
159}
160
161impl RefineConfig {
162    /// Set blending weight. Clamped to \[0, 1\].
163    ///
164    /// - `0.0` = all refinement score
165    /// - `0.5` = equal blend (default)
166    /// - `1.0` = all original score
167    #[must_use]
168    pub fn with_alpha(mut self, alpha: f32) -> Self {
169        self.alpha = alpha.clamp(0.0, 1.0);
170        self
171    }
172
173    /// Limit output to top k.
174    #[must_use]
175    pub const fn with_top_k(mut self, top_k: usize) -> Self {
176        self.top_k = Some(top_k);
177        self
178    }
179
180    /// Only use refinement scores (alpha = 0).
181    #[must_use]
182    pub const fn refinement_only() -> Self {
183        Self {
184            alpha: 0.0,
185            top_k: None,
186        }
187    }
188
189    /// Only use original scores (alpha = 1).
190    #[must_use]
191    pub const fn original_only() -> Self {
192        Self {
193            alpha: 1.0,
194            top_k: None,
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn refine_config_clamps_alpha() {
205        assert_eq!(RefineConfig::default().with_alpha(-0.5).alpha, 0.0);
206        assert_eq!(RefineConfig::default().with_alpha(1.5).alpha, 1.0);
207        assert_eq!(RefineConfig::default().with_alpha(0.7).alpha, 0.7);
208    }
209}