pub mod colbert;
pub mod diversity;
pub mod embedding;
pub mod explain;
pub mod matryoshka;
pub mod quantization;
pub mod scoring;
pub mod simd;
pub use colbert::{rank as maxsim_rank, refine as maxsim_refine};
pub use diversity::{dpp, mmr as diversity_mmr, DppConfig, MmrConfig as DiversityMmrConfig};
pub use matryoshka::refine as matryoshka_refine;
pub use quantization::{dequantize_int8, quantize_int8, QuantizationError};
pub use scoring::{Scorer, TokenScorer};
#[derive(Debug, Clone, PartialEq)]
pub enum RerankError {
InvalidHeadDims {
head_dims: usize,
query_len: usize,
},
DimensionMismatch {
expected: usize,
got: usize,
},
InvalidPoolFactor {
pool_factor: usize,
},
InvalidWindowSize {
window_size: usize,
},
}
impl std::fmt::Display for RerankError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidHeadDims {
head_dims,
query_len,
} => write!(
f,
"invalid head_dims: {head_dims} >= query length {query_len}"
),
Self::DimensionMismatch { expected, got } => {
write!(f, "expected {expected} dimensions, got {got}")
}
Self::InvalidPoolFactor { pool_factor } => {
write!(f, "pool_factor must be >= 1, got {pool_factor}")
}
Self::InvalidWindowSize { window_size } => {
write!(f, "window_size must be >= 1, got {window_size}")
}
}
}
}
impl std::error::Error for RerankError {}
pub type Result<T> = std::result::Result<T, RerankError>;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RerankConfig {
pub alpha: f32,
pub top_k: Option<usize>,
}
impl Default for RerankConfig {
fn default() -> Self {
Self {
alpha: 0.5,
top_k: None,
}
}
}
impl RerankConfig {
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha.clamp(0.0, 1.0);
self
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
#[must_use]
pub const fn refinement_only() -> Self {
Self {
alpha: 0.0,
top_k: None,
}
}
#[must_use]
pub const fn original_only() -> Self {
Self {
alpha: 1.0,
top_k: None,
}
}
}
#[inline]
pub(crate) fn sort_scored_desc<T>(results: &mut [(T, f32)]) {
results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
}