1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct EntityModalData {
13 pub entity_id: String,
15 pub tabular_features: Vec<f64>,
17 pub graph_features: Vec<f64>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CrossModalThresholds {
24 pub min_consistency: f64,
26}
27
28impl Default for CrossModalThresholds {
29 fn default() -> Self {
30 Self {
31 min_consistency: 0.60,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct CrossModalAnalysis {
39 pub tabular_graph_correlation: f64,
41 pub consistency_score: f64,
43 pub total_entities: usize,
45 pub passes: bool,
47 pub issues: Vec<String>,
49}
50
51pub struct CrossModalAnalyzer {
53 thresholds: CrossModalThresholds,
54}
55
56impl CrossModalAnalyzer {
57 pub fn new() -> Self {
59 Self {
60 thresholds: CrossModalThresholds::default(),
61 }
62 }
63
64 pub fn with_thresholds(thresholds: CrossModalThresholds) -> Self {
66 Self { thresholds }
67 }
68
69 pub fn analyze(&self, entities: &[EntityModalData]) -> EvalResult<CrossModalAnalysis> {
71 let mut issues = Vec::new();
72 let total_entities = entities.len();
73
74 if entities.is_empty() {
75 return Ok(CrossModalAnalysis {
76 tabular_graph_correlation: 0.0,
77 consistency_score: 0.0,
78 total_entities: 0,
79 passes: true,
80 issues: vec!["No entities provided".to_string()],
81 });
82 }
83
84 let min_dim = entities
86 .iter()
87 .map(|e| e.tabular_features.len().min(e.graph_features.len()))
88 .min()
89 .unwrap_or(0);
90
91 if min_dim == 0 {
92 return Ok(CrossModalAnalysis {
93 tabular_graph_correlation: 0.0,
94 consistency_score: 0.0,
95 total_entities,
96 passes: false,
97 issues: vec!["No common feature dimensions".to_string()],
98 });
99 }
100
101 let mut correlations = Vec::new();
103
104 for dim in 0..min_dim {
105 let tabular_vals: Vec<f64> = entities.iter().map(|e| e.tabular_features[dim]).collect();
106 let graph_vals: Vec<f64> = entities.iter().map(|e| e.graph_features[dim]).collect();
107
108 if let Some(corr) = pearson_correlation(&tabular_vals, &graph_vals) {
109 correlations.push(corr);
110 }
111 }
112
113 let tabular_graph_correlation = if correlations.is_empty() {
114 0.0
115 } else {
116 correlations.iter().sum::<f64>() / correlations.len() as f64
117 };
118
119 let consistency_score = ((tabular_graph_correlation + 1.0) / 2.0).clamp(0.0, 1.0);
121
122 if consistency_score < self.thresholds.min_consistency {
123 issues.push(format!(
124 "Cross-modal consistency {:.4} < {:.4} (threshold)",
125 consistency_score, self.thresholds.min_consistency
126 ));
127 }
128
129 let passes = issues.is_empty();
130
131 Ok(CrossModalAnalysis {
132 tabular_graph_correlation,
133 consistency_score,
134 total_entities,
135 passes,
136 issues,
137 })
138 }
139}
140
141fn pearson_correlation(x: &[f64], y: &[f64]) -> Option<f64> {
143 let n = x.len().min(y.len());
144 if n < 3 {
145 return None;
146 }
147
148 let mean_x = x[..n].iter().sum::<f64>() / n as f64;
149 let mean_y = y[..n].iter().sum::<f64>() / n as f64;
150
151 let mut cov = 0.0;
152 let mut var_x = 0.0;
153 let mut var_y = 0.0;
154
155 for i in 0..n {
156 let dx = x[i] - mean_x;
157 let dy = y[i] - mean_y;
158 cov += dx * dy;
159 var_x += dx * dx;
160 var_y += dy * dy;
161 }
162
163 let denom = (var_x * var_y).sqrt();
164 if denom < 1e-12 {
165 return None;
166 }
167
168 Some(cov / denom)
169}
170
171impl Default for CrossModalAnalyzer {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_consistent_modalities() {
183 let entities = vec![
184 EntityModalData {
185 entity_id: "e1".into(),
186 tabular_features: vec![1.0, 2.0, 3.0],
187 graph_features: vec![1.1, 2.1, 3.1],
188 },
189 EntityModalData {
190 entity_id: "e2".into(),
191 tabular_features: vec![4.0, 5.0, 6.0],
192 graph_features: vec![4.2, 5.1, 6.3],
193 },
194 EntityModalData {
195 entity_id: "e3".into(),
196 tabular_features: vec![7.0, 8.0, 9.0],
197 graph_features: vec![7.1, 8.2, 9.1],
198 },
199 EntityModalData {
200 entity_id: "e4".into(),
201 tabular_features: vec![10.0, 11.0, 12.0],
202 graph_features: vec![10.0, 11.1, 12.2],
203 },
204 ];
205
206 let analyzer = CrossModalAnalyzer::new();
207 let result = analyzer.analyze(&entities).unwrap();
208
209 assert_eq!(result.total_entities, 4);
210 assert!(result.tabular_graph_correlation > 0.9);
211 assert!(result.consistency_score > 0.9);
212 assert!(result.passes);
213 }
214
215 #[test]
216 fn test_inconsistent_modalities() {
217 let entities = vec![
218 EntityModalData {
219 entity_id: "e1".into(),
220 tabular_features: vec![1.0, 2.0],
221 graph_features: vec![10.0, 1.0],
222 },
223 EntityModalData {
224 entity_id: "e2".into(),
225 tabular_features: vec![2.0, 1.0],
226 graph_features: vec![9.0, 2.0],
227 },
228 EntityModalData {
229 entity_id: "e3".into(),
230 tabular_features: vec![3.0, 0.5],
231 graph_features: vec![8.0, 3.5],
232 },
233 EntityModalData {
234 entity_id: "e4".into(),
235 tabular_features: vec![4.0, 0.1],
236 graph_features: vec![7.0, 4.0],
237 },
238 ];
239
240 let analyzer = CrossModalAnalyzer::new();
241 let result = analyzer.analyze(&entities).unwrap();
242
243 assert!(result.consistency_score < 0.6);
246 assert!(!result.passes);
247 }
248
249 #[test]
250 fn test_empty_entities() {
251 let analyzer = CrossModalAnalyzer::new();
252 let result = analyzer.analyze(&[]).unwrap();
253
254 assert_eq!(result.total_entities, 0);
255 assert_eq!(result.consistency_score, 0.0);
256 }
257}