Skip to main content

datasynth_eval/ml/
graph.rs

1//! Graph structure analysis for ML.
2//!
3//! Analyzes graph properties relevant for graph neural networks.
4
5use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Results of graph analysis.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GraphAnalysis {
12    /// Basic graph metrics.
13    pub metrics: GraphMetrics,
14    /// Degree distribution analysis.
15    pub degree_distribution: DegreeDistribution,
16    /// Node type balance.
17    pub node_type_balance: HashMap<String, f64>,
18    /// Edge type balance.
19    pub edge_type_balance: HashMap<String, f64>,
20    /// Connectivity score (0.0-1.0).
21    pub connectivity_score: f64,
22    /// Whether graph meets quality criteria.
23    pub is_valid: bool,
24    /// Issues found.
25    pub issues: Vec<String>,
26}
27
28/// Basic graph metrics.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GraphMetrics {
31    /// Number of nodes.
32    pub node_count: usize,
33    /// Number of edges.
34    pub edge_count: usize,
35    /// Graph density (edges / max_possible_edges).
36    pub density: f64,
37    /// Number of connected components.
38    pub connected_components: usize,
39    /// Size of largest connected component.
40    pub largest_component_size: usize,
41    /// Percentage of nodes in largest component.
42    pub largest_component_ratio: f64,
43    /// Average degree.
44    pub average_degree: f64,
45    /// Maximum degree.
46    pub max_degree: usize,
47    /// Number of isolated nodes (degree 0).
48    pub isolated_nodes: usize,
49}
50
51/// Degree distribution analysis.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct DegreeDistribution {
54    /// Degree histogram (degree -> count).
55    pub histogram: HashMap<usize, usize>,
56    /// Mean degree.
57    pub mean: f64,
58    /// Median degree.
59    pub median: f64,
60    /// Standard deviation.
61    pub std_dev: f64,
62    /// Whether distribution follows power law.
63    pub is_power_law: bool,
64    /// Power law exponent (if applicable).
65    pub power_law_exponent: Option<f64>,
66}
67
68/// Input for graph analysis.
69#[derive(Debug, Clone)]
70pub struct GraphData {
71    /// Node count.
72    pub node_count: usize,
73    /// Edge list: (source, target) pairs.
74    pub edges: Vec<(usize, usize)>,
75    /// Node types: node_id -> type.
76    pub node_types: HashMap<usize, String>,
77    /// Edge types: edge_index -> type.
78    pub edge_types: HashMap<usize, String>,
79    /// Whether graph is directed.
80    pub is_directed: bool,
81}
82
83impl Default for GraphData {
84    fn default() -> Self {
85        Self {
86            node_count: 0,
87            edges: Vec::new(),
88            node_types: HashMap::new(),
89            edge_types: HashMap::new(),
90            is_directed: true,
91        }
92    }
93}
94
95/// Analyzer for graph structure.
96pub struct GraphAnalyzer {
97    /// Minimum connectivity threshold.
98    min_connectivity: f64,
99    /// Maximum isolated node ratio.
100    max_isolated_ratio: f64,
101}
102
103impl GraphAnalyzer {
104    /// Create a new analyzer.
105    pub fn new() -> Self {
106        Self {
107            min_connectivity: 0.95,
108            max_isolated_ratio: 0.05,
109        }
110    }
111
112    /// Analyze graph structure.
113    pub fn analyze(&self, data: &GraphData) -> EvalResult<GraphAnalysis> {
114        let mut issues = Vec::new();
115
116        if data.node_count == 0 {
117            return Ok(GraphAnalysis {
118                metrics: GraphMetrics {
119                    node_count: 0,
120                    edge_count: 0,
121                    density: 0.0,
122                    connected_components: 0,
123                    largest_component_size: 0,
124                    largest_component_ratio: 1.0,
125                    average_degree: 0.0,
126                    max_degree: 0,
127                    isolated_nodes: 0,
128                },
129                degree_distribution: DegreeDistribution {
130                    histogram: HashMap::new(),
131                    mean: 0.0,
132                    median: 0.0,
133                    std_dev: 0.0,
134                    is_power_law: false,
135                    power_law_exponent: None,
136                },
137                node_type_balance: HashMap::new(),
138                edge_type_balance: HashMap::new(),
139                connectivity_score: 1.0,
140                is_valid: true,
141                issues: vec![],
142            });
143        }
144
145        // Calculate degree for each node
146        let mut degrees: Vec<usize> = vec![0; data.node_count];
147        for (src, tgt) in &data.edges {
148            if *src < data.node_count {
149                degrees[*src] += 1;
150            }
151            if !data.is_directed && *tgt < data.node_count {
152                degrees[*tgt] += 1;
153            }
154        }
155
156        // Calculate metrics
157        let edge_count = data.edges.len();
158        let max_edges = if data.is_directed {
159            data.node_count * (data.node_count - 1)
160        } else {
161            data.node_count * (data.node_count - 1) / 2
162        };
163        let density = if max_edges > 0 {
164            edge_count as f64 / max_edges as f64
165        } else {
166            0.0
167        };
168
169        let average_degree = if data.node_count > 0 {
170            degrees.iter().sum::<usize>() as f64 / data.node_count as f64
171        } else {
172            0.0
173        };
174
175        let max_degree = degrees.iter().max().copied().unwrap_or(0);
176        let isolated_nodes = degrees.iter().filter(|d| **d == 0).count();
177
178        // Find connected components using union-find
179        let (connected_components, component_sizes) = self.find_components(data);
180        let largest_component_size = component_sizes.iter().max().copied().unwrap_or(0);
181        let largest_component_ratio = if data.node_count > 0 {
182            largest_component_size as f64 / data.node_count as f64
183        } else {
184            1.0
185        };
186
187        let connectivity_score = largest_component_ratio;
188
189        // Calculate degree distribution
190        let degree_distribution = self.calculate_degree_distribution(&degrees);
191
192        // Calculate node/edge type balance
193        let node_type_balance = self.calculate_type_balance(&data.node_types, data.node_count);
194        let edge_type_balance = self.calculate_type_balance_usize(&data.edge_types, edge_count);
195
196        let metrics = GraphMetrics {
197            node_count: data.node_count,
198            edge_count,
199            density,
200            connected_components,
201            largest_component_size,
202            largest_component_ratio,
203            average_degree,
204            max_degree,
205            isolated_nodes,
206        };
207
208        // Check for issues
209        if connectivity_score < self.min_connectivity {
210            issues.push(format!(
211                "Low connectivity: {:.2}% of nodes in largest component",
212                connectivity_score * 100.0
213            ));
214        }
215
216        let isolated_ratio = if data.node_count > 0 {
217            isolated_nodes as f64 / data.node_count as f64
218        } else {
219            0.0
220        };
221        if isolated_ratio > self.max_isolated_ratio {
222            issues.push(format!(
223                "High isolated node ratio: {:.2}%",
224                isolated_ratio * 100.0
225            ));
226        }
227
228        if connected_components > 1 {
229            issues.push(format!(
230                "Graph has {connected_components} connected components"
231            ));
232        }
233
234        let is_valid = connectivity_score >= self.min_connectivity
235            && isolated_ratio <= self.max_isolated_ratio;
236
237        Ok(GraphAnalysis {
238            metrics,
239            degree_distribution,
240            node_type_balance,
241            edge_type_balance,
242            connectivity_score,
243            is_valid,
244            issues,
245        })
246    }
247
248    /// Find connected components using union-find.
249    fn find_components(&self, data: &GraphData) -> (usize, Vec<usize>) {
250        let mut parent: Vec<usize> = (0..data.node_count).collect();
251        let mut rank: Vec<usize> = vec![0; data.node_count];
252
253        fn find(parent: &mut [usize], x: usize) -> usize {
254            if parent[x] != x {
255                parent[x] = find(parent, parent[x]);
256            }
257            parent[x]
258        }
259
260        fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
261            let px = find(parent, x);
262            let py = find(parent, y);
263            if px != py {
264                if rank[px] < rank[py] {
265                    parent[px] = py;
266                } else if rank[px] > rank[py] {
267                    parent[py] = px;
268                } else {
269                    parent[py] = px;
270                    rank[px] += 1;
271                }
272            }
273        }
274
275        for (src, tgt) in &data.edges {
276            if *src < data.node_count && *tgt < data.node_count {
277                union(&mut parent, &mut rank, *src, *tgt);
278            }
279        }
280
281        // Count components and their sizes
282        let mut component_sizes: HashMap<usize, usize> = HashMap::new();
283        for i in 0..data.node_count {
284            let root = find(&mut parent, i);
285            *component_sizes.entry(root).or_insert(0) += 1;
286        }
287
288        let num_components = component_sizes.len();
289        let sizes: Vec<usize> = component_sizes.values().copied().collect();
290
291        (num_components, sizes)
292    }
293
294    /// Calculate degree distribution statistics.
295    fn calculate_degree_distribution(&self, degrees: &[usize]) -> DegreeDistribution {
296        if degrees.is_empty() {
297            return DegreeDistribution {
298                histogram: HashMap::new(),
299                mean: 0.0,
300                median: 0.0,
301                std_dev: 0.0,
302                is_power_law: false,
303                power_law_exponent: None,
304            };
305        }
306
307        // Build histogram
308        let mut histogram: HashMap<usize, usize> = HashMap::new();
309        for &d in degrees {
310            *histogram.entry(d).or_insert(0) += 1;
311        }
312
313        // Calculate statistics
314        let n = degrees.len() as f64;
315        let mean = degrees.iter().sum::<usize>() as f64 / n;
316
317        let mut sorted = degrees.to_vec();
318        sorted.sort_unstable();
319        let median = if sorted.len().is_multiple_of(2) {
320            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) as f64 / 2.0
321        } else {
322            sorted[sorted.len() / 2] as f64
323        };
324
325        let variance: f64 = degrees
326            .iter()
327            .map(|&d| (d as f64 - mean).powi(2))
328            .sum::<f64>()
329            / n;
330        let std_dev = variance.sqrt();
331
332        // Simple power-law check: log-log linear relationship
333        // This is a simplified heuristic
334        let non_zero_degrees: Vec<_> = degrees.iter().filter(|&&d| d > 0).collect();
335        let is_power_law = if non_zero_degrees.len() > 10 && std_dev > mean {
336            // High variance relative to mean suggests heavy tail
337            true
338        } else {
339            false
340        };
341
342        let power_law_exponent = if is_power_law {
343            // Simplified estimate using Hill estimator
344            let k_min = 1.0;
345            let valid: Vec<f64> = non_zero_degrees
346                .iter()
347                .filter(|&&d| (*d as f64) >= k_min)
348                .map(|&&d| d as f64)
349                .collect();
350            if valid.len() > 2 {
351                let n = valid.len() as f64;
352                let sum_log: f64 = valid.iter().map(|&x| (x / k_min).ln()).sum();
353                Some(1.0 + n / sum_log)
354            } else {
355                None
356            }
357        } else {
358            None
359        };
360
361        DegreeDistribution {
362            histogram,
363            mean,
364            median,
365            std_dev,
366            is_power_law,
367            power_law_exponent,
368        }
369    }
370
371    /// Calculate type balance (usize keys).
372    fn calculate_type_balance(
373        &self,
374        types: &HashMap<usize, String>,
375        total: usize,
376    ) -> HashMap<String, f64> {
377        let mut counts: HashMap<String, usize> = HashMap::new();
378        for t in types.values() {
379            *counts.entry(t.clone()).or_insert(0) += 1;
380        }
381
382        counts
383            .into_iter()
384            .map(|(k, v)| {
385                (
386                    k,
387                    if total > 0 {
388                        v as f64 / total as f64
389                    } else {
390                        0.0
391                    },
392                )
393            })
394            .collect()
395    }
396
397    /// Calculate type balance for edge types.
398    fn calculate_type_balance_usize(
399        &self,
400        types: &HashMap<usize, String>,
401        total: usize,
402    ) -> HashMap<String, f64> {
403        self.calculate_type_balance(types, total)
404    }
405}
406
407impl Default for GraphAnalyzer {
408    fn default() -> Self {
409        Self::new()
410    }
411}
412
413#[cfg(test)]
414#[allow(clippy::unwrap_used)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_connected_graph() {
420        let data = GraphData {
421            node_count: 4,
422            edges: vec![(0, 1), (1, 2), (2, 3)],
423            node_types: HashMap::new(),
424            edge_types: HashMap::new(),
425            is_directed: false,
426        };
427
428        let analyzer = GraphAnalyzer::new();
429        let result = analyzer.analyze(&data).unwrap();
430
431        assert_eq!(result.metrics.connected_components, 1);
432        assert_eq!(result.metrics.largest_component_ratio, 1.0);
433        assert_eq!(result.metrics.isolated_nodes, 0);
434    }
435
436    #[test]
437    fn test_disconnected_graph() {
438        let data = GraphData {
439            node_count: 4,
440            edges: vec![(0, 1)], // Only 2 nodes connected
441            node_types: HashMap::new(),
442            edge_types: HashMap::new(),
443            is_directed: false,
444        };
445
446        let analyzer = GraphAnalyzer::new();
447        let result = analyzer.analyze(&data).unwrap();
448
449        assert!(result.metrics.connected_components > 1);
450        assert!(result.metrics.isolated_nodes > 0);
451    }
452
453    #[test]
454    fn test_empty_graph() {
455        let data = GraphData::default();
456
457        let analyzer = GraphAnalyzer::new();
458        let result = analyzer.analyze(&data).unwrap();
459
460        assert_eq!(result.metrics.node_count, 0);
461        assert!(result.is_valid);
462    }
463
464    #[test]
465    fn test_degree_distribution() {
466        let data = GraphData {
467            node_count: 5,
468            edges: vec![(0, 1), (0, 2), (0, 3), (0, 4), (1, 2)],
469            node_types: HashMap::new(),
470            edge_types: HashMap::new(),
471            is_directed: true,
472        };
473
474        let analyzer = GraphAnalyzer::new();
475        let result = analyzer.analyze(&data).unwrap();
476
477        assert_eq!(result.metrics.max_degree, 4); // Node 0 has degree 4
478        assert!(result.degree_distribution.mean > 0.0);
479    }
480}