use crate::types::CsrGraph;
use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct TriangleCountResult {
pub total_triangles: u64,
pub per_node_triangles: Vec<u64>,
pub clustering_coefficients: Vec<f64>,
pub global_clustering_coefficient: f64,
}
#[derive(Debug, Clone)]
pub struct TriangleCounting {
metadata: KernelMetadata,
}
impl TriangleCounting {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("graph/triangle-counting", Domain::GraphAnalytics)
.with_description("Triangle counting (node-iterator algorithm)")
.with_throughput(50_000)
.with_latency_us(20.0),
}
}
pub fn compute(graph: &CsrGraph) -> TriangleCountResult {
let n = graph.num_nodes;
let mut total_triangles = 0u64;
let mut per_node_triangles = vec![0u64; n];
for u in 0..n {
let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
for &v in graph.neighbors(u as u64) {
let v = v as usize;
if u >= v {
continue;
}
for &w in graph.neighbors(v as u64) {
let w_usize = w as usize;
if v >= w_usize {
continue;
}
if neighbors_u.contains(&w) {
total_triangles += 1;
per_node_triangles[u] += 1;
per_node_triangles[v] += 1;
per_node_triangles[w_usize] += 1;
}
}
}
}
let mut clustering_coefficients = vec![0.0f64; n];
let mut total_possible = 0u64;
let mut total_actual = 0u64;
for i in 0..n {
let degree = graph.out_degree(i as u64);
if degree >= 2 {
let possible = degree * (degree - 1) / 2;
total_possible += possible;
total_actual += per_node_triangles[i];
clustering_coefficients[i] = per_node_triangles[i] as f64 / possible as f64;
}
}
let global_clustering_coefficient = if total_possible > 0 {
total_actual as f64 / total_possible as f64
} else {
0.0
};
TriangleCountResult {
total_triangles,
per_node_triangles,
clustering_coefficients,
global_clustering_coefficient,
}
}
pub fn count_node_triangles(graph: &CsrGraph, node: u64) -> u64 {
let neighbors: HashSet<u64> = graph.neighbors(node).iter().copied().collect();
let mut count = 0u64;
for &v in graph.neighbors(node) {
for &w in graph.neighbors(v) {
if w != node && neighbors.contains(&w) {
count += 1;
}
}
}
count / 2
}
}
impl Default for TriangleCounting {
fn default() -> Self {
Self::new()
}
}
impl GpuKernel for TriangleCounting {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TriadType {
Empty,
Edge,
Wedge,
Triangle,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MotifResult {
pub motif_counts: std::collections::HashMap<String, u64>,
}
#[derive(Debug, Clone)]
pub struct MotifDetection {
metadata: KernelMetadata,
}
impl MotifDetection {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("graph/motif-detection", Domain::GraphAnalytics)
.with_description("Motif detection (k-node subgraph census)")
.with_throughput(10_000)
.with_latency_us(100.0),
}
}
pub fn count_triads(graph: &CsrGraph) -> MotifResult {
let n = graph.num_nodes;
let mut triangles = 0u64;
let mut wedges = 0u64;
for u in 0..n {
let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
let degree_u = neighbors_u.len();
if degree_u >= 2 {
let potential_wedges = (degree_u * (degree_u - 1)) / 2;
let mut triangles_at_u = 0u64;
for &v in graph.neighbors(u as u64) {
for &w in graph.neighbors(v) {
if w != u as u64 && neighbors_u.contains(&w) && v < w {
triangles_at_u += 1;
}
}
}
wedges += potential_wedges as u64 - triangles_at_u;
triangles += triangles_at_u;
}
}
triangles /= 3;
let edges = graph.num_edges as u64 / 2;
let mut motif_counts = std::collections::HashMap::new();
motif_counts.insert("triangles".to_string(), triangles);
motif_counts.insert("wedges".to_string(), wedges);
motif_counts.insert("edges".to_string(), edges);
MotifResult { motif_counts }
}
pub fn classify_triad(graph: &CsrGraph, nodes: [u64; 3]) -> TriadType {
let [a, b, c] = nodes;
let neighbors_a: HashSet<u64> = graph.neighbors(a).iter().copied().collect();
let neighbors_b: HashSet<u64> = graph.neighbors(b).iter().copied().collect();
let ab = neighbors_a.contains(&b);
let ac = neighbors_a.contains(&c);
let bc = neighbors_b.contains(&c);
let edge_count = ab as u8 + ac as u8 + bc as u8;
match edge_count {
0 => TriadType::Empty,
1 => TriadType::Edge,
2 => TriadType::Wedge,
3 => TriadType::Triangle,
_ => unreachable!(),
}
}
}
impl Default for MotifDetection {
fn default() -> Self {
Self::new()
}
}
impl GpuKernel for MotifDetection {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[derive(Debug, Clone)]
pub struct KCliqueDetection {
metadata: KernelMetadata,
}
impl KCliqueDetection {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("graph/k-clique", Domain::GraphAnalytics)
.with_description("K-clique detection")
.with_throughput(1_000)
.with_latency_us(1000.0),
}
}
pub fn find_cliques(graph: &CsrGraph, k: usize) -> Vec<Vec<u64>> {
let n = graph.num_nodes;
let mut cliques = Vec::new();
let adj: Vec<HashSet<u64>> = (0..n)
.map(|i| graph.neighbors(i as u64).iter().copied().collect())
.collect();
let mut current_clique = Vec::new();
let candidates: HashSet<u64> = (0..n as u64).collect();
let excluded: HashSet<u64> = HashSet::new();
Self::bron_kerbosch(
&adj,
&mut current_clique,
candidates,
excluded,
k,
&mut cliques,
);
cliques
}
fn bron_kerbosch(
adj: &[HashSet<u64>],
current: &mut Vec<u64>,
mut candidates: HashSet<u64>,
mut excluded: HashSet<u64>,
k: usize,
cliques: &mut Vec<Vec<u64>>,
) {
if current.len() == k {
cliques.push(current.clone());
return;
}
if current.len() + candidates.len() < k {
return;
}
if candidates.is_empty() {
return;
}
let pivot = candidates
.iter()
.chain(excluded.iter())
.max_by_key(|&&v| adj[v as usize].intersection(&candidates).count())
.copied();
let pivot_neighbors = pivot.map(|p| adj[p as usize].clone()).unwrap_or_default();
let to_explore: Vec<u64> = candidates.difference(&pivot_neighbors).copied().collect();
for v in to_explore {
current.push(v);
let new_candidates: HashSet<u64> =
candidates.intersection(&adj[v as usize]).copied().collect();
let new_excluded: HashSet<u64> =
excluded.intersection(&adj[v as usize]).copied().collect();
Self::bron_kerbosch(adj, current, new_candidates, new_excluded, k, cliques);
current.pop();
candidates.remove(&v);
excluded.insert(v);
}
}
pub fn count_cliques(graph: &CsrGraph, k: usize) -> u64 {
Self::find_cliques(graph, k).len() as u64
}
}
impl Default for KCliqueDetection {
fn default() -> Self {
Self::new()
}
}
impl GpuKernel for KCliqueDetection {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_triangle_graph() -> CsrGraph {
CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 0), (0, 2)])
}
fn create_square_graph() -> CsrGraph {
CsrGraph::from_edges(
4,
&[
(0, 1),
(1, 0),
(1, 2),
(2, 1),
(2, 3),
(3, 2),
(3, 0),
(0, 3),
],
)
}
#[test]
fn test_triangle_counting_metadata() {
let kernel = TriangleCounting::new();
assert_eq!(kernel.metadata().id, "graph/triangle-counting");
assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
}
#[test]
fn test_triangle_counting() {
let graph = create_triangle_graph();
let result = TriangleCounting::compute(&graph);
assert_eq!(result.total_triangles, 1, "Expected 1 triangle");
for &count in &result.per_node_triangles {
assert_eq!(count, 1);
}
assert!((result.global_clustering_coefficient - 1.0).abs() < 0.01);
}
#[test]
fn test_no_triangles() {
let graph = create_square_graph();
let result = TriangleCounting::compute(&graph);
assert_eq!(result.total_triangles, 0, "Expected 0 triangles in square");
assert!((result.global_clustering_coefficient).abs() < 0.01);
}
#[test]
fn test_triad_classification() {
let graph = create_triangle_graph();
let triad_type = MotifDetection::classify_triad(&graph, [0, 1, 2]);
assert_eq!(triad_type, TriadType::Triangle);
}
#[test]
fn test_motif_detection() {
let graph = create_triangle_graph();
let result = MotifDetection::count_triads(&graph);
assert_eq!(result.motif_counts.get("triangles"), Some(&1));
}
#[test]
fn test_k_clique_triangles() {
let graph = create_triangle_graph();
let cliques = KCliqueDetection::find_cliques(&graph, 3);
assert_eq!(cliques.len(), 1);
}
#[test]
fn test_k_clique_edges() {
let graph = create_square_graph();
let cliques = KCliqueDetection::find_cliques(&graph, 2);
assert_eq!(cliques.len(), 4);
}
}