1use std::collections::HashMap;
12
13use chrono::{Datelike, NaiveDate};
14
15use crate::models::{EdgeId, Graph, NodeId};
16
17#[derive(Debug, Clone)]
19pub struct TemporalConfig {
20 pub window_sizes: Vec<i64>,
22 pub reference_date: Option<NaiveDate>,
24 pub min_edge_count: usize,
26 pub burst_threshold: f64,
28}
29
30impl Default for TemporalConfig {
31 fn default() -> Self {
32 Self {
33 window_sizes: vec![7, 30, 90],
34 reference_date: None,
35 min_edge_count: 2,
36 burst_threshold: 3.0,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct WindowFeatures {
44 pub event_count: usize,
46 pub total_amount: f64,
48 pub avg_amount: f64,
50 pub max_amount: f64,
52 pub unique_counterparties: usize,
54}
55
56impl WindowFeatures {
57 pub fn to_features(&self) -> Vec<f64> {
59 vec![
60 self.event_count as f64,
61 (self.total_amount + 1.0).ln(),
62 (self.avg_amount + 1.0).ln(),
63 (self.max_amount + 1.0).ln(),
64 self.unique_counterparties as f64,
65 ]
66 }
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct TemporalFeatures {
72 pub transaction_velocity: f64,
74 pub inter_event_interval_mean: f64,
76 pub inter_event_interval_std: f64,
78 pub burst_score: f64,
80 pub trend_direction: f64,
82 pub seasonality_score: f64,
84 pub recency_days: f64,
86 pub window_features: HashMap<i64, WindowFeatures>,
88}
89
90impl TemporalFeatures {
91 pub fn to_features(&self, window_sizes: &[i64]) -> Vec<f64> {
94 let mut features = vec![
95 (self.transaction_velocity + 1.0).ln(),
96 self.inter_event_interval_mean,
97 self.inter_event_interval_std,
98 self.burst_score,
99 self.trend_direction,
100 self.seasonality_score,
101 self.recency_days / 365.0, ];
103
104 for &window in window_sizes {
106 if let Some(wf) = self.window_features.get(&window) {
107 features.extend(wf.to_features());
108 } else {
109 features.extend(vec![0.0; 5]);
111 }
112 }
113
114 features
115 }
116
117 pub fn feature_count(window_count: usize) -> usize {
119 7 + (5 * window_count) }
121}
122
123#[derive(Debug, Clone)]
125pub struct TemporalIndex {
126 node_edges_by_date: HashMap<NodeId, Vec<(NaiveDate, EdgeId)>>,
128 pub min_date: Option<NaiveDate>,
130 pub max_date: Option<NaiveDate>,
132}
133
134impl TemporalIndex {
135 pub fn build(graph: &Graph) -> Self {
138 let mut node_edges: HashMap<NodeId, Vec<(NaiveDate, EdgeId)>> = HashMap::new();
139 let mut min_date: Option<NaiveDate> = None;
140 let mut max_date: Option<NaiveDate> = None;
141
142 for (&edge_id, edge) in &graph.edges {
144 if let Some(date) = edge.timestamp {
145 min_date = Some(min_date.map_or(date, |d| d.min(date)));
147 max_date = Some(max_date.map_or(date, |d| d.max(date)));
148
149 node_edges
151 .entry(edge.source)
152 .or_default()
153 .push((date, edge_id));
154 node_edges
155 .entry(edge.target)
156 .or_default()
157 .push((date, edge_id));
158 }
159 }
160
161 for edges in node_edges.values_mut() {
163 edges.sort_by_key(|(date, _)| *date);
164 }
165
166 Self {
167 node_edges_by_date: node_edges,
168 min_date,
169 max_date,
170 }
171 }
172
173 pub fn edges_in_range(
175 &self,
176 node_id: NodeId,
177 start: NaiveDate,
178 end: NaiveDate,
179 ) -> Vec<(NaiveDate, EdgeId)> {
180 if let Some(edges) = self.node_edges_by_date.get(&node_id) {
181 let start_idx = edges.partition_point(|(d, _)| *d < start);
183 let end_idx = edges.partition_point(|(d, _)| *d <= end);
185
186 edges[start_idx..end_idx].to_vec()
187 } else {
188 Vec::new()
189 }
190 }
191
192 pub fn edges_for_node(&self, node_id: NodeId) -> &[(NaiveDate, EdgeId)] {
194 self.node_edges_by_date
195 .get(&node_id)
196 .map(|v| v.as_slice())
197 .unwrap_or(&[])
198 }
199
200 pub fn node_count(&self) -> usize {
202 self.node_edges_by_date.len()
203 }
204}
205
206pub fn compute_temporal_sequence_features(
208 node_id: NodeId,
209 graph: &Graph,
210 index: &TemporalIndex,
211 config: &TemporalConfig,
212) -> TemporalFeatures {
213 let edges = index.edges_for_node(node_id);
214
215 if edges.len() < config.min_edge_count {
217 return TemporalFeatures::default();
218 }
219
220 let reference_date = config
221 .reference_date
222 .or(index.max_date)
223 .unwrap_or_else(|| NaiveDate::from_ymd_opt(2024, 1, 1).unwrap());
224
225 let (interval_mean, interval_std) = compute_inter_event_intervals(edges);
227
228 let transaction_velocity = compute_transaction_velocity(edges, graph);
230
231 let burst_score = compute_burst_score(edges, config.burst_threshold);
233
234 let trend_direction = compute_trend_direction(edges, graph);
236
237 let seasonality_score = compute_seasonality_score(edges);
239
240 let recency_days = if let Some((last_date, _)) = edges.last() {
242 (reference_date - *last_date).num_days().max(0) as f64
243 } else {
244 f64::MAX
245 };
246
247 let mut window_features = HashMap::new();
249 for &window in &config.window_sizes {
250 let wf = compute_window_features(node_id, graph, index, reference_date, window);
251 window_features.insert(window, wf);
252 }
253
254 TemporalFeatures {
255 transaction_velocity,
256 inter_event_interval_mean: interval_mean,
257 inter_event_interval_std: interval_std,
258 burst_score,
259 trend_direction,
260 seasonality_score,
261 recency_days,
262 window_features,
263 }
264}
265
266pub fn compute_all_temporal_features(
268 graph: &Graph,
269 config: &TemporalConfig,
270) -> HashMap<NodeId, TemporalFeatures> {
271 let index = TemporalIndex::build(graph);
272 let mut features = HashMap::new();
273
274 for &node_id in graph.nodes.keys() {
275 let node_features = compute_temporal_sequence_features(node_id, graph, &index, config);
276 features.insert(node_id, node_features);
277 }
278
279 features
280}
281
282fn compute_inter_event_intervals(edges: &[(NaiveDate, EdgeId)]) -> (f64, f64) {
284 if edges.len() < 2 {
285 return (0.0, 0.0);
286 }
287
288 let intervals: Vec<f64> = edges
289 .windows(2)
290 .map(|w| (w[1].0 - w[0].0).num_days() as f64)
291 .collect();
292
293 let n = intervals.len() as f64;
294 let mean = intervals.iter().sum::<f64>() / n;
295 let variance = intervals.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
296 let std = variance.sqrt();
297
298 (mean, std)
299}
300
301fn compute_transaction_velocity(edges: &[(NaiveDate, EdgeId)], graph: &Graph) -> f64 {
303 if edges.len() < 2 {
304 return 0.0;
305 }
306
307 let first_date = edges.first().map(|(d, _)| *d);
308 let last_date = edges.last().map(|(d, _)| *d);
309
310 match (first_date, last_date) {
311 (Some(first), Some(last)) => {
312 let span_days = (last - first).num_days().max(1) as f64;
313 let total_amount: f64 = edges
314 .iter()
315 .filter_map(|(_, edge_id)| graph.get_edge(*edge_id))
316 .map(|e| e.weight)
317 .sum();
318 total_amount / span_days
319 }
320 _ => 0.0,
321 }
322}
323
324fn compute_burst_score(edges: &[(NaiveDate, EdgeId)], threshold: f64) -> f64 {
326 if edges.is_empty() {
327 return 0.0;
328 }
329
330 let mut daily_counts: HashMap<NaiveDate, usize> = HashMap::new();
332 for (date, _) in edges {
333 *daily_counts.entry(*date).or_insert(0) += 1;
334 }
335
336 let counts: Vec<f64> = daily_counts.values().map(|&c| c as f64).collect();
337 if counts.is_empty() {
338 return 0.0;
339 }
340
341 let mean_count = counts.iter().sum::<f64>() / counts.len() as f64;
342 let max_count = counts.iter().cloned().fold(0.0, f64::max);
343
344 if mean_count < 0.001 {
345 0.0
346 } else {
347 let ratio = max_count / mean_count;
348 if ratio > threshold {
350 (ratio - threshold).min(10.0) } else {
352 0.0
353 }
354 }
355}
356
357fn compute_trend_direction(edges: &[(NaiveDate, EdgeId)], graph: &Graph) -> f64 {
359 if edges.len() < 3 {
360 return 0.0;
361 }
362
363 let first_date = edges.first().map(|(d, _)| *d).unwrap();
364
365 let points: Vec<(f64, f64)> = edges
367 .iter()
368 .filter_map(|(date, edge_id)| {
369 let edge = graph.get_edge(*edge_id)?;
370 let x = (*date - first_date).num_days() as f64;
371 Some((x, edge.weight))
372 })
373 .collect();
374
375 if points.len() < 3 {
376 return 0.0;
377 }
378
379 let n = points.len() as f64;
381 let sum_x: f64 = points.iter().map(|(x, _)| x).sum();
382 let sum_y: f64 = points.iter().map(|(_, y)| y).sum();
383 let sum_xy: f64 = points.iter().map(|(x, y)| x * y).sum();
384 let sum_xx: f64 = points.iter().map(|(x, _)| x * x).sum();
385
386 let denominator = n * sum_xx - sum_x * sum_x;
387 if denominator.abs() < 1e-10 {
388 return 0.0;
389 }
390
391 let slope = (n * sum_xy - sum_x * sum_y) / denominator;
392
393 if slope > 0.001 {
395 1.0
396 } else if slope < -0.001 {
397 -1.0
398 } else {
399 0.0
400 }
401}
402
403fn compute_seasonality_score(edges: &[(NaiveDate, EdgeId)]) -> f64 {
405 if edges.len() < 7 {
406 return 0.0;
407 }
408
409 let mut weekday_counts = [0.0; 7];
411 for (date, _) in edges {
412 let weekday = date.weekday().num_days_from_monday() as usize;
413 weekday_counts[weekday] += 1.0;
414 }
415
416 let mean_count = weekday_counts.iter().sum::<f64>() / 7.0;
418 let variance = weekday_counts
419 .iter()
420 .map(|&c| (c - mean_count).powi(2))
421 .sum::<f64>()
422 / 7.0;
423
424 let total = edges.len() as f64;
426 if total < 1.0 {
427 0.0
428 } else {
429 (variance.sqrt() / mean_count.max(1.0)).min(1.0)
431 }
432}
433
434fn compute_window_features(
436 node_id: NodeId,
437 graph: &Graph,
438 index: &TemporalIndex,
439 reference_date: NaiveDate,
440 window_days: i64,
441) -> WindowFeatures {
442 let start_date = reference_date - chrono::Duration::days(window_days);
443 let edges = index.edges_in_range(node_id, start_date, reference_date);
444
445 if edges.is_empty() {
446 return WindowFeatures::default();
447 }
448
449 let mut total_amount = 0.0;
450 let mut max_amount = 0.0;
451 let mut counterparties = std::collections::HashSet::new();
452
453 for (_, edge_id) in &edges {
454 if let Some(edge) = graph.get_edge(*edge_id) {
455 total_amount += edge.weight;
456 if edge.weight > max_amount {
457 max_amount = edge.weight;
458 }
459 let node = graph.get_node(node_id);
461 if node.is_some() {
462 if edge.source == node_id {
463 counterparties.insert(edge.target);
464 } else {
465 counterparties.insert(edge.source);
466 }
467 }
468 }
469 }
470
471 let event_count = edges.len();
472 let avg_amount = if event_count > 0 {
473 total_amount / event_count as f64
474 } else {
475 0.0
476 };
477
478 WindowFeatures {
479 event_count,
480 total_amount,
481 avg_amount,
482 max_amount,
483 unique_counterparties: counterparties.len(),
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
491 use crate::EdgeType;
492
493 fn create_test_graph() -> Graph {
494 let mut graph = Graph::new("test", GraphType::Transaction);
495
496 let n1 = graph.add_node(GraphNode::new(
498 0,
499 NodeType::Account,
500 "1000".to_string(),
501 "Cash".to_string(),
502 ));
503 let n2 = graph.add_node(GraphNode::new(
504 0,
505 NodeType::Account,
506 "2000".to_string(),
507 "AP".to_string(),
508 ));
509 let n3 = graph.add_node(GraphNode::new(
510 0,
511 NodeType::Account,
512 "3000".to_string(),
513 "Revenue".to_string(),
514 ));
515
516 for i in 0..10 {
518 let date = NaiveDate::from_ymd_opt(2024, 1, 1 + i).unwrap();
519 let amount = 100.0 + (i as f64 * 10.0); let edge = GraphEdge::new(0, n1, n2, EdgeType::Transaction)
522 .with_weight(amount)
523 .with_timestamp(date);
524 graph.add_edge(edge);
525
526 if i % 2 == 0 {
528 let edge = GraphEdge::new(0, n1, n3, EdgeType::Transaction)
529 .with_weight(amount * 2.0)
530 .with_timestamp(date);
531 graph.add_edge(edge);
532 }
533 }
534
535 graph
536 }
537
538 #[test]
539 fn test_temporal_index_build() {
540 let graph = create_test_graph();
541 let index = TemporalIndex::build(&graph);
542
543 assert!(index.min_date.is_some());
544 assert!(index.max_date.is_some());
545 assert_eq!(
546 index.min_date.unwrap(),
547 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()
548 );
549 assert_eq!(
550 index.max_date.unwrap(),
551 NaiveDate::from_ymd_opt(2024, 1, 10).unwrap()
552 );
553 }
554
555 #[test]
556 fn test_edges_in_range() {
557 let graph = create_test_graph();
558 let index = TemporalIndex::build(&graph);
559
560 let start = NaiveDate::from_ymd_opt(2024, 1, 3).unwrap();
562 let end = NaiveDate::from_ymd_opt(2024, 1, 7).unwrap();
563 let edges = index.edges_in_range(1, start, end);
564
565 assert!(!edges.is_empty());
568 }
569
570 #[test]
571 fn test_compute_temporal_features() {
572 let graph = create_test_graph();
573 let index = TemporalIndex::build(&graph);
574 let config = TemporalConfig::default();
575
576 let features = compute_temporal_sequence_features(1, &graph, &index, &config);
577
578 assert!(features.transaction_velocity > 0.0);
580
581 assert!(features.trend_direction >= 0.0);
583
584 assert!(!features.window_features.is_empty());
586 }
587
588 #[test]
589 fn test_inter_event_intervals() {
590 let edges = vec![
591 (NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(), 1),
592 (NaiveDate::from_ymd_opt(2024, 1, 3).unwrap(), 2),
593 (NaiveDate::from_ymd_opt(2024, 1, 6).unwrap(), 3),
594 ];
595
596 let (mean, std) = compute_inter_event_intervals(&edges);
597
598 assert!((mean - 2.5).abs() < 0.01);
600 assert!(std > 0.0);
601 }
602
603 #[test]
604 fn test_burst_score() {
605 let mut edges = Vec::new();
607 for i in 0..3 {
608 edges.push((NaiveDate::from_ymd_opt(2024, 1, 1 + i).unwrap(), i as u64));
610 }
611 for i in 0..10 {
613 edges.push((NaiveDate::from_ymd_opt(2024, 1, 10).unwrap(), 100 + i));
614 }
615
616 let score = compute_burst_score(&edges, 3.0);
617
618 assert!(score > 0.0);
620 }
621
622 #[test]
623 fn test_feature_vector_length() {
624 let windows = vec![7, 30, 90];
625 let expected_len = TemporalFeatures::feature_count(windows.len());
626
627 let features = TemporalFeatures::default();
628 let vec = features.to_features(&windows);
629
630 assert_eq!(vec.len(), expected_len);
631 }
632
633 #[test]
634 fn test_compute_all_temporal_features() {
635 let graph = create_test_graph();
636 let config = TemporalConfig::default();
637
638 let all_features = compute_all_temporal_features(&graph, &config);
639
640 assert_eq!(all_features.len(), 3);
642 }
643}