1use std::sync::Arc;
2
3use crate::bundle::Bundle;
4use crate::{bm25, hybrid};
5use rayon::prelude::*;
6use std::cmp::Ordering;
7
8pub struct VectorStore {
9 bundle: Arc<Bundle>,
10}
11
12impl VectorStore {
13 pub fn open<P: AsRef<std::path::Path>>(path: P) -> crate::errors::Result<Self> {
14 let b = Bundle::open(path)?;
15 Ok(Self {
16 bundle: Arc::new(b),
17 })
18 }
19
20 pub fn from_bundle(bundle: Bundle) -> Self {
21 Self {
22 bundle: Arc::new(bundle),
23 }
24 }
25
26 pub fn bundle(&self) -> Arc<Bundle> {
28 Arc::clone(&self.bundle)
29 }
30
31 pub fn size(&self) -> usize {
32 self.bundle.manifest.num_docs as usize
33 }
34 pub fn dimensions(&self) -> usize {
35 self.bundle.manifest.dim as usize
36 }
37
38 pub fn get_document(&self, doc_id: u32) -> Option<(String, String, String)> {
39 self.bundle.get_document(doc_id)
40 }
41
42 pub fn search_vector(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
43 if k == 0 || query.len() != self.dimensions() {
44 return Vec::new();
45 }
46 let dtype = self.bundle.manifest.embedding.dtype.to_lowercase();
47 if dtype == "f16" {
48 self.search_vector_f16(query, k)
49 } else {
50 let store_f32 = self.bundle.vectors_as_f32();
51 let mut res = crate::search::search_parallel(
52 query,
53 self.size(),
54 self.dimensions(),
55 self.bundle.row_stride_f32(),
56 store_f32,
57 k,
58 );
59 res.sort_by(|a, b| {
60 b.1.partial_cmp(&a.1)
61 .unwrap_or(Ordering::Equal)
62 .then_with(|| a.0.cmp(&b.0))
63 });
64 res
65 }
66 }
67
68 fn search_vector_f16(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
69 use half::f16;
70 let n = self.size();
71 let dim = self.dimensions();
72 let stride = self.bundle.row_stride_bytes();
73 let base = self.bundle.vectors_raw();
74 #[derive(Default)]
76 struct Tk {
77 k: usize,
78 heap: std::collections::BinaryHeap<
79 std::cmp::Reverse<(ordered_float::OrderedFloat<f32>, u32)>,
80 >,
81 }
82 impl Tk {
83 fn new(k: usize) -> Self {
84 Self {
85 k,
86 heap: Default::default(),
87 }
88 }
89 fn push(&mut self, s: f32, id: u32) {
90 let item = std::cmp::Reverse((ordered_float::OrderedFloat(s), id));
91 if self.heap.len() < self.k {
92 self.heap.push(item);
93 } else if let Some(mut top) = self.heap.peek_mut() {
94 if item.0 .0 > top.0 .0 {
95 *top = item;
96 }
97 }
98 }
99 fn merge(mut self, other: Self) -> Self {
100 for it in other.heap.into_iter() {
101 self.push(it.0 .0 .0, it.0 .1);
102 }
103 self
104 }
105 }
106 let (heap, _) = (0..n as u32)
107 .into_par_iter()
108 .with_min_len(1024)
109 .fold(
110 || (Tk::new(k), vec![0f32; dim]),
111 |(mut tk, mut buf), id| {
112 let start = (id as usize) * stride;
113 let row = &base[start..start + dim * 2];
114 for j in 0..dim {
115 let lo = row[j * 2] as u16;
116 let hi = row[j * 2 + 1] as u16;
117 let bits = lo | (hi << 8);
118 buf[j] = f16::from_bits(bits).to_f32();
119 }
120 let s = crate::simd::dot(query, &buf);
121 tk.push(s, id);
122 (tk, buf)
123 },
124 )
125 .reduce(|| (Tk::new(k), vec![]), |a, b| (a.0.merge(b.0), vec![]));
126 let mut out: Vec<(u32, f32)> = heap
127 .heap
128 .into_sorted_vec()
129 .into_iter()
130 .map(|r| (r.0 .1, r.0 .0 .0))
131 .collect();
132 out.sort_by(|a, b| {
134 b.1.partial_cmp(&a.1)
135 .unwrap_or(Ordering::Equal)
136 .then_with(|| a.0.cmp(&b.0))
137 });
138 out
139 }
140
141 pub fn search_bm25(&self, query: &str, k: usize) -> Vec<(u32, f32)> {
142 bm25::search(&self.bundle, query, k)
143 }
144
145 pub fn search_hybrid(
146 &self,
147 query_vec: &[f32],
148 query_text: &str,
149 k: usize,
150 mut vector_weight: f32,
151 ) -> Vec<(u32, f32)> {
152 if k == 0 || query_vec.len() != self.dimensions() {
153 return Vec::new();
154 }
155 if vector_weight.is_nan() {
156 vector_weight = 0.5;
157 }
158 if vector_weight < 0.0 {
159 vector_weight = 0.0;
160 }
161 if vector_weight > 1.0 {
162 vector_weight = 1.0;
163 }
164 let kk = std::cmp::min(k * 2, self.size());
165 let vres = self.search_vector(query_vec, kk);
166 let bres = self.search_bm25(query_text, kk);
167 hybrid::fuse_rrf(&vres, &bres, k, vector_weight)
168 }
169}