Skip to main content

open_kioku_vector/
lib.rs

1use open_kioku_errors::{OkError, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::{BTreeMap, HashSet};
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
8pub struct VectorId(pub u64);
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct VectorRecord {
12    pub id: VectorId,
13    pub target_id: String,
14    pub target_kind: String,
15    pub vector: Vec<f32>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct VectorHit {
20    pub id: VectorId,
21    pub target_id: String,
22    pub target_kind: String,
23    pub score: f32,
24}
25
26#[derive(Debug, Clone, Default)]
27pub struct VectorSearchOptions {
28    pub limit: usize,
29    pub allowlist: Option<HashSet<VectorId>>,
30    pub target_kind: Option<String>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct VectorIndexStats {
35    pub backend: String,
36    pub dimensions: usize,
37    pub vector_count: usize,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ExactFlatVectorIndex {
42    dimensions: usize,
43    records: BTreeMap<VectorId, VectorRecord>,
44}
45
46impl ExactFlatVectorIndex {
47    pub fn new(dimensions: usize) -> Result<Self> {
48        if dimensions == 0 {
49            return Err(OkError::Unsupported(
50                "exact-flat vector index requires dimensions > 0".into(),
51            ));
52        }
53        Ok(Self {
54            dimensions,
55            records: BTreeMap::new(),
56        })
57    }
58
59    pub fn add(&mut self, record: VectorRecord) -> Result<()> {
60        if record.vector.len() != self.dimensions {
61            return Err(OkError::Storage(format!(
62                "vector {} has {} dimensions, expected {}",
63                record.id.0,
64                record.vector.len(),
65                self.dimensions
66            )));
67        }
68        if self
69            .records
70            .get(&record.id)
71            .is_some_and(|existing| existing.target_id != record.target_id)
72        {
73            return Err(OkError::Storage(format!(
74                "vector id collision for {}",
75                record.id.0
76            )));
77        }
78        self.records.insert(record.id, record);
79        Ok(())
80    }
81
82    pub fn remove(&mut self, id: VectorId) -> Option<VectorRecord> {
83        self.records.remove(&id)
84    }
85
86    pub fn search(&self, query: &[f32], options: VectorSearchOptions) -> Result<Vec<VectorHit>> {
87        if query.len() != self.dimensions {
88            return Err(OkError::Storage(format!(
89                "query vector has {} dimensions, expected {}",
90                query.len(),
91                self.dimensions
92            )));
93        }
94        let mut hits = Vec::new();
95        for record in self.records.values() {
96            if options
97                .allowlist
98                .as_ref()
99                .is_some_and(|allowlist| !allowlist.contains(&record.id))
100            {
101                continue;
102            }
103            if options
104                .target_kind
105                .as_ref()
106                .is_some_and(|kind| kind != &record.target_kind)
107            {
108                continue;
109            }
110            let score = dot(query, &record.vector);
111            if score <= 0.0 {
112                continue;
113            }
114            hits.push(VectorHit {
115                id: record.id,
116                target_id: record.target_id.clone(),
117                target_kind: record.target_kind.clone(),
118                score,
119            });
120        }
121        hits.sort_by(|left, right| {
122            right
123                .score
124                .partial_cmp(&left.score)
125                .unwrap_or(std::cmp::Ordering::Equal)
126                .then_with(|| left.target_id.cmp(&right.target_id))
127        });
128        hits.truncate(options.limit.max(1));
129        Ok(hits)
130    }
131
132    pub fn save(&self, path: &Path) -> Result<()> {
133        if let Some(parent) = path.parent() {
134            fs::create_dir_all(parent)?;
135        }
136        fs::write(path, serde_json::to_vec_pretty(self)?)?;
137        Ok(())
138    }
139
140    pub fn load(path: &Path) -> Result<Self> {
141        let raw = fs::read(path)?;
142        Ok(serde_json::from_slice(&raw)?)
143    }
144
145    pub fn stats(&self) -> VectorIndexStats {
146        VectorIndexStats {
147            backend: "exact-flat".into(),
148            dimensions: self.dimensions,
149            vector_count: self.records.len(),
150        }
151    }
152}
153
154fn dot(left: &[f32], right: &[f32]) -> f32 {
155    left.iter()
156        .zip(right)
157        .map(|(left, right)| left * right)
158        .sum()
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn exact_backend_searches_with_allowlist() {
167        let mut index = ExactFlatVectorIndex::new(2).unwrap();
168        index
169            .add(VectorRecord {
170                id: VectorId(1),
171                target_id: "a".into(),
172                target_kind: "chunk".into(),
173                vector: vec![1.0, 0.0],
174            })
175            .unwrap();
176        index
177            .add(VectorRecord {
178                id: VectorId(2),
179                target_id: "b".into(),
180                target_kind: "chunk".into(),
181                vector: vec![0.0, 1.0],
182            })
183            .unwrap();
184
185        let hits = index
186            .search(
187                &[1.0, 0.0],
188                VectorSearchOptions {
189                    limit: 5,
190                    allowlist: Some(HashSet::from([VectorId(1)])),
191                    target_kind: None,
192                },
193            )
194            .unwrap();
195
196        assert_eq!(hits.len(), 1);
197        assert_eq!(hits[0].target_id, "a");
198
199        let removed = index.remove(VectorId(1)).unwrap();
200        assert_eq!(removed.target_id, "a");
201        let hits_after_remove = index
202            .search(
203                &[1.0, 0.0],
204                VectorSearchOptions {
205                    limit: 5,
206                    allowlist: None,
207                    target_kind: None,
208                },
209            )
210            .unwrap();
211        assert!(hits_after_remove.is_empty());
212    }
213
214    #[test]
215    fn detects_vector_id_collision() {
216        let mut index = ExactFlatVectorIndex::new(1).unwrap();
217        index
218            .add(VectorRecord {
219                id: VectorId(1),
220                target_id: "a".into(),
221                target_kind: "chunk".into(),
222                vector: vec![1.0],
223            })
224            .unwrap();
225
226        let err = index
227            .add(VectorRecord {
228                id: VectorId(1),
229                target_id: "b".into(),
230                target_kind: "chunk".into(),
231                vector: vec![1.0],
232            })
233            .unwrap_err();
234        assert!(err.to_string().contains("collision"));
235    }
236}