use std::collections::HashMap;
use crate::error::RetrievalError;
use ahash::{AHashMap, AHashSet};
use issundb_core::{EdgeId, Graph, NodeId};
use issundb_text::{TextGraphExt, TextSearchOptions};
use issundb_vector::{VectorGraphExt, VectorSearchOptions};
pub struct Subgraph {
pub nodes: Vec<NodeId>,
pub edges: Vec<EdgeId>,
pub scores: HashMap<NodeId, f32>,
}
pub struct RetrieveOptions {
pub k: usize,
pub hops: u8,
pub max_distance: f32,
pub max_nodes: Option<usize>,
}
impl Default for RetrieveOptions {
fn default() -> Self {
Self {
k: 10,
hops: 2,
max_distance: f32::MAX,
max_nodes: None,
}
}
}
pub fn retrieve(graph: &Graph, q: &[f32], k: usize, hops: u8) -> Result<Subgraph, RetrievalError> {
retrieve_with(
graph,
q,
&RetrieveOptions {
k,
hops,
..Default::default()
},
)
}
pub fn retrieve_with(
graph: &Graph,
q: &[f32],
opts: &RetrieveOptions,
) -> Result<Subgraph, RetrievalError> {
let hits = graph.vector_search(q, opts.k)?;
let mut scores: AHashMap<NodeId, f32> = AHashMap::new();
let mut seeds = Vec::new();
for hit in &hits {
if hit.distance <= opts.max_distance {
scores.insert(hit.node, hit.distance);
seeds.push(hit.node);
}
}
if seeds.is_empty() {
return Ok(Subgraph {
nodes: Vec::new(),
edges: Vec::new(),
scores: HashMap::new(),
});
}
let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
scores.retain(|n, _| node_set.contains(n));
let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
for &node in &node_set {
for ne in graph.out_neighbors(node)? {
if node_set.contains(&ne.node) {
edge_set.insert(ne.edge);
}
}
}
Ok(Subgraph {
nodes: node_set.into_iter().collect(),
edges: edge_set.into_iter().collect(),
scores: scores.into_iter().collect(),
})
}
#[derive(Debug, Clone)]
pub enum FusionStrategy {
Rrf { k: u32 },
WeightedSum {
vector_weight: f32,
text_weight: f32,
},
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::Rrf { k: 60 }
}
}
pub struct HybridRetrieveOptions {
pub vector_k: usize,
pub text_k: usize,
pub text_label: Option<String>,
pub text_property: Option<String>,
pub hops: u8,
pub max_distance: f32,
pub max_nodes: Option<usize>,
pub vector_label: Option<String>,
pub fusion: FusionStrategy,
}
impl Default for HybridRetrieveOptions {
fn default() -> Self {
Self {
vector_k: 10,
text_k: 10,
text_label: None,
text_property: None,
hops: 2,
max_distance: f32::MAX,
max_nodes: None,
vector_label: None,
fusion: FusionStrategy::default(),
}
}
}
pub fn retrieve_hybrid(
graph: &Graph,
q: &[f32],
text_query: &str,
opts: &HybridRetrieveOptions,
) -> Result<Subgraph, RetrievalError> {
let mut vec_ranks: AHashMap<NodeId, usize> = AHashMap::new();
let mut vec_scores: AHashMap<NodeId, f32> = AHashMap::new();
if opts.vector_k > 0 && !q.is_empty() {
let hits = graph.vector_search_with(
q,
&VectorSearchOptions {
k: opts.vector_k,
label: opts.vector_label.clone(),
properties: None,
rescore_factor: None,
},
)?;
for (rank, hit) in hits.iter().enumerate() {
if hit.distance <= opts.max_distance {
vec_ranks.insert(hit.node, rank);
vec_scores.insert(hit.node, hit.distance);
}
}
}
let mut text_ranks: AHashMap<NodeId, usize> = AHashMap::new();
if opts.text_k > 0 && !text_query.is_empty() {
let text_opts = TextSearchOptions {
label: opts.text_label.clone(),
property: opts.text_property.clone(),
limit: opts.text_k,
..Default::default()
};
let text_hits = graph.text_search(text_query, &text_opts)?;
for (rank, hit) in text_hits.iter().enumerate() {
text_ranks.insert(hit.node, rank);
}
}
let mut fused: AHashMap<NodeId, f32> = AHashMap::new();
let all_nodes: AHashSet<NodeId> = vec_ranks.keys().chain(text_ranks.keys()).copied().collect();
for node in &all_nodes {
let score = match &opts.fusion {
FusionStrategy::Rrf { k } => {
let kf = *k as f32;
let vs = vec_ranks
.get(node)
.map(|r| 1.0 / (kf + *r as f32 + 1.0))
.unwrap_or(0.0);
let ts = text_ranks
.get(node)
.map(|r| 1.0 / (kf + *r as f32 + 1.0))
.unwrap_or(0.0);
vs + ts
}
FusionStrategy::WeightedSum {
vector_weight,
text_weight,
} => {
let total_vec = opts.vector_k.max(1) as f32;
let total_txt = opts.text_k.max(1) as f32;
let vs = vec_ranks
.get(node)
.map(|r| (total_vec - *r as f32) / total_vec)
.unwrap_or(0.0);
let ts = text_ranks
.get(node)
.map(|r| (total_txt - *r as f32) / total_txt)
.unwrap_or(0.0);
vector_weight * vs + text_weight * ts
}
};
fused.insert(*node, score);
}
let seeds: Vec<NodeId> = fused.keys().copied().collect();
if seeds.is_empty() {
return Ok(Subgraph {
nodes: Vec::new(),
edges: Vec::new(),
scores: HashMap::new(),
});
}
let node_list = graph.bfs_multi_source_graphblas(&seeds, opts.hops, opts.max_nodes)?;
let node_set: AHashSet<NodeId> = node_list.into_iter().collect();
let mut scores: AHashMap<NodeId, f32> = fused;
scores.retain(|n, _| node_set.contains(n));
let mut edge_set: AHashSet<EdgeId> = AHashSet::new();
for &node in &node_set {
for ne in graph.out_neighbors(node)? {
if node_set.contains(&ne.node) {
edge_set.insert(ne.edge);
}
}
}
Ok(Subgraph {
nodes: node_set.into_iter().collect(),
edges: edge_set.into_iter().collect(),
scores: scores.into_iter().collect(),
})
}
#[cfg(test)]
mod tests {
use serde_json::json;
use tempfile::TempDir;
use super::*;
fn open_tmp() -> (TempDir, Graph) {
let dir = TempDir::new().unwrap();
let g = Graph::open(dir.path(), 1).unwrap();
(dir, g)
}
#[test]
fn retrieve_empty_vector_index_returns_empty_subgraph() {
let (_dir, g) = open_tmp();
let sub = retrieve(&g, &[1.0f32, 0.0], 5, 2).unwrap();
assert!(sub.nodes.is_empty());
assert!(sub.edges.is_empty());
assert!(sub.scores.is_empty());
}
#[test]
fn retrieve_hops_zero_returns_only_seed_nodes() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
g.add_edge(a, c, "E", &json!({})).unwrap();
let sub = retrieve(&g, &[1.0f32, 0.0, 0.0], 1, 0).unwrap();
assert_eq!(sub.nodes.len(), 1);
assert_eq!(sub.nodes[0], a);
assert!(!sub.nodes.contains(&c));
}
#[test]
fn retrieve_expands_bfs_to_correct_depth() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.add_edge(b, c, "E", &json!({})).unwrap();
g.add_edge(c, d, "E", &json!({})).unwrap();
let sub1 = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
let sub2 = retrieve(&g, &[1.0f32, 0.0], 1, 2).unwrap();
let mut n1 = sub1.nodes.clone();
n1.sort_unstable();
assert_eq!(n1, vec![a, b]);
let mut n2 = sub2.nodes.clone();
n2.sort_unstable();
assert_eq!(n2, vec![a, b, c]);
}
#[test]
fn retrieve_subgraph_edges_connect_only_nodes_in_set() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
let sub = retrieve(&g, &[1.0f32, 0.0], 1, 1).unwrap();
assert!(sub.edges.contains(&e_ab));
assert_eq!(sub.edges.len(), 1);
}
#[test]
fn retrieve_scores_map_contains_seed_distances() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
let sub = retrieve(&g, &[1.0f32, 0.0], 1, 0).unwrap();
assert!(sub.scores.contains_key(&a));
assert!(sub.scores[&a] < 1e-5);
}
#[test]
fn retrieve_with_max_distance_filters_far_seeds() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 0,
max_distance: 0.1,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub.nodes.len(), 1);
assert_eq!(sub.nodes[0], a);
}
#[test]
fn retrieve_with_max_nodes_caps_subgraph() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
let e = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.add_edge(a, c, "E", &json!({})).unwrap();
g.add_edge(a, d, "E", &json!({})).unwrap();
g.add_edge(a, e, "E", &json!({})).unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 1,
max_distance: f32::MAX,
max_nodes: Some(3),
},
)
.unwrap();
assert!(sub.nodes.len() <= 3);
}
#[test]
fn retrieve_with_multiple_seeds_each_expand_independently() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
let e = g.add_node("N", &json!({})).unwrap();
let f = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.add_edge(b, c, "E", &json!({})).unwrap();
g.add_edge(d, e, "E", &json!({})).unwrap();
g.add_edge(e, f, "E", &json!({})).unwrap();
let sub1 = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 1,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
let mut n1 = sub1.nodes.clone();
n1.sort_unstable();
assert!(n1.contains(&a), "seed a must be present at hops=1");
assert!(n1.contains(&b), "b is 1 hop from seed a");
assert!(n1.contains(&d), "seed d must be present at hops=1");
assert!(n1.contains(&e), "e is 1 hop from seed d");
assert!(!n1.contains(&c), "c is 2 hops from a, out of range");
assert!(!n1.contains(&f), "f is 2 hops from d, out of range");
assert_eq!(n1.len(), 4);
let sub2 = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 2,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
assert!(sub2.scores.contains_key(&a));
assert!(sub2.scores.contains_key(&d));
}
#[test]
fn graphblas_retrieve_k_hop_expansion() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 1,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub.nodes.len(), 2);
assert!(sub.nodes.contains(&a));
assert!(sub.nodes.contains(&b));
}
#[test]
fn graphblas_retrieve_hops_zero_returns_only_seed() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 0,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub.nodes, vec![a]);
assert!(sub.edges.is_empty(), "no edges when hops=0");
}
#[test]
fn graphblas_retrieve_scores_keys_are_subset_of_nodes() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(b, &[0.9f32, 0.1, 0.0]).unwrap();
g.add_edge(a, c, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 1,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
for node_id in sub.scores.keys() {
assert!(
sub.nodes.contains(node_id),
"scores key {node_id:?} is absent from nodes"
);
}
}
#[test]
fn graphblas_retrieve_edges_connect_only_nodes_in_subgraph() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
let e_ab = g.add_edge(a, b, "E", &json!({})).unwrap();
let _e_bc = g.add_edge(b, c, "E", &json!({})).unwrap();
g.add_edge(c, d, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 1,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert!(sub.nodes.contains(&a));
assert!(sub.nodes.contains(&b));
assert!(!sub.nodes.contains(&c));
assert!(sub.edges.contains(&e_ab), "edge a to b must be in subgraph");
assert_eq!(
sub.edges.len(),
1,
"only a to b is within the 1-hop subgraph"
);
}
#[test]
fn graphblas_retrieve_max_distance_filters_far_seeds() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 0,
max_distance: 0.1,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub.nodes.len(), 1);
assert_eq!(sub.nodes[0], a);
assert!(sub.scores.contains_key(&a));
assert!(!sub.scores.contains_key(&b));
}
#[test]
fn graphblas_retrieve_max_nodes_caps_subgraph() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
let e = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.add_edge(a, c, "E", &json!({})).unwrap();
g.add_edge(a, d, "E", &json!({})).unwrap();
g.add_edge(a, e, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 1,
max_distance: f32::MAX,
max_nodes: Some(3),
},
)
.unwrap();
assert!(
sub.nodes.len() <= 3,
"expected at most 3 nodes, got {}",
sub.nodes.len()
);
}
#[test]
fn graphblas_retrieve_scores_contain_seed_distances() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_with(
&g,
&[1.0f32, 0.0],
&RetrieveOptions {
k: 1,
hops: 0,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert!(sub.scores.contains_key(&a));
assert!(
sub.scores[&a] < 1e-5,
"distance to identical vector must be ~0"
);
}
#[test]
fn graphblas_retrieve_empty_vector_index_returns_empty() {
let (_dir, g) = open_tmp();
g.rebuild_csr().unwrap();
let sub = retrieve_with(&g, &[1.0f32, 0.0], &RetrieveOptions::default()).unwrap();
assert!(sub.nodes.is_empty());
assert!(sub.edges.is_empty());
assert!(sub.scores.is_empty());
}
#[test]
fn graphblas_retrieve_multiple_seeds_each_expand_independently() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
let c = g.add_node("N", &json!({})).unwrap();
let d = g.add_node("N", &json!({})).unwrap();
let e = g.add_node("N", &json!({})).unwrap();
let f = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(d, &[0.0f32, 1.0, 0.0]).unwrap();
g.add_edge(a, b, "E", &json!({})).unwrap();
g.add_edge(b, c, "E", &json!({})).unwrap();
g.add_edge(d, e, "E", &json!({})).unwrap();
g.add_edge(e, f, "E", &json!({})).unwrap();
g.rebuild_csr().unwrap();
let sub1 = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 1,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert!(sub1.nodes.contains(&a), "seed a must be present at hops=1");
assert!(sub1.nodes.contains(&b), "b is 1 hop from seed a");
assert!(sub1.nodes.contains(&d), "seed d must be present at hops=1");
assert!(sub1.nodes.contains(&e), "e is 1 hop from seed d");
assert!(!sub1.nodes.contains(&c), "c is 2 hops from a, out of range");
assert!(!sub1.nodes.contains(&f), "f is 2 hops from d, out of range");
assert_eq!(sub1.nodes.len(), 4);
let sub2 = retrieve_with(
&g,
&[1.0f32, 0.0, 0.0],
&RetrieveOptions {
k: 2,
hops: 2,
max_distance: f32::MAX,
max_nodes: None,
},
)
.unwrap();
assert_eq!(sub2.nodes.len(), 6, "all six nodes reachable within 2 hops");
assert!(sub2.scores.contains_key(&a));
assert!(sub2.scores.contains_key(&d));
}
#[test]
fn hybrid_retrieve_vector_only_matches_pure_vector_search() {
let (_dir, g) = open_tmp();
let a = g.add_node("N", &json!({})).unwrap();
let b = g.add_node("N", &json!({})).unwrap();
g.upsert_vector(a, &[1.0f32, 0.0, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0, 0.0]).unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_hybrid(
&g,
&[1.0f32, 0.0, 0.0],
"",
&HybridRetrieveOptions {
vector_k: 1,
text_k: 0,
hops: 0,
..Default::default()
},
)
.unwrap();
assert_eq!(sub.nodes.len(), 1);
assert_eq!(sub.nodes[0], a);
}
#[test]
fn hybrid_retrieve_fuses_both_sources() {
let (_dir, g) = open_tmp();
let a = g
.add_node("Doc", &json!({"body": "rust graph database storage"}))
.unwrap();
let b = g
.add_node("Doc", &json!({"body": "vector search nearest neighbor"}))
.unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
g.update(|txn| txn.create_node_text_index("Doc", "body"))
.unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_hybrid(
&g,
&[1.0f32, 0.0],
"vector",
&HybridRetrieveOptions {
vector_k: 1,
text_k: 1,
text_label: Some("Doc".into()),
text_property: Some("body".into()),
hops: 0,
..Default::default()
},
)
.unwrap();
assert!(sub.nodes.contains(&a), "vector hit a must be present");
assert!(sub.nodes.contains(&b), "text hit b must be present");
}
#[test]
fn hybrid_retrieve_weighted_sum_produces_correct_scores() {
let (_dir, g) = open_tmp();
let a = g.add_node("Doc", &json!({"body": "alpha bravo"})).unwrap();
let b = g
.add_node("Doc", &json!({"body": "charlie delta"}))
.unwrap();
g.upsert_vector(a, &[1.0f32, 0.0]).unwrap();
g.upsert_vector(b, &[0.0f32, 1.0]).unwrap();
g.update(|txn| txn.create_node_text_index("Doc", "body"))
.unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_hybrid(
&g,
&[1.0f32, 0.0],
"charlie",
&HybridRetrieveOptions {
vector_k: 1,
text_k: 1,
text_label: Some("Doc".into()),
text_property: Some("body".into()),
hops: 0,
fusion: FusionStrategy::WeightedSum {
vector_weight: 0.7,
text_weight: 0.3,
},
..Default::default()
},
)
.unwrap();
assert!(
sub.scores.contains_key(&a),
"vector seed a must have a score"
);
assert!(sub.scores.contains_key(&b), "text seed b must have a score");
assert!(
(sub.scores[&a] - 0.7).abs() < 1e-5,
"a score should be 0.7, got {}",
sub.scores[&a]
);
assert!(
(sub.scores[&b] - 0.3).abs() < 1e-5,
"b score should be 0.3, got {}",
sub.scores[&b]
);
}
#[test]
fn hybrid_retrieve_text_only_returns_text_seeds() {
let (_dir, g) = open_tmp();
let a = g
.add_node("Doc", &json!({"body": "quantum computing research"}))
.unwrap();
let b = g
.add_node("Doc", &json!({"body": "classical music orchestra"}))
.unwrap();
g.update(|txn| txn.create_node_text_index("Doc", "body"))
.unwrap();
g.rebuild_csr().unwrap();
let sub = retrieve_hybrid(
&g,
&[],
"quantum",
&HybridRetrieveOptions {
vector_k: 0,
text_k: 5,
text_label: Some("Doc".into()),
text_property: Some("body".into()),
hops: 0,
..Default::default()
},
)
.unwrap();
assert_eq!(
sub.nodes.len(),
1,
"only the text-matching node should appear"
);
assert_eq!(sub.nodes[0], a);
assert!(sub.scores.contains_key(&a));
assert!(!sub.nodes.contains(&b), "non-matching node must be absent");
}
}