1use std::collections::HashMap;
12use std::sync::RwLock;
13
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
16pub struct VectorKey {
17 pub collection: String,
18 pub vector_id: u64,
19}
20
21impl VectorKey {
22 pub fn new(collection: impl Into<String>, vector_id: u64) -> Self {
23 Self {
24 collection: collection.into(),
25 vector_id,
26 }
27 }
28}
29
30#[derive(Debug, Clone, Hash, PartialEq, Eq)]
32pub struct RowKey {
33 pub table: String,
34 pub row_id: u64,
35}
36
37impl RowKey {
38 pub fn new(table: impl Into<String>, row_id: u64) -> Self {
39 Self {
40 table: table.into(),
41 row_id,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Hash, PartialEq, Eq)]
48pub enum StorageRef {
49 Node(String),
51 Edge(String),
53 Vector(VectorKey),
55 Row(RowKey),
57}
58
59impl StorageRef {
60 pub fn node(id: impl Into<String>) -> Self {
61 StorageRef::Node(id.into())
62 }
63
64 pub fn edge(id: impl Into<String>) -> Self {
65 StorageRef::Edge(id.into())
66 }
67
68 pub fn vector(collection: impl Into<String>, vector_id: u64) -> Self {
69 StorageRef::Vector(VectorKey::new(collection, vector_id))
70 }
71
72 pub fn row(table: impl Into<String>, row_id: u64) -> Self {
73 StorageRef::Row(RowKey::new(table, row_id))
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct CrossRef {
80 pub source: StorageRef,
81 pub target: StorageRef,
82 pub metadata: Option<HashMap<String, String>>,
84}
85
86impl CrossRef {
87 pub fn new(source: StorageRef, target: StorageRef) -> Self {
88 Self {
89 source,
90 target,
91 metadata: None,
92 }
93 }
94
95 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
96 self.metadata
97 .get_or_insert_with(HashMap::new)
98 .insert(key.to_string(), value.to_string());
99 self
100 }
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct UnifiedIndexStats {
106 pub node_to_vector_count: usize,
107 pub node_to_row_count: usize,
108 pub vector_to_row_count: usize,
109 pub total_refs: usize,
110}
111
112pub struct UnifiedIndex {
119 node_to_vectors: RwLock<HashMap<String, Vec<VectorKey>>>,
121 vector_to_node: RwLock<HashMap<VectorKey, String>>,
123
124 node_to_rows: RwLock<HashMap<String, Vec<RowKey>>>,
126 row_to_node: RwLock<HashMap<RowKey, String>>,
128
129 edge_to_nodes: RwLock<HashMap<String, (String, String)>>,
131
132 vector_to_row: RwLock<HashMap<VectorKey, RowKey>>,
134 row_to_vectors: RwLock<HashMap<RowKey, Vec<VectorKey>>>,
136}
137
138impl UnifiedIndex {
139 pub fn new() -> Self {
141 Self {
142 node_to_vectors: RwLock::new(HashMap::new()),
143 vector_to_node: RwLock::new(HashMap::new()),
144 node_to_rows: RwLock::new(HashMap::new()),
145 row_to_node: RwLock::new(HashMap::new()),
146 edge_to_nodes: RwLock::new(HashMap::new()),
147 vector_to_row: RwLock::new(HashMap::new()),
148 row_to_vectors: RwLock::new(HashMap::new()),
149 }
150 }
151
152 pub fn link_node_to_vector(&self, node_id: &str, collection: &str, vector_id: u64) {
158 let key = VectorKey::new(collection, vector_id);
159
160 if let Ok(mut map) = self.node_to_vectors.write() {
162 map.entry(node_id.to_string())
163 .or_insert_with(Vec::new)
164 .push(key.clone());
165 }
166
167 if let Ok(mut map) = self.vector_to_node.write() {
169 map.insert(key, node_id.to_string());
170 }
171 }
172
173 pub fn get_node_vectors(&self, node_id: &str) -> Vec<VectorKey> {
175 self.node_to_vectors
176 .read()
177 .ok()
178 .and_then(|map| map.get(node_id).cloned())
179 .unwrap_or_default()
180 }
181
182 pub fn get_vector_node(&self, collection: &str, vector_id: u64) -> Option<String> {
184 let key = VectorKey::new(collection, vector_id);
185 self.vector_to_node
186 .read()
187 .ok()
188 .and_then(|map| map.get(&key).cloned())
189 }
190
191 pub fn unlink_node_from_vector(&self, node_id: &str, collection: &str, vector_id: u64) {
193 let key = VectorKey::new(collection, vector_id);
194
195 if let Ok(mut map) = self.node_to_vectors.write() {
196 if let Some(vectors) = map.get_mut(node_id) {
197 vectors.retain(|v| v != &key);
198 if vectors.is_empty() {
199 map.remove(node_id);
200 }
201 }
202 }
203
204 if let Ok(mut map) = self.vector_to_node.write() {
205 map.remove(&key);
206 }
207 }
208
209 pub fn link_node_to_row(&self, node_id: &str, table: &str, row_id: u64) {
215 let key = RowKey::new(table, row_id);
216
217 if let Ok(mut map) = self.node_to_rows.write() {
219 map.entry(node_id.to_string())
220 .or_insert_with(Vec::new)
221 .push(key.clone());
222 }
223
224 if let Ok(mut map) = self.row_to_node.write() {
226 map.insert(key, node_id.to_string());
227 }
228 }
229
230 pub fn get_node_rows(&self, node_id: &str) -> Vec<RowKey> {
232 self.node_to_rows
233 .read()
234 .ok()
235 .and_then(|map| map.get(node_id).cloned())
236 .unwrap_or_default()
237 }
238
239 pub fn get_row_node(&self, table: &str, row_id: u64) -> Option<String> {
241 let key = RowKey::new(table, row_id);
242 self.row_to_node
243 .read()
244 .ok()
245 .and_then(|map| map.get(&key).cloned())
246 }
247
248 pub fn unlink_node_from_row(&self, node_id: &str, table: &str, row_id: u64) {
250 let key = RowKey::new(table, row_id);
251
252 if let Ok(mut map) = self.node_to_rows.write() {
253 if let Some(rows) = map.get_mut(node_id) {
254 rows.retain(|r| r != &key);
255 if rows.is_empty() {
256 map.remove(node_id);
257 }
258 }
259 }
260
261 if let Ok(mut map) = self.row_to_node.write() {
262 map.remove(&key);
263 }
264 }
265
266 pub fn register_edge(&self, edge_id: &str, source_node: &str, target_node: &str) {
272 if let Ok(mut map) = self.edge_to_nodes.write() {
273 map.insert(
274 edge_id.to_string(),
275 (source_node.to_string(), target_node.to_string()),
276 );
277 }
278 }
279
280 pub fn get_edge_nodes(&self, edge_id: &str) -> Option<(String, String)> {
282 self.edge_to_nodes
283 .read()
284 .ok()
285 .and_then(|map| map.get(edge_id).cloned())
286 }
287
288 pub fn unregister_edge(&self, edge_id: &str) {
290 if let Ok(mut map) = self.edge_to_nodes.write() {
291 map.remove(edge_id);
292 }
293 }
294
295 pub fn link_vector_to_row(&self, collection: &str, vector_id: u64, table: &str, row_id: u64) {
301 let vkey = VectorKey::new(collection, vector_id);
302 let rkey = RowKey::new(table, row_id);
303
304 if let Ok(mut map) = self.vector_to_row.write() {
306 map.insert(vkey.clone(), rkey.clone());
307 }
308
309 if let Ok(mut map) = self.row_to_vectors.write() {
311 map.entry(rkey).or_insert_with(Vec::new).push(vkey);
312 }
313 }
314
315 pub fn get_vector_row(&self, collection: &str, vector_id: u64) -> Option<RowKey> {
317 let key = VectorKey::new(collection, vector_id);
318 self.vector_to_row
319 .read()
320 .ok()
321 .and_then(|map| map.get(&key).cloned())
322 }
323
324 pub fn get_row_vectors(&self, table: &str, row_id: u64) -> Vec<VectorKey> {
326 let key = RowKey::new(table, row_id);
327 self.row_to_vectors
328 .read()
329 .ok()
330 .and_then(|map| map.get(&key).cloned())
331 .unwrap_or_default()
332 }
333
334 pub fn resolve(&self, source: &StorageRef) -> Vec<StorageRef> {
345 let mut results = Vec::new();
346
347 match source {
348 StorageRef::Node(node_id) => {
349 for vkey in self.get_node_vectors(node_id) {
351 results.push(StorageRef::Vector(vkey));
352 }
353 for rkey in self.get_node_rows(node_id) {
355 results.push(StorageRef::Row(rkey));
356 }
357 }
358 StorageRef::Vector(vkey) => {
359 if let Some(node_id) = self.get_vector_node(&vkey.collection, vkey.vector_id) {
361 results.push(StorageRef::Node(node_id));
362 }
363 if let Some(rkey) = self.get_vector_row(&vkey.collection, vkey.vector_id) {
365 results.push(StorageRef::Row(rkey));
366 }
367 }
368 StorageRef::Row(rkey) => {
369 if let Some(node_id) = self.get_row_node(&rkey.table, rkey.row_id) {
371 results.push(StorageRef::Node(node_id));
372 }
373 for vkey in self.get_row_vectors(&rkey.table, rkey.row_id) {
375 results.push(StorageRef::Vector(vkey));
376 }
377 }
378 StorageRef::Edge(edge_id) => {
379 if let Some((src, tgt)) = self.get_edge_nodes(edge_id) {
381 results.push(StorageRef::Node(src));
382 results.push(StorageRef::Node(tgt));
383 }
384 }
385 }
386
387 results
388 }
389
390 pub fn resolve_transitive(&self, source: &StorageRef, max_depth: usize) -> Vec<StorageRef> {
394 let mut visited = std::collections::HashSet::new();
395 let mut results = Vec::new();
396 let mut frontier = vec![source.clone()];
397
398 for _ in 0..max_depth {
399 let mut next_frontier = Vec::new();
400 for current in frontier {
401 if !visited.insert(current.clone()) {
402 continue;
403 }
404 for related in self.resolve(¤t) {
405 if !visited.contains(&related) {
406 results.push(related.clone());
407 next_frontier.push(related);
408 }
409 }
410 }
411 if next_frontier.is_empty() {
412 break;
413 }
414 frontier = next_frontier;
415 }
416
417 results
418 }
419
420 pub fn remove_node(&self, node_id: &str) {
426 if let Ok(mut nv) = self.node_to_vectors.write() {
428 if let Some(vectors) = nv.remove(node_id) {
429 if let Ok(mut vn) = self.vector_to_node.write() {
430 for v in vectors {
431 vn.remove(&v);
432 }
433 }
434 }
435 }
436
437 if let Ok(mut nr) = self.node_to_rows.write() {
439 if let Some(rows) = nr.remove(node_id) {
440 if let Ok(mut rn) = self.row_to_node.write() {
441 for r in rows {
442 rn.remove(&r);
443 }
444 }
445 }
446 }
447 }
448
449 pub fn remove_vector(&self, collection: &str, vector_id: u64) {
451 let key = VectorKey::new(collection, vector_id);
452
453 if let Ok(mut vn) = self.vector_to_node.write() {
455 if let Some(node_id) = vn.remove(&key) {
456 if let Ok(mut nv) = self.node_to_vectors.write() {
457 if let Some(vectors) = nv.get_mut(&node_id) {
458 vectors.retain(|v| v != &key);
459 if vectors.is_empty() {
460 nv.remove(&node_id);
461 }
462 }
463 }
464 }
465 }
466
467 if let Ok(mut vr) = self.vector_to_row.write() {
469 if let Some(rkey) = vr.remove(&key) {
470 if let Ok(mut rv) = self.row_to_vectors.write() {
471 if let Some(vectors) = rv.get_mut(&rkey) {
472 vectors.retain(|v| v != &key);
473 if vectors.is_empty() {
474 rv.remove(&rkey);
475 }
476 }
477 }
478 }
479 }
480 }
481
482 pub fn stats(&self) -> UnifiedIndexStats {
484 let node_to_vector_count = self
485 .node_to_vectors
486 .read()
487 .map(|m| m.values().map(|v| v.len()).sum())
488 .unwrap_or(0);
489 let node_to_row_count = self
490 .node_to_rows
491 .read()
492 .map(|m| m.values().map(|v| v.len()).sum())
493 .unwrap_or(0);
494 let vector_to_row_count = self.vector_to_row.read().map(|m| m.len()).unwrap_or(0);
495
496 UnifiedIndexStats {
497 node_to_vector_count,
498 node_to_row_count,
499 vector_to_row_count,
500 total_refs: node_to_vector_count + node_to_row_count + vector_to_row_count,
501 }
502 }
503
504 pub fn clear(&self) {
506 if let Ok(mut m) = self.node_to_vectors.write() {
507 m.clear();
508 }
509 if let Ok(mut m) = self.vector_to_node.write() {
510 m.clear();
511 }
512 if let Ok(mut m) = self.node_to_rows.write() {
513 m.clear();
514 }
515 if let Ok(mut m) = self.row_to_node.write() {
516 m.clear();
517 }
518 if let Ok(mut m) = self.edge_to_nodes.write() {
519 m.clear();
520 }
521 if let Ok(mut m) = self.vector_to_row.write() {
522 m.clear();
523 }
524 if let Ok(mut m) = self.row_to_vectors.write() {
525 m.clear();
526 }
527 }
528}
529
530impl Default for UnifiedIndex {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_node_vector_linking() {
542 let idx = UnifiedIndex::new();
543
544 idx.link_node_to_vector("host:1", "embeddings", 42);
546
547 let vectors = idx.get_node_vectors("host:1");
549 assert_eq!(vectors.len(), 1);
550 assert_eq!(vectors[0].collection, "embeddings");
551 assert_eq!(vectors[0].vector_id, 42);
552
553 let node = idx.get_vector_node("embeddings", 42);
555 assert_eq!(node, Some("host:1".to_string()));
556 }
557
558 #[test]
559 fn test_node_row_linking() {
560 let idx = UnifiedIndex::new();
561
562 idx.link_node_to_row("host:1", "hosts", 100);
564
565 let rows = idx.get_node_rows("host:1");
567 assert_eq!(rows.len(), 1);
568 assert_eq!(rows[0].table, "hosts");
569 assert_eq!(rows[0].row_id, 100);
570
571 let node = idx.get_row_node("hosts", 100);
573 assert_eq!(node, Some("host:1".to_string()));
574 }
575
576 #[test]
577 fn test_resolve() {
578 let idx = UnifiedIndex::new();
579
580 idx.link_node_to_vector("host:1", "embeddings", 42);
582 idx.link_node_to_row("host:1", "hosts", 100);
583
584 let refs = idx.resolve(&StorageRef::node("host:1"));
586 assert_eq!(refs.len(), 2);
587
588 let refs = idx.resolve(&StorageRef::vector("embeddings", 42));
590 assert_eq!(refs.len(), 1);
591 assert!(matches!(&refs[0], StorageRef::Node(id) if id == "host:1"));
592 }
593
594 #[test]
595 fn test_transitive_resolve() {
596 let idx = UnifiedIndex::new();
597
598 idx.link_node_to_row("host:1", "hosts", 100);
600 idx.link_node_to_vector("host:1", "embeddings", 42);
601
602 let refs = idx.resolve_transitive(&StorageRef::row("hosts", 100), 2);
604
605 assert!(refs
607 .iter()
608 .any(|r| matches!(r, StorageRef::Node(id) if id == "host:1")));
609 assert!(refs.iter().any(
610 |r| matches!(r, StorageRef::Vector(vk) if vk.collection == "embeddings" && vk.vector_id == 42)
611 ));
612 }
613
614 #[test]
615 fn test_multiple_vectors_per_node() {
616 let idx = UnifiedIndex::new();
617
618 idx.link_node_to_vector("host:1", "embeddings", 1);
620 idx.link_node_to_vector("host:1", "embeddings", 2);
621 idx.link_node_to_vector("host:1", "descriptions", 1);
622
623 let vectors = idx.get_node_vectors("host:1");
624 assert_eq!(vectors.len(), 3);
625 }
626
627 #[test]
628 fn test_unlink() {
629 let idx = UnifiedIndex::new();
630
631 idx.link_node_to_vector("host:1", "embeddings", 42);
632 assert!(idx.get_vector_node("embeddings", 42).is_some());
633
634 idx.unlink_node_from_vector("host:1", "embeddings", 42);
635 assert!(idx.get_vector_node("embeddings", 42).is_none());
636 assert!(idx.get_node_vectors("host:1").is_empty());
637 }
638
639 #[test]
640 fn test_stats() {
641 let idx = UnifiedIndex::new();
642
643 idx.link_node_to_vector("host:1", "embeddings", 1);
644 idx.link_node_to_vector("host:1", "embeddings", 2);
645 idx.link_node_to_row("host:1", "hosts", 100);
646 idx.link_vector_to_row("embeddings", 3, "hosts", 200);
647
648 let stats = idx.stats();
649 assert_eq!(stats.node_to_vector_count, 2);
650 assert_eq!(stats.node_to_row_count, 1);
651 assert_eq!(stats.vector_to_row_count, 1);
652 assert_eq!(stats.total_refs, 4);
653 }
654}