1use std::collections::HashMap;
2
3use crate::error::RetrievalError;
4use ahash::{AHashMap, AHashSet};
5use issundb_core::{EdgeId, Graph, NodeId};
6use issundb_text::{TextGraphExt, TextSearchOptions};
7use issundb_vector::{VectorGraphExt, VectorSearchOptions};
8
9pub struct Subgraph {
18 pub nodes: Vec<NodeId>,
19 pub edges: Vec<EdgeId>,
20 pub scores: HashMap<NodeId, f32>,
21}
22
23pub struct RetrieveOptions {
25 pub k: usize,
27 pub hops: u8,
29 pub max_distance: f32,
33 pub max_nodes: Option<usize>,
37}
38
39impl Default for RetrieveOptions {
40 fn default() -> Self {
41 Self {
42 k: 10,
43 hops: 2,
44 max_distance: f32::MAX,
45 max_nodes: None,
46 }
47 }
48}
49
50pub fn retrieve(graph: &Graph, q: &[f32], k: usize, hops: u8) -> Result<Subgraph, RetrievalError> {
53 retrieve_with(
54 graph,
55 q,
56 &RetrieveOptions {
57 k,
58 hops,
59 ..Default::default()
60 },
61 )
62}
63
64pub fn retrieve_with(
71 graph: &Graph,
72 q: &[f32],
73 opts: &RetrieveOptions,
74) -> Result<Subgraph, RetrievalError> {
75 let hits = graph.vector_search(q, opts.k)?;
76
77 let mut scores: AHashMap<NodeId, f32> = AHashMap::new();
78 let mut seeds = Vec::new();
79 for hit in &hits {
80 if hit.distance <= opts.max_distance {
81 scores.insert(hit.node, hit.distance);
82 seeds.push(hit.node);
83 }
84 }
85
86 if seeds.is_empty() {
87 return Ok(Subgraph {
88 nodes: Vec::new(),
89 edges: Vec::new(),
90 scores: HashMap::new(),
91 });
92 }
93
94 let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
95 let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
96
97 scores.retain(|n, _| node_set.contains(n));
102
103 let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
104 for &node in &node_set {
105 for ne in graph.out_neighbors(node)? {
106 if node_set.contains(&ne.node) {
107 edge_set.insert(ne.edge);
108 }
109 }
110 }
111
112 Ok(Subgraph {
113 nodes: node_set.into_iter().collect(),
114 edges: edge_set.into_iter().collect(),
115 scores: scores.into_iter().collect(),
116 })
117}
118
119#[derive(Debug, Clone)]
121pub enum FusionStrategy {
122 Rrf { k: u32 },
125 WeightedSum {
127 vector_weight: f32,
128 text_weight: f32,
129 },
130}
131
132impl Default for FusionStrategy {
133 fn default() -> Self {
134 Self::Rrf { k: 60 }
135 }
136}
137
138pub struct HybridRetrieveOptions {
140 pub vector_k: usize,
142 pub text_k: usize,
144 pub text_label: Option<String>,
146 pub text_property: Option<String>,
148 pub hops: u8,
150 pub max_distance: f32,
152 pub max_nodes: Option<usize>,
154 pub vector_label: Option<String>,
156 pub fusion: FusionStrategy,
158}
159
160impl Default for HybridRetrieveOptions {
161 fn default() -> Self {
162 Self {
163 vector_k: 10,
164 text_k: 10,
165 text_label: None,
166 text_property: None,
167 hops: 2,
168 max_distance: f32::MAX,
169 max_nodes: None,
170 vector_label: None,
171 fusion: FusionStrategy::default(),
172 }
173 }
174}
175
176pub fn retrieve_hybrid(
183 graph: &Graph,
184 q: &[f32],
185 text_query: &str,
186 opts: &HybridRetrieveOptions,
187) -> Result<Subgraph, RetrievalError> {
188 let mut vec_ranks: AHashMap<NodeId, usize> = AHashMap::new();
190 let mut vec_scores: AHashMap<NodeId, f32> = AHashMap::new();
191
192 if opts.vector_k > 0 && !q.is_empty() {
193 let hits = graph.vector_search_with(
194 q,
195 &VectorSearchOptions {
196 k: opts.vector_k,
197 label: opts.vector_label.clone(),
198 properties: None,
199 rescore_factor: None,
200 },
201 )?;
202 for (rank, hit) in hits.iter().enumerate() {
203 if hit.distance <= opts.max_distance {
204 vec_ranks.insert(hit.node, rank);
205 vec_scores.insert(hit.node, hit.distance);
206 }
207 }
208 }
209
210 let mut text_ranks: AHashMap<NodeId, usize> = AHashMap::new();
212
213 if opts.text_k > 0 && !text_query.is_empty() {
214 let text_opts = TextSearchOptions {
215 label: opts.text_label.clone(),
216 property: opts.text_property.clone(),
217 limit: opts.text_k,
218 ..Default::default()
219 };
220 let text_hits = graph.text_search(text_query, &text_opts)?;
221 for (rank, hit) in text_hits.iter().enumerate() {
222 text_ranks.insert(hit.node, rank);
223 }
224 }
225
226 let mut fused: AHashMap<NodeId, f32> = AHashMap::new();
228
229 let all_nodes: AHashSet<NodeId> = vec_ranks.keys().chain(text_ranks.keys()).copied().collect();
230
231 for node in &all_nodes {
232 let score = match &opts.fusion {
233 FusionStrategy::Rrf { k } => {
234 let kf = *k as f32;
235 let vs = vec_ranks
236 .get(node)
237 .map(|r| 1.0 / (kf + *r as f32 + 1.0))
238 .unwrap_or(0.0);
239 let ts = text_ranks
240 .get(node)
241 .map(|r| 1.0 / (kf + *r as f32 + 1.0))
242 .unwrap_or(0.0);
243 vs + ts
244 }
245 FusionStrategy::WeightedSum {
246 vector_weight,
247 text_weight,
248 } => {
249 let total_vec = opts.vector_k.max(1) as f32;
250 let total_txt = opts.text_k.max(1) as f32;
251 let vs = vec_ranks
252 .get(node)
253 .map(|r| (total_vec - *r as f32) / total_vec)
254 .unwrap_or(0.0);
255 let ts = text_ranks
256 .get(node)
257 .map(|r| (total_txt - *r as f32) / total_txt)
258 .unwrap_or(0.0);
259 vector_weight * vs + text_weight * ts
260 }
261 };
262 fused.insert(*node, score);
263 }
264
265 let seeds: Vec<NodeId> = fused.keys().copied().collect();
266
267 if seeds.is_empty() {
268 return Ok(Subgraph {
269 nodes: Vec::new(),
270 edges: Vec::new(),
271 scores: HashMap::new(),
272 });
273 }
274
275 let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
277 let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
278
279 let mut scores: AHashMap<NodeId, f32> = fused;
280 scores.retain(|n, _| node_set.contains(n));
281
282 let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
283 for &node in &node_set {
284 for ne in graph.out_neighbors(node)? {
285 if node_set.contains(&ne.node) {
286 edge_set.insert(ne.edge);
287 }
288 }
289 }
290
291 Ok(Subgraph {
292 nodes: node_set.into_iter().collect(),
293 edges: edge_set.into_iter().collect(),
294 scores: scores.into_iter().collect(),
295 })
296}
297
298#[cfg(test)]
299mod tests {
300 use serde_json::json;
301 use tempfile::TempDir;
302
303 use super::*;
304
305 fn open_tmp() -> (TempDir, Graph) {
306 let dir = TempDir::new().unwrap();
307 let g = Graph::open(dir.path(), 1).unwrap();
308 (dir, g)
309 }
310
311 #[test]
312 fn retrieve_empty_vector_index_returns_empty_subgraph() {
313 let (_dir, g) = open_tmp();
314 let sub = retrieve(&g, &[1.0f32, 0.0], 5, 2).unwrap();
315 assert!(sub.nodes.is_empty());
316 assert!(sub.edges.is_empty());
317 assert!(sub.scores.is_empty());
318 }
319
320 #[test]
321 fn retrieve_hops_zero_returns_only_seed_nodes() {
322 let (_dir, g) = open_tmp();
323 let a = g.add_node("N", &json!({})).unwrap();
324 let b = g.add_node("N", &json!({})).unwrap();
325 let c = g.add_node("N", &json!({})).unwrap();
326 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
327 g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
328 g.add_edge(a, c, "E", &json!({})).unwrap();
329
330 let sub = retrieve(&g, &[1.0f32, 0.0, 0.0], 1, 0).unwrap();
332 assert_eq!(sub.nodes.len(), 1);
333 assert_eq!(sub.nodes[0], a);
334 assert!(!sub.nodes.contains(&c));
335 }
336
337 #[test]
338 fn retrieve_expands_bfs_to_correct_depth() {
339 let (_dir, g) = open_tmp();
340 let a = g.add_node("N", &json!({})).unwrap();
342 let b = g.add_node("N", &json!({})).unwrap();
343 let c = g.add_node("N", &json!({})).unwrap();
344 let d = g.add_node("N", &json!({})).unwrap();
345 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
346 g.add_edge(a, b, "E", &json!({})).unwrap();
347 g.add_edge(b, c, "E", &json!({})).unwrap();
348 g.add_edge(c, d, "E", &json!({})).unwrap();
349
350 let sub1 = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
351 let sub2 = retrieve(&g, &[1.0f32, 0.0], 1, 2).unwrap();
352
353 let mut n1 = sub1.nodes.clone();
354 n1.sort_unstable();
355 assert_eq!(n1, vec![a, b]);
356
357 let mut n2 = sub2.nodes.clone();
358 n2.sort_unstable();
359 assert_eq!(n2, vec![a, b, c]);
360 }
361
362 #[test]
363 fn retrieve_subgraph_edges_connect_only_nodes_in_set() {
364 let (_dir, g) = open_tmp();
365 let a = g.add_node("N", &json!({})).unwrap();
367 let b = g.add_node("N", &json!({})).unwrap();
368 let c = g.add_node("N", &json!({})).unwrap();
369 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
370 let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
371 let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
372
373 let sub = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
374 assert!(sub.edges.contains(&e_ab));
375 assert_eq!(sub.edges.len(), 1);
377 }
378
379 #[test]
380 fn retrieve_scores_map_contains_seed_distances() {
381 let (_dir, g) = open_tmp();
382 let a = g.add_node("N", &json!({})).unwrap();
383 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
384
385 let sub = retrieve(&g, &[1.0f32, 0.0], 1, 0).unwrap();
386 assert!(sub.scores.contains_key(&a));
387 assert!(sub.scores[&a] < 1e-5);
388 }
389
390 #[test]
391 fn retrieve_with_max_distance_filters_far_seeds() {
392 let (_dir, g) = open_tmp();
393 let a = g.add_node("N", &json!({})).unwrap();
394 let b = g.add_node("N", &json!({})).unwrap();
395 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
397 g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
398
399 let sub = retrieve_with(
400 &g,
401 &[1.0f32, 0.0, 0.0],
402 &RetrieveOptions {
403 k: 2,
404 hops: 0,
405 max_distance: 0.1,
406 max_nodes: None,
407 },
408 )
409 .unwrap();
410
411 assert_eq!(sub.nodes.len(), 1);
413 assert_eq!(sub.nodes[0], a);
414 }
415
416 #[test]
417 fn retrieve_with_max_nodes_caps_subgraph() {
418 let (_dir, g) = open_tmp();
419 let a = g.add_node("N", &json!({})).unwrap();
421 let b = g.add_node("N", &json!({})).unwrap();
422 let c = g.add_node("N", &json!({})).unwrap();
423 let d = g.add_node("N", &json!({})).unwrap();
424 let e = g.add_node("N", &json!({})).unwrap();
425 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
426 g.add_edge(a, b, "E", &json!({})).unwrap();
427 g.add_edge(a, c, "E", &json!({})).unwrap();
428 g.add_edge(a, d, "E", &json!({})).unwrap();
429 g.add_edge(a, e, "E", &json!({})).unwrap();
430
431 let sub = retrieve_with(
432 &g,
433 &[1.0f32, 0.0],
434 &RetrieveOptions {
435 k: 1,
436 hops: 1,
437 max_distance: f32::MAX,
438 max_nodes: Some(3),
439 },
440 )
441 .unwrap();
442
443 assert!(sub.nodes.len() <= 3);
444 }
445
446 #[test]
447 fn retrieve_with_multiple_seeds_each_expand_independently() {
448 let (_dir, g) = open_tmp();
449 let a = g.add_node("N", &json!({})).unwrap();
454 let b = g.add_node("N", &json!({})).unwrap();
455 let c = g.add_node("N", &json!({})).unwrap();
456 let d = g.add_node("N", &json!({})).unwrap();
457 let e = g.add_node("N", &json!({})).unwrap();
458 let f = g.add_node("N", &json!({})).unwrap();
459 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
460 g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
461 g.add_edge(a, b, "E", &json!({})).unwrap();
462 g.add_edge(b, c, "E", &json!({})).unwrap();
463 g.add_edge(d, e, "E", &json!({})).unwrap();
464 g.add_edge(e, f, "E", &json!({})).unwrap();
465
466 let sub1 = retrieve_with(
467 &g,
468 &[1.0f32, 0.0, 0.0],
469 &RetrieveOptions {
470 k: 2,
471 hops: 1,
472 max_distance: f32::MAX,
473 max_nodes: None,
474 },
475 )
476 .unwrap();
477 let mut n1 = sub1.nodes.clone();
478 n1.sort_unstable();
479 assert!(n1.contains(&a), "seed a must be present at hops=1");
480 assert!(n1.contains(&b), "b is 1 hop from seed a");
481 assert!(n1.contains(&d), "seed d must be present at hops=1");
482 assert!(n1.contains(&e), "e is 1 hop from seed d");
483 assert!(!n1.contains(&c), "c is 2 hops from a, out of range");
484 assert!(!n1.contains(&f), "f is 2 hops from d, out of range");
485 assert_eq!(n1.len(), 4);
486
487 let sub2 = retrieve_with(
488 &g,
489 &[1.0f32, 0.0, 0.0],
490 &RetrieveOptions {
491 k: 2,
492 hops: 2,
493 max_distance: f32::MAX,
494 max_nodes: None,
495 },
496 )
497 .unwrap();
498 assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
499 assert!(sub2.scores.contains_key(&a));
500 assert!(sub2.scores.contains_key(&d));
501 }
502
503 #[test]
509 fn graphblas_retrieve_k_hop_expansion() {
510 let (_dir, g) = open_tmp();
511 let a = g.add_node("N", &json!({})).unwrap();
512 let b = g.add_node("N", &json!({})).unwrap();
513 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
514 g.add_edge(a, b, "E", &json!({})).unwrap();
515 g.rebuild_csr().unwrap();
516
517 let sub = retrieve_with(
518 &g,
519 &[1.0f32, 0.0],
520 &RetrieveOptions {
521 k: 1,
522 hops: 1,
523 max_distance: f32::MAX,
524 max_nodes: None,
525 },
526 )
527 .unwrap();
528
529 assert_eq!(sub.nodes.len(), 2);
530 assert!(sub.nodes.contains(&a));
531 assert!(sub.nodes.contains(&b));
532 }
533
534 #[test]
535 fn graphblas_retrieve_hops_zero_returns_only_seed() {
536 let (_dir, g) = open_tmp();
537 let a = g.add_node("N", &json!({})).unwrap();
538 let b = g.add_node("N", &json!({})).unwrap();
539 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
540 g.add_edge(a, b, "E", &json!({})).unwrap();
541 g.rebuild_csr().unwrap();
542
543 let sub = retrieve_with(
544 &g,
545 &[1.0f32, 0.0],
546 &RetrieveOptions {
547 k: 1,
548 hops: 0,
549 max_distance: f32::MAX,
550 max_nodes: None,
551 },
552 )
553 .unwrap();
554
555 assert_eq!(sub.nodes, vec![a]);
556 assert!(sub.edges.is_empty(), "no edges when hops=0");
557 }
558
559 #[test]
560 fn graphblas_retrieve_scores_keys_are_subset_of_nodes() {
561 let (_dir, g) = open_tmp();
562 let a = g.add_node("N", &json!({})).unwrap();
563 let b = g.add_node("N", &json!({})).unwrap();
564 let c = g.add_node("N", &json!({})).unwrap();
565 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
566 g.upsert_vector(b, &[0.9f32, 0.1, 0.0]).unwrap();
567 g.add_edge(a, c, "E", &json!({})).unwrap();
568 g.rebuild_csr().unwrap();
569
570 let sub = retrieve_with(
571 &g,
572 &[1.0f32, 0.0, 0.0],
573 &RetrieveOptions {
574 k: 2,
575 hops: 1,
576 max_distance: f32::MAX,
577 max_nodes: None,
578 },
579 )
580 .unwrap();
581
582 for node_id in sub.scores.keys() {
584 assert!(
585 sub.nodes.contains(node_id),
586 "scores key {node_id:?} is absent from nodes"
587 );
588 }
589 }
590
591 #[test]
592 fn graphblas_retrieve_edges_connect_only_nodes_in_subgraph() {
593 let (_dir, g) = open_tmp();
594 let a = g.add_node("N", &json!({})).unwrap();
596 let b = g.add_node("N", &json!({})).unwrap();
597 let c = g.add_node("N", &json!({})).unwrap();
598 let d = g.add_node("N", &json!({})).unwrap();
599 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
600 let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
601 let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
602 g.add_edge(c, d, "E", &json!({})).unwrap();
603 g.rebuild_csr().unwrap();
604
605 let sub = retrieve_with(
606 &g,
607 &[1.0f32, 0.0],
608 &RetrieveOptions {
609 k: 1,
610 hops: 1,
611 max_distance: f32::MAX,
612 max_nodes: None,
613 },
614 )
615 .unwrap();
616
617 assert!(sub.nodes.contains(&a));
618 assert!(sub.nodes.contains(&b));
619 assert!(!sub.nodes.contains(&c));
620 assert!(sub.edges.contains(&e_ab), "edge a to b must be in subgraph");
621 assert_eq!(
622 sub.edges.len(),
623 1,
624 "only a to b is within the 1-hop subgraph"
625 );
626 }
627
628 #[test]
629 fn graphblas_retrieve_max_distance_filters_far_seeds() {
630 let (_dir, g) = open_tmp();
631 let a = g.add_node("N", &json!({})).unwrap();
632 let b = g.add_node("N", &json!({})).unwrap();
633 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
635 g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
636 g.rebuild_csr().unwrap();
637
638 let sub = retrieve_with(
639 &g,
640 &[1.0f32, 0.0, 0.0],
641 &RetrieveOptions {
642 k: 2,
643 hops: 0,
644 max_distance: 0.1,
645 max_nodes: None,
646 },
647 )
648 .unwrap();
649
650 assert_eq!(sub.nodes.len(), 1);
651 assert_eq!(sub.nodes[0], a);
652 assert!(sub.scores.contains_key(&a));
653 assert!(!sub.scores.contains_key(&b));
654 }
655
656 #[test]
657 fn graphblas_retrieve_max_nodes_caps_subgraph() {
658 let (_dir, g) = open_tmp();
659 let a = g.add_node("N", &json!({})).unwrap();
661 let b = g.add_node("N", &json!({})).unwrap();
662 let c = g.add_node("N", &json!({})).unwrap();
663 let d = g.add_node("N", &json!({})).unwrap();
664 let e = g.add_node("N", &json!({})).unwrap();
665 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
666 g.add_edge(a, b, "E", &json!({})).unwrap();
667 g.add_edge(a, c, "E", &json!({})).unwrap();
668 g.add_edge(a, d, "E", &json!({})).unwrap();
669 g.add_edge(a, e, "E", &json!({})).unwrap();
670 g.rebuild_csr().unwrap();
671
672 let sub = retrieve_with(
673 &g,
674 &[1.0f32, 0.0],
675 &RetrieveOptions {
676 k: 1,
677 hops: 1,
678 max_distance: f32::MAX,
679 max_nodes: Some(3),
680 },
681 )
682 .unwrap();
683
684 assert!(
685 sub.nodes.len() <= 3,
686 "expected at most 3 nodes, got {}",
687 sub.nodes.len()
688 );
689 }
690
691 #[test]
692 fn graphblas_retrieve_scores_contain_seed_distances() {
693 let (_dir, g) = open_tmp();
694 let a = g.add_node("N", &json!({})).unwrap();
695 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
696 g.rebuild_csr().unwrap();
697
698 let sub = retrieve_with(
699 &g,
700 &[1.0f32, 0.0],
701 &RetrieveOptions {
702 k: 1,
703 hops: 0,
704 max_distance: f32::MAX,
705 max_nodes: None,
706 },
707 )
708 .unwrap();
709
710 assert!(sub.scores.contains_key(&a));
711 assert!(
712 sub.scores[&a] < 1e-5,
713 "distance to identical vector must be ~0"
714 );
715 }
716
717 #[test]
718 fn graphblas_retrieve_empty_vector_index_returns_empty() {
719 let (_dir, g) = open_tmp();
720 g.rebuild_csr().unwrap();
721
722 let sub = retrieve_with(&g, &[1.0f32, 0.0], &RetrieveOptions::default()).unwrap();
723
724 assert!(sub.nodes.is_empty());
725 assert!(sub.edges.is_empty());
726 assert!(sub.scores.is_empty());
727 }
728
729 #[test]
730 fn graphblas_retrieve_multiple_seeds_each_expand_independently() {
731 let (_dir, g) = open_tmp();
732 let a = g.add_node("N", &json!({})).unwrap();
735 let b = g.add_node("N", &json!({})).unwrap();
736 let c = g.add_node("N", &json!({})).unwrap();
737 let d = g.add_node("N", &json!({})).unwrap();
738 let e = g.add_node("N", &json!({})).unwrap();
739 let f = g.add_node("N", &json!({})).unwrap();
740 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
741 g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
742 g.add_edge(a, b, "E", &json!({})).unwrap();
743 g.add_edge(b, c, "E", &json!({})).unwrap();
744 g.add_edge(d, e, "E", &json!({})).unwrap();
745 g.add_edge(e, f, "E", &json!({})).unwrap();
746 g.rebuild_csr().unwrap();
747
748 let sub1 = retrieve_with(
749 &g,
750 &[1.0f32, 0.0, 0.0],
751 &RetrieveOptions {
752 k: 2,
753 hops: 1,
754 max_distance: f32::MAX,
755 max_nodes: None,
756 },
757 )
758 .unwrap();
759 assert!(sub1.nodes.contains(&a), "seed a must be present at hops=1");
760 assert!(sub1.nodes.contains(&b), "b is 1 hop from seed a");
761 assert!(sub1.nodes.contains(&d), "seed d must be present at hops=1");
762 assert!(sub1.nodes.contains(&e), "e is 1 hop from seed d");
763 assert!(!sub1.nodes.contains(&c), "c is 2 hops from a, out of range");
764 assert!(!sub1.nodes.contains(&f), "f is 2 hops from d, out of range");
765 assert_eq!(sub1.nodes.len(), 4);
766
767 let sub2 = retrieve_with(
768 &g,
769 &[1.0f32, 0.0, 0.0],
770 &RetrieveOptions {
771 k: 2,
772 hops: 2,
773 max_distance: f32::MAX,
774 max_nodes: None,
775 },
776 )
777 .unwrap();
778 assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
779 assert!(sub2.scores.contains_key(&a));
780 assert!(sub2.scores.contains_key(&d));
781 }
782
783 #[test]
784 fn hybrid_retrieve_vector_only_matches_pure_vector_search() {
785 let (_dir, g) = open_tmp();
786 let a = g.add_node("N", &json!({})).unwrap();
787 let b = g.add_node("N", &json!({})).unwrap();
788 g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
789 g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
790 g.rebuild_csr().unwrap();
791
792 let sub = retrieve_hybrid(
793 &g,
794 &[1.0f32, 0.0, 0.0],
795 "",
796 &HybridRetrieveOptions {
797 vector_k: 1,
798 text_k: 0,
799 hops: 0,
800 ..Default::default()
801 },
802 )
803 .unwrap();
804 assert_eq!(sub.nodes.len(), 1);
805 assert_eq!(sub.nodes[0], a);
806 }
807
808 #[test]
809 fn hybrid_retrieve_fuses_both_sources() {
810 let (_dir, g) = open_tmp();
811 let a = g
812 .add_node("Doc", &json!({"body": "rust graph database storage"}))
813 .unwrap();
814 let b = g
815 .add_node("Doc", &json!({"body": "vector search nearest neighbor"}))
816 .unwrap();
817 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
818 g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
819 g.update(|txn| txn.create_node_text_index("Doc", "body"))
820 .unwrap();
821 g.rebuild_csr().unwrap();
822
823 let sub = retrieve_hybrid(
825 &g,
826 &[1.0f32, 0.0],
827 "vector",
828 &HybridRetrieveOptions {
829 vector_k: 1,
830 text_k: 1,
831 text_label: Some("Doc".into()),
832 text_property: Some("body".into()),
833 hops: 0,
834 ..Default::default()
835 },
836 )
837 .unwrap();
838 assert!(sub.nodes.contains(&a), "vector hit a must be present");
840 assert!(sub.nodes.contains(&b), "text hit b must be present");
841 }
842
843 #[test]
844 fn hybrid_retrieve_weighted_sum_produces_correct_scores() {
845 let (_dir, g) = open_tmp();
846 let a = g.add_node("Doc", &json!({"body": "alpha bravo"})).unwrap();
847 let b = g
848 .add_node("Doc", &json!({"body": "charlie delta"}))
849 .unwrap();
850 g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
851 g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
852 g.update(|txn| txn.create_node_text_index("Doc", "body"))
853 .unwrap();
854 g.rebuild_csr().unwrap();
855
856 let sub = retrieve_hybrid(
862 &g,
863 &[1.0f32, 0.0],
864 "charlie",
865 &HybridRetrieveOptions {
866 vector_k: 1,
867 text_k: 1,
868 text_label: Some("Doc".into()),
869 text_property: Some("body".into()),
870 hops: 0,
871 fusion: FusionStrategy::WeightedSum {
872 vector_weight: 0.7,
873 text_weight: 0.3,
874 },
875 ..Default::default()
876 },
877 )
878 .unwrap();
879
880 assert!(
881 sub.scores.contains_key(&a),
882 "vector seed a must have a score"
883 );
884 assert!(sub.scores.contains_key(&b), "text seed b must have a score");
885 assert!(
886 (sub.scores[&a] - 0.7).abs() < 1e-5,
887 "a score should be 0.7, got {}",
888 sub.scores[&a]
889 );
890 assert!(
891 (sub.scores[&b] - 0.3).abs() < 1e-5,
892 "b score should be 0.3, got {}",
893 sub.scores[&b]
894 );
895 }
896
897 #[test]
898 fn hybrid_retrieve_text_only_returns_text_seeds() {
899 let (_dir, g) = open_tmp();
900 let a = g
901 .add_node("Doc", &json!({"body": "quantum computing research"}))
902 .unwrap();
903 let b = g
904 .add_node("Doc", &json!({"body": "classical music orchestra"}))
905 .unwrap();
906 g.update(|txn| txn.create_node_text_index("Doc", "body"))
907 .unwrap();
908 g.rebuild_csr().unwrap();
909
910 let sub = retrieve_hybrid(
912 &g,
913 &[],
914 "quantum",
915 &HybridRetrieveOptions {
916 vector_k: 0,
917 text_k: 5,
918 text_label: Some("Doc".into()),
919 text_property: Some("body".into()),
920 hops: 0,
921 ..Default::default()
922 },
923 )
924 .unwrap();
925
926 assert_eq!(
927 sub.nodes.len(),
928 1,
929 "only the text-matching node should appear"
930 );
931 assert_eq!(sub.nodes[0], a);
932 assert!(sub.scores.contains_key(&a));
933 assert!(!sub.nodes.contains(&b), "non-matching node must be absent");
934 }
935}