Skip to main content

datasynth_eval/ml/
gnn_readiness.rs

1//! GNN readiness evaluation.
2//!
3//! Evaluates graph structure suitability for Graph Neural Network training,
4//! including feature completeness, homophily ratio, label leakage, and
5//! neighborhood diversity.
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Input graph data for GNN readiness analysis.
12#[derive(Debug, Clone)]
13pub struct GraphData {
14    /// Node identifiers.
15    pub node_ids: Vec<String>,
16    /// Optional label for each node.
17    pub node_labels: Vec<Option<String>>,
18    /// Feature vector length for each node (0 = missing features).
19    pub node_feature_counts: Vec<usize>,
20    /// Edge list as index pairs into node_ids.
21    pub edges: Vec<(usize, usize)>,
22    /// Feature vector length for each edge.
23    pub edge_feature_counts: Vec<usize>,
24}
25
26/// Thresholds for GNN readiness analysis.
27#[derive(Debug, Clone)]
28pub struct GnnReadinessThresholds {
29    /// Minimum overall GNN readiness score.
30    pub min_gnn_readiness: f64,
31}
32
33impl Default for GnnReadinessThresholds {
34    fn default() -> Self {
35        Self {
36            min_gnn_readiness: 0.65,
37        }
38    }
39}
40
41/// Results of GNN readiness analysis.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct GnnReadinessAnalysis {
44    /// Overall GNN readiness score (0.0-1.0).
45    pub gnn_readiness_score: f64,
46    /// Fraction of edges connecting nodes with the same label.
47    pub homophily_ratio: f64,
48    /// Correlation between node degree and label (structural label leakage).
49    pub structural_label_leakage: f64,
50    /// Fraction of nodes with complete (non-zero) feature vectors.
51    pub feature_completeness_score: f64,
52    /// Average number of distinct labels in each node's 1-hop neighborhood.
53    pub avg_neighborhood_diversity: f64,
54    /// Total number of nodes.
55    pub total_nodes: usize,
56    /// Total number of edges.
57    pub total_edges: usize,
58    /// Whether the analysis passes all thresholds.
59    pub passes: bool,
60    /// Issues found during analysis.
61    pub issues: Vec<String>,
62}
63
64/// Analyzer for GNN readiness.
65pub struct GnnReadinessAnalyzer {
66    thresholds: GnnReadinessThresholds,
67}
68
69impl GnnReadinessAnalyzer {
70    /// Create a new analyzer with default thresholds.
71    pub fn new() -> Self {
72        Self {
73            thresholds: GnnReadinessThresholds::default(),
74        }
75    }
76
77    /// Create an analyzer with custom thresholds.
78    pub fn with_thresholds(thresholds: GnnReadinessThresholds) -> Self {
79        Self { thresholds }
80    }
81
82    /// Analyze GNN readiness.
83    pub fn analyze(&self, data: &GraphData) -> EvalResult<GnnReadinessAnalysis> {
84        let mut issues = Vec::new();
85        let total_nodes = data.node_ids.len();
86        let total_edges = data.edges.len();
87
88        if total_nodes == 0 {
89            return Ok(GnnReadinessAnalysis {
90                gnn_readiness_score: 0.0,
91                homophily_ratio: 0.0,
92                structural_label_leakage: 0.0,
93                feature_completeness_score: 0.0,
94                avg_neighborhood_diversity: 0.0,
95                total_nodes: 0,
96                total_edges: 0,
97                passes: true,
98                issues: vec!["No nodes provided".to_string()],
99            });
100        }
101
102        // Feature completeness: fraction of nodes with non-zero feature count
103        let complete_nodes = data.node_feature_counts.iter().filter(|&&c| c > 0).count();
104        let feature_completeness_score = complete_nodes as f64 / total_nodes as f64;
105
106        // Build adjacency list
107        let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
108        for &(src, tgt) in &data.edges {
109            if src < total_nodes && tgt < total_nodes {
110                adjacency.entry(src).or_default().push(tgt);
111                adjacency.entry(tgt).or_default().push(src);
112            }
113        }
114
115        // Homophily ratio: fraction of edges where both endpoints share a label
116        let homophily_ratio = self.compute_homophily(data, total_nodes);
117
118        // Structural label leakage: correlation between degree and label
119        let structural_label_leakage = self.compute_label_leakage(data, &adjacency, total_nodes);
120
121        // Neighborhood diversity: average distinct labels in 1-hop neighborhood
122        let avg_neighborhood_diversity =
123            self.compute_neighborhood_diversity(data, &adjacency, total_nodes);
124
125        // Composite readiness score
126        let gnn_readiness_score = (feature_completeness_score * 0.3
127            + homophily_ratio.clamp(0.0, 1.0) * 0.3
128            + (1.0 - structural_label_leakage.abs()).clamp(0.0, 1.0) * 0.2
129            + avg_neighborhood_diversity.clamp(0.0, 1.0) * 0.2)
130            .clamp(0.0, 1.0);
131
132        if gnn_readiness_score < self.thresholds.min_gnn_readiness {
133            issues.push(format!(
134                "GNN readiness score {:.4} < {:.4} (threshold)",
135                gnn_readiness_score, self.thresholds.min_gnn_readiness
136            ));
137        }
138
139        if feature_completeness_score < 0.5 {
140            issues.push(format!(
141                "Low feature completeness: {:.2}%",
142                feature_completeness_score * 100.0
143            ));
144        }
145
146        let passes = issues.is_empty();
147
148        Ok(GnnReadinessAnalysis {
149            gnn_readiness_score,
150            homophily_ratio,
151            structural_label_leakage,
152            feature_completeness_score,
153            avg_neighborhood_diversity,
154            total_nodes,
155            total_edges,
156            passes,
157            issues,
158        })
159    }
160
161    /// Compute homophily ratio: fraction of edges connecting same-label nodes.
162    fn compute_homophily(&self, data: &GraphData, total_nodes: usize) -> f64 {
163        if data.edges.is_empty() {
164            return 0.0;
165        }
166
167        let mut same_label = 0usize;
168        let mut labeled_edges = 0usize;
169
170        for &(src, tgt) in &data.edges {
171            if src >= total_nodes || tgt >= total_nodes {
172                continue;
173            }
174            let src_label = data.node_labels.get(src).and_then(|l| l.as_ref());
175            let tgt_label = data.node_labels.get(tgt).and_then(|l| l.as_ref());
176
177            if let (Some(sl), Some(tl)) = (src_label, tgt_label) {
178                labeled_edges += 1;
179                if sl == tl {
180                    same_label += 1;
181                }
182            }
183        }
184
185        if labeled_edges == 0 {
186            return 0.0;
187        }
188
189        same_label as f64 / labeled_edges as f64
190    }
191
192    /// Compute structural label leakage as correlation between degree and label.
193    ///
194    /// Encodes labels as ordinal indices and computes Pearson correlation
195    /// with node degree.
196    fn compute_label_leakage(
197        &self,
198        data: &GraphData,
199        adjacency: &HashMap<usize, Vec<usize>>,
200        total_nodes: usize,
201    ) -> f64 {
202        // Build label-to-index mapping
203        let mut label_map: HashMap<&str, f64> = HashMap::new();
204        let mut next_idx = 0.0;
205        for label in data.node_labels.iter().flatten() {
206            if !label_map.contains_key(label.as_str()) {
207                label_map.insert(label.as_str(), next_idx);
208                next_idx += 1.0;
209            }
210        }
211
212        let mut degrees = Vec::new();
213        let mut label_indices = Vec::new();
214
215        for i in 0..total_nodes {
216            if let Some(Some(ref label)) = data.node_labels.get(i) {
217                if let Some(&idx) = label_map.get(label.as_str()) {
218                    let degree = adjacency.get(&i).map_or(0, |v| v.len());
219                    degrees.push(degree as f64);
220                    label_indices.push(idx);
221                }
222            }
223        }
224
225        if degrees.len() < 3 {
226            return 0.0;
227        }
228
229        pearson_correlation_slices(&degrees, &label_indices).unwrap_or(0.0)
230    }
231
232    /// Compute average neighborhood diversity.
233    ///
234    /// For each node with a label, count distinct labels among its 1-hop neighbors,
235    /// normalized by the total number of distinct labels.
236    fn compute_neighborhood_diversity(
237        &self,
238        data: &GraphData,
239        adjacency: &HashMap<usize, Vec<usize>>,
240        total_nodes: usize,
241    ) -> f64 {
242        let all_labels: HashSet<&str> = data
243            .node_labels
244            .iter()
245            .filter_map(|l| l.as_deref())
246            .collect();
247
248        if all_labels.is_empty() || all_labels.len() == 1 {
249            return if all_labels.len() == 1 { 1.0 } else { 0.0 };
250        }
251
252        let label_count = all_labels.len() as f64;
253        let mut total_diversity = 0.0;
254        let mut counted_nodes = 0usize;
255
256        for i in 0..total_nodes {
257            if let Some(neighbors) = adjacency.get(&i) {
258                if neighbors.is_empty() {
259                    continue;
260                }
261                let neighbor_labels: HashSet<&str> = neighbors
262                    .iter()
263                    .filter_map(|&n| data.node_labels.get(n).and_then(|l| l.as_deref()))
264                    .collect();
265
266                if !neighbor_labels.is_empty() {
267                    total_diversity += neighbor_labels.len() as f64 / label_count;
268                    counted_nodes += 1;
269                }
270            }
271        }
272
273        if counted_nodes == 0 {
274            return 0.0;
275        }
276
277        total_diversity / counted_nodes as f64
278    }
279}
280
281/// Compute Pearson correlation between two slices.
282fn pearson_correlation_slices(x: &[f64], y: &[f64]) -> Option<f64> {
283    let n = x.len().min(y.len());
284    if n < 3 {
285        return None;
286    }
287
288    let mean_x = x[..n].iter().sum::<f64>() / n as f64;
289    let mean_y = y[..n].iter().sum::<f64>() / n as f64;
290
291    let mut cov = 0.0;
292    let mut var_x = 0.0;
293    let mut var_y = 0.0;
294
295    for i in 0..n {
296        let dx = x[i] - mean_x;
297        let dy = y[i] - mean_y;
298        cov += dx * dy;
299        var_x += dx * dx;
300        var_y += dy * dy;
301    }
302
303    let denom = (var_x * var_y).sqrt();
304    if denom < 1e-12 {
305        return None;
306    }
307
308    Some(cov / denom)
309}
310
311impl Default for GnnReadinessAnalyzer {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_valid_graph() {
324        let data = GraphData {
325            node_ids: vec!["n0".into(), "n1".into(), "n2".into(), "n3".into()],
326            node_labels: vec![
327                Some("A".into()),
328                Some("A".into()),
329                Some("B".into()),
330                Some("B".into()),
331            ],
332            node_feature_counts: vec![10, 10, 10, 10],
333            edges: vec![(0, 1), (1, 2), (2, 3), (0, 3)],
334            edge_feature_counts: vec![5, 5, 5, 5],
335        };
336
337        let analyzer = GnnReadinessAnalyzer::new();
338        let result = analyzer.analyze(&data).unwrap();
339
340        assert_eq!(result.total_nodes, 4);
341        assert_eq!(result.total_edges, 4);
342        assert!(result.feature_completeness_score > 0.99);
343        assert!(result.gnn_readiness_score > 0.0);
344    }
345
346    #[test]
347    fn test_invalid_graph_missing_features() {
348        let data = GraphData {
349            node_ids: vec!["n0".into(), "n1".into(), "n2".into(), "n3".into()],
350            node_labels: vec![
351                Some("A".into()),
352                Some("A".into()),
353                Some("B".into()),
354                Some("B".into()),
355            ],
356            node_feature_counts: vec![0, 0, 0, 0], // no features
357            edges: vec![(0, 1)],
358            edge_feature_counts: vec![0],
359        };
360
361        let analyzer = GnnReadinessAnalyzer::new();
362        let result = analyzer.analyze(&data).unwrap();
363
364        assert!(result.feature_completeness_score < 0.01);
365        assert!(!result.passes);
366    }
367
368    #[test]
369    fn test_empty_graph() {
370        let data = GraphData {
371            node_ids: Vec::new(),
372            node_labels: Vec::new(),
373            node_feature_counts: Vec::new(),
374            edges: Vec::new(),
375            edge_feature_counts: Vec::new(),
376        };
377
378        let analyzer = GnnReadinessAnalyzer::new();
379        let result = analyzer.analyze(&data).unwrap();
380
381        assert_eq!(result.total_nodes, 0);
382        assert_eq!(result.gnn_readiness_score, 0.0);
383    }
384}