Skip to main content

datasynth_graph/ml/
features.rs

1//! Feature engineering utilities for graph neural networks.
2
3use std::collections::HashMap;
4
5use chrono::{Datelike, NaiveDate};
6use rust_decimal::Decimal;
7
8use crate::models::{Graph, NodeId};
9
10/// Feature normalization method.
11#[derive(Debug, Clone)]
12pub enum NormalizationMethod {
13    /// No normalization.
14    None,
15    /// Min-max normalization to [0, 1].
16    MinMax,
17    /// Z-score standardization.
18    ZScore,
19    /// Log transformation (log1p).
20    Log,
21    /// Robust scaling (using median and IQR).
22    Robust,
23}
24
25/// Feature normalizer for graph features.
26pub struct FeatureNormalizer {
27    method: NormalizationMethod,
28    /// Statistics per feature dimension.
29    stats: Vec<FeatureStats>,
30}
31
32/// Statistics for a single feature dimension.
33#[derive(Debug, Clone, Default)]
34struct FeatureStats {
35    min: f64,
36    max: f64,
37    mean: f64,
38    std: f64,
39    median: f64,
40    q1: f64,
41    q3: f64,
42}
43
44impl FeatureNormalizer {
45    /// Creates a new feature normalizer.
46    pub fn new(method: NormalizationMethod) -> Self {
47        Self {
48            method,
49            stats: Vec::new(),
50        }
51    }
52
53    /// Fits the normalizer to node features.
54    pub fn fit_nodes(&mut self, graph: &Graph) {
55        let features = graph.node_features();
56        self.fit(&features);
57    }
58
59    /// Fits the normalizer to edge features.
60    pub fn fit_edges(&mut self, graph: &Graph) {
61        let features = graph.edge_features();
62        self.fit(&features);
63    }
64
65    /// Fits the normalizer to features.
66    fn fit(&mut self, features: &[Vec<f64>]) {
67        if features.is_empty() {
68            return;
69        }
70
71        let dim = features[0].len();
72        self.stats = (0..dim)
73            .map(|d| {
74                let values: Vec<f64> = features
75                    .iter()
76                    .map(|f| f.get(d).copied().unwrap_or(0.0))
77                    .collect();
78                Self::compute_stats(&values)
79            })
80            .collect();
81    }
82
83    /// Computes statistics for a feature dimension.
84    fn compute_stats(values: &[f64]) -> FeatureStats {
85        if values.is_empty() {
86            return FeatureStats::default();
87        }
88
89        let n = values.len() as f64;
90        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
91        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
92        let sum: f64 = values.iter().sum();
93        let mean = sum / n;
94        let variance: f64 = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
95        let std = variance.sqrt();
96
97        // Compute quartiles
98        let mut sorted = values.to_vec();
99        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100
101        let median = if sorted.len().is_multiple_of(2) {
102            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
103        } else {
104            sorted[sorted.len() / 2]
105        };
106
107        let q1_idx = sorted.len() / 4;
108        let q3_idx = 3 * sorted.len() / 4;
109        let q1 = sorted.get(q1_idx).copied().unwrap_or(min);
110        let q3 = sorted.get(q3_idx).copied().unwrap_or(max);
111
112        FeatureStats {
113            min,
114            max,
115            mean,
116            std,
117            median,
118            q1,
119            q3,
120        }
121    }
122
123    /// Transforms features using fitted statistics.
124    pub fn transform(&self, features: &[Vec<f64>]) -> Vec<Vec<f64>> {
125        features.iter().map(|f| self.transform_single(f)).collect()
126    }
127
128    /// Transforms a single feature vector.
129    fn transform_single(&self, features: &[f64]) -> Vec<f64> {
130        features
131            .iter()
132            .enumerate()
133            .map(|(i, &x)| {
134                let stats = self.stats.get(i).cloned().unwrap_or_default();
135                self.normalize_value(x, &stats)
136            })
137            .collect()
138    }
139
140    /// Normalizes a single value.
141    fn normalize_value(&self, x: f64, stats: &FeatureStats) -> f64 {
142        match self.method {
143            NormalizationMethod::None => x,
144            NormalizationMethod::MinMax => {
145                let range = stats.max - stats.min;
146                if range.abs() < 1e-10 {
147                    0.0
148                } else {
149                    (x - stats.min) / range
150                }
151            }
152            NormalizationMethod::ZScore => {
153                if stats.std.abs() < 1e-10 {
154                    0.0
155                } else {
156                    (x - stats.mean) / stats.std
157                }
158            }
159            NormalizationMethod::Log => (x.abs() + 1.0).ln() * x.signum(),
160            NormalizationMethod::Robust => {
161                let iqr = stats.q3 - stats.q1;
162                if iqr.abs() < 1e-10 {
163                    0.0
164                } else {
165                    (x - stats.median) / iqr
166                }
167            }
168        }
169    }
170}
171
172/// Computes structural features for nodes.
173pub fn compute_structural_features(graph: &Graph) -> HashMap<NodeId, Vec<f64>> {
174    let mut features = HashMap::new();
175
176    for &node_id in graph.nodes.keys() {
177        let mut node_features = Vec::new();
178
179        // Degree features
180        let in_degree = graph.in_degree(node_id) as f64;
181        let out_degree = graph.out_degree(node_id) as f64;
182        let total_degree = in_degree + out_degree;
183
184        node_features.push(in_degree);
185        node_features.push(out_degree);
186        node_features.push(total_degree);
187
188        // Log degree (common in GNNs)
189        node_features.push((in_degree + 1.0).ln());
190        node_features.push((out_degree + 1.0).ln());
191
192        // Degree ratio
193        if total_degree > 0.0 {
194            node_features.push(in_degree / total_degree);
195            node_features.push(out_degree / total_degree);
196        } else {
197            node_features.push(0.5);
198            node_features.push(0.5);
199        }
200
201        // Weight-based features (sum of incident edge weights)
202        let in_weight: f64 = graph.incoming_edges(node_id).iter().map(|e| e.weight).sum();
203        let out_weight: f64 = graph.outgoing_edges(node_id).iter().map(|e| e.weight).sum();
204
205        node_features.push((in_weight + 1.0).ln());
206        node_features.push((out_weight + 1.0).ln());
207
208        // Average edge weight
209        if in_degree > 0.0 {
210            node_features.push(in_weight / in_degree);
211        } else {
212            node_features.push(0.0);
213        }
214        if out_degree > 0.0 {
215            node_features.push(out_weight / out_degree);
216        } else {
217            node_features.push(0.0);
218        }
219
220        // Local clustering coefficient (simplified)
221        let neighbors = graph.neighbors(node_id);
222        let k = neighbors.len();
223        if k > 1 {
224            let mut triangle_count = 0;
225            for i in 0..k {
226                for j in i + 1..k {
227                    if graph.neighbors(neighbors[i]).contains(&neighbors[j]) {
228                        triangle_count += 1;
229                    }
230                }
231            }
232            let max_triangles = k * (k - 1) / 2;
233            node_features.push(triangle_count as f64 / max_triangles as f64);
234        } else {
235            node_features.push(0.0);
236        }
237
238        features.insert(node_id, node_features);
239    }
240
241    features
242}
243
244/// Computes temporal features for edges.
245pub fn compute_temporal_features(date: NaiveDate) -> Vec<f64> {
246    let mut features = Vec::new();
247
248    // Day of week (0-6)
249    let weekday = date.weekday().num_days_from_monday() as f64;
250    features.push(weekday / 6.0);
251
252    // Day of month (1-31)
253    let day = date.day() as f64;
254    features.push(day / 31.0);
255
256    // Month (1-12)
257    let month = date.month() as f64;
258    features.push(month / 12.0);
259
260    // Quarter (1-4)
261    let quarter = ((month - 1.0) / 3.0).floor() + 1.0;
262    features.push(quarter / 4.0);
263
264    // Is weekend
265    features.push(if weekday >= 5.0 { 1.0 } else { 0.0 });
266
267    // Is month end (last 3 days)
268    features.push(if day >= 28.0 { 1.0 } else { 0.0 });
269
270    // Is quarter end
271    let is_quarter_end = matches!(month as u32, 3 | 6 | 9 | 12) && day >= 28.0;
272    features.push(if is_quarter_end { 1.0 } else { 0.0 });
273
274    // Is year end (December)
275    features.push(if month == 12.0 { 1.0 } else { 0.0 });
276
277    // Cyclical encoding for periodicity
278    let day_of_year = date.ordinal() as f64;
279    features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).sin());
280    features.push((2.0 * std::f64::consts::PI * day_of_year / 365.0).cos());
281
282    // Cyclical encoding for week
283    features.push((2.0 * std::f64::consts::PI * weekday / 7.0).sin());
284    features.push((2.0 * std::f64::consts::PI * weekday / 7.0).cos());
285
286    features
287}
288
289/// Computes Benford's law features for an amount.
290pub fn compute_benford_features(amount: f64) -> Vec<f64> {
291    let mut features = Vec::new();
292
293    // First digit
294    let first_digit = extract_first_digit(amount);
295    let benford_prob = benford_probability(first_digit);
296    features.push(benford_prob);
297
298    // Deviation from expected Benford distribution
299    let expected_benford = [
300        0.301, 0.176, 0.125, 0.097, 0.079, 0.067, 0.058, 0.051, 0.046,
301    ];
302    if (1..=9).contains(&first_digit) {
303        let deviation = (expected_benford[first_digit as usize - 1] - benford_prob).abs();
304        features.push(deviation);
305    } else {
306        features.push(0.0);
307    }
308
309    // First digit one-hot encoding
310    for d in 1..=9 {
311        features.push(if first_digit == d { 1.0 } else { 0.0 });
312    }
313
314    // Second digit (if available)
315    let second_digit = extract_second_digit(amount);
316    features.push(second_digit as f64 / 9.0);
317
318    features
319}
320
321/// Extracts the first significant digit.
322fn extract_first_digit(value: f64) -> u32 {
323    if value == 0.0 {
324        return 0;
325    }
326    let abs_val = value.abs();
327    let log10 = abs_val.log10().floor();
328    let normalized = abs_val / 10_f64.powf(log10);
329    normalized.floor() as u32
330}
331
332/// Extracts the second significant digit.
333fn extract_second_digit(value: f64) -> u32 {
334    if value == 0.0 {
335        return 0;
336    }
337    let abs_val = value.abs();
338    let log10 = abs_val.log10().floor();
339    let normalized = abs_val / 10_f64.powf(log10);
340    ((normalized - normalized.floor()) * 10.0).floor() as u32
341}
342
343/// Returns the expected Benford's law probability for a digit.
344fn benford_probability(digit: u32) -> f64 {
345    if digit == 0 || digit > 9 {
346        return 0.0;
347    }
348    (1.0 + 1.0 / digit as f64).log10()
349}
350
351/// Computes amount-based features.
352pub fn compute_amount_features(amount: Decimal) -> Vec<f64> {
353    let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
354    let mut features = Vec::new();
355
356    // Log amount
357    features.push((amount_f64.abs() + 1.0).ln());
358
359    // Sign
360    features.push(if amount_f64 >= 0.0 { 1.0 } else { 0.0 });
361
362    // Is round number
363    let is_round = (amount_f64 % 100.0).abs() < 0.01;
364    features.push(if is_round { 1.0 } else { 0.0 });
365
366    // Magnitude bucket
367    let magnitude = if amount_f64.abs() < 1.0 {
368        0
369    } else {
370        (amount_f64.abs().log10().floor() as i32).clamp(0, 9)
371    };
372    for m in 0..10 {
373        features.push(if magnitude == m { 1.0 } else { 0.0 });
374    }
375
376    // Benford features
377    features.extend(compute_benford_features(amount_f64));
378
379    features
380}
381
382/// One-hot encodes a categorical value.
383pub fn one_hot_encode(value: &str, categories: &[&str]) -> Vec<f64> {
384    let mut encoding = vec![0.0; categories.len()];
385    if let Some(idx) = categories.iter().position(|&c| c == value) {
386        encoding[idx] = 1.0;
387    }
388    encoding
389}
390
391/// Label encodes a categorical value.
392pub fn label_encode(value: &str, categories: &[&str]) -> f64 {
393    categories
394        .iter()
395        .position(|&c| c == value)
396        .map(|i| i as f64)
397        .unwrap_or(-1.0)
398}
399
400/// Creates positional encoding for graph nodes (similar to transformer positional encoding).
401pub fn positional_encoding(position: usize, d_model: usize) -> Vec<f64> {
402    let mut encoding = Vec::with_capacity(d_model);
403
404    for i in 0..d_model {
405        let angle = position as f64 / 10000_f64.powf(2.0 * (i / 2) as f64 / d_model as f64);
406        if i % 2 == 0 {
407            encoding.push(angle.sin());
408        } else {
409            encoding.push(angle.cos());
410        }
411    }
412
413    encoding
414}
415
416/// Computes edge direction features for directed graphs.
417pub fn compute_edge_direction_features(
418    source_features: &[f64],
419    target_features: &[f64],
420) -> Vec<f64> {
421    let mut features = Vec::new();
422
423    // Feature differences
424    for (s, t) in source_features.iter().zip(target_features.iter()) {
425        features.push(t - s); // Direction: source -> target
426    }
427
428    // Absolute differences
429    for (s, t) in source_features.iter().zip(target_features.iter()) {
430        features.push((t - s).abs());
431    }
432
433    // Hadamard product
434    for (s, t) in source_features.iter().zip(target_features.iter()) {
435        features.push(s * t);
436    }
437
438    // Concatenation indicator (which node is "larger")
439    let source_sum: f64 = source_features.iter().sum();
440    let target_sum: f64 = target_features.iter().sum();
441    features.push(if source_sum > target_sum { 1.0 } else { 0.0 });
442
443    features
444}
445
446/// Compute velocity features for a node — number of edges (and summed weights)
447/// incident to `node_id` within each of the rolling windows specified by
448/// `window_days`.
449///
450/// # Parameters
451/// - `node_id`: the graph node whose activity is being measured.
452/// - `edges`: a slice of `(source, target, timestamp_days_since_epoch)` tuples.
453///   The timestamp is a Unix-epoch day count (`days = unix_seconds / 86400`).
454/// - `window_days`: rolling-window lengths in days, e.g. `[7, 30, 90]`.
455/// - `reference_day`: the reference day (e.g. today as `chrono::Utc::now().num_days_from_ce()`).
456///
457/// # Returns
458/// A `Vec<f64>` with `2 * window_days.len()` features:
459/// - For each window: `[edge_count, weight_sum, ...]`
460///
461/// # Example
462/// ```
463/// use datasynth_graph::ml::compute_velocity_features;
464/// let edges = vec![(1usize, 2usize, 1000.0f64), (1, 3, 998.0)];
465/// let features = compute_velocity_features(1, &edges, &[7, 30], 1005.0);
466/// assert_eq!(features.len(), 4); // 2 windows × 2 metrics each
467/// ```
468pub fn compute_velocity_features(
469    node_id: usize,
470    edges: &[(usize, usize, f64)], // (source, target, timestamp_as_day)
471    window_days: &[u32],
472    reference_day: f64,
473) -> Vec<f64> {
474    let mut features = Vec::with_capacity(window_days.len() * 2);
475
476    for &window in window_days {
477        let cutoff = reference_day - f64::from(window);
478        let mut count = 0u64;
479        let mut weight_sum = 0.0_f64;
480
481        for &(src, tgt, ts) in edges {
482            if (src == node_id || tgt == node_id) && ts >= cutoff {
483                count += 1;
484                weight_sum += 1.0; // each edge counts as weight 1.0
485            }
486        }
487
488        features.push(count as f64);
489        features.push(weight_sum);
490    }
491
492    features
493}
494
495/// Simple iterative PageRank computation.
496///
497/// # Parameters
498/// - `adjacency`: outgoing adjacency list — `adjacency[i]` contains the indices of
499///   nodes that node `i` links to.
500/// - `damping`: damping factor (typically 0.85).
501/// - `iterations`: number of power-iteration steps.
502///
503/// # Returns
504/// A `Vec<f64>` of length `adjacency.len()` with the PageRank score for each node.
505/// Scores are normalised so that they sum to 1.0.
506pub fn pagerank(adjacency: &[Vec<usize>], damping: f64, iterations: usize) -> Vec<f64> {
507    let n = adjacency.len();
508    if n == 0 {
509        return Vec::new();
510    }
511
512    let init = 1.0 / n as f64;
513    let mut rank = vec![init; n];
514
515    for _ in 0..iterations {
516        let mut new_rank = vec![(1.0 - damping) / n as f64; n];
517
518        for (src, targets) in adjacency.iter().enumerate() {
519            if targets.is_empty() {
520                // Dangling node: distribute rank evenly across all nodes
521                let share = damping * rank[src] / n as f64;
522                for r in new_rank.iter_mut() {
523                    *r += share;
524                }
525            } else {
526                let share = damping * rank[src] / targets.len() as f64;
527                for &tgt in targets {
528                    if tgt < n {
529                        new_rank[tgt] += share;
530                    }
531                }
532            }
533        }
534
535        rank = new_rank;
536    }
537
538    // Normalise so scores sum to 1.0
539    let total: f64 = rank.iter().sum();
540    if total > 0.0 {
541        rank.iter_mut().for_each(|r| *r /= total);
542    }
543
544    rank
545}
546
547/// Degree centrality (normalised).
548///
549/// Returns `degree(v) / (n - 1)` for each node, where degree is the
550/// total (in + out) degree derived from the adjacency list.
551/// The result is in `[0, 1]` for graphs with `n >= 2`.
552///
553/// # Parameters
554/// - `adjacency`: outgoing adjacency list — `adjacency[i]` contains the indices of
555///   nodes that node `i` links to.
556pub fn degree_centrality(adjacency: &[Vec<usize>]) -> Vec<f64> {
557    let n = adjacency.len();
558    if n == 0 {
559        return Vec::new();
560    }
561
562    // Compute out-degree from adjacency and in-degree from reverse scan.
563    let mut degree = vec![0usize; n];
564
565    for (src, targets) in adjacency.iter().enumerate() {
566        degree[src] += targets.len(); // out-degree contribution
567        for &tgt in targets {
568            if tgt < n {
569                degree[tgt] += 1; // in-degree contribution
570            }
571        }
572    }
573
574    let normalizer = if n > 1 { (n - 1) as f64 } else { 1.0 };
575    degree.iter().map(|&d| (d as f64) / normalizer).collect()
576}
577
578#[cfg(test)]
579#[allow(clippy::unwrap_used)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_benford_probability() {
585        let prob1 = benford_probability(1);
586        assert!((prob1 - 0.301).abs() < 0.001);
587
588        let prob9 = benford_probability(9);
589        assert!((prob9 - 0.046).abs() < 0.001);
590    }
591
592    #[test]
593    fn test_extract_first_digit() {
594        assert_eq!(extract_first_digit(1234.56), 1);
595        assert_eq!(extract_first_digit(9876.54), 9);
596        assert_eq!(extract_first_digit(0.0123), 1);
597    }
598
599    #[test]
600    fn test_temporal_features() {
601        let date = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
602        let features = compute_temporal_features(date);
603
604        assert!(!features.is_empty());
605        // Should indicate year end
606        assert!(features[7] > 0.5); // is_year_end
607    }
608
609    #[test]
610    fn test_normalization() {
611        let features = vec![vec![1.0, 100.0], vec![2.0, 200.0], vec![3.0, 300.0]];
612
613        let mut normalizer = FeatureNormalizer::new(NormalizationMethod::MinMax);
614        normalizer.fit(&features);
615
616        let normalized = normalizer.transform(&features);
617        assert_eq!(normalized[0][0], 0.0); // Min
618        assert_eq!(normalized[2][0], 1.0); // Max
619    }
620
621    #[test]
622    fn test_one_hot_encode() {
623        let categories = ["A", "B", "C"];
624        let encoded = one_hot_encode("B", &categories);
625        assert_eq!(encoded, vec![0.0, 1.0, 0.0]);
626    }
627
628    #[test]
629    fn test_positional_encoding() {
630        let encoding = positional_encoding(0, 8);
631        assert_eq!(encoding.len(), 8);
632        assert_eq!(encoding[0], 0.0); // sin(0) = 0
633    }
634
635    #[test]
636    fn test_velocity_features_length() {
637        // node 1 has edges at days 995, 998, 1000 (within window 7 of day 1002)
638        let edges = vec![(1, 2, 995.0), (1, 3, 998.0), (1, 4, 1000.0)];
639        let features = compute_velocity_features(1, &edges, &[7, 30, 90], 1002.0);
640        assert_eq!(features.len(), 6, "2 metrics per window × 3 windows");
641    }
642
643    #[test]
644    fn test_velocity_features_windowing() {
645        // Day 1000 reference; edges at days 994 (within 7), 960 (within 30+), 900 (within 90+)
646        let edges = vec![(1, 2, 994.0), (1, 3, 960.0), (1, 4, 900.0)];
647        let features = compute_velocity_features(1, &edges, &[7, 30, 90], 1000.0);
648        // window 7: only edge at 994 qualifies (1000 - 7 = 993; 994 >= 993 ✓)
649        assert_eq!(features[0], 1.0, "7-day count");
650        // window 30: 994 and 960 qualify (1000 - 30 = 970; 994 ≥ 970 ✓, 960 < 970 ✗)
651        assert_eq!(features[2], 1.0, "30-day count");
652        // window 90: 994 and 960 qualify (1000 - 90 = 910; 900 < 910 ✗)
653        assert_eq!(features[4], 2.0, "90-day count");
654    }
655
656    #[test]
657    fn test_pagerank_basic() {
658        // Simple 3-node graph: 0→1, 1→2, 2→0 (cycle)
659        let adjacency = vec![vec![1], vec![2], vec![0]];
660        let pr = pagerank(&adjacency, 0.85, 50);
661        assert_eq!(pr.len(), 3);
662        // In a balanced cycle, all nodes should have equal rank (~0.333)
663        for &r in &pr {
664            assert!((r - 1.0 / 3.0).abs() < 0.01, "Expected ~0.333 but got {r}");
665        }
666        // Scores must sum to 1.0
667        let total: f64 = pr.iter().sum();
668        assert!((total - 1.0).abs() < 1e-9, "PageRank must sum to 1.0");
669    }
670
671    #[test]
672    fn test_pagerank_empty() {
673        let pr = pagerank(&[], 0.85, 10);
674        assert!(pr.is_empty());
675    }
676
677    #[test]
678    fn test_degree_centrality_basic() {
679        // Star graph: node 0 connects to nodes 1, 2, 3
680        let adjacency = vec![vec![1, 2, 3], vec![], vec![], vec![]];
681        let dc = degree_centrality(&adjacency);
682        assert_eq!(dc.len(), 4);
683        // Node 0: out-degree 3, max possible = n-1 = 3, so centrality = 3/3 = 1.0
684        assert!((dc[0] - 1.0).abs() < 1e-9, "Hub should have centrality 1.0");
685        // Leaf nodes: in-degree 1, centrality = 1/3
686        for &c in &dc[1..] {
687            assert!(
688                (c - 1.0 / 3.0).abs() < 1e-9,
689                "Leaf centrality should be ~0.333"
690            );
691        }
692    }
693
694    #[test]
695    fn test_degree_centrality_empty() {
696        let dc = degree_centrality(&[]);
697        assert!(dc.is_empty());
698    }
699}