datasynth_graph/ml/
features.rs1use std::collections::HashMap;
4
5use chrono::{Datelike, NaiveDate};
6use rust_decimal::Decimal;
7
8use crate::models::{Graph, NodeId};
9
10#[derive(Debug, Clone)]
12pub enum NormalizationMethod {
13 None,
15 MinMax,
17 ZScore,
19 Log,
21 Robust,
23}
24
25pub struct FeatureNormalizer {
27 method: NormalizationMethod,
28 stats: Vec<FeatureStats>,
30}
31
32#[derive(Debug, Clone, Default)]
34struct FeatureStats {
35 min: f64,
36 max: f64,
37 mean: f64,
38 std: f64,
39 median: f64,
40 q1: f64,
41 q3: f64,
42}
43
44impl FeatureNormalizer {
45 pub fn new(method: NormalizationMethod) -> Self {
47 Self {
48 method,
49 stats: Vec::new(),
50 }
51 }
52
53 pub fn fit_nodes(&mut self, graph: &Graph) {
55 let features = graph.node_features();
56 self.fit(&features);
57 }
58
59 pub fn fit_edges(&mut self, graph: &Graph) {
61 let features = graph.edge_features();
62 self.fit(&features);
63 }
64
65 fn fit(&mut self, features: &[Vec<f64>]) {
67 if features.is_empty() {
68 return;
69 }
70
71 let dim = features[0].len();
72 self.stats = (0..dim)
73 .map(|d| {
74 let values: Vec<f64> = features
75 .iter()
76 .map(|f| f.get(d).copied().unwrap_or(0.0))
77 .collect();
78 Self::compute_stats(&values)
79 })
80 .collect();
81 }
82
83 fn compute_stats(values: &[f64]) -> FeatureStats {
85 if values.is_empty() {
86 return FeatureStats::default();
87 }
88
89 let n = values.len() as f64;
90 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
91 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
92 let sum: f64 = values.iter().sum();
93 let mean = sum / n;
94 let variance: f64 = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
95 let std = variance.sqrt();
96
97 let mut sorted = values.to_vec();
99 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100
101 let median = if sorted.len() % 2 == 0 {
102 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
103 } else {
104 sorted[sorted.len() / 2]
105 };
106
107 let q1_idx = sorted.len() / 4;
108 let q3_idx = 3 * sorted.len() / 4;
109 let q1 = sorted.get(q1_idx).copied().unwrap_or(min);
110 let q3 = sorted.get(q3_idx).copied().unwrap_or(max);
111
112 FeatureStats {
113 min,
114 max,
115 mean,
116 std,
117 median,
118 q1,
119 q3,
120 }
121 }
122
123 pub fn transform(&self, features: &[Vec<f64>]) -> Vec<Vec<f64>> {
125 features.iter().map(|f| self.transform_single(f)).collect()
126 }
127
128 fn transform_single(&self, features: &[f64]) -> Vec<f64> {
130 features
131 .iter()
132 .enumerate()
133 .map(|(i, &x)| {
134 let stats = self.stats.get(i).cloned().unwrap_or_default();
135 self.normalize_value(x, &stats)
136 })
137 .collect()
138 }
139
140 fn normalize_value(&self, x: f64, stats: &FeatureStats) -> f64 {
142 match self.method {
143 NormalizationMethod::None => x,
144 NormalizationMethod::MinMax => {
145 let range = stats.max - stats.min;
146 if range.abs() < 1e-10 {
147 0.0
148 } else {
149 (x - stats.min) / range
150 }
151 }
152 NormalizationMethod::ZScore => {
153 if stats.std.abs() < 1e-10 {
154 0.0
155 } else {
156 (x - stats.mean) / stats.std
157 }
158 }
159 NormalizationMethod::Log => (x.abs() + 1.0).ln() * x.signum(),
160 NormalizationMethod::Robust => {
161 let iqr = stats.q3 - stats.q1;
162 if iqr.abs() < 1e-10 {
163 0.0
164 } else {
165 (x - stats.median) / iqr
166 }
167 }
168 }
169 }
170}
171
172pub fn compute_structural_features(graph: &Graph) -> HashMap<NodeId, Vec<f64>> {
174 let mut features = HashMap::new();
175
176 for &node_id in graph.nodes.keys() {
177 let mut node_features = Vec::new();
178
179 let in_degree = graph.in_degree(node_id) as f64;
181 let out_degree = graph.out_degree(node_id) as f64;
182 let total_degree = in_degree + out_degree;
183
184 node_features.push(in_degree);
185 node_features.push(out_degree);
186 node_features.push(total_degree);
187
188 node_features.push((in_degree + 1.0).ln());
190 node_features.push((out_degree + 1.0).ln());
191
192 if total_degree > 0.0 {
194 node_features.push(in_degree / total_degree);
195 node_features.push(out_degree / total_degree);
196 } else {
197 node_features.push(0.5);
198 node_features.push(0.5);
199 }
200
201 let in_weight: f64 = graph.incoming_edges(node_id).iter().map(|e| e.weight).sum();
203 let out_weight: f64 = graph.outgoing_edges(node_id).iter().map(|e| e.weight).sum();
204
205 node_features.push((in_weight + 1.0).ln());
206 node_features.push((out_weight + 1.0).ln());
207
208 if in_degree > 0.0 {
210 node_features.push(in_weight / in_degree);
211 } else {
212 node_features.push(0.0);
213 }
214 if out_degree > 0.0 {
215 node_features.push(out_weight / out_degree);
216 } else {
217 node_features.push(0.0);
218 }
219
220 let neighbors = graph.neighbors(node_id);
222 let k = neighbors.len();
223 if k > 1 {
224 let mut triangle_count = 0;
225 for i in 0..k {
226 for j in i + 1..k {
227 if graph.neighbors(neighbors[i]).contains(&neighbors[j]) {
228 triangle_count += 1;
229 }
230 }
231 }
232 let max_triangles = k * (k - 1) / 2;
233 node_features.push(triangle_count as f64 / max_triangles as f64);
234 } else {
235 node_features.push(0.0);
236 }
237
238 features.insert(node_id, node_features);
239 }
240
241 features
242}
243
244pub fn compute_temporal_features(date: NaiveDate) -> Vec<f64> {
246 let mut features = Vec::new();
247
248 let weekday = date.weekday().num_days_from_monday() as f64;
250 features.push(weekday / 6.0);
251
252 let day = date.day() as f64;
254 features.push(day / 31.0);
255
256 let month = date.month() as f64;
258 features.push(month / 12.0);
259
260 let quarter = ((month - 1.0) / 3.0).floor() + 1.0;
262 features.push(quarter / 4.0);
263
264 features.push(if weekday >= 5.0 { 1.0 } else { 0.0 });
266
267 features.push(if day >= 28.0 { 1.0 } else { 0.0 });
269
270 let is_quarter_end = matches!(month as u32, 3 | 6 | 9 | 12) && day >= 28.0;
272 features.push(if is_quarter_end { 1.0 } else { 0.0 });
273
274 features.push(if month == 12.0 { 1.0 } else { 0.0 });
276
277 let day_of_year = date.ordinal() as f64;
279 features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).sin());
280 features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).cos());
281
282 features.push((2.0 * std::f64::consts::PI * weekday / 7.0).sin());
284 features.push((2.0 * std::f64::consts::PI * weekday / 7.0).cos());
285
286 features
287}
288
289pub fn compute_benford_features(amount: f64) -> Vec<f64> {
291 let mut features = Vec::new();
292
293 let first_digit = extract_first_digit(amount);
295 let benford_prob = benford_probability(first_digit);
296 features.push(benford_prob);
297
298 let expected_benford = [
300 0.301, 0.176, 0.125, 0.097, 0.079, 0.067, 0.058, 0.051, 0.046,
301 ];
302 if (1..=9).contains(&first_digit) {
303 let deviation = (expected_benford[first_digit as usize - 1] - benford_prob).abs();
304 features.push(deviation);
305 } else {
306 features.push(0.0);
307 }
308
309 for d in 1..=9 {
311 features.push(if first_digit == d { 1.0 } else { 0.0 });
312 }
313
314 let second_digit = extract_second_digit(amount);
316 features.push(second_digit as f64 / 9.0);
317
318 features
319}
320
321fn extract_first_digit(value: f64) -> u32 {
323 if value == 0.0 {
324 return 0;
325 }
326 let abs_val = value.abs();
327 let log10 = abs_val.log10().floor();
328 let normalized = abs_val / 10_f64.powf(log10);
329 normalized.floor() as u32
330}
331
332fn extract_second_digit(value: f64) -> u32 {
334 if value == 0.0 {
335 return 0;
336 }
337 let abs_val = value.abs();
338 let log10 = abs_val.log10().floor();
339 let normalized = abs_val / 10_f64.powf(log10);
340 ((normalized - normalized.floor()) * 10.0).floor() as u32
341}
342
343fn benford_probability(digit: u32) -> f64 {
345 if digit == 0 || digit > 9 {
346 return 0.0;
347 }
348 (1.0 + 1.0 / digit as f64).log10()
349}
350
351pub fn compute_amount_features(amount: Decimal) -> Vec<f64> {
353 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
354 let mut features = Vec::new();
355
356 features.push((amount_f64.abs() + 1.0).ln());
358
359 features.push(if amount_f64 >= 0.0 { 1.0 } else { 0.0 });
361
362 let is_round = (amount_f64 % 100.0).abs() < 0.01;
364 features.push(if is_round { 1.0 } else { 0.0 });
365
366 let magnitude = if amount_f64.abs() < 1.0 {
368 0
369 } else {
370 (amount_f64.abs().log10().floor() as i32).clamp(0, 9)
371 };
372 for m in 0..10 {
373 features.push(if magnitude == m { 1.0 } else { 0.0 });
374 }
375
376 features.extend(compute_benford_features(amount_f64));
378
379 features
380}
381
382pub fn one_hot_encode(value: &str, categories: &[&str]) -> Vec<f64> {
384 let mut encoding = vec![0.0; categories.len()];
385 if let Some(idx) = categories.iter().position(|&c| c == value) {
386 encoding[idx] = 1.0;
387 }
388 encoding
389}
390
391pub fn label_encode(value: &str, categories: &[&str]) -> f64 {
393 categories
394 .iter()
395 .position(|&c| c == value)
396 .map(|i| i as f64)
397 .unwrap_or(-1.0)
398}
399
400pub fn positional_encoding(position: usize, d_model: usize) -> Vec<f64> {
402 let mut encoding = Vec::with_capacity(d_model);
403
404 for i in 0..d_model {
405 let angle = position as f64 / 10000_f64.powf(2.0 * (i / 2) as f64 / d_model as f64);
406 if i % 2 == 0 {
407 encoding.push(angle.sin());
408 } else {
409 encoding.push(angle.cos());
410 }
411 }
412
413 encoding
414}
415
416pub fn compute_edge_direction_features(
418 source_features: &[f64],
419 target_features: &[f64],
420) -> Vec<f64> {
421 let mut features = Vec::new();
422
423 for (s, t) in source_features.iter().zip(target_features.iter()) {
425 features.push(t - s); }
427
428 for (s, t) in source_features.iter().zip(target_features.iter()) {
430 features.push((t - s).abs());
431 }
432
433 for (s, t) in source_features.iter().zip(target_features.iter()) {
435 features.push(s * t);
436 }
437
438 let source_sum: f64 = source_features.iter().sum();
440 let target_sum: f64 = target_features.iter().sum();
441 features.push(if source_sum > target_sum { 1.0 } else { 0.0 });
442
443 features
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_benford_probability() {
452 let prob1 = benford_probability(1);
453 assert!((prob1 - 0.301).abs() < 0.001);
454
455 let prob9 = benford_probability(9);
456 assert!((prob9 - 0.046).abs() < 0.001);
457 }
458
459 #[test]
460 fn test_extract_first_digit() {
461 assert_eq!(extract_first_digit(1234.56), 1);
462 assert_eq!(extract_first_digit(9876.54), 9);
463 assert_eq!(extract_first_digit(0.0123), 1);
464 }
465
466 #[test]
467 fn test_temporal_features() {
468 let date = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
469 let features = compute_temporal_features(date);
470
471 assert!(!features.is_empty());
472 assert!(features[7] > 0.5); }
475
476 #[test]
477 fn test_normalization() {
478 let features = vec![vec![1.0, 100.0], vec![2.0, 200.0], vec![3.0, 300.0]];
479
480 let mut normalizer = FeatureNormalizer::new(NormalizationMethod::MinMax);
481 normalizer.fit(&features);
482
483 let normalized = normalizer.transform(&features);
484 assert_eq!(normalized[0][0], 0.0); assert_eq!(normalized[2][0], 1.0); }
487
488 #[test]
489 fn test_one_hot_encode() {
490 let categories = ["A", "B", "C"];
491 let encoded = one_hot_encode("B", &categories);
492 assert_eq!(encoded, vec![0.0, 1.0, 0.0]);
493 }
494
495 #[test]
496 fn test_positional_encoding() {
497 let encoding = positional_encoding(0, 8);
498 assert_eq!(encoding.len(), 8);
499 assert_eq!(encoding[0], 0.0); }
501}