1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
13pub struct GraphData {
14 pub node_ids: Vec<String>,
16 pub node_labels: Vec<Option<String>>,
18 pub node_feature_counts: Vec<usize>,
20 pub edges: Vec<(usize, usize)>,
22 pub edge_feature_counts: Vec<usize>,
24}
25
26#[derive(Debug, Clone)]
28pub struct GnnReadinessThresholds {
29 pub min_gnn_readiness: f64,
31}
32
33impl Default for GnnReadinessThresholds {
34 fn default() -> Self {
35 Self {
36 min_gnn_readiness: 0.65,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct GnnReadinessAnalysis {
44 pub gnn_readiness_score: f64,
46 pub homophily_ratio: f64,
48 pub structural_label_leakage: f64,
50 pub feature_completeness_score: f64,
52 pub avg_neighborhood_diversity: f64,
54 pub total_nodes: usize,
56 pub total_edges: usize,
58 pub passes: bool,
60 pub issues: Vec<String>,
62}
63
64pub struct GnnReadinessAnalyzer {
66 thresholds: GnnReadinessThresholds,
67}
68
69impl GnnReadinessAnalyzer {
70 pub fn new() -> Self {
72 Self {
73 thresholds: GnnReadinessThresholds::default(),
74 }
75 }
76
77 pub fn with_thresholds(thresholds: GnnReadinessThresholds) -> Self {
79 Self { thresholds }
80 }
81
82 pub fn analyze(&self, data: &GraphData) -> EvalResult<GnnReadinessAnalysis> {
84 let mut issues = Vec::new();
85 let total_nodes = data.node_ids.len();
86 let total_edges = data.edges.len();
87
88 if total_nodes == 0 {
89 return Ok(GnnReadinessAnalysis {
90 gnn_readiness_score: 0.0,
91 homophily_ratio: 0.0,
92 structural_label_leakage: 0.0,
93 feature_completeness_score: 0.0,
94 avg_neighborhood_diversity: 0.0,
95 total_nodes: 0,
96 total_edges: 0,
97 passes: true,
98 issues: vec!["No nodes provided".to_string()],
99 });
100 }
101
102 let complete_nodes = data.node_feature_counts.iter().filter(|&&c| c > 0).count();
104 let feature_completeness_score = complete_nodes as f64 / total_nodes as f64;
105
106 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
108 for &(src, tgt) in &data.edges {
109 if src < total_nodes && tgt < total_nodes {
110 adjacency.entry(src).or_default().push(tgt);
111 adjacency.entry(tgt).or_default().push(src);
112 }
113 }
114
115 let homophily_ratio = self.compute_homophily(data, total_nodes);
117
118 let structural_label_leakage = self.compute_label_leakage(data, &adjacency, total_nodes);
120
121 let avg_neighborhood_diversity =
123 self.compute_neighborhood_diversity(data, &adjacency, total_nodes);
124
125 let gnn_readiness_score = (feature_completeness_score * 0.3
127 + homophily_ratio.clamp(0.0, 1.0) * 0.3
128 + (1.0 - structural_label_leakage.abs()).clamp(0.0, 1.0) * 0.2
129 + avg_neighborhood_diversity.clamp(0.0, 1.0) * 0.2)
130 .clamp(0.0, 1.0);
131
132 if gnn_readiness_score < self.thresholds.min_gnn_readiness {
133 issues.push(format!(
134 "GNN readiness score {:.4} < {:.4} (threshold)",
135 gnn_readiness_score, self.thresholds.min_gnn_readiness
136 ));
137 }
138
139 if feature_completeness_score < 0.5 {
140 issues.push(format!(
141 "Low feature completeness: {:.2}%",
142 feature_completeness_score * 100.0
143 ));
144 }
145
146 let passes = issues.is_empty();
147
148 Ok(GnnReadinessAnalysis {
149 gnn_readiness_score,
150 homophily_ratio,
151 structural_label_leakage,
152 feature_completeness_score,
153 avg_neighborhood_diversity,
154 total_nodes,
155 total_edges,
156 passes,
157 issues,
158 })
159 }
160
161 fn compute_homophily(&self, data: &GraphData, total_nodes: usize) -> f64 {
163 if data.edges.is_empty() {
164 return 0.0;
165 }
166
167 let mut same_label = 0usize;
168 let mut labeled_edges = 0usize;
169
170 for &(src, tgt) in &data.edges {
171 if src >= total_nodes || tgt >= total_nodes {
172 continue;
173 }
174 let src_label = data.node_labels.get(src).and_then(|l| l.as_ref());
175 let tgt_label = data.node_labels.get(tgt).and_then(|l| l.as_ref());
176
177 if let (Some(sl), Some(tl)) = (src_label, tgt_label) {
178 labeled_edges += 1;
179 if sl == tl {
180 same_label += 1;
181 }
182 }
183 }
184
185 if labeled_edges == 0 {
186 return 0.0;
187 }
188
189 same_label as f64 / labeled_edges as f64
190 }
191
192 fn compute_label_leakage(
197 &self,
198 data: &GraphData,
199 adjacency: &HashMap<usize, Vec<usize>>,
200 total_nodes: usize,
201 ) -> f64 {
202 let mut label_map: HashMap<&str, f64> = HashMap::new();
204 let mut next_idx = 0.0;
205 for label in data.node_labels.iter().flatten() {
206 if !label_map.contains_key(label.as_str()) {
207 label_map.insert(label.as_str(), next_idx);
208 next_idx += 1.0;
209 }
210 }
211
212 let mut degrees = Vec::new();
213 let mut label_indices = Vec::new();
214
215 for i in 0..total_nodes {
216 if let Some(Some(ref label)) = data.node_labels.get(i) {
217 if let Some(&idx) = label_map.get(label.as_str()) {
218 let degree = adjacency.get(&i).map_or(0, |v| v.len());
219 degrees.push(degree as f64);
220 label_indices.push(idx);
221 }
222 }
223 }
224
225 if degrees.len() < 3 {
226 return 0.0;
227 }
228
229 pearson_correlation_slices(°rees, &label_indices).unwrap_or(0.0)
230 }
231
232 fn compute_neighborhood_diversity(
237 &self,
238 data: &GraphData,
239 adjacency: &HashMap<usize, Vec<usize>>,
240 total_nodes: usize,
241 ) -> f64 {
242 let all_labels: HashSet<&str> = data
243 .node_labels
244 .iter()
245 .filter_map(|l| l.as_deref())
246 .collect();
247
248 if all_labels.is_empty() || all_labels.len() == 1 {
249 return if all_labels.len() == 1 { 1.0 } else { 0.0 };
250 }
251
252 let label_count = all_labels.len() as f64;
253 let mut total_diversity = 0.0;
254 let mut counted_nodes = 0usize;
255
256 for i in 0..total_nodes {
257 if let Some(neighbors) = adjacency.get(&i) {
258 if neighbors.is_empty() {
259 continue;
260 }
261 let neighbor_labels: HashSet<&str> = neighbors
262 .iter()
263 .filter_map(|&n| data.node_labels.get(n).and_then(|l| l.as_deref()))
264 .collect();
265
266 if !neighbor_labels.is_empty() {
267 total_diversity += neighbor_labels.len() as f64 / label_count;
268 counted_nodes += 1;
269 }
270 }
271 }
272
273 if counted_nodes == 0 {
274 return 0.0;
275 }
276
277 total_diversity / counted_nodes as f64
278 }
279}
280
281fn pearson_correlation_slices(x: &[f64], y: &[f64]) -> Option<f64> {
283 let n = x.len().min(y.len());
284 if n < 3 {
285 return None;
286 }
287
288 let mean_x = x[..n].iter().sum::<f64>() / n as f64;
289 let mean_y = y[..n].iter().sum::<f64>() / n as f64;
290
291 let mut cov = 0.0;
292 let mut var_x = 0.0;
293 let mut var_y = 0.0;
294
295 for i in 0..n {
296 let dx = x[i] - mean_x;
297 let dy = y[i] - mean_y;
298 cov += dx * dy;
299 var_x += dx * dx;
300 var_y += dy * dy;
301 }
302
303 let denom = (var_x * var_y).sqrt();
304 if denom < 1e-12 {
305 return None;
306 }
307
308 Some(cov / denom)
309}
310
311impl Default for GnnReadinessAnalyzer {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_valid_graph() {
324 let data = GraphData {
325 node_ids: vec!["n0".into(), "n1".into(), "n2".into(), "n3".into()],
326 node_labels: vec![
327 Some("A".into()),
328 Some("A".into()),
329 Some("B".into()),
330 Some("B".into()),
331 ],
332 node_feature_counts: vec![10, 10, 10, 10],
333 edges: vec![(0, 1), (1, 2), (2, 3), (0, 3)],
334 edge_feature_counts: vec![5, 5, 5, 5],
335 };
336
337 let analyzer = GnnReadinessAnalyzer::new();
338 let result = analyzer.analyze(&data).unwrap();
339
340 assert_eq!(result.total_nodes, 4);
341 assert_eq!(result.total_edges, 4);
342 assert!(result.feature_completeness_score > 0.99);
343 assert!(result.gnn_readiness_score > 0.0);
344 }
345
346 #[test]
347 fn test_invalid_graph_missing_features() {
348 let data = GraphData {
349 node_ids: vec!["n0".into(), "n1".into(), "n2".into(), "n3".into()],
350 node_labels: vec![
351 Some("A".into()),
352 Some("A".into()),
353 Some("B".into()),
354 Some("B".into()),
355 ],
356 node_feature_counts: vec![0, 0, 0, 0], edges: vec![(0, 1)],
358 edge_feature_counts: vec![0],
359 };
360
361 let analyzer = GnnReadinessAnalyzer::new();
362 let result = analyzer.analyze(&data).unwrap();
363
364 assert!(result.feature_completeness_score < 0.01);
365 assert!(!result.passes);
366 }
367
368 #[test]
369 fn test_empty_graph() {
370 let data = GraphData {
371 node_ids: Vec::new(),
372 node_labels: Vec::new(),
373 node_feature_counts: Vec::new(),
374 edges: Vec::new(),
375 edge_feature_counts: Vec::new(),
376 };
377
378 let analyzer = GnnReadinessAnalyzer::new();
379 let result = analyzer.analyze(&data).unwrap();
380
381 assert_eq!(result.total_nodes, 0);
382 assert_eq!(result.gnn_readiness_score, 0.0);
383 }
384}