use crate::graph::schema::{ConnectivityTriple, DirGraph, GraphBackend, InternedKey};
use crate::graph::storage::GraphRead;
use std::collections::{HashMap, HashSet};
use super::{NeighborConnection, NeighborsSchema};
type CountMap = HashMap<(InternedKey, InternedKey, InternedKey), usize>;
pub fn compute_type_connectivity(graph: &DirGraph) -> Vec<ConnectivityTriple> {
let backend = &graph.graph;
let counts: CountMap = match backend {
GraphBackend::Disk(dg) => compute_disk_parallel(dg),
_ => compute_serial(backend),
};
let mut triples: Vec<ConnectivityTriple> = counts
.into_iter()
.map(|((sk, ck, tk), count)| ConnectivityTriple {
src: graph.interner.resolve(sk).to_string(),
conn: graph.interner.resolve(ck).to_string(),
tgt: graph.interner.resolve(tk).to_string(),
count,
})
.collect();
triples.sort_by_key(|t| std::cmp::Reverse(t.count));
triples
}
fn compute_disk_parallel(dg: &crate::graph::storage::disk::graph::DiskGraph) -> CountMap {
use crate::graph::storage::disk::csr::TOMBSTONE_EDGE;
use petgraph::graph::NodeIndex;
use rayon::prelude::*;
let total = (dg.next_edge_idx as usize).min(dg.edge_endpoints.len());
if total == 0 {
return HashMap::new();
}
dg.edge_endpoints.advise_sequential();
let chunk = (total / rayon::current_num_threads().max(1)).max(1 << 20);
let ranges: Vec<(usize, usize)> = (0..total)
.step_by(chunk)
.map(|lo| (lo, (lo + chunk).min(total)))
.collect();
let shard_maps: Vec<CountMap> = ranges
.into_par_iter()
.map(|(lo, hi)| {
let mut acc: CountMap = HashMap::new();
for i in lo..hi {
let ep = dg.edge_endpoints.get(i);
if ep.source == TOMBSTONE_EDGE {
continue;
}
let src = NodeIndex::new(ep.source as usize);
let tgt = NodeIndex::new(ep.target as usize);
if let (Some(sk), Some(tk)) = (dg.node_type_of(src), dg.node_type_of(tgt)) {
let conn = InternedKey::from_u64(ep.connection_type);
*acc.entry((sk, conn, tk)).or_insert(0) += 1;
}
}
acc
})
.collect();
dg.edge_endpoints.advise_dontneed();
let mut combined: CountMap = HashMap::new();
for shard in shard_maps {
for (k, v) in shard {
*combined.entry(k).or_insert(0) += v;
}
}
combined
}
fn compute_serial(backend: &GraphBackend) -> CountMap {
let mut counts: CountMap = HashMap::new();
backend.for_each_edge_endpoint_key(|src_idx, tgt_idx, conn_key| {
let src_key = backend.node_type_of(src_idx);
let tgt_key = backend.node_type_of(tgt_idx);
if let (Some(sk), Some(tk)) = (src_key, tgt_key) {
*counts.entry((sk, conn_key, tk)).or_insert(0) += 1;
}
});
counts
}
pub fn neighbors_from_triples(triples: &[ConnectivityTriple], node_type: &str) -> NeighborsSchema {
let mut outgoing: Vec<NeighborConnection> = Vec::new();
let mut incoming: Vec<NeighborConnection> = Vec::new();
for t in triples {
if t.src == node_type {
outgoing.push(NeighborConnection {
connection_type: t.conn.clone(),
other_type: t.tgt.clone(),
count: t.count,
});
}
if t.tgt == node_type {
incoming.push(NeighborConnection {
connection_type: t.conn.clone(),
other_type: t.src.clone(),
count: t.count,
});
}
}
outgoing.sort_by_key(|o| std::cmp::Reverse(o.count));
incoming.sort_by_key(|i| std::cmp::Reverse(i.count));
NeighborsSchema { outgoing, incoming }
}
pub struct TypeConnectivityIndex {
index: HashMap<String, NeighborsSchema>,
}
impl TypeConnectivityIndex {
pub fn from_triples(triples: &[ConnectivityTriple]) -> Self {
let mut out_map: HashMap<String, Vec<NeighborConnection>> = HashMap::new();
let mut in_map: HashMap<String, Vec<NeighborConnection>> = HashMap::new();
for t in triples {
out_map
.entry(t.src.clone())
.or_default()
.push(NeighborConnection {
connection_type: t.conn.clone(),
other_type: t.tgt.clone(),
count: t.count,
});
in_map
.entry(t.tgt.clone())
.or_default()
.push(NeighborConnection {
connection_type: t.conn.clone(),
other_type: t.src.clone(),
count: t.count,
});
}
let all_types: HashSet<String> = out_map.keys().chain(in_map.keys()).cloned().collect();
let mut index = HashMap::with_capacity(all_types.len());
for nt in all_types {
let mut outgoing = out_map.remove(&nt).unwrap_or_default();
outgoing.sort_by_key(|o| std::cmp::Reverse(o.count));
let mut incoming = in_map.remove(&nt).unwrap_or_default();
incoming.sort_by_key(|i| std::cmp::Reverse(i.count));
index.insert(nt, NeighborsSchema { outgoing, incoming });
}
TypeConnectivityIndex { index }
}
pub fn get(&self, node_type: &str) -> NeighborsSchema {
self.index
.get(node_type)
.cloned()
.unwrap_or(NeighborsSchema {
outgoing: Vec::new(),
incoming: Vec::new(),
})
}
}
pub struct DerivedEdgeStats {
pub counts: HashMap<String, usize>,
pub endpoints: HashMap<String, (HashSet<String>, HashSet<String>)>,
}
pub fn derive_edge_counts_from_triples(triples: &[ConnectivityTriple]) -> DerivedEdgeStats {
let mut counts: HashMap<String, usize> = HashMap::new();
let mut endpoints: HashMap<String, (HashSet<String>, HashSet<String>)> = HashMap::new();
for t in triples {
*counts.entry(t.conn.clone()).or_insert(0) += t.count;
let entry = endpoints
.entry(t.conn.clone())
.or_insert_with(|| (HashSet::new(), HashSet::new()));
entry.0.insert(t.src.clone());
entry.1.insert(t.tgt.clone());
}
DerivedEdgeStats { counts, endpoints }
}