1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct EntityModalData {
13 pub entity_id: String,
15 pub tabular_features: Vec<f64>,
17 pub graph_features: Vec<f64>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CrossModalThresholds {
24 pub min_consistency: f64,
26}
27
28impl Default for CrossModalThresholds {
29 fn default() -> Self {
30 Self {
31 min_consistency: 0.60,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CrossModalAnalysis {
39 pub tabular_graph_correlation: f64,
41 pub consistency_score: f64,
43 pub total_entities: usize,
45 pub passes: bool,
47 pub issues: Vec<String>,
49}
50
51pub struct CrossModalAnalyzer {
53 thresholds: CrossModalThresholds,
54}
55
56impl CrossModalAnalyzer {
57 pub fn new() -> Self {
59 Self {
60 thresholds: CrossModalThresholds::default(),
61 }
62 }
63
64 pub fn with_thresholds(thresholds: CrossModalThresholds) -> Self {
66 Self { thresholds }
67 }
68
69 pub fn analyze(&self, entities: &[EntityModalData]) -> EvalResult<CrossModalAnalysis> {
71 let mut issues = Vec::new();
72 let total_entities = entities.len();
73
74 if entities.is_empty() {
75 return Ok(CrossModalAnalysis {
76 tabular_graph_correlation: 0.0,
77 consistency_score: 0.0,
78 total_entities: 0,
79 passes: true,
80 issues: vec!["No entities provided".to_string()],
81 });
82 }
83
84 let min_dim = entities
86 .iter()
87 .map(|e| e.tabular_features.len().min(e.graph_features.len()))
88 .min()
89 .unwrap_or(0);
90
91 if min_dim == 0 {
92 return Ok(CrossModalAnalysis {
93 tabular_graph_correlation: 0.0,
94 consistency_score: 0.0,
95 total_entities,
96 passes: false,
97 issues: vec!["No common feature dimensions".to_string()],
98 });
99 }
100
101 let mut correlations = Vec::new();
103
104 for dim in 0..min_dim {
105 let tabular_vals: Vec<f64> = entities.iter().map(|e| e.tabular_features[dim]).collect();
106 let graph_vals: Vec<f64> = entities.iter().map(|e| e.graph_features[dim]).collect();
107
108 if let Some(corr) = pearson_correlation(&tabular_vals, &graph_vals) {
109 correlations.push(corr);
110 }
111 }
112
113 let tabular_graph_correlation = if correlations.is_empty() {
114 0.0
115 } else {
116 correlations.iter().sum::<f64>() / correlations.len() as f64
117 };
118
119 let consistency_score = ((tabular_graph_correlation + 1.0) / 2.0).clamp(0.0, 1.0);
121
122 if consistency_score < self.thresholds.min_consistency {
123 issues.push(format!(
124 "Cross-modal consistency {:.4} < {:.4} (threshold)",
125 consistency_score, self.thresholds.min_consistency
126 ));
127 }
128
129 let passes = issues.is_empty();
130
131 Ok(CrossModalAnalysis {
132 tabular_graph_correlation,
133 consistency_score,
134 total_entities,
135 passes,
136 issues,
137 })
138 }
139}
140
141fn pearson_correlation(x: &[f64], y: &[f64]) -> Option<f64> {
143 let n = x.len().min(y.len());
144 if n < 3 {
145 return None;
146 }
147
148 let mean_x = x[..n].iter().sum::<f64>() / n as f64;
149 let mean_y = y[..n].iter().sum::<f64>() / n as f64;
150
151 let mut cov = 0.0;
152 let mut var_x = 0.0;
153 let mut var_y = 0.0;
154
155 for i in 0..n {
156 let dx = x[i] - mean_x;
157 let dy = y[i] - mean_y;
158 cov += dx * dy;
159 var_x += dx * dx;
160 var_y += dy * dy;
161 }
162
163 let denom = (var_x * var_y).sqrt();
164 if denom < 1e-12 {
165 return None;
166 }
167
168 Some(cov / denom)
169}
170
171impl Default for CrossModalAnalyzer {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(test)]
178#[allow(clippy::unwrap_used)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_consistent_modalities() {
184 let entities = vec![
185 EntityModalData {
186 entity_id: "e1".into(),
187 tabular_features: vec![1.0, 2.0, 3.0],
188 graph_features: vec![1.1, 2.1, 3.1],
189 },
190 EntityModalData {
191 entity_id: "e2".into(),
192 tabular_features: vec![4.0, 5.0, 6.0],
193 graph_features: vec![4.2, 5.1, 6.3],
194 },
195 EntityModalData {
196 entity_id: "e3".into(),
197 tabular_features: vec![7.0, 8.0, 9.0],
198 graph_features: vec![7.1, 8.2, 9.1],
199 },
200 EntityModalData {
201 entity_id: "e4".into(),
202 tabular_features: vec![10.0, 11.0, 12.0],
203 graph_features: vec![10.0, 11.1, 12.2],
204 },
205 ];
206
207 let analyzer = CrossModalAnalyzer::new();
208 let result = analyzer.analyze(&entities).unwrap();
209
210 assert_eq!(result.total_entities, 4);
211 assert!(result.tabular_graph_correlation > 0.9);
212 assert!(result.consistency_score > 0.9);
213 assert!(result.passes);
214 }
215
216 #[test]
217 fn test_inconsistent_modalities() {
218 let entities = vec![
219 EntityModalData {
220 entity_id: "e1".into(),
221 tabular_features: vec![1.0, 2.0],
222 graph_features: vec![10.0, 1.0],
223 },
224 EntityModalData {
225 entity_id: "e2".into(),
226 tabular_features: vec![2.0, 1.0],
227 graph_features: vec![9.0, 2.0],
228 },
229 EntityModalData {
230 entity_id: "e3".into(),
231 tabular_features: vec![3.0, 0.5],
232 graph_features: vec![8.0, 3.5],
233 },
234 EntityModalData {
235 entity_id: "e4".into(),
236 tabular_features: vec![4.0, 0.1],
237 graph_features: vec![7.0, 4.0],
238 },
239 ];
240
241 let analyzer = CrossModalAnalyzer::new();
242 let result = analyzer.analyze(&entities).unwrap();
243
244 assert!(result.consistency_score < 0.6);
247 assert!(!result.passes);
248 }
249
250 #[test]
251 fn test_empty_entities() {
252 let analyzer = CrossModalAnalyzer::new();
253 let result = analyzer.analyze(&[]).unwrap();
254
255 assert_eq!(result.total_entities, 0);
256 assert_eq!(result.consistency_score, 0.0);
257 }
258}