1use open_kioku_errors::{OkError, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::{BTreeMap, HashSet};
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
8pub struct VectorId(pub u64);
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct VectorRecord {
12 pub id: VectorId,
13 pub target_id: String,
14 pub target_kind: String,
15 pub vector: Vec<f32>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct VectorHit {
20 pub id: VectorId,
21 pub target_id: String,
22 pub target_kind: String,
23 pub score: f32,
24}
25
26#[derive(Debug, Clone, Default)]
27pub struct VectorSearchOptions {
28 pub limit: usize,
29 pub allowlist: Option<HashSet<VectorId>>,
30 pub target_kind: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct VectorIndexStats {
35 pub backend: String,
36 pub dimensions: usize,
37 pub vector_count: usize,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ExactFlatVectorIndex {
42 dimensions: usize,
43 records: BTreeMap<VectorId, VectorRecord>,
44}
45
46impl ExactFlatVectorIndex {
47 pub fn new(dimensions: usize) -> Result<Self> {
48 if dimensions == 0 {
49 return Err(OkError::Unsupported(
50 "exact-flat vector index requires dimensions > 0".into(),
51 ));
52 }
53 Ok(Self {
54 dimensions,
55 records: BTreeMap::new(),
56 })
57 }
58
59 pub fn add(&mut self, record: VectorRecord) -> Result<()> {
60 if record.vector.len() != self.dimensions {
61 return Err(OkError::Storage(format!(
62 "vector {} has {} dimensions, expected {}",
63 record.id.0,
64 record.vector.len(),
65 self.dimensions
66 )));
67 }
68 if self
69 .records
70 .get(&record.id)
71 .is_some_and(|existing| existing.target_id != record.target_id)
72 {
73 return Err(OkError::Storage(format!(
74 "vector id collision for {}",
75 record.id.0
76 )));
77 }
78 self.records.insert(record.id, record);
79 Ok(())
80 }
81
82 pub fn remove(&mut self, id: VectorId) -> Option<VectorRecord> {
83 self.records.remove(&id)
84 }
85
86 pub fn search(&self, query: &[f32], options: VectorSearchOptions) -> Result<Vec<VectorHit>> {
87 if query.len() != self.dimensions {
88 return Err(OkError::Storage(format!(
89 "query vector has {} dimensions, expected {}",
90 query.len(),
91 self.dimensions
92 )));
93 }
94 let mut hits = Vec::new();
95 for record in self.records.values() {
96 if options
97 .allowlist
98 .as_ref()
99 .is_some_and(|allowlist| !allowlist.contains(&record.id))
100 {
101 continue;
102 }
103 if options
104 .target_kind
105 .as_ref()
106 .is_some_and(|kind| kind != &record.target_kind)
107 {
108 continue;
109 }
110 let score = dot(query, &record.vector);
111 if score <= 0.0 {
112 continue;
113 }
114 hits.push(VectorHit {
115 id: record.id,
116 target_id: record.target_id.clone(),
117 target_kind: record.target_kind.clone(),
118 score,
119 });
120 }
121 hits.sort_by(|left, right| {
122 right
123 .score
124 .partial_cmp(&left.score)
125 .unwrap_or(std::cmp::Ordering::Equal)
126 .then_with(|| left.target_id.cmp(&right.target_id))
127 });
128 hits.truncate(options.limit.max(1));
129 Ok(hits)
130 }
131
132 pub fn save(&self, path: &Path) -> Result<()> {
133 if let Some(parent) = path.parent() {
134 fs::create_dir_all(parent)?;
135 }
136 fs::write(path, serde_json::to_vec_pretty(self)?)?;
137 Ok(())
138 }
139
140 pub fn load(path: &Path) -> Result<Self> {
141 let raw = fs::read(path)?;
142 Ok(serde_json::from_slice(&raw)?)
143 }
144
145 pub fn stats(&self) -> VectorIndexStats {
146 VectorIndexStats {
147 backend: "exact-flat".into(),
148 dimensions: self.dimensions,
149 vector_count: self.records.len(),
150 }
151 }
152}
153
154fn dot(left: &[f32], right: &[f32]) -> f32 {
155 left.iter()
156 .zip(right)
157 .map(|(left, right)| left * right)
158 .sum()
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn exact_backend_searches_with_allowlist() {
167 let mut index = ExactFlatVectorIndex::new(2).unwrap();
168 index
169 .add(VectorRecord {
170 id: VectorId(1),
171 target_id: "a".into(),
172 target_kind: "chunk".into(),
173 vector: vec![1.0, 0.0],
174 })
175 .unwrap();
176 index
177 .add(VectorRecord {
178 id: VectorId(2),
179 target_id: "b".into(),
180 target_kind: "chunk".into(),
181 vector: vec![0.0, 1.0],
182 })
183 .unwrap();
184
185 let hits = index
186 .search(
187 &[1.0, 0.0],
188 VectorSearchOptions {
189 limit: 5,
190 allowlist: Some(HashSet::from([VectorId(1)])),
191 target_kind: None,
192 },
193 )
194 .unwrap();
195
196 assert_eq!(hits.len(), 1);
197 assert_eq!(hits[0].target_id, "a");
198
199 let removed = index.remove(VectorId(1)).unwrap();
200 assert_eq!(removed.target_id, "a");
201 let hits_after_remove = index
202 .search(
203 &[1.0, 0.0],
204 VectorSearchOptions {
205 limit: 5,
206 allowlist: None,
207 target_kind: None,
208 },
209 )
210 .unwrap();
211 assert!(hits_after_remove.is_empty());
212 }
213
214 #[test]
215 fn detects_vector_id_collision() {
216 let mut index = ExactFlatVectorIndex::new(1).unwrap();
217 index
218 .add(VectorRecord {
219 id: VectorId(1),
220 target_id: "a".into(),
221 target_kind: "chunk".into(),
222 vector: vec![1.0],
223 })
224 .unwrap();
225
226 let err = index
227 .add(VectorRecord {
228 id: VectorId(1),
229 target_id: "b".into(),
230 target_kind: "chunk".into(),
231 vector: vec![1.0],
232 })
233 .unwrap_err();
234 assert!(err.to_string().contains("collision"));
235 }
236}