Skip to main content

datasynth_graph/ml/
features.rs

1//! Feature engineering utilities for graph neural networks.
2
3use std::collections::HashMap;
4
5use chrono::{Datelike, NaiveDate};
6use rust_decimal::Decimal;
7
8use crate::models::{Graph, NodeId};
9
10/// Feature normalization method.
11#[derive(Debug, Clone)]
12pub enum NormalizationMethod {
13    /// No normalization.
14    None,
15    /// Min-max normalization to [0, 1].
16    MinMax,
17    /// Z-score standardization.
18    ZScore,
19    /// Log transformation (log1p).
20    Log,
21    /// Robust scaling (using median and IQR).
22    Robust,
23}
24
25/// Feature normalizer for graph features.
26pub struct FeatureNormalizer {
27    method: NormalizationMethod,
28    /// Statistics per feature dimension.
29    stats: Vec<FeatureStats>,
30}
31
32/// Statistics for a single feature dimension.
33#[derive(Debug, Clone, Default)]
34struct FeatureStats {
35    min: f64,
36    max: f64,
37    mean: f64,
38    std: f64,
39    median: f64,
40    q1: f64,
41    q3: f64,
42}
43
44impl FeatureNormalizer {
45    /// Creates a new feature normalizer.
46    pub fn new(method: NormalizationMethod) -> Self {
47        Self {
48            method,
49            stats: Vec::new(),
50        }
51    }
52
53    /// Fits the normalizer to node features.
54    pub fn fit_nodes(&mut self, graph: &Graph) {
55        let features = graph.node_features();
56        self.fit(&features);
57    }
58
59    /// Fits the normalizer to edge features.
60    pub fn fit_edges(&mut self, graph: &Graph) {
61        let features = graph.edge_features();
62        self.fit(&features);
63    }
64
65    /// Fits the normalizer to features.
66    fn fit(&mut self, features: &[Vec<f64>]) {
67        if features.is_empty() {
68            return;
69        }
70
71        let dim = features[0].len();
72        self.stats = (0..dim)
73            .map(|d| {
74                let values: Vec<f64> = features
75                    .iter()
76                    .map(|f| f.get(d).copied().unwrap_or(0.0))
77                    .collect();
78                Self::compute_stats(&values)
79            })
80            .collect();
81    }
82
83    /// Computes statistics for a feature dimension.
84    fn compute_stats(values: &[f64]) -> FeatureStats {
85        if values.is_empty() {
86            return FeatureStats::default();
87        }
88
89        let n = values.len() as f64;
90        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
91        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
92        let sum: f64 = values.iter().sum();
93        let mean = sum / n;
94        let variance: f64 = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
95        let std = variance.sqrt();
96
97        // Compute quartiles
98        let mut sorted = values.to_vec();
99        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100
101        let median = if sorted.len() % 2 == 0 {
102            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
103        } else {
104            sorted[sorted.len() / 2]
105        };
106
107        let q1_idx = sorted.len() / 4;
108        let q3_idx = 3 * sorted.len() / 4;
109        let q1 = sorted.get(q1_idx).copied().unwrap_or(min);
110        let q3 = sorted.get(q3_idx).copied().unwrap_or(max);
111
112        FeatureStats {
113            min,
114            max,
115            mean,
116            std,
117            median,
118            q1,
119            q3,
120        }
121    }
122
123    /// Transforms features using fitted statistics.
124    pub fn transform(&self, features: &[Vec<f64>]) -> Vec<Vec<f64>> {
125        features.iter().map(|f| self.transform_single(f)).collect()
126    }
127
128    /// Transforms a single feature vector.
129    fn transform_single(&self, features: &[f64]) -> Vec<f64> {
130        features
131            .iter()
132            .enumerate()
133            .map(|(i, &x)| {
134                let stats = self.stats.get(i).cloned().unwrap_or_default();
135                self.normalize_value(x, &stats)
136            })
137            .collect()
138    }
139
140    /// Normalizes a single value.
141    fn normalize_value(&self, x: f64, stats: &FeatureStats) -> f64 {
142        match self.method {
143            NormalizationMethod::None => x,
144            NormalizationMethod::MinMax => {
145                let range = stats.max - stats.min;
146                if range.abs() < 1e-10 {
147                    0.0
148                } else {
149                    (x - stats.min) / range
150                }
151            }
152            NormalizationMethod::ZScore => {
153                if stats.std.abs() < 1e-10 {
154                    0.0
155                } else {
156                    (x - stats.mean) / stats.std
157                }
158            }
159            NormalizationMethod::Log => (x.abs() + 1.0).ln() * x.signum(),
160            NormalizationMethod::Robust => {
161                let iqr = stats.q3 - stats.q1;
162                if iqr.abs() < 1e-10 {
163                    0.0
164                } else {
165                    (x - stats.median) / iqr
166                }
167            }
168        }
169    }
170}
171
172/// Computes structural features for nodes.
173pub fn compute_structural_features(graph: &Graph) -> HashMap<NodeId, Vec<f64>> {
174    let mut features = HashMap::new();
175
176    for &node_id in graph.nodes.keys() {
177        let mut node_features = Vec::new();
178
179        // Degree features
180        let in_degree = graph.in_degree(node_id) as f64;
181        let out_degree = graph.out_degree(node_id) as f64;
182        let total_degree = in_degree + out_degree;
183
184        node_features.push(in_degree);
185        node_features.push(out_degree);
186        node_features.push(total_degree);
187
188        // Log degree (common in GNNs)
189        node_features.push((in_degree + 1.0).ln());
190        node_features.push((out_degree + 1.0).ln());
191
192        // Degree ratio
193        if total_degree > 0.0 {
194            node_features.push(in_degree / total_degree);
195            node_features.push(out_degree / total_degree);
196        } else {
197            node_features.push(0.5);
198            node_features.push(0.5);
199        }
200
201        // Weight-based features (sum of incident edge weights)
202        let in_weight: f64 = graph.incoming_edges(node_id).iter().map(|e| e.weight).sum();
203        let out_weight: f64 = graph.outgoing_edges(node_id).iter().map(|e| e.weight).sum();
204
205        node_features.push((in_weight + 1.0).ln());
206        node_features.push((out_weight + 1.0).ln());
207
208        // Average edge weight
209        if in_degree > 0.0 {
210            node_features.push(in_weight / in_degree);
211        } else {
212            node_features.push(0.0);
213        }
214        if out_degree > 0.0 {
215            node_features.push(out_weight / out_degree);
216        } else {
217            node_features.push(0.0);
218        }
219
220        // Local clustering coefficient (simplified)
221        let neighbors = graph.neighbors(node_id);
222        let k = neighbors.len();
223        if k > 1 {
224            let mut triangle_count = 0;
225            for i in 0..k {
226                for j in i + 1..k {
227                    if graph.neighbors(neighbors[i]).contains(&neighbors[j]) {
228                        triangle_count += 1;
229                    }
230                }
231            }
232            let max_triangles = k * (k - 1) / 2;
233            node_features.push(triangle_count as f64 / max_triangles as f64);
234        } else {
235            node_features.push(0.0);
236        }
237
238        features.insert(node_id, node_features);
239    }
240
241    features
242}
243
244/// Computes temporal features for edges.
245pub fn compute_temporal_features(date: NaiveDate) -> Vec<f64> {
246    let mut features = Vec::new();
247
248    // Day of week (0-6)
249    let weekday = date.weekday().num_days_from_monday() as f64;
250    features.push(weekday / 6.0);
251
252    // Day of month (1-31)
253    let day = date.day() as f64;
254    features.push(day / 31.0);
255
256    // Month (1-12)
257    let month = date.month() as f64;
258    features.push(month / 12.0);
259
260    // Quarter (1-4)
261    let quarter = ((month - 1.0) / 3.0).floor() + 1.0;
262    features.push(quarter / 4.0);
263
264    // Is weekend
265    features.push(if weekday >= 5.0 { 1.0 } else { 0.0 });
266
267    // Is month end (last 3 days)
268    features.push(if day >= 28.0 { 1.0 } else { 0.0 });
269
270    // Is quarter end
271    let is_quarter_end = matches!(month as u32, 3 | 6 | 9 | 12) && day >= 28.0;
272    features.push(if is_quarter_end { 1.0 } else { 0.0 });
273
274    // Is year end (December)
275    features.push(if month == 12.0 { 1.0 } else { 0.0 });
276
277    // Cyclical encoding for periodicity
278    let day_of_year = date.ordinal() as f64;
279    features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).sin());
280    features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).cos());
281
282    // Cyclical encoding for week
283    features.push((2.0 * std::f64::consts::PI * weekday / 7.0).sin());
284    features.push((2.0 * std::f64::consts::PI * weekday / 7.0).cos());
285
286    features
287}
288
289/// Computes Benford's law features for an amount.
290pub fn compute_benford_features(amount: f64) -> Vec<f64> {
291    let mut features = Vec::new();
292
293    // First digit
294    let first_digit = extract_first_digit(amount);
295    let benford_prob = benford_probability(first_digit);
296    features.push(benford_prob);
297
298    // Deviation from expected Benford distribution
299    let expected_benford = [
300        0.301, 0.176, 0.125, 0.097, 0.079, 0.067, 0.058, 0.051, 0.046,
301    ];
302    if (1..=9).contains(&first_digit) {
303        let deviation = (expected_benford[first_digit as usize - 1] - benford_prob).abs();
304        features.push(deviation);
305    } else {
306        features.push(0.0);
307    }
308
309    // First digit one-hot encoding
310    for d in 1..=9 {
311        features.push(if first_digit == d { 1.0 } else { 0.0 });
312    }
313
314    // Second digit (if available)
315    let second_digit = extract_second_digit(amount);
316    features.push(second_digit as f64 / 9.0);
317
318    features
319}
320
321/// Extracts the first significant digit.
322fn extract_first_digit(value: f64) -> u32 {
323    if value == 0.0 {
324        return 0;
325    }
326    let abs_val = value.abs();
327    let log10 = abs_val.log10().floor();
328    let normalized = abs_val / 10_f64.powf(log10);
329    normalized.floor() as u32
330}
331
332/// Extracts the second significant digit.
333fn extract_second_digit(value: f64) -> u32 {
334    if value == 0.0 {
335        return 0;
336    }
337    let abs_val = value.abs();
338    let log10 = abs_val.log10().floor();
339    let normalized = abs_val / 10_f64.powf(log10);
340    ((normalized - normalized.floor()) * 10.0).floor() as u32
341}
342
343/// Returns the expected Benford's law probability for a digit.
344fn benford_probability(digit: u32) -> f64 {
345    if digit == 0 || digit > 9 {
346        return 0.0;
347    }
348    (1.0 + 1.0 / digit as f64).log10()
349}
350
351/// Computes amount-based features.
352pub fn compute_amount_features(amount: Decimal) -> Vec<f64> {
353    let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
354    let mut features = Vec::new();
355
356    // Log amount
357    features.push((amount_f64.abs() + 1.0).ln());
358
359    // Sign
360    features.push(if amount_f64 >= 0.0 { 1.0 } else { 0.0 });
361
362    // Is round number
363    let is_round = (amount_f64 % 100.0).abs() < 0.01;
364    features.push(if is_round { 1.0 } else { 0.0 });
365
366    // Magnitude bucket
367    let magnitude = if amount_f64.abs() < 1.0 {
368        0
369    } else {
370        (amount_f64.abs().log10().floor() as i32).clamp(0, 9)
371    };
372    for m in 0..10 {
373        features.push(if magnitude == m { 1.0 } else { 0.0 });
374    }
375
376    // Benford features
377    features.extend(compute_benford_features(amount_f64));
378
379    features
380}
381
382/// One-hot encodes a categorical value.
383pub fn one_hot_encode(value: &str, categories: &[&str]) -> Vec<f64> {
384    let mut encoding = vec![0.0; categories.len()];
385    if let Some(idx) = categories.iter().position(|&c| c == value) {
386        encoding[idx] = 1.0;
387    }
388    encoding
389}
390
391/// Label encodes a categorical value.
392pub fn label_encode(value: &str, categories: &[&str]) -> f64 {
393    categories
394        .iter()
395        .position(|&c| c == value)
396        .map(|i| i as f64)
397        .unwrap_or(-1.0)
398}
399
400/// Creates positional encoding for graph nodes (similar to transformer positional encoding).
401pub fn positional_encoding(position: usize, d_model: usize) -> Vec<f64> {
402    let mut encoding = Vec::with_capacity(d_model);
403
404    for i in 0..d_model {
405        let angle = position as f64 / 10000_f64.powf(2.0 * (i / 2) as f64 / d_model as f64);
406        if i % 2 == 0 {
407            encoding.push(angle.sin());
408        } else {
409            encoding.push(angle.cos());
410        }
411    }
412
413    encoding
414}
415
416/// Computes edge direction features for directed graphs.
417pub fn compute_edge_direction_features(
418    source_features: &[f64],
419    target_features: &[f64],
420) -> Vec<f64> {
421    let mut features = Vec::new();
422
423    // Feature differences
424    for (s, t) in source_features.iter().zip(target_features.iter()) {
425        features.push(t - s); // Direction: source -> target
426    }
427
428    // Absolute differences
429    for (s, t) in source_features.iter().zip(target_features.iter()) {
430        features.push((t - s).abs());
431    }
432
433    // Hadamard product
434    for (s, t) in source_features.iter().zip(target_features.iter()) {
435        features.push(s * t);
436    }
437
438    // Concatenation indicator (which node is "larger")
439    let source_sum: f64 = source_features.iter().sum();
440    let target_sum: f64 = target_features.iter().sum();
441    features.push(if source_sum > target_sum { 1.0 } else { 0.0 });
442
443    features
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_benford_probability() {
452        let prob1 = benford_probability(1);
453        assert!((prob1 - 0.301).abs() < 0.001);
454
455        let prob9 = benford_probability(9);
456        assert!((prob9 - 0.046).abs() < 0.001);
457    }
458
459    #[test]
460    fn test_extract_first_digit() {
461        assert_eq!(extract_first_digit(1234.56), 1);
462        assert_eq!(extract_first_digit(9876.54), 9);
463        assert_eq!(extract_first_digit(0.0123), 1);
464    }
465
466    #[test]
467    fn test_temporal_features() {
468        let date = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
469        let features = compute_temporal_features(date);
470
471        assert!(!features.is_empty());
472        // Should indicate year end
473        assert!(features[7] > 0.5); // is_year_end
474    }
475
476    #[test]
477    fn test_normalization() {
478        let features = vec![vec![1.0, 100.0], vec![2.0, 200.0], vec![3.0, 300.0]];
479
480        let mut normalizer = FeatureNormalizer::new(NormalizationMethod::MinMax);
481        normalizer.fit(&features);
482
483        let normalized = normalizer.transform(&features);
484        assert_eq!(normalized[0][0], 0.0); // Min
485        assert_eq!(normalized[2][0], 1.0); // Max
486    }
487
488    #[test]
489    fn test_one_hot_encode() {
490        let categories = ["A", "B", "C"];
491        let encoded = one_hot_encode("B", &categories);
492        assert_eq!(encoded, vec![0.0, 1.0, 0.0]);
493    }
494
495    #[test]
496    fn test_positional_encoding() {
497        let encoding = positional_encoding(0, 8);
498        assert_eq!(encoding.len(), 8);
499        assert_eq!(encoding[0], 0.0); // sin(0) = 0
500    }
501}