use crate::core::{Graph, IgraphError, IgraphResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TriadType {
T003 = 0,
T012 = 1,
T102 = 2,
T021D = 3,
T021U = 4,
T021C = 5,
T111D = 6,
T111U = 7,
T030T = 8,
T030C = 9,
T201 = 10,
T120D = 11,
T120U = 12,
T120C = 13,
T210 = 14,
T300 = 15,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TriadCensus {
pub counts: [f64; 16],
}
impl TriadCensus {
pub fn get(&self, triad_type: TriadType) -> f64 {
self.counts[triad_type as usize]
}
}
pub fn triad_census(graph: &Graph) -> IgraphResult<TriadCensus> {
let n = graph.vcount();
if n < 3 {
return Ok(TriadCensus { counts: [0.0; 16] });
}
let adj = build_dyad_matrix(graph)?;
let mut counts = [0.0_f64; 16];
for i in 0..n {
for j in (i + 1)..n {
for k in (j + 1)..n {
let ab = adj[(i as usize) * (n as usize) + (j as usize)];
let ac = adj[(i as usize) * (n as usize) + (k as usize)];
let bc = adj[(j as usize) * (n as usize) + (k as usize)];
let idx = lookup_triad_type(ab, ac, bc);
counts[idx] += 1.0;
}
}
}
Ok(TriadCensus { counts })
}
fn build_dyad_matrix(graph: &Graph) -> IgraphResult<Vec<u8>> {
let n = graph.vcount();
let size = (n as usize)
.checked_mul(n as usize)
.ok_or_else(|| IgraphError::InvalidArgument("graph too large for triad census".into()))?;
let mut matrix = vec![0u8; size];
let nn = n as usize;
let ecount = graph.ecount();
for eid in 0..ecount {
#[allow(clippy::cast_possible_truncation)]
let (src, tgt) = graph.edge(eid as u32)?;
if src == tgt {
continue;
}
let idx_st = (src as usize) * nn + (tgt as usize);
let idx_ts = (tgt as usize) * nn + (src as usize);
matrix[idx_st] |= 1;
matrix[idx_ts] |= 2;
}
if !graph.is_directed() {
for cell in &mut matrix {
if *cell != 0 {
*cell = 3;
}
}
}
Ok(matrix)
}
fn lookup_triad_type(ab: u8, ac: u8, bc: u8) -> usize {
let mut m = 0u8;
let mut a = 0u8;
let mut n_count = 0u8;
for &d in &[ab, ac, bc] {
match d {
0 => n_count += 1,
3 => m += 1,
_ => a += 1,
}
}
match (m, a, n_count) {
(0, 1, 2) => 1, (1, 0, 2) => 2, (0, 2, 1) => classify_021(ab, ac, bc),
(1, 1, 1) => classify_111(ab, ac, bc),
(0, 3, 0) => classify_030(ab, ac, bc),
(2, 0, 1) => 10, (1, 2, 0) => classify_120(ab, ac, bc),
(2, 1, 0) => 14, (3, 0, 0) => 15, _ => 0, }
}
fn classify_021(ab: u8, ac: u8, bc: u8) -> usize {
let (from_center_1, from_center_2) = if bc == 0 {
(ab, ac)
} else if ac == 0 {
(flip_dyad(ab), bc)
} else {
(flip_dyad(ac), flip_dyad(bc))
};
match (from_center_1, from_center_2) {
(1, 1) => 3, (2, 2) => 4, _ => 5, }
}
fn classify_111(ab: u8, ac: u8, bc: u8) -> usize {
let asym_from_mutual_vertex = if ab == 3 {
if ac != 0 { ac } else { bc }
} else if ac == 3 {
if ab != 0 { ab } else { flip_dyad(bc) }
} else {
if ab != 0 {
flip_dyad(ab)
} else {
flip_dyad(ac)
}
};
if asym_from_mutual_vertex == 1 {
7 } else {
6 }
}
fn classify_030(ab: u8, ac: u8, bc: u8) -> usize {
let mut out_a = 0u8;
let mut out_b = 0u8;
let mut out_c = 0u8;
if ab == 1 {
out_a += 1;
} else {
out_b += 1;
}
if ac == 1 {
out_a += 1;
} else {
out_c += 1;
}
if bc == 1 {
out_b += 1;
} else {
out_c += 1;
}
if out_a == 2 || out_b == 2 || out_c == 2 {
8 } else {
9 }
}
fn classify_120(ab: u8, ac: u8, bc: u8) -> usize {
let (to_third_1, to_third_2) = if ab == 3 {
(ac, bc)
} else if ac == 3 {
(ab, flip_dyad(bc))
} else {
(flip_dyad(ab), flip_dyad(ac))
};
match (to_third_1, to_third_2) {
(2, 2) => 11, (1, 1) => 12, _ => 13, }
}
fn flip_dyad(d: u8) -> u8 {
match d {
1 => 2,
2 => 1,
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_graph() {
let g = Graph::new(0, true).unwrap();
let tc = triad_census(&g).unwrap();
assert!(tc.counts.iter().all(|&c| c.abs() < 1e-10));
}
#[test]
fn test_two_vertices() {
let g = Graph::new(2, true).unwrap();
let tc = triad_census(&g).unwrap();
assert!(tc.counts.iter().all(|&c| c.abs() < 1e-10));
}
#[test]
fn test_three_vertices_no_edges() {
let g = Graph::new(3, true).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T003) - 1.0).abs() < 1e-10);
assert!(tc.counts[1..].iter().all(|&c| c.abs() < 1e-10));
}
#[test]
fn test_single_edge() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T012) - 1.0).abs() < 1e-10);
assert!((tc.get(TriadType::T003)).abs() < 1e-10);
}
#[test]
fn test_mutual_edge() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 0).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T102) - 1.0).abs() < 1e-10);
}
#[test]
fn test_directed_3_cycle() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(2, 0).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T030C) - 1.0).abs() < 1e-10);
}
#[test]
fn test_transitive_triple() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T030T) - 1.0).abs() < 1e-10);
}
#[test]
fn test_complete_directed() {
let mut g = Graph::new(3, true).unwrap();
for i in 0..3u32 {
for j in 0..3u32 {
if i != j {
g.add_edge(i, j).unwrap();
}
}
}
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T300) - 1.0).abs() < 1e-10);
}
#[test]
fn test_021d_out_star() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(1, 0).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T021D) - 1.0).abs() < 1e-10);
}
#[test]
fn test_021u_in_star() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(2, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T021U) - 1.0).abs() < 1e-10);
}
#[test]
fn test_021c_chain() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T021C) - 1.0).abs() < 1e-10);
}
#[test]
fn test_four_vertices_sum() {
let mut g = Graph::new(4, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(2, 0).unwrap();
let tc = triad_census(&g).unwrap();
let total: f64 = tc.counts.iter().sum();
assert!((total - 4.0).abs() < 1e-10);
}
#[test]
fn test_201_two_mutual() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 0).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(2, 0).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T201) - 1.0).abs() < 1e-10);
}
#[test]
fn test_undirected_triangle() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(0, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T300) - 1.0).abs() < 1e-10);
}
#[test]
fn test_undirected_path() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T201) - 1.0).abs() < 1e-10);
}
#[test]
fn test_counts_sum_to_total() {
let mut g = Graph::new(5, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(2, 3).unwrap();
g.add_edge(3, 4).unwrap();
g.add_edge(4, 0).unwrap();
let tc = triad_census(&g).unwrap();
let total: f64 = tc.counts.iter().sum();
assert!((total - 10.0).abs() < 1e-10);
}
#[test]
fn test_111d() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 0).unwrap();
g.add_edge(2, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T111D) - 1.0).abs() < 1e-10);
}
#[test]
fn test_111u() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 0).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T111U) - 1.0).abs() < 1e-10);
}
#[test]
fn test_210() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(1, 2).unwrap();
g.add_edge(2, 1).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(2, 0).unwrap();
g.add_edge(0, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T210) - 1.0).abs() < 1e-10);
}
#[test]
fn test_self_loops_ignored() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 0).unwrap();
g.add_edge(0, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T012) - 1.0).abs() < 1e-10);
}
#[test]
fn test_120d() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(2, 0).unwrap();
g.add_edge(1, 0).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T120D) - 1.0).abs() < 1e-10);
}
#[test]
fn test_120u() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(2, 0).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(2, 1).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T120U) - 1.0).abs() < 1e-10);
}
#[test]
fn test_120c() {
let mut g = Graph::new(3, true).unwrap();
g.add_edge(0, 2).unwrap();
g.add_edge(2, 0).unwrap();
g.add_edge(0, 1).unwrap();
g.add_edge(1, 2).unwrap();
let tc = triad_census(&g).unwrap();
assert!((tc.get(TriadType::T120C) - 1.0).abs() < 1e-10);
}
}