rustkernel_graph/
motif.rs

1//! Graph motif detection kernels.
2//!
3//! This module provides algorithms for detecting graph motifs:
4//! - Triangle counting (local and global)
5//! - Motif detection (k-node subgraph census)
6
7use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use std::collections::HashSet;
10
11// ============================================================================
12// Triangle Counting Kernel
13// ============================================================================
14
15/// Result of triangle counting.
16#[derive(Debug, Clone)]
17pub struct TriangleCountResult {
18    /// Total number of triangles in the graph.
19    pub total_triangles: u64,
20    /// Number of triangles per node (node participates in).
21    pub per_node_triangles: Vec<u64>,
22    /// Clustering coefficient per node.
23    pub clustering_coefficients: Vec<f64>,
24    /// Global clustering coefficient.
25    pub global_clustering_coefficient: f64,
26}
27
28/// Triangle counting kernel.
29///
30/// Counts triangles using the node-iterator algorithm.
31/// Each triangle is counted once.
32#[derive(Debug, Clone)]
33pub struct TriangleCounting {
34    metadata: KernelMetadata,
35}
36
37impl TriangleCounting {
38    /// Create a new triangle counting kernel.
39    #[must_use]
40    pub fn new() -> Self {
41        Self {
42            metadata: KernelMetadata::batch("graph/triangle-counting", Domain::GraphAnalytics)
43                .with_description("Triangle counting (node-iterator algorithm)")
44                .with_throughput(50_000)
45                .with_latency_us(20.0),
46        }
47    }
48
49    /// Count all triangles in the graph.
50    ///
51    /// Uses the node-iterator algorithm with degree ordering
52    /// to ensure each triangle is counted exactly once.
53    pub fn compute(graph: &CsrGraph) -> TriangleCountResult {
54        let n = graph.num_nodes;
55        let mut total_triangles = 0u64;
56        let mut per_node_triangles = vec![0u64; n];
57
58        // For each node, look at pairs of its neighbors and check if they're connected
59        for u in 0..n {
60            let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
61
62            for &v in graph.neighbors(u as u64) {
63                let v = v as usize;
64
65                // Only count once: process when u < v
66                if u >= v {
67                    continue;
68                }
69
70                // Check common neighbors (nodes that form triangles with u and v)
71                for &w in graph.neighbors(v as u64) {
72                    let w_usize = w as usize;
73
74                    // Only count once: ensure u < v < w (by node index)
75                    if v >= w_usize {
76                        continue;
77                    }
78
79                    if neighbors_u.contains(&w) {
80                        // Found triangle: u-v-w
81                        total_triangles += 1;
82                        per_node_triangles[u] += 1;
83                        per_node_triangles[v] += 1;
84                        per_node_triangles[w_usize] += 1;
85                    }
86                }
87            }
88        }
89
90        // Compute clustering coefficients
91        let mut clustering_coefficients = vec![0.0f64; n];
92        let mut total_possible = 0u64;
93        let mut total_actual = 0u64;
94
95        for i in 0..n {
96            let degree = graph.out_degree(i as u64);
97            if degree >= 2 {
98                let possible = degree * (degree - 1) / 2;
99                total_possible += possible;
100                total_actual += per_node_triangles[i];
101                clustering_coefficients[i] = per_node_triangles[i] as f64 / possible as f64;
102            }
103        }
104
105        let global_clustering_coefficient = if total_possible > 0 {
106            total_actual as f64 / total_possible as f64
107        } else {
108            0.0
109        };
110
111        TriangleCountResult {
112            total_triangles,
113            per_node_triangles,
114            clustering_coefficients,
115            global_clustering_coefficient,
116        }
117    }
118
119    /// Count triangles for a specific node.
120    pub fn count_node_triangles(graph: &CsrGraph, node: u64) -> u64 {
121        let neighbors: HashSet<u64> = graph.neighbors(node).iter().copied().collect();
122
123        let mut count = 0u64;
124
125        for &v in graph.neighbors(node) {
126            for &w in graph.neighbors(v) {
127                if w != node && neighbors.contains(&w) {
128                    count += 1;
129                }
130            }
131        }
132
133        // Each triangle is counted twice (once for each direction)
134        count / 2
135    }
136}
137
138impl Default for TriangleCounting {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl GpuKernel for TriangleCounting {
145    fn metadata(&self) -> &KernelMetadata {
146        &self.metadata
147    }
148}
149
150// ============================================================================
151// Motif Detection Kernel
152// ============================================================================
153
154/// Types of 3-node motifs (triads).
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
156pub enum TriadType {
157    /// No edges (independent nodes)
158    Empty,
159    /// One edge (a single connection)
160    Edge,
161    /// Two edges forming a path (wedge)
162    Wedge,
163    /// Three edges forming a triangle
164    Triangle,
165}
166
167/// Result of motif detection.
168#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
169pub struct MotifResult {
170    /// Count of each motif type.
171    pub motif_counts: std::collections::HashMap<String, u64>,
172}
173
174/// Motif detection kernel.
175///
176/// Counts occurrences of small subgraph patterns.
177#[derive(Debug, Clone)]
178pub struct MotifDetection {
179    metadata: KernelMetadata,
180}
181
182impl MotifDetection {
183    /// Create a new motif detection kernel.
184    #[must_use]
185    pub fn new() -> Self {
186        Self {
187            metadata: KernelMetadata::batch("graph/motif-detection", Domain::GraphAnalytics)
188                .with_description("Motif detection (k-node subgraph census)")
189                .with_throughput(10_000)
190                .with_latency_us(100.0),
191        }
192    }
193
194    /// Count 3-node motifs (triads) in the graph.
195    pub fn count_triads(graph: &CsrGraph) -> MotifResult {
196        let n = graph.num_nodes;
197        let mut triangles = 0u64;
198        let mut wedges = 0u64;
199
200        // Count triangles and wedges
201        for u in 0..n {
202            let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
203            let degree_u = neighbors_u.len();
204
205            // Wedges centered at u: C(degree, 2)
206            if degree_u >= 2 {
207                let potential_wedges = (degree_u * (degree_u - 1)) / 2;
208
209                // Count how many of these are actually triangles
210                let mut triangles_at_u = 0u64;
211                for &v in graph.neighbors(u as u64) {
212                    for &w in graph.neighbors(v) {
213                        if w != u as u64 && neighbors_u.contains(&w) && v < w {
214                            triangles_at_u += 1;
215                        }
216                    }
217                }
218
219                wedges += potential_wedges as u64 - triangles_at_u;
220                triangles += triangles_at_u;
221            }
222        }
223
224        // Triangles were counted 3 times (once per vertex)
225        triangles /= 3;
226
227        // Edges count
228        let edges = graph.num_edges as u64 / 2; // Undirected edges
229
230        let mut motif_counts = std::collections::HashMap::new();
231        motif_counts.insert("triangles".to_string(), triangles);
232        motif_counts.insert("wedges".to_string(), wedges);
233        motif_counts.insert("edges".to_string(), edges);
234
235        MotifResult { motif_counts }
236    }
237
238    /// Classify a triad (set of 3 nodes) by its structure.
239    pub fn classify_triad(graph: &CsrGraph, nodes: [u64; 3]) -> TriadType {
240        let [a, b, c] = nodes;
241
242        let neighbors_a: HashSet<u64> = graph.neighbors(a).iter().copied().collect();
243        let neighbors_b: HashSet<u64> = graph.neighbors(b).iter().copied().collect();
244
245        let ab = neighbors_a.contains(&b);
246        let ac = neighbors_a.contains(&c);
247        let bc = neighbors_b.contains(&c);
248
249        let edge_count = ab as u8 + ac as u8 + bc as u8;
250
251        match edge_count {
252            0 => TriadType::Empty,
253            1 => TriadType::Edge,
254            2 => TriadType::Wedge,
255            3 => TriadType::Triangle,
256            _ => unreachable!(),
257        }
258    }
259}
260
261impl Default for MotifDetection {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl GpuKernel for MotifDetection {
268    fn metadata(&self) -> &KernelMetadata {
269        &self.metadata
270    }
271}
272
273// ============================================================================
274// K-Clique Detection Kernel
275// ============================================================================
276
277/// K-clique detection kernel.
278///
279/// Finds all cliques of size k in the graph.
280#[derive(Debug, Clone)]
281pub struct KCliqueDetection {
282    metadata: KernelMetadata,
283}
284
285impl KCliqueDetection {
286    /// Create a new k-clique detection kernel.
287    #[must_use]
288    pub fn new() -> Self {
289        Self {
290            metadata: KernelMetadata::batch("graph/k-clique", Domain::GraphAnalytics)
291                .with_description("K-clique detection")
292                .with_throughput(1_000)
293                .with_latency_us(1000.0),
294        }
295    }
296
297    /// Find all cliques of size k.
298    ///
299    /// Uses Bron-Kerbosch algorithm with pivoting.
300    pub fn find_cliques(graph: &CsrGraph, k: usize) -> Vec<Vec<u64>> {
301        let n = graph.num_nodes;
302        let mut cliques = Vec::new();
303
304        // Build adjacency set for each node
305        let adj: Vec<HashSet<u64>> = (0..n)
306            .map(|i| graph.neighbors(i as u64).iter().copied().collect())
307            .collect();
308
309        // Bron-Kerbosch with size limit
310        let mut current_clique = Vec::new();
311        let candidates: HashSet<u64> = (0..n as u64).collect();
312        let excluded: HashSet<u64> = HashSet::new();
313
314        Self::bron_kerbosch(
315            &adj,
316            &mut current_clique,
317            candidates,
318            excluded,
319            k,
320            &mut cliques,
321        );
322
323        cliques
324    }
325
326    fn bron_kerbosch(
327        adj: &[HashSet<u64>],
328        current: &mut Vec<u64>,
329        mut candidates: HashSet<u64>,
330        mut excluded: HashSet<u64>,
331        k: usize,
332        cliques: &mut Vec<Vec<u64>>,
333    ) {
334        // Found a clique of size k
335        if current.len() == k {
336            cliques.push(current.clone());
337            return;
338        }
339
340        // Can't reach size k
341        if current.len() + candidates.len() < k {
342            return;
343        }
344
345        // No more candidates
346        if candidates.is_empty() {
347            return;
348        }
349
350        // Choose pivot (node with most connections to candidates)
351        let pivot = candidates
352            .iter()
353            .chain(excluded.iter())
354            .max_by_key(|&&v| adj[v as usize].intersection(&candidates).count())
355            .copied();
356
357        let pivot_neighbors = pivot.map(|p| adj[p as usize].clone()).unwrap_or_default();
358
359        let to_explore: Vec<u64> = candidates.difference(&pivot_neighbors).copied().collect();
360
361        for v in to_explore {
362            current.push(v);
363
364            let new_candidates: HashSet<u64> =
365                candidates.intersection(&adj[v as usize]).copied().collect();
366            let new_excluded: HashSet<u64> =
367                excluded.intersection(&adj[v as usize]).copied().collect();
368
369            Self::bron_kerbosch(adj, current, new_candidates, new_excluded, k, cliques);
370
371            current.pop();
372            candidates.remove(&v);
373            excluded.insert(v);
374        }
375    }
376
377    /// Count cliques of size k (more efficient than enumerating all).
378    pub fn count_cliques(graph: &CsrGraph, k: usize) -> u64 {
379        Self::find_cliques(graph, k).len() as u64
380    }
381}
382
383impl Default for KCliqueDetection {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389impl GpuKernel for KCliqueDetection {
390    fn metadata(&self) -> &KernelMetadata {
391        &self.metadata
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    fn create_triangle_graph() -> CsrGraph {
400        // Graph with one triangle: 0-1-2-0
401        CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 0), (0, 2)])
402    }
403
404    fn create_square_graph() -> CsrGraph {
405        // Graph with square: 0-1-2-3-0 (no triangles)
406        CsrGraph::from_edges(
407            4,
408            &[
409                (0, 1),
410                (1, 0),
411                (1, 2),
412                (2, 1),
413                (2, 3),
414                (3, 2),
415                (3, 0),
416                (0, 3),
417            ],
418        )
419    }
420
421    #[test]
422    fn test_triangle_counting_metadata() {
423        let kernel = TriangleCounting::new();
424        assert_eq!(kernel.metadata().id, "graph/triangle-counting");
425        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
426    }
427
428    #[test]
429    fn test_triangle_counting() {
430        let graph = create_triangle_graph();
431        let result = TriangleCounting::compute(&graph);
432
433        assert_eq!(result.total_triangles, 1, "Expected 1 triangle");
434
435        // Each node participates in 1 triangle
436        for &count in &result.per_node_triangles {
437            assert_eq!(count, 1);
438        }
439
440        // Clustering coefficient should be 1.0 for a complete graph
441        assert!((result.global_clustering_coefficient - 1.0).abs() < 0.01);
442    }
443
444    #[test]
445    fn test_no_triangles() {
446        let graph = create_square_graph();
447        let result = TriangleCounting::compute(&graph);
448
449        assert_eq!(result.total_triangles, 0, "Expected 0 triangles in square");
450        assert!((result.global_clustering_coefficient).abs() < 0.01);
451    }
452
453    #[test]
454    fn test_triad_classification() {
455        let graph = create_triangle_graph();
456
457        let triad_type = MotifDetection::classify_triad(&graph, [0, 1, 2]);
458        assert_eq!(triad_type, TriadType::Triangle);
459    }
460
461    #[test]
462    fn test_motif_detection() {
463        let graph = create_triangle_graph();
464        let result = MotifDetection::count_triads(&graph);
465
466        assert_eq!(result.motif_counts.get("triangles"), Some(&1));
467    }
468
469    #[test]
470    fn test_k_clique_triangles() {
471        let graph = create_triangle_graph();
472        let cliques = KCliqueDetection::find_cliques(&graph, 3);
473
474        // Should find one 3-clique (the triangle)
475        assert_eq!(cliques.len(), 1);
476    }
477
478    #[test]
479    fn test_k_clique_edges() {
480        let graph = create_square_graph();
481        let cliques = KCliqueDetection::find_cliques(&graph, 2);
482
483        // Should find 4 edges (2-cliques)
484        assert_eq!(cliques.len(), 4);
485    }
486}