nodedb_vector/rerank/
codec.rs1use std::sync::Arc;
4
5use super::types::RerankError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CodecName {
11 Sq8,
12 Pq,
13 Binary,
14 RaBitQ,
15 Bbq,
16}
17
18impl CodecName {
19 pub fn as_str(self) -> &'static str {
20 match self {
21 CodecName::Sq8 => "sq8",
22 CodecName::Pq => "pq",
23 CodecName::Binary => "binary",
24 CodecName::RaBitQ => "rabitq",
25 CodecName::Bbq => "bbq",
26 }
27 }
28}
29
30pub enum PreparedQuery {
34 Raw(Vec<f32>),
37 Lut(Vec<Vec<f32>>),
39 Bytes(Vec<u8>),
42}
43
44pub trait RerankCodec: Send + Sync {
49 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError>;
51
52 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError>;
54
55 fn distance_prepared(
57 &self,
58 prepared: &PreparedQuery,
59 encoded: &[u8],
60 ) -> Result<f32, RerankError>;
61
62 fn name(&self) -> CodecName;
64
65 fn train(&mut self, _samples: &[&[f32]]) -> Result<(), RerankError> {
68 Ok(())
69 }
70
71 fn to_bytes(&self) -> Result<Vec<u8>, RerankError>;
75}
76
77pub fn rerank_codec_from_bytes(
80 name: CodecName,
81 bytes: &[u8],
82) -> Result<Arc<dyn RerankCodec>, RerankError> {
83 use crate::quantize::pq::PqCodec;
84 use crate::quantize::sq8::Sq8Codec;
85 use crate::rerank::codecs::{BbqRerank, BinaryRerank, PqRerank, RaBitQRerank, Sq8Rerank};
86 use nodedb_codec::vector_quant::bbq::BbqCodec;
87 use nodedb_codec::vector_quant::rabitq::RaBitQCodec;
88
89 match name {
90 CodecName::Sq8 => {
91 let inner = Sq8Codec::from_bytes(bytes)
92 .map_err(|e| RerankError::BadInput(format!("sq8 from_bytes: {e}")))?;
93 Ok(Arc::new(Sq8Rerank::from_codec(inner)))
94 }
95 CodecName::Binary => {
96 if bytes.len() < 11 {
98 return Err(RerankError::BadInput("binary from_bytes: too short".into()));
99 }
100 if &bytes[..6] != b"NDBIN\0" {
101 return Err(RerankError::BadInput("binary from_bytes: bad magic".into()));
102 }
103 let dim = u32::from_le_bytes([bytes[7], bytes[8], bytes[9], bytes[10]]) as usize;
105 Ok(Arc::new(BinaryRerank::new(dim)))
106 }
107 CodecName::Pq => {
108 let inner = PqCodec::from_bytes(bytes)
109 .map_err(|e| RerankError::BadInput(format!("pq from_bytes: {e}")))?;
110 Ok(Arc::new(PqRerank::from_codec(inner)))
111 }
112 CodecName::RaBitQ => {
113 let inner = RaBitQCodec::from_bytes(bytes)
114 .map_err(|e| RerankError::BadInput(format!("rabitq from_bytes: {e}")))?;
115 Ok(Arc::new(RaBitQRerank::from_codec(inner)))
116 }
117 CodecName::Bbq => {
118 let inner = BbqCodec::from_bytes(bytes)
119 .map_err(|e| RerankError::BadInput(format!("bbq from_bytes: {e}")))?;
120 Ok(Arc::new(BbqRerank::from_codec(inner)))
121 }
122 }
123}