Skip to main content

datasynth_graph/ml/
temporal.rs

1//! Temporal sequence feature computation for graph nodes.
2//!
3//! This module provides temporal analysis features including:
4//! - Transaction velocity (amount per time)
5//! - Inter-event interval statistics
6//! - Burst detection using Kleinberg-style counting
7//! - Trend analysis via linear regression
8//! - Seasonality scoring via weekly pattern variance
9//! - Per-window aggregation features
10
11use std::collections::HashMap;
12
13use chrono::{Datelike, NaiveDate};
14
15use crate::models::{EdgeId, Graph, NodeId};
16
17/// Configuration for temporal feature computation.
18#[derive(Debug, Clone)]
19pub struct TemporalConfig {
20    /// Window sizes in days for aggregation (e.g., [7, 30, 90]).
21    pub window_sizes: Vec<i64>,
22    /// Reference date for computing recency. If None, uses max date in data.
23    pub reference_date: Option<NaiveDate>,
24    /// Minimum number of edges for a node to have temporal features computed.
25    pub min_edge_count: usize,
26    /// Threshold multiplier for burst detection (events > threshold * mean = burst).
27    pub burst_threshold: f64,
28}
29
30impl Default for TemporalConfig {
31    fn default() -> Self {
32        Self {
33            window_sizes: vec![7, 30, 90],
34            reference_date: None,
35            min_edge_count: 2,
36            burst_threshold: 3.0,
37        }
38    }
39}
40
41/// Aggregated features for a specific time window.
42#[derive(Debug, Clone, Default)]
43pub struct WindowFeatures {
44    /// Number of events in the window.
45    pub event_count: usize,
46    /// Total amount (sum of edge weights) in the window.
47    pub total_amount: f64,
48    /// Average amount per event.
49    pub avg_amount: f64,
50    /// Maximum amount in the window.
51    pub max_amount: f64,
52    /// Number of unique counterparties in the window.
53    pub unique_counterparties: usize,
54}
55
56impl WindowFeatures {
57    /// Converts window features to a feature vector.
58    pub fn to_features(&self) -> Vec<f64> {
59        vec![
60            self.event_count as f64,
61            (self.total_amount + 1.0).ln(),
62            (self.avg_amount + 1.0).ln(),
63            (self.max_amount + 1.0).ln(),
64            self.unique_counterparties as f64,
65        ]
66    }
67}
68
69/// Temporal sequence features for a single node.
70#[derive(Debug, Clone, Default)]
71pub struct TemporalFeatures {
72    /// Transaction velocity: total amount / time span in days.
73    pub transaction_velocity: f64,
74    /// Mean inter-event interval in days.
75    pub inter_event_interval_mean: f64,
76    /// Standard deviation of inter-event intervals.
77    pub inter_event_interval_std: f64,
78    /// Burst score: max daily event count / mean daily event count.
79    pub burst_score: f64,
80    /// Trend direction: +1.0 (increasing), -1.0 (decreasing), 0.0 (stable).
81    pub trend_direction: f64,
82    /// Seasonality score: variance of weekday activity normalized.
83    pub seasonality_score: f64,
84    /// Days since last event (recency).
85    pub recency_days: f64,
86    /// Per-window aggregated features.
87    pub window_features: HashMap<i64, WindowFeatures>,
88}
89
90impl TemporalFeatures {
91    /// Converts temporal features to a flat feature vector.
92    /// Returns base features + window features for each configured window.
93    pub fn to_features(&self, window_sizes: &[i64]) -> Vec<f64> {
94        let mut features = vec![
95            (self.transaction_velocity + 1.0).ln(),
96            self.inter_event_interval_mean,
97            self.inter_event_interval_std,
98            self.burst_score,
99            self.trend_direction,
100            self.seasonality_score,
101            self.recency_days / 365.0, // Normalize to ~[0, 1] for yearly data
102        ];
103
104        // Add window features in order
105        for &window in window_sizes {
106            if let Some(wf) = self.window_features.get(&window) {
107                features.extend(wf.to_features());
108            } else {
109                // Default values if window not present
110                features.extend(vec![0.0; 5]);
111            }
112        }
113
114        features
115    }
116
117    /// Returns the number of features in the output vector.
118    pub fn feature_count(window_count: usize) -> usize {
119        7 + (5 * window_count) // 7 base features + 5 per window
120    }
121}
122
123/// Index for efficient temporal queries on graph edges.
124#[derive(Debug, Clone)]
125pub struct TemporalIndex {
126    /// For each node, sorted list of (date, edge_id) pairs.
127    node_edges_by_date: HashMap<NodeId, Vec<(NaiveDate, EdgeId)>>,
128    /// Minimum date in the index.
129    pub min_date: Option<NaiveDate>,
130    /// Maximum date in the index.
131    pub max_date: Option<NaiveDate>,
132}
133
134impl TemporalIndex {
135    /// Builds a temporal index from a graph.
136    /// Complexity: O(E log E) for sorting edges by date.
137    pub fn build(graph: &Graph) -> Self {
138        let mut node_edges: HashMap<NodeId, Vec<(NaiveDate, EdgeId)>> = HashMap::new();
139        let mut min_date: Option<NaiveDate> = None;
140        let mut max_date: Option<NaiveDate> = None;
141
142        // Collect edges with timestamps
143        for (&edge_id, edge) in &graph.edges {
144            if let Some(date) = edge.timestamp {
145                // Update global date range
146                min_date = Some(min_date.map_or(date, |d| d.min(date)));
147                max_date = Some(max_date.map_or(date, |d| d.max(date)));
148
149                // Add to source and target node indices
150                node_edges
151                    .entry(edge.source)
152                    .or_default()
153                    .push((date, edge_id));
154                node_edges
155                    .entry(edge.target)
156                    .or_default()
157                    .push((date, edge_id));
158            }
159        }
160
161        // Sort edges by date for each node
162        for edges in node_edges.values_mut() {
163            edges.sort_by_key(|(date, _)| *date);
164        }
165
166        Self {
167            node_edges_by_date: node_edges,
168            min_date,
169            max_date,
170        }
171    }
172
173    /// Returns edges for a node within a date range (inclusive).
174    pub fn edges_in_range(
175        &self,
176        node_id: NodeId,
177        start: NaiveDate,
178        end: NaiveDate,
179    ) -> Vec<(NaiveDate, EdgeId)> {
180        if let Some(edges) = self.node_edges_by_date.get(&node_id) {
181            // Binary search for start position
182            let start_idx = edges.partition_point(|(d, _)| *d < start);
183            // Binary search for end position
184            let end_idx = edges.partition_point(|(d, _)| *d <= end);
185
186            edges[start_idx..end_idx].to_vec()
187        } else {
188            Vec::new()
189        }
190    }
191
192    /// Returns all edges for a node, sorted by date.
193    pub fn edges_for_node(&self, node_id: NodeId) -> &[(NaiveDate, EdgeId)] {
194        self.node_edges_by_date
195            .get(&node_id)
196            .map(|v| v.as_slice())
197            .unwrap_or(&[])
198    }
199
200    /// Returns the number of nodes with temporal data.
201    pub fn node_count(&self) -> usize {
202        self.node_edges_by_date.len()
203    }
204}
205
206/// Computes temporal sequence features for a single node.
207pub fn compute_temporal_sequence_features(
208    node_id: NodeId,
209    graph: &Graph,
210    index: &TemporalIndex,
211    config: &TemporalConfig,
212) -> TemporalFeatures {
213    let edges = index.edges_for_node(node_id);
214
215    // Return default if insufficient data
216    if edges.len() < config.min_edge_count {
217        return TemporalFeatures::default();
218    }
219
220    let reference_date = config
221        .reference_date
222        .or(index.max_date)
223        .unwrap_or_else(|| NaiveDate::from_ymd_opt(2024, 1, 1).unwrap());
224
225    // Compute inter-event intervals
226    let (interval_mean, interval_std) = compute_inter_event_intervals(edges);
227
228    // Compute transaction velocity
229    let transaction_velocity = compute_transaction_velocity(edges, graph);
230
231    // Compute burst score
232    let burst_score = compute_burst_score(edges, config.burst_threshold);
233
234    // Compute trend direction
235    let trend_direction = compute_trend_direction(edges, graph);
236
237    // Compute seasonality score
238    let seasonality_score = compute_seasonality_score(edges);
239
240    // Compute recency
241    let recency_days = if let Some((last_date, _)) = edges.last() {
242        (reference_date - *last_date).num_days().max(0) as f64
243    } else {
244        f64::MAX
245    };
246
247    // Compute window features
248    let mut window_features = HashMap::new();
249    for &window in &config.window_sizes {
250        let wf = compute_window_features(node_id, graph, index, reference_date, window);
251        window_features.insert(window, wf);
252    }
253
254    TemporalFeatures {
255        transaction_velocity,
256        inter_event_interval_mean: interval_mean,
257        inter_event_interval_std: interval_std,
258        burst_score,
259        trend_direction,
260        seasonality_score,
261        recency_days,
262        window_features,
263    }
264}
265
266/// Computes temporal features for all nodes in the graph.
267pub fn compute_all_temporal_features(
268    graph: &Graph,
269    config: &TemporalConfig,
270) -> HashMap<NodeId, TemporalFeatures> {
271    let index = TemporalIndex::build(graph);
272    let mut features = HashMap::new();
273
274    for &node_id in graph.nodes.keys() {
275        let node_features = compute_temporal_sequence_features(node_id, graph, &index, config);
276        features.insert(node_id, node_features);
277    }
278
279    features
280}
281
282/// Computes mean and standard deviation of inter-event intervals.
283fn compute_inter_event_intervals(edges: &[(NaiveDate, EdgeId)]) -> (f64, f64) {
284    if edges.len() < 2 {
285        return (0.0, 0.0);
286    }
287
288    let intervals: Vec<f64> = edges
289        .windows(2)
290        .map(|w| (w[1].0 - w[0].0).num_days() as f64)
291        .collect();
292
293    let n = intervals.len() as f64;
294    let mean = intervals.iter().sum::<f64>() / n;
295    let variance = intervals.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
296    let std = variance.sqrt();
297
298    (mean, std)
299}
300
301/// Computes transaction velocity: total amount / time span.
302fn compute_transaction_velocity(edges: &[(NaiveDate, EdgeId)], graph: &Graph) -> f64 {
303    if edges.len() < 2 {
304        return 0.0;
305    }
306
307    let first_date = edges.first().map(|(d, _)| *d);
308    let last_date = edges.last().map(|(d, _)| *d);
309
310    match (first_date, last_date) {
311        (Some(first), Some(last)) => {
312            let span_days = (last - first).num_days().max(1) as f64;
313            let total_amount: f64 = edges
314                .iter()
315                .filter_map(|(_, edge_id)| graph.get_edge(*edge_id))
316                .map(|e| e.weight)
317                .sum();
318            total_amount / span_days
319        }
320        _ => 0.0,
321    }
322}
323
324/// Computes burst score using Kleinberg-style daily event counting.
325fn compute_burst_score(edges: &[(NaiveDate, EdgeId)], threshold: f64) -> f64 {
326    if edges.is_empty() {
327        return 0.0;
328    }
329
330    // Count events per day
331    let mut daily_counts: HashMap<NaiveDate, usize> = HashMap::new();
332    for (date, _) in edges {
333        *daily_counts.entry(*date).or_insert(0) += 1;
334    }
335
336    let counts: Vec<f64> = daily_counts.values().map(|&c| c as f64).collect();
337    if counts.is_empty() {
338        return 0.0;
339    }
340
341    let mean_count = counts.iter().sum::<f64>() / counts.len() as f64;
342    let max_count = counts.iter().cloned().fold(0.0, f64::max);
343
344    if mean_count < 0.001 {
345        0.0
346    } else {
347        let ratio = max_count / mean_count;
348        // Score is how much max exceeds threshold
349        if ratio > threshold {
350            (ratio - threshold).min(10.0) // Cap at 10 for stability
351        } else {
352            0.0
353        }
354    }
355}
356
357/// Computes trend direction using linear regression on amounts over time.
358fn compute_trend_direction(edges: &[(NaiveDate, EdgeId)], graph: &Graph) -> f64 {
359    if edges.len() < 3 {
360        return 0.0;
361    }
362
363    let first_date = edges.first().map(|(d, _)| *d).unwrap();
364
365    // Collect (days_since_start, amount) pairs
366    let points: Vec<(f64, f64)> = edges
367        .iter()
368        .filter_map(|(date, edge_id)| {
369            let edge = graph.get_edge(*edge_id)?;
370            let x = (*date - first_date).num_days() as f64;
371            Some((x, edge.weight))
372        })
373        .collect();
374
375    if points.len() < 3 {
376        return 0.0;
377    }
378
379    // Simple linear regression to find slope
380    let n = points.len() as f64;
381    let sum_x: f64 = points.iter().map(|(x, _)| x).sum();
382    let sum_y: f64 = points.iter().map(|(_, y)| y).sum();
383    let sum_xy: f64 = points.iter().map(|(x, y)| x * y).sum();
384    let sum_xx: f64 = points.iter().map(|(x, _)| x * x).sum();
385
386    let denominator = n * sum_xx - sum_x * sum_x;
387    if denominator.abs() < 1e-10 {
388        return 0.0;
389    }
390
391    let slope = (n * sum_xy - sum_x * sum_y) / denominator;
392
393    // Normalize slope direction
394    if slope > 0.001 {
395        1.0
396    } else if slope < -0.001 {
397        -1.0
398    } else {
399        0.0
400    }
401}
402
403/// Computes seasonality score based on weekday activity variance.
404fn compute_seasonality_score(edges: &[(NaiveDate, EdgeId)]) -> f64 {
405    if edges.len() < 7 {
406        return 0.0;
407    }
408
409    // Count events per weekday
410    let mut weekday_counts = [0.0; 7];
411    for (date, _) in edges {
412        let weekday = date.weekday().num_days_from_monday() as usize;
413        weekday_counts[weekday] += 1.0;
414    }
415
416    // Compute variance of weekday distribution
417    let mean_count = weekday_counts.iter().sum::<f64>() / 7.0;
418    let variance = weekday_counts
419        .iter()
420        .map(|&c| (c - mean_count).powi(2))
421        .sum::<f64>()
422        / 7.0;
423
424    // Normalize by total count to get relative variance
425    let total = edges.len() as f64;
426    if total < 1.0 {
427        0.0
428    } else {
429        // Coefficient of variation for weekday distribution
430        (variance.sqrt() / mean_count.max(1.0)).min(1.0)
431    }
432}
433
434/// Computes window-based aggregate features.
435fn compute_window_features(
436    node_id: NodeId,
437    graph: &Graph,
438    index: &TemporalIndex,
439    reference_date: NaiveDate,
440    window_days: i64,
441) -> WindowFeatures {
442    let start_date = reference_date - chrono::Duration::days(window_days);
443    let edges = index.edges_in_range(node_id, start_date, reference_date);
444
445    if edges.is_empty() {
446        return WindowFeatures::default();
447    }
448
449    let mut total_amount = 0.0;
450    let mut max_amount = 0.0;
451    let mut counterparties = std::collections::HashSet::new();
452
453    for (_, edge_id) in &edges {
454        if let Some(edge) = graph.get_edge(*edge_id) {
455            total_amount += edge.weight;
456            if edge.weight > max_amount {
457                max_amount = edge.weight;
458            }
459            // Track counterparty (the other end of the edge)
460            let node = graph.get_node(node_id);
461            if node.is_some() {
462                if edge.source == node_id {
463                    counterparties.insert(edge.target);
464                } else {
465                    counterparties.insert(edge.source);
466                }
467            }
468        }
469    }
470
471    let event_count = edges.len();
472    let avg_amount = if event_count > 0 {
473        total_amount / event_count as f64
474    } else {
475        0.0
476    };
477
478    WindowFeatures {
479        event_count,
480        total_amount,
481        avg_amount,
482        max_amount,
483        unique_counterparties: counterparties.len(),
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
491    use crate::EdgeType;
492
493    fn create_test_graph() -> Graph {
494        let mut graph = Graph::new("test", GraphType::Transaction);
495
496        // Add nodes
497        let n1 = graph.add_node(GraphNode::new(
498            0,
499            NodeType::Account,
500            "1000".to_string(),
501            "Cash".to_string(),
502        ));
503        let n2 = graph.add_node(GraphNode::new(
504            0,
505            NodeType::Account,
506            "2000".to_string(),
507            "AP".to_string(),
508        ));
509        let n3 = graph.add_node(GraphNode::new(
510            0,
511            NodeType::Account,
512            "3000".to_string(),
513            "Revenue".to_string(),
514        ));
515
516        // Add edges with timestamps spanning several days
517        for i in 0..10 {
518            let date = NaiveDate::from_ymd_opt(2024, 1, 1 + i).unwrap();
519            let amount = 100.0 + (i as f64 * 10.0); // Increasing trend
520
521            let edge = GraphEdge::new(0, n1, n2, EdgeType::Transaction)
522                .with_weight(amount)
523                .with_timestamp(date);
524            graph.add_edge(edge);
525
526            // Add some edges to n3 for variety
527            if i % 2 == 0 {
528                let edge = GraphEdge::new(0, n1, n3, EdgeType::Transaction)
529                    .with_weight(amount * 2.0)
530                    .with_timestamp(date);
531                graph.add_edge(edge);
532            }
533        }
534
535        graph
536    }
537
538    #[test]
539    fn test_temporal_index_build() {
540        let graph = create_test_graph();
541        let index = TemporalIndex::build(&graph);
542
543        assert!(index.min_date.is_some());
544        assert!(index.max_date.is_some());
545        assert_eq!(
546            index.min_date.unwrap(),
547            NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()
548        );
549        assert_eq!(
550            index.max_date.unwrap(),
551            NaiveDate::from_ymd_opt(2024, 1, 10).unwrap()
552        );
553    }
554
555    #[test]
556    fn test_edges_in_range() {
557        let graph = create_test_graph();
558        let index = TemporalIndex::build(&graph);
559
560        // Node 1 should have many edges
561        let start = NaiveDate::from_ymd_opt(2024, 1, 3).unwrap();
562        let end = NaiveDate::from_ymd_opt(2024, 1, 7).unwrap();
563        let edges = index.edges_in_range(1, start, end);
564
565        // Should have edges on days 3, 4, 5, 6, 7 (5 days)
566        // Node 1 has 2 edges on even days (to n2 and n3) and 1 on odd days (to n2)
567        assert!(!edges.is_empty());
568    }
569
570    #[test]
571    fn test_compute_temporal_features() {
572        let graph = create_test_graph();
573        let index = TemporalIndex::build(&graph);
574        let config = TemporalConfig::default();
575
576        let features = compute_temporal_sequence_features(1, &graph, &index, &config);
577
578        // Node 1 should have positive velocity
579        assert!(features.transaction_velocity > 0.0);
580
581        // Should have positive trend (amounts increase over time)
582        assert!(features.trend_direction >= 0.0);
583
584        // Should have window features
585        assert!(!features.window_features.is_empty());
586    }
587
588    #[test]
589    fn test_inter_event_intervals() {
590        let edges = vec![
591            (NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(), 1),
592            (NaiveDate::from_ymd_opt(2024, 1, 3).unwrap(), 2),
593            (NaiveDate::from_ymd_opt(2024, 1, 6).unwrap(), 3),
594        ];
595
596        let (mean, std) = compute_inter_event_intervals(&edges);
597
598        // Intervals are 2 and 3 days, mean = 2.5
599        assert!((mean - 2.5).abs() < 0.01);
600        assert!(std > 0.0);
601    }
602
603    #[test]
604    fn test_burst_score() {
605        // Create edges with a burst on one day
606        let mut edges = Vec::new();
607        for i in 0..3 {
608            // Normal days with 1 event each
609            edges.push((NaiveDate::from_ymd_opt(2024, 1, 1 + i).unwrap(), i as u64));
610        }
611        // Burst day with 10 events
612        for i in 0..10 {
613            edges.push((NaiveDate::from_ymd_opt(2024, 1, 10).unwrap(), 100 + i));
614        }
615
616        let score = compute_burst_score(&edges, 3.0);
617
618        // Should detect burst
619        assert!(score > 0.0);
620    }
621
622    #[test]
623    fn test_feature_vector_length() {
624        let windows = vec![7, 30, 90];
625        let expected_len = TemporalFeatures::feature_count(windows.len());
626
627        let features = TemporalFeatures::default();
628        let vec = features.to_features(&windows);
629
630        assert_eq!(vec.len(), expected_len);
631    }
632
633    #[test]
634    fn test_compute_all_temporal_features() {
635        let graph = create_test_graph();
636        let config = TemporalConfig::default();
637
638        let all_features = compute_all_temporal_features(&graph, &config);
639
640        // Should have features for all nodes
641        assert_eq!(all_features.len(), 3);
642    }
643}