Skip to main content

datasynth_eval/ml/
cross_modal.rs

1//! Cross-modal consistency evaluation.
2//!
3//! Measures consistency between graph and tabular feature representations
4//! for the same entities, using Pearson correlation for corresponding
5//! feature dimensions.
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9
10/// Modal data for a single entity with tabular and graph feature vectors.
11#[derive(Debug, Clone)]
12pub struct EntityModalData {
13    /// Entity identifier.
14    pub entity_id: String,
15    /// Tabular feature vector.
16    pub tabular_features: Vec<f64>,
17    /// Graph-derived feature vector.
18    pub graph_features: Vec<f64>,
19}
20
21/// Thresholds for cross-modal consistency analysis.
22#[derive(Debug, Clone)]
23pub struct CrossModalThresholds {
24    /// Minimum consistency score.
25    pub min_consistency: f64,
26}
27
28impl Default for CrossModalThresholds {
29    fn default() -> Self {
30        Self {
31            min_consistency: 0.60,
32        }
33    }
34}
35
36/// Results of cross-modal consistency analysis.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CrossModalAnalysis {
39    /// Average Pearson correlation between tabular and graph features.
40    pub tabular_graph_correlation: f64,
41    /// Overall consistency score (0.0-1.0).
42    pub consistency_score: f64,
43    /// Total number of entities analyzed.
44    pub total_entities: usize,
45    /// Whether the analysis passes all thresholds.
46    pub passes: bool,
47    /// Issues found during analysis.
48    pub issues: Vec<String>,
49}
50
51/// Analyzer for cross-modal consistency.
52pub struct CrossModalAnalyzer {
53    thresholds: CrossModalThresholds,
54}
55
56impl CrossModalAnalyzer {
57    /// Create a new analyzer with default thresholds.
58    pub fn new() -> Self {
59        Self {
60            thresholds: CrossModalThresholds::default(),
61        }
62    }
63
64    /// Create an analyzer with custom thresholds.
65    pub fn with_thresholds(thresholds: CrossModalThresholds) -> Self {
66        Self { thresholds }
67    }
68
69    /// Analyze cross-modal consistency.
70    pub fn analyze(&self, entities: &[EntityModalData]) -> EvalResult<CrossModalAnalysis> {
71        let mut issues = Vec::new();
72        let total_entities = entities.len();
73
74        if entities.is_empty() {
75            return Ok(CrossModalAnalysis {
76                tabular_graph_correlation: 0.0,
77                consistency_score: 0.0,
78                total_entities: 0,
79                passes: true,
80                issues: vec!["No entities provided".to_string()],
81            });
82        }
83
84        // Determine the common feature dimension
85        let min_dim = entities
86            .iter()
87            .map(|e| e.tabular_features.len().min(e.graph_features.len()))
88            .min()
89            .unwrap_or(0);
90
91        if min_dim == 0 {
92            return Ok(CrossModalAnalysis {
93                tabular_graph_correlation: 0.0,
94                consistency_score: 0.0,
95                total_entities,
96                passes: false,
97                issues: vec!["No common feature dimensions".to_string()],
98            });
99        }
100
101        // Compute per-dimension Pearson correlation across entities
102        let mut correlations = Vec::new();
103
104        for dim in 0..min_dim {
105            let tabular_vals: Vec<f64> = entities.iter().map(|e| e.tabular_features[dim]).collect();
106            let graph_vals: Vec<f64> = entities.iter().map(|e| e.graph_features[dim]).collect();
107
108            if let Some(corr) = pearson_correlation(&tabular_vals, &graph_vals) {
109                correlations.push(corr);
110            }
111        }
112
113        let tabular_graph_correlation = if correlations.is_empty() {
114            0.0
115        } else {
116            correlations.iter().sum::<f64>() / correlations.len() as f64
117        };
118
119        // Consistency score: map correlation from [-1, 1] to [0, 1]
120        let consistency_score = ((tabular_graph_correlation + 1.0) / 2.0).clamp(0.0, 1.0);
121
122        if consistency_score < self.thresholds.min_consistency {
123            issues.push(format!(
124                "Cross-modal consistency {:.4} < {:.4} (threshold)",
125                consistency_score, self.thresholds.min_consistency
126            ));
127        }
128
129        let passes = issues.is_empty();
130
131        Ok(CrossModalAnalysis {
132            tabular_graph_correlation,
133            consistency_score,
134            total_entities,
135            passes,
136            issues,
137        })
138    }
139}
140
141/// Compute Pearson correlation between two vectors.
142fn pearson_correlation(x: &[f64], y: &[f64]) -> Option<f64> {
143    let n = x.len().min(y.len());
144    if n < 3 {
145        return None;
146    }
147
148    let mean_x = x[..n].iter().sum::<f64>() / n as f64;
149    let mean_y = y[..n].iter().sum::<f64>() / n as f64;
150
151    let mut cov = 0.0;
152    let mut var_x = 0.0;
153    let mut var_y = 0.0;
154
155    for i in 0..n {
156        let dx = x[i] - mean_x;
157        let dy = y[i] - mean_y;
158        cov += dx * dy;
159        var_x += dx * dx;
160        var_y += dy * dy;
161    }
162
163    let denom = (var_x * var_y).sqrt();
164    if denom < 1e-12 {
165        return None;
166    }
167
168    Some(cov / denom)
169}
170
171impl Default for CrossModalAnalyzer {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[cfg(test)]
178#[allow(clippy::unwrap_used)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_consistent_modalities() {
184        let entities = vec![
185            EntityModalData {
186                entity_id: "e1".into(),
187                tabular_features: vec![1.0, 2.0, 3.0],
188                graph_features: vec![1.1, 2.1, 3.1],
189            },
190            EntityModalData {
191                entity_id: "e2".into(),
192                tabular_features: vec![4.0, 5.0, 6.0],
193                graph_features: vec![4.2, 5.1, 6.3],
194            },
195            EntityModalData {
196                entity_id: "e3".into(),
197                tabular_features: vec![7.0, 8.0, 9.0],
198                graph_features: vec![7.1, 8.2, 9.1],
199            },
200            EntityModalData {
201                entity_id: "e4".into(),
202                tabular_features: vec![10.0, 11.0, 12.0],
203                graph_features: vec![10.0, 11.1, 12.2],
204            },
205        ];
206
207        let analyzer = CrossModalAnalyzer::new();
208        let result = analyzer.analyze(&entities).unwrap();
209
210        assert_eq!(result.total_entities, 4);
211        assert!(result.tabular_graph_correlation > 0.9);
212        assert!(result.consistency_score > 0.9);
213        assert!(result.passes);
214    }
215
216    #[test]
217    fn test_inconsistent_modalities() {
218        let entities = vec![
219            EntityModalData {
220                entity_id: "e1".into(),
221                tabular_features: vec![1.0, 2.0],
222                graph_features: vec![10.0, 1.0],
223            },
224            EntityModalData {
225                entity_id: "e2".into(),
226                tabular_features: vec![2.0, 1.0],
227                graph_features: vec![9.0, 2.0],
228            },
229            EntityModalData {
230                entity_id: "e3".into(),
231                tabular_features: vec![3.0, 0.5],
232                graph_features: vec![8.0, 3.5],
233            },
234            EntityModalData {
235                entity_id: "e4".into(),
236                tabular_features: vec![4.0, 0.1],
237                graph_features: vec![7.0, 4.0],
238            },
239        ];
240
241        let analyzer = CrossModalAnalyzer::new();
242        let result = analyzer.analyze(&entities).unwrap();
243
244        // Anti-correlated on first dim, some correlation on second
245        // Overall consistency should be lower
246        assert!(result.consistency_score < 0.6);
247        assert!(!result.passes);
248    }
249
250    #[test]
251    fn test_empty_entities() {
252        let analyzer = CrossModalAnalyzer::new();
253        let result = analyzer.analyze(&[]).unwrap();
254
255        assert_eq!(result.total_entities, 0);
256        assert_eq!(result.consistency_score, 0.0);
257    }
258}