use std::fmt;
use serde::{Deserialize, Serialize};
use crate::superfile::vector::distance::Metric;
const LOW_DIM_RERANK_FLOOR_THRESHOLD: usize = 384;
const FP32_LOW_DIM_RERANK_FLOOR: usize = 20;
const FP32_HIGH_DIM_RERANK_FLOOR: usize = 50;
const SQ8_LOW_DIM_RERANK_FLOOR: usize = 50;
const SQ8_HIGH_DIM_RERANK_FLOOR: usize = 100;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RerankCodec {
Fp32,
Sq8ResidualEpsilon,
RabitqOnly,
}
impl Default for RerankCodec {
fn default() -> Self {
Self::Sq8ResidualEpsilon
}
}
impl RerankCodec {
#[inline]
pub const fn codec_id(self) -> u8 {
match self {
Self::Fp32 => 0,
Self::Sq8ResidualEpsilon => 1,
Self::RabitqOnly => 2,
}
}
#[inline]
pub const fn from_codec_id(id: u8) -> Option<Self> {
match id {
0 => Some(Self::Fp32),
1 => Some(Self::Sq8ResidualEpsilon),
2 => Some(Self::RabitqOnly),
_ => None,
}
}
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Fp32 => "fp32",
Self::Sq8ResidualEpsilon => "sq8_residual",
Self::RabitqOnly => "rabitq_only",
}
}
#[inline]
pub const fn per_vector_bytes(self, dim: usize) -> usize {
match self {
Self::Fp32 => dim * 4,
Self::Sq8ResidualEpsilon => dim * 2,
Self::RabitqOnly => 0,
}
}
#[inline]
pub const fn writes_full(self) -> bool {
!matches!(self, Self::RabitqOnly)
}
#[inline]
pub const fn is_implemented(self) -> bool {
matches!(
self,
Self::Fp32 | Self::Sq8ResidualEpsilon | Self::RabitqOnly
)
}
#[inline]
pub const fn recommended_rerank_mult_floor(self, dim: usize) -> Option<usize> {
let high_dim = dim > LOW_DIM_RERANK_FLOOR_THRESHOLD;
match self {
Self::Fp32 => Some(if high_dim {
FP32_HIGH_DIM_RERANK_FLOOR
} else {
FP32_LOW_DIM_RERANK_FLOOR
}),
Self::Sq8ResidualEpsilon => Some(if high_dim {
SQ8_HIGH_DIM_RERANK_FLOOR
} else {
SQ8_LOW_DIM_RERANK_FLOOR
}),
Self::RabitqOnly => None,
}
}
#[inline]
pub const fn codec_meta_bytes(
self,
dim: usize,
n_docs: usize,
n_cent: usize,
metric: Metric,
) -> usize {
match self {
Self::Fp32 | Self::RabitqOnly => 0,
Self::Sq8ResidualEpsilon => {
let scale_offset_bytes = 2 * n_cent * dim * 4;
let norms_bytes = match metric {
Metric::L2Sq | Metric::Cosine => n_docs * 4,
Metric::NegDot => 0,
};
scale_offset_bytes + norms_bytes
}
}
}
}
impl fmt::Display for RerankCodec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_sq8_residual() {
assert_eq!(RerankCodec::default(), RerankCodec::Sq8ResidualEpsilon);
}
#[test]
fn fp32_codec_id_is_zero() {
assert_eq!(RerankCodec::Fp32.codec_id(), 0u8);
}
#[test]
fn codec_id_roundtrips_every_variant() {
for c in [
RerankCodec::Fp32,
RerankCodec::Sq8ResidualEpsilon,
RerankCodec::RabitqOnly,
] {
assert_eq!(
RerankCodec::from_codec_id(c.codec_id()),
Some(c),
"round-trip mismatch for {c:?}"
);
}
}
#[test]
fn unknown_codec_id_is_none() {
for id in [3u8, 4, 5, 16, 200, 255] {
assert_eq!(
RerankCodec::from_codec_id(id),
None,
"unknown id {id} must not map to a codec"
);
}
}
#[test]
fn per_vector_bytes_matches_spec() {
assert_eq!(RerankCodec::Fp32.per_vector_bytes(384), 1536);
assert_eq!(RerankCodec::Sq8ResidualEpsilon.per_vector_bytes(384), 768);
assert_eq!(RerankCodec::RabitqOnly.per_vector_bytes(384), 0);
}
#[test]
fn writes_full_matches_per_vector_bytes() {
for c in [
RerankCodec::Fp32,
RerankCodec::Sq8ResidualEpsilon,
RerankCodec::RabitqOnly,
] {
assert_eq!(
c.writes_full(),
c.per_vector_bytes(384) > 0,
"writes_full disagrees with per_vector_bytes for {c:?}"
);
}
}
#[test]
fn all_codecs_implemented() {
assert!(RerankCodec::Fp32.is_implemented());
assert!(RerankCodec::Sq8ResidualEpsilon.is_implemented());
assert!(RerankCodec::RabitqOnly.is_implemented());
}
#[test]
fn recommended_rerank_mult_floor_matches_calibration_table() {
assert_eq!(
RerankCodec::Fp32.recommended_rerank_mult_floor(384),
Some(20)
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.recommended_rerank_mult_floor(384),
Some(50)
);
assert_eq!(
RerankCodec::RabitqOnly.recommended_rerank_mult_floor(384),
None
);
assert_eq!(
RerankCodec::Fp32.recommended_rerank_mult_floor(1024),
Some(50)
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.recommended_rerank_mult_floor(1024),
Some(100)
);
assert_eq!(
RerankCodec::RabitqOnly.recommended_rerank_mult_floor(1024),
None
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.recommended_rerank_mult_floor(385),
Some(100)
);
}
#[test]
fn display_renders_stable_name() {
assert_eq!(RerankCodec::Fp32.to_string(), "fp32");
assert_eq!(RerankCodec::Sq8ResidualEpsilon.to_string(), "sq8_residual");
assert_eq!(RerankCodec::RabitqOnly.to_string(), "rabitq_only");
for c in [
RerankCodec::Fp32,
RerankCodec::Sq8ResidualEpsilon,
RerankCodec::RabitqOnly,
] {
assert_eq!(c.to_string(), c.name());
}
}
#[test]
fn codec_meta_bytes_matches_layout_spec() {
for c in [RerankCodec::Fp32, RerankCodec::RabitqOnly] {
for m in [Metric::L2Sq, Metric::Cosine, Metric::NegDot] {
assert_eq!(
c.codec_meta_bytes(384, 1_000_000, 1024, m),
0,
"{c:?} / {m:?}"
);
}
}
let so_bytes = 2 * 1024 * 384 * 4;
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.codec_meta_bytes(384, 1_000_000, 1024, Metric::NegDot),
so_bytes
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.codec_meta_bytes(384, 1_000_000, 1024, Metric::Cosine),
so_bytes + 1_000_000 * 4
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.codec_meta_bytes(384, 1_000_000, 1024, Metric::L2Sq),
so_bytes + 1_000_000 * 4
);
assert_eq!(
RerankCodec::Sq8ResidualEpsilon.codec_meta_bytes(384, 1_000_000, 1024, Metric::NegDot),
so_bytes
);
}
}