use std::sync::Arc;
use super::types::RerankError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodecName {
Sq8,
Pq,
Binary,
RaBitQ,
Bbq,
}
impl CodecName {
pub fn as_str(self) -> &'static str {
match self {
CodecName::Sq8 => "sq8",
CodecName::Pq => "pq",
CodecName::Binary => "binary",
CodecName::RaBitQ => "rabitq",
CodecName::Bbq => "bbq",
}
}
}
pub enum PreparedQuery {
Raw(Vec<f32>),
Lut(Vec<Vec<f32>>),
Bytes(Vec<u8>),
}
pub trait RerankCodec: Send + Sync {
fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError>;
fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError>;
fn distance_prepared(
&self,
prepared: &PreparedQuery,
encoded: &[u8],
) -> Result<f32, RerankError>;
fn name(&self) -> CodecName;
fn train(&mut self, _samples: &[&[f32]]) -> Result<(), RerankError> {
Ok(())
}
fn to_bytes(&self) -> Result<Vec<u8>, RerankError>;
}
pub fn rerank_codec_from_bytes(
name: CodecName,
bytes: &[u8],
) -> Result<Arc<dyn RerankCodec>, RerankError> {
use crate::quantize::pq::PqCodec;
use crate::quantize::sq8::Sq8Codec;
use crate::rerank::codecs::{BbqRerank, BinaryRerank, PqRerank, RaBitQRerank, Sq8Rerank};
use nodedb_codec::vector_quant::bbq::BbqCodec;
use nodedb_codec::vector_quant::rabitq::RaBitQCodec;
match name {
CodecName::Sq8 => {
let inner = Sq8Codec::from_bytes(bytes)
.map_err(|e| RerankError::BadInput(format!("sq8 from_bytes: {e}")))?;
Ok(Arc::new(Sq8Rerank::from_codec(inner)))
}
CodecName::Binary => {
if bytes.len() < 11 {
return Err(RerankError::BadInput("binary from_bytes: too short".into()));
}
if &bytes[..6] != b"NDBIN\0" {
return Err(RerankError::BadInput("binary from_bytes: bad magic".into()));
}
let dim = u32::from_le_bytes([bytes[7], bytes[8], bytes[9], bytes[10]]) as usize;
Ok(Arc::new(BinaryRerank::new(dim)))
}
CodecName::Pq => {
let inner = PqCodec::from_bytes(bytes)
.map_err(|e| RerankError::BadInput(format!("pq from_bytes: {e}")))?;
Ok(Arc::new(PqRerank::from_codec(inner)))
}
CodecName::RaBitQ => {
let inner = RaBitQCodec::from_bytes(bytes)
.map_err(|e| RerankError::BadInput(format!("rabitq from_bytes: {e}")))?;
Ok(Arc::new(RaBitQRerank::from_codec(inner)))
}
CodecName::Bbq => {
let inner = BbqCodec::from_bytes(bytes)
.map_err(|e| RerankError::BadInput(format!("bbq from_bytes: {e}")))?;
Ok(Arc::new(BbqRerank::from_codec(inner)))
}
}
}