Skip to main content

dwbase_vector_hnsw/
lib.rs

1//! HNSW-backed VectorEngine implementation for DWBase.
2//!
3//! Vectors are stored per world in an in-memory HNSW index. Persistence is out of scope for v1.
4//! The index assumes a consistent dimension per world (inferred from the first insert).
5
6use std::collections::HashMap;
7use std::sync::RwLock;
8
9use dwbase_core::{AtomId, AtomKind, WorldKey};
10use dwbase_engine::{AtomFilter, DwbaseError, Result, VectorEngine};
11use hnsw_rs::prelude::*;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct VectorMetadata {
17    pub kind: Option<AtomKind>,
18    pub labels: Vec<String>,
19    pub flags: Vec<String>,
20}
21
22#[derive(Debug, Error)]
23enum VectorError {
24    #[error("dimension mismatch: expected {expected}, got {got}")]
25    DimensionMismatch { expected: usize, got: usize },
26}
27
28impl From<VectorError> for DwbaseError {
29    fn from(err: VectorError) -> Self {
30        DwbaseError::Vector(err.to_string())
31    }
32}
33
34struct WorldIndex {
35    hnsw: Hnsw<'static, f32, DistL2>,
36    dim: usize,
37    meta: HashMap<usize, (AtomId, VectorMetadata)>,
38    next_point: usize,
39}
40
41impl WorldIndex {
42    fn new(dim: usize) -> Self {
43        // Parameters chosen for small in-memory indices; tweak as needed later.
44        let max_nb_connection = 16;
45        let max_elements = 10_000;
46        let max_layers = 12;
47        let ef_c = 100;
48        let hnsw =
49            Hnsw::<f32, DistL2>::new(max_nb_connection, max_elements, max_layers, ef_c, DistL2 {});
50        Self {
51            hnsw,
52            dim,
53            meta: HashMap::new(),
54            next_point: 0,
55        }
56    }
57}
58
59/// HNSW-based vector search engine.
60pub struct HnswVectorEngine {
61    worlds: RwLock<HashMap<WorldKey, WorldIndex>>,
62}
63
64impl HnswVectorEngine {
65    pub fn new() -> Self {
66        Self {
67            worlds: RwLock::new(HashMap::new()),
68        }
69    }
70
71    fn metadata_matches(meta: &VectorMetadata, filter: &AtomFilter) -> bool {
72        if !filter.kinds.is_empty() {
73            if let Some(kind) = &meta.kind {
74                if !filter.kinds.contains(kind) {
75                    return false;
76                }
77            } else {
78                return false;
79            }
80        }
81        if !filter.labels.is_empty() && !filter.labels.iter().all(|l| meta.labels.contains(l)) {
82            return false;
83        }
84        if !filter.flags.is_empty() && !filter.flags.iter().all(|f| meta.flags.contains(f)) {
85            return false;
86        }
87        true
88    }
89}
90
91impl Default for HnswVectorEngine {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl VectorEngine for HnswVectorEngine {
98    fn upsert(&self, world: &WorldKey, atom_id: &AtomId, vector: &[f32]) -> Result<()> {
99        self.upsert_with_metadata(world, atom_id, vector, VectorMetadata::default())
100    }
101
102    fn search(
103        &self,
104        world: &WorldKey,
105        query: &[f32],
106        k: usize,
107        filter: &AtomFilter,
108    ) -> Result<Vec<AtomId>> {
109        if let Some(filter_world) = &filter.world {
110            if filter_world != world {
111                return Ok(Vec::new());
112            }
113        }
114        let guard = self.worlds.read().expect("poisoned world index lock");
115        let world_idx = match guard.get(world) {
116            Some(idx) => idx,
117            None => return Ok(Vec::new()),
118        };
119        if world_idx.dim != query.len() {
120            return Err(VectorError::DimensionMismatch {
121                expected: world_idx.dim,
122                got: query.len(),
123            }
124            .into());
125        }
126
127        let ef_search = 200;
128        let results = world_idx.hnsw.search(query, k, ef_search);
129
130        let mut out = Vec::new();
131        for point in results {
132            if let Some((atom_id, meta)) = world_idx.meta.get(&point.d_id) {
133                if Self::metadata_matches(meta, filter) {
134                    out.push(atom_id.clone());
135                }
136            }
137        }
138        // Fallback: if filtered results are empty, return the first matching items regardless of ANN score.
139        if out.is_empty() {
140            for (_pid, (atom_id, meta)) in world_idx.meta.iter() {
141                if Self::metadata_matches(meta, filter) {
142                    out.push(atom_id.clone());
143                    if out.len() >= k {
144                        break;
145                    }
146                }
147            }
148        }
149        Ok(out)
150    }
151
152    fn rebuild(&self, _world: &WorldKey) -> Result<()> {
153        Ok(())
154    }
155}
156
157impl HnswVectorEngine {
158    pub fn upsert_with_metadata(
159        &self,
160        world: &WorldKey,
161        atom_id: &AtomId,
162        vector: &[f32],
163        metadata: VectorMetadata,
164    ) -> Result<()> {
165        let dim = vector.len();
166        let mut guard = self.worlds.write().expect("poisoned world index lock");
167        if let Some(existing) = guard.get(world) {
168            if existing.dim != dim {
169                return Err(VectorError::DimensionMismatch {
170                    expected: existing.dim,
171                    got: dim,
172                }
173                .into());
174            }
175        }
176        let world_idx = guard
177            .entry(world.clone())
178            .or_insert_with(|| WorldIndex::new(dim));
179
180        let point_id = world_idx.next_point;
181        world_idx.next_point += 1;
182        world_idx.hnsw.insert((vector, point_id));
183        world_idx.meta.insert(point_id, (atom_id.clone(), metadata));
184        Ok(())
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    fn test_engine() -> HnswVectorEngine {
193        HnswVectorEngine::new()
194    }
195
196    #[test]
197    fn insert_and_search_returns_nearest_ids() {
198        let engine = test_engine();
199        let world = WorldKey::new("w1");
200        engine
201            .upsert_with_metadata(
202                &world,
203                &AtomId::new("a1"),
204                &[0.0, 0.0],
205                VectorMetadata {
206                    kind: Some(AtomKind::Observation),
207                    labels: vec!["foo".into()],
208                    flags: vec![],
209                },
210            )
211            .unwrap();
212        engine
213            .upsert_with_metadata(
214                &world,
215                &AtomId::new("a2"),
216                &[10.0, 10.0],
217                VectorMetadata {
218                    kind: Some(AtomKind::Observation),
219                    labels: vec!["bar".into()],
220                    flags: vec![],
221                },
222            )
223            .unwrap();
224
225        let filter = AtomFilter::default();
226        let hits = engine.search(&world, &[0.1, 0.1], 1, &filter).unwrap();
227        assert_eq!(hits, vec![AtomId::new("a1")]);
228    }
229
230    #[test]
231    fn filter_by_labels_and_kinds() {
232        let engine = test_engine();
233        let world = WorldKey::new("w1");
234        engine
235            .upsert_with_metadata(
236                &world,
237                &AtomId::new("a1"),
238                &[1.0, 1.0],
239                VectorMetadata {
240                    kind: Some(AtomKind::Observation),
241                    labels: vec!["x".into()],
242                    flags: vec![],
243                },
244            )
245            .unwrap();
246        engine
247            .upsert_with_metadata(
248                &world,
249                &AtomId::new("a2"),
250                &[1.1, 1.1],
251                VectorMetadata {
252                    kind: Some(AtomKind::Reflection),
253                    labels: vec!["y".into()],
254                    flags: vec!["skip".into()],
255                },
256            )
257            .unwrap();
258
259        let filter = AtomFilter {
260            world: Some(world.clone()),
261            kinds: vec![AtomKind::Reflection],
262            labels: vec!["y".into()],
263            flags: vec!["skip".into()],
264            since: None,
265            until: None,
266            limit: Some(5),
267        };
268        let hits = engine.search(&world, &[1.0, 1.0], 2, &filter).unwrap();
269        assert_eq!(hits, vec![AtomId::new("a2")]);
270    }
271}