Skip to main content

datasynth_eval/ml/
graph.rs

1//! Graph structure analysis for ML.
2//!
3//! Analyzes graph properties relevant for graph neural networks.
4
5use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Results of graph analysis.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GraphAnalysis {
12    /// Basic graph metrics.
13    pub metrics: GraphMetrics,
14    /// Degree distribution analysis.
15    pub degree_distribution: DegreeDistribution,
16    /// Node type balance.
17    pub node_type_balance: HashMap<String, f64>,
18    /// Edge type balance.
19    pub edge_type_balance: HashMap<String, f64>,
20    /// Connectivity score (0.0-1.0).
21    pub connectivity_score: f64,
22    /// Whether graph meets quality criteria.
23    pub is_valid: bool,
24    /// Issues found.
25    pub issues: Vec<String>,
26}
27
28/// Basic graph metrics.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GraphMetrics {
31    /// Number of nodes.
32    pub node_count: usize,
33    /// Number of edges.
34    pub edge_count: usize,
35    /// Graph density (edges / max_possible_edges).
36    pub density: f64,
37    /// Number of connected components.
38    pub connected_components: usize,
39    /// Size of largest connected component.
40    pub largest_component_size: usize,
41    /// Percentage of nodes in largest component.
42    pub largest_component_ratio: f64,
43    /// Average degree.
44    pub average_degree: f64,
45    /// Maximum degree.
46    pub max_degree: usize,
47    /// Number of isolated nodes (degree 0).
48    pub isolated_nodes: usize,
49}
50
51/// Degree distribution analysis.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct DegreeDistribution {
54    /// Degree histogram (degree -> count).
55    pub histogram: HashMap<usize, usize>,
56    /// Mean degree.
57    pub mean: f64,
58    /// Median degree.
59    pub median: f64,
60    /// Standard deviation.
61    pub std_dev: f64,
62    /// Whether distribution follows power law.
63    pub is_power_law: bool,
64    /// Power law exponent (if applicable).
65    pub power_law_exponent: Option<f64>,
66}
67
68/// Input for graph analysis.
69#[derive(Debug, Clone)]
70pub struct GraphData {
71    /// Node count.
72    pub node_count: usize,
73    /// Edge list: (source, target) pairs.
74    pub edges: Vec<(usize, usize)>,
75    /// Node types: node_id -> type.
76    pub node_types: HashMap<usize, String>,
77    /// Edge types: edge_index -> type.
78    pub edge_types: HashMap<usize, String>,
79    /// Whether graph is directed.
80    pub is_directed: bool,
81}
82
83impl Default for GraphData {
84    fn default() -> Self {
85        Self {
86            node_count: 0,
87            edges: Vec::new(),
88            node_types: HashMap::new(),
89            edge_types: HashMap::new(),
90            is_directed: true,
91        }
92    }
93}
94
95/// Analyzer for graph structure.
96pub struct GraphAnalyzer {
97    /// Minimum connectivity threshold.
98    min_connectivity: f64,
99    /// Maximum isolated node ratio.
100    max_isolated_ratio: f64,
101}
102
103impl GraphAnalyzer {
104    /// Create a new analyzer.
105    pub fn new() -> Self {
106        Self {
107            min_connectivity: 0.95,
108            max_isolated_ratio: 0.05,
109        }
110    }
111
112    /// Analyze graph structure.
113    pub fn analyze(&self, data: &GraphData) -> EvalResult<GraphAnalysis> {
114        let mut issues = Vec::new();
115
116        if data.node_count == 0 {
117            return Ok(GraphAnalysis {
118                metrics: GraphMetrics {
119                    node_count: 0,
120                    edge_count: 0,
121                    density: 0.0,
122                    connected_components: 0,
123                    largest_component_size: 0,
124                    largest_component_ratio: 1.0,
125                    average_degree: 0.0,
126                    max_degree: 0,
127                    isolated_nodes: 0,
128                },
129                degree_distribution: DegreeDistribution {
130                    histogram: HashMap::new(),
131                    mean: 0.0,
132                    median: 0.0,
133                    std_dev: 0.0,
134                    is_power_law: false,
135                    power_law_exponent: None,
136                },
137                node_type_balance: HashMap::new(),
138                edge_type_balance: HashMap::new(),
139                connectivity_score: 1.0,
140                is_valid: true,
141                issues: vec![],
142            });
143        }
144
145        // Calculate degree for each node
146        let mut degrees: Vec<usize> = vec![0; data.node_count];
147        for (src, tgt) in &data.edges {
148            if *src < data.node_count {
149                degrees[*src] += 1;
150            }
151            if !data.is_directed && *tgt < data.node_count {
152                degrees[*tgt] += 1;
153            }
154        }
155
156        // Calculate metrics
157        let edge_count = data.edges.len();
158        let max_edges = if data.is_directed {
159            data.node_count * (data.node_count - 1)
160        } else {
161            data.node_count * (data.node_count - 1) / 2
162        };
163        let density = if max_edges > 0 {
164            edge_count as f64 / max_edges as f64
165        } else {
166            0.0
167        };
168
169        let average_degree = if data.node_count > 0 {
170            degrees.iter().sum::<usize>() as f64 / data.node_count as f64
171        } else {
172            0.0
173        };
174
175        let max_degree = degrees.iter().max().copied().unwrap_or(0);
176        let isolated_nodes = degrees.iter().filter(|d| **d == 0).count();
177
178        // Find connected components using union-find
179        let (connected_components, component_sizes) = self.find_components(data);
180        let largest_component_size = component_sizes.iter().max().copied().unwrap_or(0);
181        let largest_component_ratio = if data.node_count > 0 {
182            largest_component_size as f64 / data.node_count as f64
183        } else {
184            1.0
185        };
186
187        let connectivity_score = largest_component_ratio;
188
189        // Calculate degree distribution
190        let degree_distribution = self.calculate_degree_distribution(&degrees);
191
192        // Calculate node/edge type balance
193        let node_type_balance = self.calculate_type_balance(&data.node_types, data.node_count);
194        let edge_type_balance = self.calculate_type_balance_usize(&data.edge_types, edge_count);
195
196        let metrics = GraphMetrics {
197            node_count: data.node_count,
198            edge_count,
199            density,
200            connected_components,
201            largest_component_size,
202            largest_component_ratio,
203            average_degree,
204            max_degree,
205            isolated_nodes,
206        };
207
208        // Check for issues
209        if connectivity_score < self.min_connectivity {
210            issues.push(format!(
211                "Low connectivity: {:.2}% of nodes in largest component",
212                connectivity_score * 100.0
213            ));
214        }
215
216        let isolated_ratio = if data.node_count > 0 {
217            isolated_nodes as f64 / data.node_count as f64
218        } else {
219            0.0
220        };
221        if isolated_ratio > self.max_isolated_ratio {
222            issues.push(format!(
223                "High isolated node ratio: {:.2}%",
224                isolated_ratio * 100.0
225            ));
226        }
227
228        if connected_components > 1 {
229            issues.push(format!(
230                "Graph has {} connected components",
231                connected_components
232            ));
233        }
234
235        let is_valid = connectivity_score >= self.min_connectivity
236            && isolated_ratio <= self.max_isolated_ratio;
237
238        Ok(GraphAnalysis {
239            metrics,
240            degree_distribution,
241            node_type_balance,
242            edge_type_balance,
243            connectivity_score,
244            is_valid,
245            issues,
246        })
247    }
248
249    /// Find connected components using union-find.
250    fn find_components(&self, data: &GraphData) -> (usize, Vec<usize>) {
251        let mut parent: Vec<usize> = (0..data.node_count).collect();
252        let mut rank: Vec<usize> = vec![0; data.node_count];
253
254        fn find(parent: &mut [usize], x: usize) -> usize {
255            if parent[x] != x {
256                parent[x] = find(parent, parent[x]);
257            }
258            parent[x]
259        }
260
261        fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
262            let px = find(parent, x);
263            let py = find(parent, y);
264            if px != py {
265                if rank[px] < rank[py] {
266                    parent[px] = py;
267                } else if rank[px] > rank[py] {
268                    parent[py] = px;
269                } else {
270                    parent[py] = px;
271                    rank[px] += 1;
272                }
273            }
274        }
275
276        for (src, tgt) in &data.edges {
277            if *src < data.node_count && *tgt < data.node_count {
278                union(&mut parent, &mut rank, *src, *tgt);
279            }
280        }
281
282        // Count components and their sizes
283        let mut component_sizes: HashMap<usize, usize> = HashMap::new();
284        for i in 0..data.node_count {
285            let root = find(&mut parent, i);
286            *component_sizes.entry(root).or_insert(0) += 1;
287        }
288
289        let num_components = component_sizes.len();
290        let sizes: Vec<usize> = component_sizes.values().copied().collect();
291
292        (num_components, sizes)
293    }
294
295    /// Calculate degree distribution statistics.
296    fn calculate_degree_distribution(&self, degrees: &[usize]) -> DegreeDistribution {
297        if degrees.is_empty() {
298            return DegreeDistribution {
299                histogram: HashMap::new(),
300                mean: 0.0,
301                median: 0.0,
302                std_dev: 0.0,
303                is_power_law: false,
304                power_law_exponent: None,
305            };
306        }
307
308        // Build histogram
309        let mut histogram: HashMap<usize, usize> = HashMap::new();
310        for &d in degrees {
311            *histogram.entry(d).or_insert(0) += 1;
312        }
313
314        // Calculate statistics
315        let n = degrees.len() as f64;
316        let mean = degrees.iter().sum::<usize>() as f64 / n;
317
318        let mut sorted = degrees.to_vec();
319        sorted.sort_unstable();
320        let median = if sorted.len().is_multiple_of(2) {
321            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) as f64 / 2.0
322        } else {
323            sorted[sorted.len() / 2] as f64
324        };
325
326        let variance: f64 = degrees
327            .iter()
328            .map(|&d| (d as f64 - mean).powi(2))
329            .sum::<f64>()
330            / n;
331        let std_dev = variance.sqrt();
332
333        // Simple power-law check: log-log linear relationship
334        // This is a simplified heuristic
335        let non_zero_degrees: Vec<_> = degrees.iter().filter(|&&d| d > 0).collect();
336        let is_power_law = if non_zero_degrees.len() > 10 && std_dev > mean {
337            // High variance relative to mean suggests heavy tail
338            true
339        } else {
340            false
341        };
342
343        let power_law_exponent = if is_power_law {
344            // Simplified estimate using Hill estimator
345            let k_min = 1.0;
346            let valid: Vec<f64> = non_zero_degrees
347                .iter()
348                .filter(|&&d| (*d as f64) >= k_min)
349                .map(|&&d| d as f64)
350                .collect();
351            if valid.len() > 2 {
352                let n = valid.len() as f64;
353                let sum_log: f64 = valid.iter().map(|&x| (x / k_min).ln()).sum();
354                Some(1.0 + n / sum_log)
355            } else {
356                None
357            }
358        } else {
359            None
360        };
361
362        DegreeDistribution {
363            histogram,
364            mean,
365            median,
366            std_dev,
367            is_power_law,
368            power_law_exponent,
369        }
370    }
371
372    /// Calculate type balance (usize keys).
373    fn calculate_type_balance(
374        &self,
375        types: &HashMap<usize, String>,
376        total: usize,
377    ) -> HashMap<String, f64> {
378        let mut counts: HashMap<String, usize> = HashMap::new();
379        for t in types.values() {
380            *counts.entry(t.clone()).or_insert(0) += 1;
381        }
382
383        counts
384            .into_iter()
385            .map(|(k, v)| {
386                (
387                    k,
388                    if total > 0 {
389                        v as f64 / total as f64
390                    } else {
391                        0.0
392                    },
393                )
394            })
395            .collect()
396    }
397
398    /// Calculate type balance for edge types.
399    fn calculate_type_balance_usize(
400        &self,
401        types: &HashMap<usize, String>,
402        total: usize,
403    ) -> HashMap<String, f64> {
404        self.calculate_type_balance(types, total)
405    }
406}
407
408impl Default for GraphAnalyzer {
409    fn default() -> Self {
410        Self::new()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_connected_graph() {
420        let data = GraphData {
421            node_count: 4,
422            edges: vec![(0, 1), (1, 2), (2, 3)],
423            node_types: HashMap::new(),
424            edge_types: HashMap::new(),
425            is_directed: false,
426        };
427
428        let analyzer = GraphAnalyzer::new();
429        let result = analyzer.analyze(&data).unwrap();
430
431        assert_eq!(result.metrics.connected_components, 1);
432        assert_eq!(result.metrics.largest_component_ratio, 1.0);
433        assert_eq!(result.metrics.isolated_nodes, 0);
434    }
435
436    #[test]
437    fn test_disconnected_graph() {
438        let data = GraphData {
439            node_count: 4,
440            edges: vec![(0, 1)], // Only 2 nodes connected
441            node_types: HashMap::new(),
442            edge_types: HashMap::new(),
443            is_directed: false,
444        };
445
446        let analyzer = GraphAnalyzer::new();
447        let result = analyzer.analyze(&data).unwrap();
448
449        assert!(result.metrics.connected_components > 1);
450        assert!(result.metrics.isolated_nodes > 0);
451    }
452
453    #[test]
454    fn test_empty_graph() {
455        let data = GraphData::default();
456
457        let analyzer = GraphAnalyzer::new();
458        let result = analyzer.analyze(&data).unwrap();
459
460        assert_eq!(result.metrics.node_count, 0);
461        assert!(result.is_valid);
462    }
463
464    #[test]
465    fn test_degree_distribution() {
466        let data = GraphData {
467            node_count: 5,
468            edges: vec![(0, 1), (0, 2), (0, 3), (0, 4), (1, 2)],
469            node_types: HashMap::new(),
470            edge_types: HashMap::new(),
471            is_directed: true,
472        };
473
474        let analyzer = GraphAnalyzer::new();
475        let result = analyzer.analyze(&data).unwrap();
476
477        assert_eq!(result.metrics.max_degree, 4); // Node 0 has degree 4
478        assert!(result.degree_distribution.mean > 0.0);
479    }
480}