Skip to main content

nodedb_vector/collection/
codec_dispatch.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Per-collection codec selection. Wraps the generic `HnswCodecIndex<C>`
4//! for codecs other than Sq8 (which retains its specialised fast path
5//! in `quantize.rs` / `search.rs`).
6
7use nodedb_codec::vector_quant::bbq::BbqCodec;
8use nodedb_codec::vector_quant::rabitq::RaBitQCodec;
9
10use crate::codec_index::HnswCodecIndex;
11
12/// One built codec-index per collection (other than Sq8). Variants match
13/// the publicly-selectable quantization choices that route through
14/// `HnswCodecIndex`.
15#[non_exhaustive]
16pub enum CollectionCodec {
17    RaBitQ(HnswCodecIndex<RaBitQCodec>),
18    Bbq(HnswCodecIndex<BbqCodec>),
19}
20
21impl CollectionCodec {
22    /// Forwarding `search` so the collection layer doesn't have to match
23    /// on the variant for the common case.
24    pub fn search(
25        &self,
26        query: &[f32],
27        k: usize,
28        ef_search: usize,
29    ) -> Vec<crate::codec_index::CodecSearchResult> {
30        match self {
31            Self::RaBitQ(idx) => idx.search(query, k, ef_search),
32            Self::Bbq(idx) => idx.search(query, k, ef_search),
33        }
34    }
35
36    /// Forwarding `insert`.
37    pub fn insert(&mut self, id: u32, v: &[f32]) {
38        match self {
39            Self::RaBitQ(idx) => idx.insert(id, v),
40            Self::Bbq(idx) => idx.insert(id, v),
41        }
42    }
43
44    /// Total nodes (including deleted).
45    pub fn len(&self) -> usize {
46        match self {
47            Self::RaBitQ(idx) => idx.len(),
48            Self::Bbq(idx) => idx.len(),
49        }
50    }
51
52    pub fn is_empty(&self) -> bool {
53        self.len() == 0
54    }
55
56    /// Quantization tag for stats reporting.
57    pub fn quantization(&self) -> &'static str {
58        match self {
59            Self::RaBitQ(_) => "rabitq",
60            Self::Bbq(_) => "bbq",
61        }
62    }
63}
64
65/// Build a `CollectionCodec` from a quantization tag and training vectors.
66///
67/// Returns `None` for unsupported or unrecognised tags (e.g. "sq8", "pq",
68/// "none") — those variants use separate per-segment code paths.
69pub fn build_collection_codec(
70    quantization: &str,
71    vectors: &[Vec<f32>],
72    dim: usize,
73    m: usize,
74    ef_construction: usize,
75    seed: u64,
76) -> Option<CollectionCodec> {
77    if vectors.is_empty() {
78        return None;
79    }
80    let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
81    match quantization {
82        "rabitq" => {
83            let codec = RaBitQCodec::calibrate(&refs, dim, seed);
84            let mut idx = HnswCodecIndex::new(dim, m, ef_construction, codec, seed);
85            for (i, v) in vectors.iter().enumerate() {
86                idx.insert(i as u32, v);
87            }
88            Some(CollectionCodec::RaBitQ(idx))
89        }
90        "bbq" => {
91            let codec = BbqCodec::calibrate(&refs, dim, 3);
92            let mut idx = HnswCodecIndex::new(dim, m, ef_construction, codec, seed);
93            for (i, v) in vectors.iter().enumerate() {
94                idx.insert(i as u32, v);
95            }
96            Some(CollectionCodec::Bbq(idx))
97        }
98        _ => None,
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn make_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
107        (0..n)
108            .map(|i| (0..dim).map(|d| (i * dim + d) as f32 * 0.01).collect())
109            .collect()
110    }
111
112    #[test]
113    fn build_rabitq_returns_some() {
114        let vecs = make_vectors(50, 8);
115        let result = build_collection_codec("rabitq", &vecs, 8, 16, 100, 42);
116        assert!(
117            matches!(result, Some(CollectionCodec::RaBitQ(_))),
118            "expected RaBitQ variant"
119        );
120    }
121
122    #[test]
123    fn build_bbq_returns_some() {
124        let vecs = make_vectors(50, 8);
125        let result = build_collection_codec("bbq", &vecs, 8, 16, 100, 42);
126        assert!(
127            matches!(result, Some(CollectionCodec::Bbq(_))),
128            "expected Bbq variant"
129        );
130    }
131
132    #[test]
133    fn unknown_codec_returns_none() {
134        let vecs = make_vectors(50, 8);
135        let result = build_collection_codec("unknown_codec", &vecs, 8, 16, 100, 42);
136        assert!(result.is_none(), "unknown codec should return None");
137    }
138
139    #[test]
140    fn sq8_tag_returns_none() {
141        let vecs = make_vectors(50, 8);
142        let result = build_collection_codec("sq8", &vecs, 8, 16, 100, 42);
143        assert!(
144            result.is_none(),
145            "sq8 tag should fall through to per-segment path"
146        );
147    }
148
149    #[test]
150    fn empty_vectors_returns_none() {
151        let result = build_collection_codec("rabitq", &[], 8, 16, 100, 42);
152        assert!(result.is_none(), "empty vectors should return None");
153    }
154
155    #[test]
156    fn len_and_is_empty() {
157        let vecs = make_vectors(20, 4);
158        let codec = build_collection_codec("bbq", &vecs, 4, 8, 50, 1).unwrap();
159        assert_eq!(codec.len(), 20);
160        assert!(!codec.is_empty());
161    }
162
163    #[test]
164    fn quantization_tag() {
165        let vecs = make_vectors(10, 4);
166        let rabitq = build_collection_codec("rabitq", &vecs, 4, 8, 50, 1).unwrap();
167        assert_eq!(rabitq.quantization(), "rabitq");
168        let bbq = build_collection_codec("bbq", &vecs, 4, 8, 50, 1).unwrap();
169        assert_eq!(bbq.quantization(), "bbq");
170    }
171}