1use std::collections::HashMap;
7use std::sync::RwLock;
8
9use dwbase_core::{AtomId, AtomKind, WorldKey};
10use dwbase_engine::{AtomFilter, DwbaseError, Result, VectorEngine};
11use hnsw_rs::prelude::*;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct VectorMetadata {
17 pub kind: Option<AtomKind>,
18 pub labels: Vec<String>,
19 pub flags: Vec<String>,
20}
21
22#[derive(Debug, Error)]
23enum VectorError {
24 #[error("dimension mismatch: expected {expected}, got {got}")]
25 DimensionMismatch { expected: usize, got: usize },
26}
27
28impl From<VectorError> for DwbaseError {
29 fn from(err: VectorError) -> Self {
30 DwbaseError::Vector(err.to_string())
31 }
32}
33
34struct WorldIndex {
35 hnsw: Hnsw<'static, f32, DistL2>,
36 dim: usize,
37 meta: HashMap<usize, (AtomId, VectorMetadata)>,
38 next_point: usize,
39}
40
41impl WorldIndex {
42 fn new(dim: usize) -> Self {
43 let max_nb_connection = 16;
45 let max_elements = 10_000;
46 let max_layers = 12;
47 let ef_c = 100;
48 let hnsw =
49 Hnsw::<f32, DistL2>::new(max_nb_connection, max_elements, max_layers, ef_c, DistL2 {});
50 Self {
51 hnsw,
52 dim,
53 meta: HashMap::new(),
54 next_point: 0,
55 }
56 }
57}
58
59pub struct HnswVectorEngine {
61 worlds: RwLock<HashMap<WorldKey, WorldIndex>>,
62}
63
64impl HnswVectorEngine {
65 pub fn new() -> Self {
66 Self {
67 worlds: RwLock::new(HashMap::new()),
68 }
69 }
70
71 fn metadata_matches(meta: &VectorMetadata, filter: &AtomFilter) -> bool {
72 if !filter.kinds.is_empty() {
73 if let Some(kind) = &meta.kind {
74 if !filter.kinds.contains(kind) {
75 return false;
76 }
77 } else {
78 return false;
79 }
80 }
81 if !filter.labels.is_empty() && !filter.labels.iter().all(|l| meta.labels.contains(l)) {
82 return false;
83 }
84 if !filter.flags.is_empty() && !filter.flags.iter().all(|f| meta.flags.contains(f)) {
85 return false;
86 }
87 true
88 }
89}
90
91impl Default for HnswVectorEngine {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl VectorEngine for HnswVectorEngine {
98 fn upsert(&self, world: &WorldKey, atom_id: &AtomId, vector: &[f32]) -> Result<()> {
99 self.upsert_with_metadata(world, atom_id, vector, VectorMetadata::default())
100 }
101
102 fn search(
103 &self,
104 world: &WorldKey,
105 query: &[f32],
106 k: usize,
107 filter: &AtomFilter,
108 ) -> Result<Vec<AtomId>> {
109 if let Some(filter_world) = &filter.world {
110 if filter_world != world {
111 return Ok(Vec::new());
112 }
113 }
114 let guard = self.worlds.read().expect("poisoned world index lock");
115 let world_idx = match guard.get(world) {
116 Some(idx) => idx,
117 None => return Ok(Vec::new()),
118 };
119 if world_idx.dim != query.len() {
120 return Err(VectorError::DimensionMismatch {
121 expected: world_idx.dim,
122 got: query.len(),
123 }
124 .into());
125 }
126
127 let ef_search = 200;
128 let results = world_idx.hnsw.search(query, k, ef_search);
129
130 let mut out = Vec::new();
131 for point in results {
132 if let Some((atom_id, meta)) = world_idx.meta.get(&point.d_id) {
133 if Self::metadata_matches(meta, filter) {
134 out.push(atom_id.clone());
135 }
136 }
137 }
138 if out.is_empty() {
140 for (_pid, (atom_id, meta)) in world_idx.meta.iter() {
141 if Self::metadata_matches(meta, filter) {
142 out.push(atom_id.clone());
143 if out.len() >= k {
144 break;
145 }
146 }
147 }
148 }
149 Ok(out)
150 }
151
152 fn rebuild(&self, _world: &WorldKey) -> Result<()> {
153 Ok(())
154 }
155}
156
157impl HnswVectorEngine {
158 pub fn upsert_with_metadata(
159 &self,
160 world: &WorldKey,
161 atom_id: &AtomId,
162 vector: &[f32],
163 metadata: VectorMetadata,
164 ) -> Result<()> {
165 let dim = vector.len();
166 let mut guard = self.worlds.write().expect("poisoned world index lock");
167 if let Some(existing) = guard.get(world) {
168 if existing.dim != dim {
169 return Err(VectorError::DimensionMismatch {
170 expected: existing.dim,
171 got: dim,
172 }
173 .into());
174 }
175 }
176 let world_idx = guard
177 .entry(world.clone())
178 .or_insert_with(|| WorldIndex::new(dim));
179
180 let point_id = world_idx.next_point;
181 world_idx.next_point += 1;
182 world_idx.hnsw.insert((vector, point_id));
183 world_idx.meta.insert(point_id, (atom_id.clone(), metadata));
184 Ok(())
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 fn test_engine() -> HnswVectorEngine {
193 HnswVectorEngine::new()
194 }
195
196 #[test]
197 fn insert_and_search_returns_nearest_ids() {
198 let engine = test_engine();
199 let world = WorldKey::new("w1");
200 engine
201 .upsert_with_metadata(
202 &world,
203 &AtomId::new("a1"),
204 &[0.0, 0.0],
205 VectorMetadata {
206 kind: Some(AtomKind::Observation),
207 labels: vec!["foo".into()],
208 flags: vec![],
209 },
210 )
211 .unwrap();
212 engine
213 .upsert_with_metadata(
214 &world,
215 &AtomId::new("a2"),
216 &[10.0, 10.0],
217 VectorMetadata {
218 kind: Some(AtomKind::Observation),
219 labels: vec!["bar".into()],
220 flags: vec![],
221 },
222 )
223 .unwrap();
224
225 let filter = AtomFilter::default();
226 let hits = engine.search(&world, &[0.1, 0.1], 1, &filter).unwrap();
227 assert_eq!(hits, vec![AtomId::new("a1")]);
228 }
229
230 #[test]
231 fn filter_by_labels_and_kinds() {
232 let engine = test_engine();
233 let world = WorldKey::new("w1");
234 engine
235 .upsert_with_metadata(
236 &world,
237 &AtomId::new("a1"),
238 &[1.0, 1.0],
239 VectorMetadata {
240 kind: Some(AtomKind::Observation),
241 labels: vec!["x".into()],
242 flags: vec![],
243 },
244 )
245 .unwrap();
246 engine
247 .upsert_with_metadata(
248 &world,
249 &AtomId::new("a2"),
250 &[1.1, 1.1],
251 VectorMetadata {
252 kind: Some(AtomKind::Reflection),
253 labels: vec!["y".into()],
254 flags: vec!["skip".into()],
255 },
256 )
257 .unwrap();
258
259 let filter = AtomFilter {
260 world: Some(world.clone()),
261 kinds: vec![AtomKind::Reflection],
262 labels: vec!["y".into()],
263 flags: vec!["skip".into()],
264 since: None,
265 until: None,
266 limit: Some(5),
267 };
268 let hits = engine.search(&world, &[1.0, 1.0], 2, &filter).unwrap();
269 assert_eq!(hits, vec![AtomId::new("a2")]);
270 }
271}