use crate::KnowledgeGraph;
use rand::prelude::*;
use rand_xorshift::XorShiftRng;
use std::collections::{BTreeMap, BTreeSet, HashMap};
#[must_use]
pub fn sample_neighbors(
kg: &KnowledgeGraph,
nodes: &[String],
k: usize,
seed: u64,
) -> HashMap<String, Vec<String>> {
let mut rng = XorShiftRng::seed_from_u64(seed);
let graph = kg.as_petgraph();
let mut result = HashMap::with_capacity(nodes.len());
for node_id in nodes {
let entity_id = crate::EntityId::from(node_id.as_str());
let neighbors = kg.get_node_index(&entity_id).map_or_else(Vec::new, |idx| {
let all_neighbors: Vec<String> = graph
.neighbors(idx)
.map(|n_idx| graph[n_idx].id.as_str().to_owned())
.collect();
if all_neighbors.len() <= k {
all_neighbors
} else {
all_neighbors
.choose_multiple(&mut rng, k)
.cloned()
.collect()
}
});
result.insert(node_id.clone(), neighbors);
}
result
}
#[must_use]
pub fn sample_neighbors_with_replacement(
kg: &KnowledgeGraph,
nodes: &[String],
k: usize,
seed: u64,
) -> HashMap<String, Vec<String>> {
let mut rng = XorShiftRng::seed_from_u64(seed);
let graph = kg.as_petgraph();
let mut result = HashMap::with_capacity(nodes.len());
for node_id in nodes {
let entity_id = crate::EntityId::from(node_id.as_str());
let neighbors = kg.get_node_index(&entity_id).map_or_else(Vec::new, |idx| {
let all_neighbors: Vec<_> = graph.neighbors(idx).collect();
if all_neighbors.is_empty() {
vec![]
} else {
(0..k)
.map(|_| {
let n_idx = *all_neighbors
.choose(&mut rng)
.unwrap_or_else(|| unreachable!("checked is_empty above"));
graph[n_idx].id.as_str().to_owned()
})
.collect()
}
});
result.insert(node_id.clone(), neighbors);
}
result
}
#[derive(Debug, Clone)]
pub struct SubgraphBatch {
pub target_nodes: Vec<String>,
pub all_nodes: Vec<String>,
pub node_to_idx: HashMap<String, usize>,
pub edges_per_layer: Vec<Vec<(usize, usize)>>,
pub layer_sizes: Vec<usize>,
}
impl SubgraphBatch {
pub fn num_layers(&self) -> usize {
self.edges_per_layer.len()
}
pub fn num_nodes(&self) -> usize {
self.all_nodes.len()
}
pub fn target_indices(&self) -> impl Iterator<Item = usize> {
0..self.target_nodes.len()
}
}
pub struct NeighborSampler<'a> {
kg: &'a KnowledgeGraph,
fanout: Vec<usize>,
}
impl<'a> NeighborSampler<'a> {
pub fn new(kg: &'a KnowledgeGraph, fanout: Vec<usize>) -> Self {
Self { kg, fanout }
}
pub fn sample(&self, seed_nodes: &[String], seed: u64) -> SubgraphBatch {
let mut rng = XorShiftRng::seed_from_u64(seed);
let graph = self.kg.as_petgraph();
let mut all_nodes = Vec::new();
let mut node_to_idx = HashMap::new();
for node in seed_nodes {
if !node_to_idx.contains_key(node) {
node_to_idx.insert(node.clone(), all_nodes.len());
all_nodes.push(node.clone());
}
}
let target_count = all_nodes.len();
let mut layer_sizes = vec![target_count];
let mut edges_per_layer = Vec::new();
let mut frontier: BTreeSet<String> = seed_nodes.iter().cloned().collect();
for &num_neighbors in &self.fanout {
let mut layer_edges = Vec::new();
let mut next_frontier = BTreeSet::new();
for node_id in &frontier {
let src_idx = *node_to_idx
.get(node_id)
.expect("internal: frontier node in index");
let entity_id = crate::EntityId::from(node_id.as_str());
if let Some(node_idx) = self.kg.get_node_index(&entity_id) {
let all_neighbors: Vec<String> = graph
.neighbors(node_idx)
.map(|n_idx| graph[n_idx].id.as_str().to_owned())
.collect();
let sampled: Vec<_> = if all_neighbors.len() <= num_neighbors {
all_neighbors
} else {
all_neighbors
.choose_multiple(&mut rng, num_neighbors)
.cloned()
.collect()
};
for neighbor in sampled {
let dst_idx = if let Some(&idx) = node_to_idx.get(&neighbor) {
idx
} else {
let idx = all_nodes.len();
node_to_idx.insert(neighbor.clone(), idx);
all_nodes.push(neighbor.clone());
idx
};
layer_edges.push((dst_idx, src_idx));
next_frontier.insert(neighbor);
}
}
}
layer_sizes.push(all_nodes.len());
edges_per_layer.push(layer_edges);
frontier = next_frontier;
}
SubgraphBatch {
target_nodes: seed_nodes.to_vec(),
all_nodes,
node_to_idx,
edges_per_layer,
layer_sizes,
}
}
}
pub struct HeteroNeighborSampler<'a> {
kg: &'a crate::HeteroGraph,
fanout: BTreeMap<crate::EdgeType, Vec<usize>>,
}
impl<'a> HeteroNeighborSampler<'a> {
pub fn new(kg: &'a crate::HeteroGraph, fanout: HashMap<crate::EdgeType, Vec<usize>>) -> Self {
Self {
kg,
fanout: fanout.into_iter().collect(),
}
}
pub fn sample(
&self,
seed_type: &crate::NodeType,
seed_indices: &[usize],
seed: u64,
) -> HeteroSubgraphBatch {
let mut rng = XorShiftRng::seed_from_u64(seed);
let mut sampled_nodes: HashMap<crate::NodeType, Vec<usize>> = HashMap::new();
sampled_nodes.insert(seed_type.clone(), seed_indices.to_vec());
let mut sampled_edges: HashMap<crate::EdgeType, Vec<(usize, usize)>> = HashMap::new();
let max_layers = self.fanout.values().map(|v| v.len()).max().unwrap_or(0);
for layer in 0..max_layers {
for (edge_type, fanout_layers) in &self.fanout {
let num_sample = match fanout_layers.get(layer) {
Some(&n) => n,
None => continue, };
let mut new_dst_nodes = Vec::new();
let src_nodes: Vec<usize> = sampled_nodes
.get(&edge_type.src_type)
.cloned()
.unwrap_or_default();
let edges = sampled_edges.entry(edge_type.clone()).or_default();
for src_local_idx in src_nodes {
let neighbors = self.kg.neighbors(edge_type, src_local_idx);
let sampled: Vec<_> = if neighbors.len() <= num_sample {
neighbors
} else {
neighbors
.choose_multiple(&mut rng, num_sample)
.copied()
.collect()
};
for dst_local_idx in sampled {
edges.push((src_local_idx, dst_local_idx));
new_dst_nodes.push(dst_local_idx);
}
}
sampled_nodes
.entry(edge_type.dst_type.clone())
.or_default()
.extend(new_dst_nodes);
}
for nodes in sampled_nodes.values_mut() {
nodes.sort_unstable();
nodes.dedup();
}
}
HeteroSubgraphBatch {
seed_type: seed_type.clone(),
seed_indices: seed_indices.to_vec(),
sampled_nodes,
sampled_edges,
}
}
}
#[derive(Debug, Clone)]
pub struct HeteroSubgraphBatch {
pub seed_type: crate::NodeType,
pub seed_indices: Vec<usize>,
pub sampled_nodes: HashMap<crate::NodeType, Vec<usize>>,
pub sampled_edges: HashMap<crate::EdgeType, Vec<(usize, usize)>>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
#[test]
fn test_sample_neighbors() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("A", "rel", "C"));
kg.add_triple(Triple::new("A", "rel", "D"));
let result = sample_neighbors(&kg, &["A".into()], 2, 42);
let sampled = result.get("A").unwrap();
assert_eq!(sampled.len(), 2);
for s in sampled {
assert!(["B", "C", "D"].contains(&s.as_str()));
}
}
#[test]
fn test_sample_neighbors_all() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("A", "rel", "C"));
let result = sample_neighbors(&kg, &["A".into()], 10, 42);
let sampled = result.get("A").unwrap();
assert_eq!(sampled.len(), 2);
}
#[test]
fn test_sample_neighbors_missing_node() {
let kg = KnowledgeGraph::new();
let result = sample_neighbors(&kg, &["NotExist".into()], 5, 42);
assert!(result.get("NotExist").unwrap().is_empty());
}
#[test]
fn test_sample_with_replacement() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
let result = sample_neighbors_with_replacement(&kg, &["A".into()], 5, 42);
let sampled = result.get("A").unwrap();
assert_eq!(sampled.len(), 5);
assert!(sampled.iter().all(|s| s == "B"));
}
#[test]
fn test_neighbor_sampler_single_hop() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("A", "rel", "C"));
kg.add_triple(Triple::new("A", "rel", "D"));
let sampler = NeighborSampler::new(&kg, vec![2]);
let batch = sampler.sample(&["A".to_string()], 42);
assert_eq!(batch.target_nodes, vec!["A"]);
assert!(batch.all_nodes.contains(&"A".to_string()));
assert_eq!(batch.num_layers(), 1);
assert!(batch.all_nodes.len() <= 4);
assert!(batch.all_nodes.len() >= 2);
}
#[test]
fn test_neighbor_sampler_multi_hop() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("A", "rel", "C"));
kg.add_triple(Triple::new("B", "rel", "D"));
kg.add_triple(Triple::new("C", "rel", "D"));
let sampler = NeighborSampler::new(&kg, vec![2, 2]);
let batch = sampler.sample(&["A".to_string()], 42);
assert_eq!(batch.num_layers(), 2);
assert!(batch.all_nodes.len() >= 3);
assert!(batch.all_nodes.contains(&"A".to_string()));
}
}