rustkernel_graph/
motif.rs

1//! Graph motif detection kernels.
2//!
3//! This module provides algorithms for detecting graph motifs:
4//! - Triangle counting (local and global)
5//! - Motif detection (k-node subgraph census)
6
7use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use std::collections::HashSet;
10
11// ============================================================================
12// Triangle Counting Kernel
13// ============================================================================
14
15/// Result of triangle counting.
16#[derive(Debug, Clone)]
17pub struct TriangleCountResult {
18    /// Total number of triangles in the graph.
19    pub total_triangles: u64,
20    /// Number of triangles per node (node participates in).
21    pub per_node_triangles: Vec<u64>,
22    /// Clustering coefficient per node.
23    pub clustering_coefficients: Vec<f64>,
24    /// Global clustering coefficient.
25    pub global_clustering_coefficient: f64,
26}
27
28/// Triangle counting kernel.
29///
30/// Counts triangles using the node-iterator algorithm.
31/// Each triangle is counted once.
32#[derive(Debug, Clone)]
33pub struct TriangleCounting {
34    metadata: KernelMetadata,
35}
36
37impl TriangleCounting {
38    /// Create a new triangle counting kernel.
39    #[must_use]
40    pub fn new() -> Self {
41        Self {
42            metadata: KernelMetadata::batch("graph/triangle-counting", Domain::GraphAnalytics)
43                .with_description("Triangle counting (node-iterator algorithm)")
44                .with_throughput(50_000)
45                .with_latency_us(20.0),
46        }
47    }
48
49    /// Count all triangles in the graph.
50    ///
51    /// Uses the node-iterator algorithm with degree ordering
52    /// to ensure each triangle is counted exactly once.
53    pub fn compute(graph: &CsrGraph) -> TriangleCountResult {
54        let n = graph.num_nodes;
55        let mut total_triangles = 0u64;
56        let mut per_node_triangles = vec![0u64; n];
57
58        // For each node, look at pairs of its neighbors and check if they're connected
59        for u in 0..n {
60            let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
61
62            for &v in graph.neighbors(u as u64) {
63                let v = v as usize;
64
65                // Only count once: process when u < v
66                if u >= v {
67                    continue;
68                }
69
70                // Check common neighbors (nodes that form triangles with u and v)
71                for &w in graph.neighbors(v as u64) {
72                    let w_usize = w as usize;
73
74                    // Only count once: ensure u < v < w (by node index)
75                    if v >= w_usize {
76                        continue;
77                    }
78
79                    if neighbors_u.contains(&w) {
80                        // Found triangle: u-v-w
81                        total_triangles += 1;
82                        per_node_triangles[u] += 1;
83                        per_node_triangles[v] += 1;
84                        per_node_triangles[w_usize] += 1;
85                    }
86                }
87            }
88        }
89
90        // Compute clustering coefficients
91        let mut clustering_coefficients = vec![0.0f64; n];
92        let mut total_possible = 0u64;
93        let mut total_actual = 0u64;
94
95        for i in 0..n {
96            let degree = graph.out_degree(i as u64);
97            if degree >= 2 {
98                let possible = degree * (degree - 1) / 2;
99                total_possible += possible;
100                total_actual += per_node_triangles[i];
101                clustering_coefficients[i] = per_node_triangles[i] as f64 / possible as f64;
102            }
103        }
104
105        let global_clustering_coefficient = if total_possible > 0 {
106            total_actual as f64 / total_possible as f64
107        } else {
108            0.0
109        };
110
111        TriangleCountResult {
112            total_triangles,
113            per_node_triangles,
114            clustering_coefficients,
115            global_clustering_coefficient,
116        }
117    }
118
119    /// Count triangles for a specific node.
120    pub fn count_node_triangles(graph: &CsrGraph, node: u64) -> u64 {
121        let neighbors: HashSet<u64> = graph.neighbors(node).iter().copied().collect();
122
123        let mut count = 0u64;
124
125        for &v in graph.neighbors(node) {
126            for &w in graph.neighbors(v) {
127                if w != node && neighbors.contains(&w) {
128                    count += 1;
129                }
130            }
131        }
132
133        // Each triangle is counted twice (once for each direction)
134        count / 2
135    }
136}
137
138impl Default for TriangleCounting {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl GpuKernel for TriangleCounting {
145    fn metadata(&self) -> &KernelMetadata {
146        &self.metadata
147    }
148}
149
150// ============================================================================
151// Motif Detection Kernel
152// ============================================================================
153
154/// Types of 3-node motifs (triads).
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
156pub enum TriadType {
157    /// No edges (independent nodes)
158    Empty,
159    /// One edge (a single connection)
160    Edge,
161    /// Two edges forming a path (wedge)
162    Wedge,
163    /// Three edges forming a triangle
164    Triangle,
165}
166
167/// Result of motif detection.
168#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
169pub struct MotifResult {
170    /// Count of each motif type.
171    pub motif_counts: std::collections::HashMap<String, u64>,
172}
173
174/// Motif detection kernel.
175///
176/// Counts occurrences of small subgraph patterns.
177#[derive(Debug, Clone)]
178pub struct MotifDetection {
179    metadata: KernelMetadata,
180}
181
182impl MotifDetection {
183    /// Create a new motif detection kernel.
184    #[must_use]
185    pub fn new() -> Self {
186        Self {
187            metadata: KernelMetadata::batch("graph/motif-detection", Domain::GraphAnalytics)
188                .with_description("Motif detection (k-node subgraph census)")
189                .with_throughput(10_000)
190                .with_latency_us(100.0),
191        }
192    }
193
194    /// Count 3-node motifs (triads) in the graph.
195    pub fn count_triads(graph: &CsrGraph) -> MotifResult {
196        let n = graph.num_nodes;
197        let mut triangles = 0u64;
198        let mut wedges = 0u64;
199        let mut edges = 0u64;
200
201        // Count triangles and wedges
202        for u in 0..n {
203            let neighbors_u: HashSet<u64> = graph.neighbors(u as u64).iter().copied().collect();
204            let degree_u = neighbors_u.len();
205
206            // Wedges centered at u: C(degree, 2)
207            if degree_u >= 2 {
208                let potential_wedges = (degree_u * (degree_u - 1)) / 2;
209
210                // Count how many of these are actually triangles
211                let mut triangles_at_u = 0u64;
212                for &v in graph.neighbors(u as u64) {
213                    for &w in graph.neighbors(v) {
214                        if w != u as u64 && neighbors_u.contains(&w) && v < w {
215                            triangles_at_u += 1;
216                        }
217                    }
218                }
219
220                wedges += potential_wedges as u64 - triangles_at_u;
221                triangles += triangles_at_u;
222            }
223        }
224
225        // Triangles were counted 3 times (once per vertex)
226        triangles /= 3;
227
228        // Edges count
229        edges = graph.num_edges as u64 / 2; // Undirected edges
230
231        let mut motif_counts = std::collections::HashMap::new();
232        motif_counts.insert("triangles".to_string(), triangles);
233        motif_counts.insert("wedges".to_string(), wedges);
234        motif_counts.insert("edges".to_string(), edges);
235
236        MotifResult { motif_counts }
237    }
238
239    /// Classify a triad (set of 3 nodes) by its structure.
240    pub fn classify_triad(graph: &CsrGraph, nodes: [u64; 3]) -> TriadType {
241        let [a, b, c] = nodes;
242
243        let neighbors_a: HashSet<u64> = graph.neighbors(a).iter().copied().collect();
244        let neighbors_b: HashSet<u64> = graph.neighbors(b).iter().copied().collect();
245
246        let ab = neighbors_a.contains(&b);
247        let ac = neighbors_a.contains(&c);
248        let bc = neighbors_b.contains(&c);
249
250        let edge_count = ab as u8 + ac as u8 + bc as u8;
251
252        match edge_count {
253            0 => TriadType::Empty,
254            1 => TriadType::Edge,
255            2 => TriadType::Wedge,
256            3 => TriadType::Triangle,
257            _ => unreachable!(),
258        }
259    }
260}
261
262impl Default for MotifDetection {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl GpuKernel for MotifDetection {
269    fn metadata(&self) -> &KernelMetadata {
270        &self.metadata
271    }
272}
273
274// ============================================================================
275// K-Clique Detection Kernel
276// ============================================================================
277
278/// K-clique detection kernel.
279///
280/// Finds all cliques of size k in the graph.
281#[derive(Debug, Clone)]
282pub struct KCliqueDetection {
283    metadata: KernelMetadata,
284}
285
286impl KCliqueDetection {
287    /// Create a new k-clique detection kernel.
288    #[must_use]
289    pub fn new() -> Self {
290        Self {
291            metadata: KernelMetadata::batch("graph/k-clique", Domain::GraphAnalytics)
292                .with_description("K-clique detection")
293                .with_throughput(1_000)
294                .with_latency_us(1000.0),
295        }
296    }
297
298    /// Find all cliques of size k.
299    ///
300    /// Uses Bron-Kerbosch algorithm with pivoting.
301    pub fn find_cliques(graph: &CsrGraph, k: usize) -> Vec<Vec<u64>> {
302        let n = graph.num_nodes;
303        let mut cliques = Vec::new();
304
305        // Build adjacency set for each node
306        let adj: Vec<HashSet<u64>> = (0..n)
307            .map(|i| graph.neighbors(i as u64).iter().copied().collect())
308            .collect();
309
310        // Bron-Kerbosch with size limit
311        let mut current_clique = Vec::new();
312        let candidates: HashSet<u64> = (0..n as u64).collect();
313        let excluded: HashSet<u64> = HashSet::new();
314
315        Self::bron_kerbosch(
316            &adj,
317            &mut current_clique,
318            candidates,
319            excluded,
320            k,
321            &mut cliques,
322        );
323
324        cliques
325    }
326
327    fn bron_kerbosch(
328        adj: &[HashSet<u64>],
329        current: &mut Vec<u64>,
330        mut candidates: HashSet<u64>,
331        mut excluded: HashSet<u64>,
332        k: usize,
333        cliques: &mut Vec<Vec<u64>>,
334    ) {
335        // Found a clique of size k
336        if current.len() == k {
337            cliques.push(current.clone());
338            return;
339        }
340
341        // Can't reach size k
342        if current.len() + candidates.len() < k {
343            return;
344        }
345
346        // No more candidates
347        if candidates.is_empty() {
348            return;
349        }
350
351        // Choose pivot (node with most connections to candidates)
352        let pivot = candidates
353            .iter()
354            .chain(excluded.iter())
355            .max_by_key(|&&v| adj[v as usize].intersection(&candidates).count())
356            .copied();
357
358        let pivot_neighbors = pivot.map(|p| adj[p as usize].clone()).unwrap_or_default();
359
360        let to_explore: Vec<u64> = candidates.difference(&pivot_neighbors).copied().collect();
361
362        for v in to_explore {
363            current.push(v);
364
365            let new_candidates: HashSet<u64> =
366                candidates.intersection(&adj[v as usize]).copied().collect();
367            let new_excluded: HashSet<u64> =
368                excluded.intersection(&adj[v as usize]).copied().collect();
369
370            Self::bron_kerbosch(adj, current, new_candidates, new_excluded, k, cliques);
371
372            current.pop();
373            candidates.remove(&v);
374            excluded.insert(v);
375        }
376    }
377
378    /// Count cliques of size k (more efficient than enumerating all).
379    pub fn count_cliques(graph: &CsrGraph, k: usize) -> u64 {
380        Self::find_cliques(graph, k).len() as u64
381    }
382}
383
384impl Default for KCliqueDetection {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390impl GpuKernel for KCliqueDetection {
391    fn metadata(&self) -> &KernelMetadata {
392        &self.metadata
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    fn create_triangle_graph() -> CsrGraph {
401        // Graph with one triangle: 0-1-2-0
402        CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 0), (0, 2)])
403    }
404
405    fn create_square_graph() -> CsrGraph {
406        // Graph with square: 0-1-2-3-0 (no triangles)
407        CsrGraph::from_edges(
408            4,
409            &[
410                (0, 1),
411                (1, 0),
412                (1, 2),
413                (2, 1),
414                (2, 3),
415                (3, 2),
416                (3, 0),
417                (0, 3),
418            ],
419        )
420    }
421
422    #[test]
423    fn test_triangle_counting_metadata() {
424        let kernel = TriangleCounting::new();
425        assert_eq!(kernel.metadata().id, "graph/triangle-counting");
426        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
427    }
428
429    #[test]
430    fn test_triangle_counting() {
431        let graph = create_triangle_graph();
432        let result = TriangleCounting::compute(&graph);
433
434        assert_eq!(result.total_triangles, 1, "Expected 1 triangle");
435
436        // Each node participates in 1 triangle
437        for &count in &result.per_node_triangles {
438            assert_eq!(count, 1);
439        }
440
441        // Clustering coefficient should be 1.0 for a complete graph
442        assert!((result.global_clustering_coefficient - 1.0).abs() < 0.01);
443    }
444
445    #[test]
446    fn test_no_triangles() {
447        let graph = create_square_graph();
448        let result = TriangleCounting::compute(&graph);
449
450        assert_eq!(result.total_triangles, 0, "Expected 0 triangles in square");
451        assert!((result.global_clustering_coefficient).abs() < 0.01);
452    }
453
454    #[test]
455    fn test_triad_classification() {
456        let graph = create_triangle_graph();
457
458        let triad_type = MotifDetection::classify_triad(&graph, [0, 1, 2]);
459        assert_eq!(triad_type, TriadType::Triangle);
460    }
461
462    #[test]
463    fn test_motif_detection() {
464        let graph = create_triangle_graph();
465        let result = MotifDetection::count_triads(&graph);
466
467        assert_eq!(result.motif_counts.get("triangles"), Some(&1));
468    }
469
470    #[test]
471    fn test_k_clique_triangles() {
472        let graph = create_triangle_graph();
473        let cliques = KCliqueDetection::find_cliques(&graph, 3);
474
475        // Should find one 3-clique (the triangle)
476        assert_eq!(cliques.len(), 1);
477    }
478
479    #[test]
480    fn test_k_clique_edges() {
481        let graph = create_square_graph();
482        let cliques = KCliqueDetection::find_cliques(&graph, 2);
483
484        // Should find 4 edges (2-cliques)
485        assert_eq!(cliques.len(), 4);
486    }
487}