use super::common::GraphView;
use std::collections::HashSet;
use rayon::prelude::*;
pub fn count_triangles(view: &GraphView) -> usize {
let n = view.node_count;
if n >= 1000 {
(0..n).into_par_iter().map(|u| {
let u_neighbors: HashSet<_> = view.successors(u).iter()
.chain(view.predecessors(u).iter())
.cloned()
.collect();
let mut count = 0;
for &v in &u_neighbors {
if v <= u { continue; }
let v_neighbors: HashSet<_> = view.successors(v).iter()
.chain(view.predecessors(v).iter())
.cloned()
.collect();
for &w in &v_neighbors {
if w <= v { continue; }
if u_neighbors.contains(&w) {
count += 1;
}
}
}
count
}).sum()
} else {
let mut triangle_count = 0;
for u in 0..n {
let u_neighbors: HashSet<_> = view.successors(u).iter()
.chain(view.predecessors(u).iter())
.cloned()
.collect();
for &v in &u_neighbors {
if v <= u { continue; }
let v_neighbors: HashSet<_> = view.successors(v).iter()
.chain(view.predecessors(v).iter())
.cloned()
.collect();
for &w in &v_neighbors {
if w <= v { continue; }
if u_neighbors.contains(&w) {
triangle_count += 1;
}
}
}
}
triangle_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_triangle_counting() {
let node_count = 4;
let mut outgoing = vec![vec![]; 4];
let mut incoming = vec![vec![]; 4];
for i in 0..4 {
for j in (i+1)..4 {
outgoing[i].push(j);
incoming[j].push(i);
}
}
let view = GraphView::from_adjacency_list(
node_count,
vec![0, 1, 2, 3],
HashMap::new(),
outgoing,
incoming,
None,
);
let count = count_triangles(&view);
assert_eq!(count, 4);
}
}