Skip to main content

nvs_core/
vector_store.rs

1use std::sync::Arc;
2
3use crate::bundle::Bundle;
4use crate::{bm25, hybrid};
5use rayon::prelude::*;
6use std::cmp::Ordering;
7
8pub struct VectorStore {
9    bundle: Arc<Bundle>,
10}
11
12impl VectorStore {
13    pub fn open<P: AsRef<std::path::Path>>(path: P) -> crate::errors::Result<Self> {
14        let b = Bundle::open(path)?;
15        Ok(Self {
16            bundle: Arc::new(b),
17        })
18    }
19
20    pub fn from_bundle(bundle: Bundle) -> Self {
21        Self {
22            bundle: Arc::new(bundle),
23        }
24    }
25
26    /// Returns a shared reference-counted handle to the underlying bundle.
27    pub fn bundle(&self) -> Arc<Bundle> {
28        Arc::clone(&self.bundle)
29    }
30
31    pub fn size(&self) -> usize {
32        self.bundle.manifest.num_docs as usize
33    }
34    pub fn dimensions(&self) -> usize {
35        self.bundle.manifest.dim as usize
36    }
37
38    pub fn get_document(&self, doc_id: u32) -> Option<(String, String, String)> {
39        self.bundle.get_document(doc_id)
40    }
41
42    pub fn search_vector(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
43        if k == 0 || query.len() != self.dimensions() {
44            return Vec::new();
45        }
46        let dtype = self.bundle.manifest.embedding.dtype.to_lowercase();
47        if dtype == "f16" {
48            self.search_vector_f16(query, k)
49        } else {
50            let store_f32 = self.bundle.vectors_as_f32();
51            let mut res = crate::search::search_parallel(
52                query,
53                self.size(),
54                self.dimensions(),
55                self.bundle.row_stride_f32(),
56                store_f32,
57                k,
58            );
59            res.sort_by(|a, b| {
60                b.1.partial_cmp(&a.1)
61                    .unwrap_or(Ordering::Equal)
62                    .then_with(|| a.0.cmp(&b.0))
63            });
64            res
65        }
66    }
67
68    fn search_vector_f16(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
69        use half::f16;
70        let n = self.size();
71        let dim = self.dimensions();
72        let stride = self.bundle.row_stride_bytes();
73        let base = self.bundle.vectors_raw();
74        // Per-thread heap and buffer
75        #[derive(Default)]
76        struct Tk {
77            k: usize,
78            heap: std::collections::BinaryHeap<
79                std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>,
80            >,
81        }
82        impl Tk {
83            fn new(k: usize) -> Self {
84                Self {
85                    k,
86                    heap: Default::default(),
87                }
88            }
89            fn push(&mut self, s: f32, id: u32) {
90                let item = std::cmp::Reverse((ordered_float::OrderedFloat(s), id));
91                if self.heap.len() < self.k {
92                    self.heap.push(item);
93                } else if let Some(mut top) = self.heap.peek_mut() {
94                    if item.0 .0 > top.0 .0 {
95                        *top = item;
96                    }
97                }
98            }
99            fn merge(mut self, other: Self) -> Self {
100                for it in other.heap.into_iter() {
101                    self.push(it.0 .0 .0, it.0 .1);
102                }
103                self
104            }
105        }
106        let (heap, _) = (0..n as u32)
107            .into_par_iter()
108            .with_min_len(1024)
109            .fold(
110                || (Tk::new(k), vec![0f32; dim]),
111                |(mut tk, mut buf), id| {
112                    let start = (id as usize) * stride;
113                    let row = &base[start..start + dim * 2];
114                    for j in 0..dim {
115                        let lo = row[j * 2] as u16;
116                        let hi = row[j * 2 + 1] as u16;
117                        let bits = lo | (hi << 8);
118                        buf[j] = f16::from_bits(bits).to_f32();
119                    }
120                    let s = crate::simd::dot(query, &buf);
121                    tk.push(s, id);
122                    (tk, buf)
123                },
124            )
125            .reduce(|| (Tk::new(k), vec![]), |a, b| (a.0.merge(b.0), vec![]));
126        let mut out: Vec<(u32, f32)> = heap
127            .heap
128            .into_sorted_vec()
129            .into_iter()
130            .map(|r| (r.0 .1, r.0 .0 .0))
131            .collect();
132        // Ensure deterministic order: score desc, id asc
133        out.sort_by(|a, b| {
134            b.1.partial_cmp(&a.1)
135                .unwrap_or(Ordering::Equal)
136                .then_with(|| a.0.cmp(&b.0))
137        });
138        out
139    }
140
141    pub fn search_bm25(&self, query: &str, k: usize) -> Vec<(u32, f32)> {
142        bm25::search(&self.bundle, query, k)
143    }
144
145    pub fn search_hybrid(
146        &self,
147        query_vec: &[f32],
148        query_text: &str,
149        k: usize,
150        mut vector_weight: f32,
151    ) -> Vec<(u32, f32)> {
152        if k == 0 || query_vec.len() != self.dimensions() {
153            return Vec::new();
154        }
155        if vector_weight.is_nan() {
156            vector_weight = 0.5;
157        }
158        if vector_weight < 0.0 {
159            vector_weight = 0.0;
160        }
161        if vector_weight > 1.0 {
162            vector_weight = 1.0;
163        }
164        let kk = std::cmp::min(k * 2, self.size());
165        let vres = self.search_vector(query_vec, kk);
166        let bres = self.search_bm25(query_text, kk);
167        hybrid::fuse_rrf(&vres, &bres, k, vector_weight)
168    }
169}