Skip to main content

issundb_retrieval/
retrieve.rs

1use std::collections::HashMap;
2
3use crate::error::RetrievalError;
4use ahash::{AHashMap, AHashSet};
5use issundb_core::{EdgeId, Graph, NodeId};
6use issundb_text::{TextGraphExt, TextSearchOptions};
7use issundb_vector::{VectorGraphExt, VectorSearchOptions};
8
9/// A subgraph extracted by a retrieval call.
10///
11/// `nodes` and `edges` are deduplicated but unordered. `scores` maps each seed
12/// node to its relevance value; expansion-only nodes are absent from the map.
13/// For [`retrieve`] and [`retrieve_with`] the value is the seed's cosine
14/// distance from the query (lower is closer). For [`retrieve_hybrid`] it is the
15/// fused score produced by the configured [`FusionStrategy`] over the vector
16/// and text seeds (higher is more relevant).
17pub struct Subgraph {
18    pub nodes: Vec<NodeId>,
19    pub edges: Vec<EdgeId>,
20    pub scores: HashMap<NodeId, f32>,
21}
22
23/// Options for `retrieve_with`.
24pub struct RetrieveOptions {
25    /// Number of seed nodes returned by the vector search.
26    pub k: usize,
27    /// BFS expansion depth from each seed node.
28    pub hops: u8,
29    /// Maximum cosine distance for a vector hit to qualify as a seed.
30    /// Hits with `distance > max_distance` are dropped before BFS expansion.
31    /// Default: `f32::MAX` (keep all k hits).
32    pub max_distance: f32,
33    /// Hard cap on the total number of nodes in the returned subgraph.
34    /// BFS stops as soon as this limit is reached.
35    /// `None` means no cap.
36    pub max_nodes: Option<usize>,
37}
38
39impl Default for RetrieveOptions {
40    fn default() -> Self {
41        Self {
42            k: 10,
43            hops: 2,
44            max_distance: f32::MAX,
45            max_nodes: None,
46        }
47    }
48}
49
50/// Convenience wrapper: vector search to k seeds, then `hops`-hop BFS expansion to
51/// subgraph materialization.
52pub fn retrieve(graph: &Graph, q: &[f32], k: usize, hops: u8) -> Result<Subgraph, RetrievalError> {
53    retrieve_with(
54        graph,
55        q,
56        &RetrieveOptions {
57            k,
58            hops,
59            ..Default::default()
60        },
61    )
62}
63
64/// Full retrieve with configurable options.
65///
66/// GraphBLAS SpMV k-hop expansion for hybrid retrieval.
67///
68/// Runs multi-source SpMV BFS from the filtered seed nodes up to `hops` hops.
69/// Stops early or caps the results if `max_nodes` is specified and exceeded.
70pub fn retrieve_with(
71    graph: &Graph,
72    q: &[f32],
73    opts: &RetrieveOptions,
74) -> Result<Subgraph, RetrievalError> {
75    let hits = graph.vector_search(q, opts.k)?;
76
77    let mut scores: AHashMap<NodeId, f32> = AHashMap::new();
78    let mut seeds = Vec::new();
79    for hit in &hits {
80        if hit.distance <= opts.max_distance {
81            scores.insert(hit.node, hit.distance);
82            seeds.push(hit.node);
83        }
84    }
85
86    if seeds.is_empty() {
87        return Ok(Subgraph {
88            nodes: Vec::new(),
89            edges: Vec::new(),
90            scores: HashMap::new(),
91        });
92    }
93
94    let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
95    let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
96
97    // Keep only scores whose seed node actually appears in the BFS result.
98    // `bfs_multi_source_graphblas` guarantees this when every seed is present in
99    // the CSR snapshot; this retain is a defensive guard to ensure
100    // `scores.keys() ⊆ nodes` even if that invariant is ever broken upstream.
101    scores.retain(|n, _| node_set.contains(n));
102
103    let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
104    for &node in &node_set {
105        for ne in graph.out_neighbors(node)? {
106            if node_set.contains(&ne.node) {
107                edge_set.insert(ne.edge);
108            }
109        }
110    }
111
112    Ok(Subgraph {
113        nodes: node_set.into_iter().collect(),
114        edges: edge_set.into_iter().collect(),
115        scores: scores.into_iter().collect(),
116    })
117}
118
119/// Strategy for fusing vector and text relevance scores.
120#[derive(Debug, Clone)]
121pub enum FusionStrategy {
122    /// Reciprocal Rank Fusion: score = Σ 1 / (k + rank).
123    /// `k` is a smoothing constant; default 60.
124    Rrf { k: u32 },
125    /// Weighted linear combination: score = α·vector_score + β·text_score.
126    WeightedSum {
127        vector_weight: f32,
128        text_weight: f32,
129    },
130}
131
132impl Default for FusionStrategy {
133    fn default() -> Self {
134        Self::Rrf { k: 60 }
135    }
136}
137
138/// Options for `retrieve_hybrid`.
139pub struct HybridRetrieveOptions {
140    /// Number of seed nodes from the vector search. `0` disables vector search.
141    pub vector_k: usize,
142    /// Number of seed nodes from the text search. `0` disables text search.
143    pub text_k: usize,
144    /// Label to restrict the text search. `None` searches all indexed labels.
145    pub text_label: Option<String>,
146    /// Property to restrict the text search. `None` searches all indexed properties.
147    pub text_property: Option<String>,
148    /// BFS expansion depth from each seed.
149    pub hops: u8,
150    /// Maximum cosine distance for a vector hit to qualify as a seed.
151    pub max_distance: f32,
152    /// Hard cap on total subgraph nodes.
153    pub max_nodes: Option<usize>,
154    /// If set, only nodes with this label qualify as vector-search seeds.
155    pub vector_label: Option<String>,
156    /// Score fusion strategy.
157    pub fusion: FusionStrategy,
158}
159
160impl Default for HybridRetrieveOptions {
161    fn default() -> Self {
162        Self {
163            vector_k: 10,
164            text_k: 10,
165            text_label: None,
166            text_property: None,
167            hops: 2,
168            max_distance: f32::MAX,
169            max_nodes: None,
170            vector_label: None,
171            fusion: FusionStrategy::default(),
172        }
173    }
174}
175
176/// Hybrid retrieval: merges vector search seeds with full-text search seeds,
177/// fuses their scores using `opts.fusion`, then expands via BFS.
178///
179/// Vector search is run when `opts.vector_k > 0` and `q` is non-empty.
180/// Text search is run when `opts.text_k > 0` and `text_query` is non-empty.
181/// Both may run simultaneously; their ranked lists are merged before BFS.
182pub fn retrieve_hybrid(
183    graph: &Graph,
184    q: &[f32],
185    text_query: &str,
186    opts: &HybridRetrieveOptions,
187) -> Result<Subgraph, RetrievalError> {
188    // ---- collect vector hits -----------------------------------------------
189    let mut vec_ranks: AHashMap<NodeId, usize> = AHashMap::new();
190    let mut vec_scores: AHashMap<NodeId, f32> = AHashMap::new();
191
192    if opts.vector_k > 0 && !q.is_empty() {
193        let hits = graph.vector_search_with(
194            q,
195            &VectorSearchOptions {
196                k: opts.vector_k,
197                label: opts.vector_label.clone(),
198                properties: None,
199                rescore_factor: None,
200            },
201        )?;
202        for (rank, hit) in hits.iter().enumerate() {
203            if hit.distance <= opts.max_distance {
204                vec_ranks.insert(hit.node, rank);
205                vec_scores.insert(hit.node, hit.distance);
206            }
207        }
208    }
209
210    // ---- collect text hits -------------------------------------------------
211    let mut text_ranks: AHashMap<NodeId, usize> = AHashMap::new();
212
213    if opts.text_k > 0 && !text_query.is_empty() {
214        let text_opts = TextSearchOptions {
215            label: opts.text_label.clone(),
216            property: opts.text_property.clone(),
217            limit: opts.text_k,
218            ..Default::default()
219        };
220        let text_hits = graph.text_search(text_query, &text_opts)?;
221        for (rank, hit) in text_hits.iter().enumerate() {
222            text_ranks.insert(hit.node, rank);
223        }
224    }
225
226    // ---- fuse scores -------------------------------------------------------
227    let mut fused: AHashMap<NodeId, f32> = AHashMap::new();
228
229    let all_nodes: AHashSet<NodeId> = vec_ranks.keys().chain(text_ranks.keys()).copied().collect();
230
231    for node in &all_nodes {
232        let score = match &opts.fusion {
233            FusionStrategy::Rrf { k } => {
234                let kf = *k as f32;
235                let vs = vec_ranks
236                    .get(node)
237                    .map(|r| 1.0 / (kf + *r as f32 + 1.0))
238                    .unwrap_or(0.0);
239                let ts = text_ranks
240                    .get(node)
241                    .map(|r| 1.0 / (kf + *r as f32 + 1.0))
242                    .unwrap_or(0.0);
243                vs + ts
244            }
245            FusionStrategy::WeightedSum {
246                vector_weight,
247                text_weight,
248            } => {
249                let total_vec = opts.vector_k.max(1) as f32;
250                let total_txt = opts.text_k.max(1) as f32;
251                let vs = vec_ranks
252                    .get(node)
253                    .map(|r| (total_vec - *r as f32) / total_vec)
254                    .unwrap_or(0.0);
255                let ts = text_ranks
256                    .get(node)
257                    .map(|r| (total_txt - *r as f32) / total_txt)
258                    .unwrap_or(0.0);
259                vector_weight * vs + text_weight * ts
260            }
261        };
262        fused.insert(*node, score);
263    }
264
265    let seeds: Vec<NodeId> = fused.keys().copied().collect();
266
267    if seeds.is_empty() {
268        return Ok(Subgraph {
269            nodes: Vec::new(),
270            edges: Vec::new(),
271            scores: HashMap::new(),
272        });
273    }
274
275    // ---- BFS expansion -----------------------------------------------------
276    let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
277    let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
278
279    let mut scores: AHashMap<NodeId, f32> = fused;
280    scores.retain(|n, _| node_set.contains(n));
281
282    let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
283    for &node in &node_set {
284        for ne in graph.out_neighbors(node)? {
285            if node_set.contains(&ne.node) {
286                edge_set.insert(ne.edge);
287            }
288        }
289    }
290
291    Ok(Subgraph {
292        nodes: node_set.into_iter().collect(),
293        edges: edge_set.into_iter().collect(),
294        scores: scores.into_iter().collect(),
295    })
296}
297
298#[cfg(test)]
299mod tests {
300    use serde_json::json;
301    use tempfile::TempDir;
302
303    use super::*;
304
305    fn open_tmp() -> (TempDir, Graph) {
306        let dir = TempDir::new().unwrap();
307        let g = Graph::open(dir.path(), 1).unwrap();
308        (dir, g)
309    }
310
311    #[test]
312    fn retrieve_empty_vector_index_returns_empty_subgraph() {
313        let (_dir, g) = open_tmp();
314        let sub = retrieve(&g, &[1.0f32, 0.0], 5, 2).unwrap();
315        assert!(sub.nodes.is_empty());
316        assert!(sub.edges.is_empty());
317        assert!(sub.scores.is_empty());
318    }
319
320    #[test]
321    fn retrieve_hops_zero_returns_only_seed_nodes() {
322        let (_dir, g) = open_tmp();
323        let a = g.add_node("N", &json!({})).unwrap();
324        let b = g.add_node("N", &json!({})).unwrap();
325        let c = g.add_node("N", &json!({})).unwrap();
326        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
327        g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
328        g.add_edge(a, c, "E", &json!({})).unwrap();
329
330        // hops=0: no BFS expansion; c is only reachable via a's out-edge.
331        let sub = retrieve(&g, &[1.0f32, 0.0, 0.0], 1, 0).unwrap();
332        assert_eq!(sub.nodes.len(), 1);
333        assert_eq!(sub.nodes[0], a);
334        assert!(!sub.nodes.contains(&c));
335    }
336
337    #[test]
338    fn retrieve_expands_bfs_to_correct_depth() {
339        let (_dir, g) = open_tmp();
340        // Chain: a to b to c to d; only a has a vector.
341        let a = g.add_node("N", &json!({})).unwrap();
342        let b = g.add_node("N", &json!({})).unwrap();
343        let c = g.add_node("N", &json!({})).unwrap();
344        let d = g.add_node("N", &json!({})).unwrap();
345        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
346        g.add_edge(a, b, "E", &json!({})).unwrap();
347        g.add_edge(b, c, "E", &json!({})).unwrap();
348        g.add_edge(c, d, "E", &json!({})).unwrap();
349
350        let sub1 = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
351        let sub2 = retrieve(&g, &[1.0f32, 0.0], 1, 2).unwrap();
352
353        let mut n1 = sub1.nodes.clone();
354        n1.sort_unstable();
355        assert_eq!(n1, vec![a, b]);
356
357        let mut n2 = sub2.nodes.clone();
358        n2.sort_unstable();
359        assert_eq!(n2, vec![a, b, c]);
360    }
361
362    #[test]
363    fn retrieve_subgraph_edges_connect_only_nodes_in_set() {
364        let (_dir, g) = open_tmp();
365        // a to b to c; only a and b are in the subgraph (hops=1 from a).
366        let a = g.add_node("N", &json!({})).unwrap();
367        let b = g.add_node("N", &json!({})).unwrap();
368        let c = g.add_node("N", &json!({})).unwrap();
369        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
370        let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
371        let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
372
373        let sub = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
374        assert!(sub.edges.contains(&e_ab));
375        // b to c edge must NOT appear: c is outside the 1-hop subgraph.
376        assert_eq!(sub.edges.len(), 1);
377    }
378
379    #[test]
380    fn retrieve_scores_map_contains_seed_distances() {
381        let (_dir, g) = open_tmp();
382        let a = g.add_node("N", &json!({})).unwrap();
383        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
384
385        let sub = retrieve(&g, &[1.0f32, 0.0], 1, 0).unwrap();
386        assert!(sub.scores.contains_key(&a));
387        assert!(sub.scores[&a] < 1e-5);
388    }
389
390    #[test]
391    fn retrieve_with_max_distance_filters_far_seeds() {
392        let (_dir, g) = open_tmp();
393        let a = g.add_node("N", &json!({})).unwrap();
394        let b = g.add_node("N", &json!({})).unwrap();
395        // a is at distance ~0 from the query; b is orthogonal (distance ~1).
396        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
397        g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
398
399        let sub = retrieve_with(
400            &g,
401            &[1.0f32, 0.0, 0.0],
402            &RetrieveOptions {
403                k: 2,
404                hops: 0,
405                max_distance: 0.1,
406                max_nodes: None,
407            },
408        )
409        .unwrap();
410
411        // Only a is within 0.1 cosine distance of the query.
412        assert_eq!(sub.nodes.len(), 1);
413        assert_eq!(sub.nodes[0], a);
414    }
415
416    #[test]
417    fn retrieve_with_max_nodes_caps_subgraph() {
418        let (_dir, g) = open_tmp();
419        // Star: a to b, c, d, e
420        let a = g.add_node("N", &json!({})).unwrap();
421        let b = g.add_node("N", &json!({})).unwrap();
422        let c = g.add_node("N", &json!({})).unwrap();
423        let d = g.add_node("N", &json!({})).unwrap();
424        let e = g.add_node("N", &json!({})).unwrap();
425        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
426        g.add_edge(a, b, "E", &json!({})).unwrap();
427        g.add_edge(a, c, "E", &json!({})).unwrap();
428        g.add_edge(a, d, "E", &json!({})).unwrap();
429        g.add_edge(a, e, "E", &json!({})).unwrap();
430
431        let sub = retrieve_with(
432            &g,
433            &[1.0f32, 0.0],
434            &RetrieveOptions {
435                k: 1,
436                hops: 1,
437                max_distance: f32::MAX,
438                max_nodes: Some(3),
439            },
440        )
441        .unwrap();
442
443        assert!(sub.nodes.len() <= 3);
444    }
445
446    #[test]
447    fn retrieve_with_multiple_seeds_each_expand_independently() {
448        let (_dir, g) = open_tmp();
449        // Two disconnected chains: a to b to c; d to e to f
450        // Both a and d have vectors and qualify as seeds.
451        // With hops=1 the subgraph must include {a, b, d, e} but not {c, f}.
452        // With hops=2 it must include all six nodes.
453        let a = g.add_node("N", &json!({})).unwrap();
454        let b = g.add_node("N", &json!({})).unwrap();
455        let c = g.add_node("N", &json!({})).unwrap();
456        let d = g.add_node("N", &json!({})).unwrap();
457        let e = g.add_node("N", &json!({})).unwrap();
458        let f = g.add_node("N", &json!({})).unwrap();
459        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
460        g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
461        g.add_edge(a, b, "E", &json!({})).unwrap();
462        g.add_edge(b, c, "E", &json!({})).unwrap();
463        g.add_edge(d, e, "E", &json!({})).unwrap();
464        g.add_edge(e, f, "E", &json!({})).unwrap();
465
466        let sub1 = retrieve_with(
467            &g,
468            &[1.0f32, 0.0, 0.0],
469            &RetrieveOptions {
470                k: 2,
471                hops: 1,
472                max_distance: f32::MAX,
473                max_nodes: None,
474            },
475        )
476        .unwrap();
477        let mut n1 = sub1.nodes.clone();
478        n1.sort_unstable();
479        assert!(n1.contains(&a), "seed a must be present at hops=1");
480        assert!(n1.contains(&b), "b is 1 hop from seed a");
481        assert!(n1.contains(&d), "seed d must be present at hops=1");
482        assert!(n1.contains(&e), "e is 1 hop from seed d");
483        assert!(!n1.contains(&c), "c is 2 hops from a, out of range");
484        assert!(!n1.contains(&f), "f is 2 hops from d, out of range");
485        assert_eq!(n1.len(), 4);
486
487        let sub2 = retrieve_with(
488            &g,
489            &[1.0f32, 0.0, 0.0],
490            &RetrieveOptions {
491                k: 2,
492                hops: 2,
493                max_distance: f32::MAX,
494                max_nodes: None,
495            },
496        )
497        .unwrap();
498        assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
499        assert!(sub2.scores.contains_key(&a));
500        assert!(sub2.scores.contains_key(&d));
501    }
502
503    // --- retrieve_with (GraphBLAS) ---
504    //
505    // Each test calls `rebuild_csr()` after graph mutations so the GraphBLAS
506    // adjacency matrix is current before retrieve_with is invoked.
507
508    #[test]
509    fn graphblas_retrieve_k_hop_expansion() {
510        let (_dir, g) = open_tmp();
511        let a = g.add_node("N", &json!({})).unwrap();
512        let b = g.add_node("N", &json!({})).unwrap();
513        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
514        g.add_edge(a, b, "E", &json!({})).unwrap();
515        g.rebuild_csr().unwrap();
516
517        let sub = retrieve_with(
518            &g,
519            &[1.0f32, 0.0],
520            &RetrieveOptions {
521                k: 1,
522                hops: 1,
523                max_distance: f32::MAX,
524                max_nodes: None,
525            },
526        )
527        .unwrap();
528
529        assert_eq!(sub.nodes.len(), 2);
530        assert!(sub.nodes.contains(&a));
531        assert!(sub.nodes.contains(&b));
532    }
533
534    #[test]
535    fn graphblas_retrieve_hops_zero_returns_only_seed() {
536        let (_dir, g) = open_tmp();
537        let a = g.add_node("N", &json!({})).unwrap();
538        let b = g.add_node("N", &json!({})).unwrap();
539        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
540        g.add_edge(a, b, "E", &json!({})).unwrap();
541        g.rebuild_csr().unwrap();
542
543        let sub = retrieve_with(
544            &g,
545            &[1.0f32, 0.0],
546            &RetrieveOptions {
547                k: 1,
548                hops: 0,
549                max_distance: f32::MAX,
550                max_nodes: None,
551            },
552        )
553        .unwrap();
554
555        assert_eq!(sub.nodes, vec![a]);
556        assert!(sub.edges.is_empty(), "no edges when hops=0");
557    }
558
559    #[test]
560    fn graphblas_retrieve_scores_keys_are_subset_of_nodes() {
561        let (_dir, g) = open_tmp();
562        let a = g.add_node("N", &json!({})).unwrap();
563        let b = g.add_node("N", &json!({})).unwrap();
564        let c = g.add_node("N", &json!({})).unwrap();
565        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
566        g.upsert_vector(b, &[0.9f32, 0.1, 0.0]).unwrap();
567        g.add_edge(a, c, "E", &json!({})).unwrap();
568        g.rebuild_csr().unwrap();
569
570        let sub = retrieve_with(
571            &g,
572            &[1.0f32, 0.0, 0.0],
573            &RetrieveOptions {
574                k: 2,
575                hops: 1,
576                max_distance: f32::MAX,
577                max_nodes: None,
578            },
579        )
580        .unwrap();
581
582        // Every key in scores must be present in nodes.
583        for node_id in sub.scores.keys() {
584            assert!(
585                sub.nodes.contains(node_id),
586                "scores key {node_id:?} is absent from nodes"
587            );
588        }
589    }
590
591    #[test]
592    fn graphblas_retrieve_edges_connect_only_nodes_in_subgraph() {
593        let (_dir, g) = open_tmp();
594        // Chain: a to b to c to d; seed is a (hops=1 includes {a, b}).
595        let a = g.add_node("N", &json!({})).unwrap();
596        let b = g.add_node("N", &json!({})).unwrap();
597        let c = g.add_node("N", &json!({})).unwrap();
598        let d = g.add_node("N", &json!({})).unwrap();
599        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
600        let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
601        let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
602        g.add_edge(c, d, "E", &json!({})).unwrap();
603        g.rebuild_csr().unwrap();
604
605        let sub = retrieve_with(
606            &g,
607            &[1.0f32, 0.0],
608            &RetrieveOptions {
609                k: 1,
610                hops: 1,
611                max_distance: f32::MAX,
612                max_nodes: None,
613            },
614        )
615        .unwrap();
616
617        assert!(sub.nodes.contains(&a));
618        assert!(sub.nodes.contains(&b));
619        assert!(!sub.nodes.contains(&c));
620        assert!(sub.edges.contains(&e_ab), "edge a to b must be in subgraph");
621        assert_eq!(
622            sub.edges.len(),
623            1,
624            "only a to b is within the 1-hop subgraph"
625        );
626    }
627
628    #[test]
629    fn graphblas_retrieve_max_distance_filters_far_seeds() {
630        let (_dir, g) = open_tmp();
631        let a = g.add_node("N", &json!({})).unwrap();
632        let b = g.add_node("N", &json!({})).unwrap();
633        // a is close to query; b is orthogonal (distance ~1).
634        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
635        g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
636        g.rebuild_csr().unwrap();
637
638        let sub = retrieve_with(
639            &g,
640            &[1.0f32, 0.0, 0.0],
641            &RetrieveOptions {
642                k: 2,
643                hops: 0,
644                max_distance: 0.1,
645                max_nodes: None,
646            },
647        )
648        .unwrap();
649
650        assert_eq!(sub.nodes.len(), 1);
651        assert_eq!(sub.nodes[0], a);
652        assert!(sub.scores.contains_key(&a));
653        assert!(!sub.scores.contains_key(&b));
654    }
655
656    #[test]
657    fn graphblas_retrieve_max_nodes_caps_subgraph() {
658        let (_dir, g) = open_tmp();
659        // Star: a to b, c, d, e
660        let a = g.add_node("N", &json!({})).unwrap();
661        let b = g.add_node("N", &json!({})).unwrap();
662        let c = g.add_node("N", &json!({})).unwrap();
663        let d = g.add_node("N", &json!({})).unwrap();
664        let e = g.add_node("N", &json!({})).unwrap();
665        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
666        g.add_edge(a, b, "E", &json!({})).unwrap();
667        g.add_edge(a, c, "E", &json!({})).unwrap();
668        g.add_edge(a, d, "E", &json!({})).unwrap();
669        g.add_edge(a, e, "E", &json!({})).unwrap();
670        g.rebuild_csr().unwrap();
671
672        let sub = retrieve_with(
673            &g,
674            &[1.0f32, 0.0],
675            &RetrieveOptions {
676                k: 1,
677                hops: 1,
678                max_distance: f32::MAX,
679                max_nodes: Some(3),
680            },
681        )
682        .unwrap();
683
684        assert!(
685            sub.nodes.len() <= 3,
686            "expected at most 3 nodes, got {}",
687            sub.nodes.len()
688        );
689    }
690
691    #[test]
692    fn graphblas_retrieve_scores_contain_seed_distances() {
693        let (_dir, g) = open_tmp();
694        let a = g.add_node("N", &json!({})).unwrap();
695        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
696        g.rebuild_csr().unwrap();
697
698        let sub = retrieve_with(
699            &g,
700            &[1.0f32, 0.0],
701            &RetrieveOptions {
702                k: 1,
703                hops: 0,
704                max_distance: f32::MAX,
705                max_nodes: None,
706            },
707        )
708        .unwrap();
709
710        assert!(sub.scores.contains_key(&a));
711        assert!(
712            sub.scores[&a] < 1e-5,
713            "distance to identical vector must be ~0"
714        );
715    }
716
717    #[test]
718    fn graphblas_retrieve_empty_vector_index_returns_empty() {
719        let (_dir, g) = open_tmp();
720        g.rebuild_csr().unwrap();
721
722        let sub = retrieve_with(&g, &[1.0f32, 0.0], &RetrieveOptions::default()).unwrap();
723
724        assert!(sub.nodes.is_empty());
725        assert!(sub.edges.is_empty());
726        assert!(sub.scores.is_empty());
727    }
728
729    #[test]
730    fn graphblas_retrieve_multiple_seeds_each_expand_independently() {
731        let (_dir, g) = open_tmp();
732        // Mirrors the non-graphblas variant: two disconnected chains
733        // a to b to c; d to e to f, with vectors on a and d.
734        let a = g.add_node("N", &json!({})).unwrap();
735        let b = g.add_node("N", &json!({})).unwrap();
736        let c = g.add_node("N", &json!({})).unwrap();
737        let d = g.add_node("N", &json!({})).unwrap();
738        let e = g.add_node("N", &json!({})).unwrap();
739        let f = g.add_node("N", &json!({})).unwrap();
740        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
741        g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
742        g.add_edge(a, b, "E", &json!({})).unwrap();
743        g.add_edge(b, c, "E", &json!({})).unwrap();
744        g.add_edge(d, e, "E", &json!({})).unwrap();
745        g.add_edge(e, f, "E", &json!({})).unwrap();
746        g.rebuild_csr().unwrap();
747
748        let sub1 = retrieve_with(
749            &g,
750            &[1.0f32, 0.0, 0.0],
751            &RetrieveOptions {
752                k: 2,
753                hops: 1,
754                max_distance: f32::MAX,
755                max_nodes: None,
756            },
757        )
758        .unwrap();
759        assert!(sub1.nodes.contains(&a), "seed a must be present at hops=1");
760        assert!(sub1.nodes.contains(&b), "b is 1 hop from seed a");
761        assert!(sub1.nodes.contains(&d), "seed d must be present at hops=1");
762        assert!(sub1.nodes.contains(&e), "e is 1 hop from seed d");
763        assert!(!sub1.nodes.contains(&c), "c is 2 hops from a, out of range");
764        assert!(!sub1.nodes.contains(&f), "f is 2 hops from d, out of range");
765        assert_eq!(sub1.nodes.len(), 4);
766
767        let sub2 = retrieve_with(
768            &g,
769            &[1.0f32, 0.0, 0.0],
770            &RetrieveOptions {
771                k: 2,
772                hops: 2,
773                max_distance: f32::MAX,
774                max_nodes: None,
775            },
776        )
777        .unwrap();
778        assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
779        assert!(sub2.scores.contains_key(&a));
780        assert!(sub2.scores.contains_key(&d));
781    }
782
783    #[test]
784    fn hybrid_retrieve_vector_only_matches_pure_vector_search() {
785        let (_dir, g) = open_tmp();
786        let a = g.add_node("N", &json!({})).unwrap();
787        let b = g.add_node("N", &json!({})).unwrap();
788        g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
789        g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
790        g.rebuild_csr().unwrap();
791
792        let sub = retrieve_hybrid(
793            &g,
794            &[1.0f32, 0.0, 0.0],
795            "",
796            &HybridRetrieveOptions {
797                vector_k: 1,
798                text_k: 0,
799                hops: 0,
800                ..Default::default()
801            },
802        )
803        .unwrap();
804        assert_eq!(sub.nodes.len(), 1);
805        assert_eq!(sub.nodes[0], a);
806    }
807
808    #[test]
809    fn hybrid_retrieve_fuses_both_sources() {
810        let (_dir, g) = open_tmp();
811        let a = g
812            .add_node("Doc", &json!({"body": "rust graph database storage"}))
813            .unwrap();
814        let b = g
815            .add_node("Doc", &json!({"body": "vector search nearest neighbor"}))
816            .unwrap();
817        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
818        g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
819        g.update(|txn| txn.create_node_text_index("Doc", "body"))
820            .unwrap();
821        g.rebuild_csr().unwrap();
822
823        // b has text match for "vector"; a has vector match for [1, 0].
824        let sub = retrieve_hybrid(
825            &g,
826            &[1.0f32, 0.0],
827            "vector",
828            &HybridRetrieveOptions {
829                vector_k: 1,
830                text_k: 1,
831                text_label: Some("Doc".into()),
832                text_property: Some("body".into()),
833                hops: 0,
834                ..Default::default()
835            },
836        )
837        .unwrap();
838        // Both a (vector hit) and b (text hit) should be in the result.
839        assert!(sub.nodes.contains(&a), "vector hit a must be present");
840        assert!(sub.nodes.contains(&b), "text hit b must be present");
841    }
842
843    #[test]
844    fn hybrid_retrieve_weighted_sum_produces_correct_scores() {
845        let (_dir, g) = open_tmp();
846        let a = g.add_node("Doc", &json!({"body": "alpha bravo"})).unwrap();
847        let b = g
848            .add_node("Doc", &json!({"body": "charlie delta"}))
849            .unwrap();
850        g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
851        g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
852        g.update(|txn| txn.create_node_text_index("Doc", "body"))
853            .unwrap();
854        g.rebuild_csr().unwrap();
855
856        // a is the top vector hit (rank 0) and b is the top text hit (rank 0).
857        // vector_k=1, text_k=1, so normalized rank score = (k - rank) / k = 1.0.
858        // WeightedSum: score = 0.7 * vec_norm + 0.3 * text_norm.
859        // a: vec_norm = 1.0, text_norm = 0.0 => 0.7
860        // b: vec_norm = 0.0, text_norm = 1.0 => 0.3
861        let sub = retrieve_hybrid(
862            &g,
863            &[1.0f32, 0.0],
864            "charlie",
865            &HybridRetrieveOptions {
866                vector_k: 1,
867                text_k: 1,
868                text_label: Some("Doc".into()),
869                text_property: Some("body".into()),
870                hops: 0,
871                fusion: FusionStrategy::WeightedSum {
872                    vector_weight: 0.7,
873                    text_weight: 0.3,
874                },
875                ..Default::default()
876            },
877        )
878        .unwrap();
879
880        assert!(
881            sub.scores.contains_key(&a),
882            "vector seed a must have a score"
883        );
884        assert!(sub.scores.contains_key(&b), "text seed b must have a score");
885        assert!(
886            (sub.scores[&a] - 0.7).abs() < 1e-5,
887            "a score should be 0.7, got {}",
888            sub.scores[&a]
889        );
890        assert!(
891            (sub.scores[&b] - 0.3).abs() < 1e-5,
892            "b score should be 0.3, got {}",
893            sub.scores[&b]
894        );
895    }
896
897    #[test]
898    fn hybrid_retrieve_text_only_returns_text_seeds() {
899        let (_dir, g) = open_tmp();
900        let a = g
901            .add_node("Doc", &json!({"body": "quantum computing research"}))
902            .unwrap();
903        let b = g
904            .add_node("Doc", &json!({"body": "classical music orchestra"}))
905            .unwrap();
906        g.update(|txn| txn.create_node_text_index("Doc", "body"))
907            .unwrap();
908        g.rebuild_csr().unwrap();
909
910        // vector_k=0 disables vector search; only text seeds are used.
911        let sub = retrieve_hybrid(
912            &g,
913            &[],
914            "quantum",
915            &HybridRetrieveOptions {
916                vector_k: 0,
917                text_k: 5,
918                text_label: Some("Doc".into()),
919                text_property: Some("body".into()),
920                hops: 0,
921                ..Default::default()
922            },
923        )
924        .unwrap();
925
926        assert_eq!(
927            sub.nodes.len(),
928            1,
929            "only the text-matching node should appear"
930        );
931        assert_eq!(sub.nodes[0], a);
932        assert!(sub.scores.contains_key(&a));
933        assert!(!sub.nodes.contains(&b), "non-matching node must be absent");
934    }
935}