Skip to main content

datasynth_graph/ml/
relationship_features.rs

1//! Entity relationship feature computation for fraud detection.
2//!
3//! This module provides relationship-based features including:
4//! - Counterparty concentration (Herfindahl index)
5//! - Relationship age and velocity
6//! - Reciprocity (bidirectional transaction patterns)
7//! - Counterparty risk propagation
8
9use std::collections::{HashMap, HashSet};
10
11use chrono::NaiveDate;
12use serde::{Deserialize, Serialize};
13
14use crate::models::{Graph, NodeId};
15
16/// Configuration for relationship feature computation.
17#[derive(Debug, Clone)]
18pub struct RelationshipFeatureConfig {
19    /// Number of days to consider a relationship "new".
20    pub new_relationship_days: i64,
21    /// Reference date for age calculations.
22    pub reference_date: NaiveDate,
23    /// Threshold for high-risk counterparty classification.
24    pub high_risk_threshold: f64,
25    /// Whether to weight features by transaction amount.
26    pub weight_by_amount: bool,
27    /// Minimum number of transactions for meaningful features.
28    pub min_transactions: usize,
29}
30
31impl Default for RelationshipFeatureConfig {
32    fn default() -> Self {
33        Self {
34            new_relationship_days: 30,
35            reference_date: NaiveDate::from_ymd_opt(2024, 12, 31).unwrap(),
36            high_risk_threshold: 0.5,
37            weight_by_amount: true,
38            min_transactions: 1,
39        }
40    }
41}
42
43/// Relationship features for a node.
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct RelationshipFeatures {
46    /// Number of unique counterparties.
47    pub unique_counterparties: usize,
48    /// Ratio of new relationships (< new_relationship_days old).
49    pub new_relationship_ratio: f64,
50    /// Herfindahl-Hirschman Index for counterparty concentration.
51    pub counterparty_concentration: f64,
52    /// Ratio of bidirectional relationships.
53    pub relationship_reciprocity: f64,
54    /// Average relationship age in days.
55    pub avg_relationship_age_days: f64,
56    /// Rate of new relationships per month.
57    pub relationship_velocity: f64,
58    /// Total number of relationships (including multiple txns per counterparty).
59    pub total_relationships: usize,
60    /// Share of transactions with dominant counterparty.
61    pub dominant_counterparty_share: f64,
62}
63
64impl RelationshipFeatures {
65    /// Converts to a feature vector.
66    pub fn to_features(&self) -> Vec<f64> {
67        vec![
68            self.unique_counterparties as f64,
69            self.new_relationship_ratio,
70            self.counterparty_concentration,
71            self.relationship_reciprocity,
72            self.avg_relationship_age_days / 365.0, // Normalize to years
73            self.relationship_velocity,
74            self.total_relationships as f64,
75            self.dominant_counterparty_share,
76        ]
77    }
78
79    /// Returns the number of features.
80    pub fn feature_count() -> usize {
81        8
82    }
83
84    /// Returns feature names.
85    pub fn feature_names() -> Vec<&'static str> {
86        vec![
87            "unique_counterparties",
88            "new_relationship_ratio",
89            "counterparty_concentration_hhi",
90            "relationship_reciprocity",
91            "avg_relationship_age_years",
92            "relationship_velocity",
93            "total_relationships",
94            "dominant_counterparty_share",
95        ]
96    }
97}
98
99/// Counterparty risk features for a node.
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct CounterpartyRisk {
102    /// Ratio of high-risk counterparties.
103    pub high_risk_counterparty_ratio: f64,
104    /// Average risk score of counterparties.
105    pub avg_counterparty_risk_score: f64,
106    /// Concentration of risk in few counterparties.
107    pub risk_concentration: f64,
108    /// Number of anomalous counterparties.
109    pub anomalous_counterparty_count: usize,
110    /// Total exposure to high-risk counterparties (by amount).
111    pub high_risk_exposure: f64,
112}
113
114impl CounterpartyRisk {
115    /// Converts to a feature vector.
116    pub fn to_features(&self) -> Vec<f64> {
117        vec![
118            self.high_risk_counterparty_ratio,
119            self.avg_counterparty_risk_score,
120            self.risk_concentration,
121            self.anomalous_counterparty_count as f64,
122            (self.high_risk_exposure + 1.0).ln(),
123        ]
124    }
125
126    /// Returns the number of features.
127    pub fn feature_count() -> usize {
128        5
129    }
130
131    /// Returns feature names.
132    pub fn feature_names() -> Vec<&'static str> {
133        vec![
134            "high_risk_counterparty_ratio",
135            "avg_counterparty_risk_score",
136            "risk_concentration",
137            "anomalous_counterparty_count",
138            "high_risk_exposure_log",
139        ]
140    }
141}
142
143/// Internal structure for tracking counterparty relationships.
144#[derive(Debug, Clone, Default)]
145struct CounterpartyInfo {
146    /// First transaction date with this counterparty.
147    first_contact: Option<NaiveDate>,
148    /// Total transaction count.
149    transaction_count: usize,
150    /// Total transaction volume.
151    total_volume: f64,
152    /// Is this counterparty anomalous.
153    is_anomalous: bool,
154    /// Risk score for this counterparty.
155    risk_score: f64,
156}
157
158/// Computes relationship features for a single node.
159pub fn compute_relationship_features(
160    node_id: NodeId,
161    graph: &Graph,
162    config: &RelationshipFeatureConfig,
163) -> RelationshipFeatures {
164    let outgoing = graph.outgoing_edges(node_id);
165    let incoming = graph.incoming_edges(node_id);
166
167    if outgoing.is_empty() && incoming.is_empty() {
168        return RelationshipFeatures::default();
169    }
170
171    // Build counterparty info
172    let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
173    let mut outgoing_targets: HashSet<NodeId> = HashSet::new();
174    let mut incoming_sources: HashSet<NodeId> = HashSet::new();
175
176    // Process outgoing edges
177    for edge in &outgoing {
178        outgoing_targets.insert(edge.target);
179        let info = counterparties.entry(edge.target).or_default();
180        info.transaction_count += 1;
181        info.total_volume += edge.weight;
182
183        if let Some(date) = edge.timestamp {
184            match info.first_contact {
185                None => info.first_contact = Some(date),
186                Some(existing) if date < existing => info.first_contact = Some(date),
187                _ => {}
188            }
189        }
190    }
191
192    // Process incoming edges
193    for edge in &incoming {
194        incoming_sources.insert(edge.source);
195        let info = counterparties.entry(edge.source).or_default();
196        info.transaction_count += 1;
197        info.total_volume += edge.weight;
198
199        if let Some(date) = edge.timestamp {
200            match info.first_contact {
201                None => info.first_contact = Some(date),
202                Some(existing) if date < existing => info.first_contact = Some(date),
203                _ => {}
204            }
205        }
206    }
207
208    let unique_counterparties = counterparties.len();
209    let total_relationships = outgoing.len() + incoming.len();
210
211    if unique_counterparties == 0 {
212        return RelationshipFeatures::default();
213    }
214
215    // Calculate new relationship ratio
216    let new_threshold =
217        config.reference_date - chrono::Duration::days(config.new_relationship_days);
218    let new_count = counterparties
219        .values()
220        .filter(|info| {
221            info.first_contact
222                .map(|d| d >= new_threshold)
223                .unwrap_or(false)
224        })
225        .count();
226    let new_relationship_ratio = new_count as f64 / unique_counterparties as f64;
227
228    // Calculate HHI for concentration
229    let total_volume: f64 = counterparties.values().map(|i| i.total_volume).sum();
230    let counterparty_concentration = if total_volume > 0.0 {
231        counterparties
232            .values()
233            .map(|info| {
234                let share = info.total_volume / total_volume;
235                share * share
236            })
237            .sum()
238    } else {
239        1.0 / unique_counterparties as f64 // Equal distribution
240    };
241
242    // Calculate reciprocity (bidirectional relationships)
243    let bidirectional_count = outgoing_targets.intersection(&incoming_sources).count();
244    let relationship_reciprocity = if unique_counterparties > 0 {
245        bidirectional_count as f64 / unique_counterparties as f64
246    } else {
247        0.0
248    };
249
250    // Calculate average relationship age
251    let ages: Vec<i64> = counterparties
252        .values()
253        .filter_map(|info| info.first_contact)
254        .map(|date| (config.reference_date - date).num_days().max(0))
255        .collect();
256
257    let avg_relationship_age_days = if !ages.is_empty() {
258        ages.iter().sum::<i64>() as f64 / ages.len() as f64
259    } else {
260        0.0
261    };
262
263    // Calculate relationship velocity (new relationships per month)
264    let date_range = counterparties
265        .values()
266        .filter_map(|info| info.first_contact)
267        .fold((None, None), |(min, max), date| {
268            let new_min = min.map_or(date, |m: NaiveDate| m.min(date));
269            let new_max = max.map_or(date, |m: NaiveDate| m.max(date));
270            (Some(new_min), Some(new_max))
271        });
272
273    let relationship_velocity = if let (Some(min_date), Some(max_date)) = date_range {
274        let months = (max_date - min_date).num_days() as f64 / 30.0;
275        if months > 0.0 {
276            unique_counterparties as f64 / months
277        } else {
278            unique_counterparties as f64
279        }
280    } else {
281        0.0
282    };
283
284    // Calculate dominant counterparty share
285    let max_volume = counterparties
286        .values()
287        .map(|i| i.total_volume)
288        .fold(0.0, f64::max);
289    let dominant_counterparty_share = if total_volume > 0.0 {
290        max_volume / total_volume
291    } else {
292        0.0
293    };
294
295    RelationshipFeatures {
296        unique_counterparties,
297        new_relationship_ratio,
298        counterparty_concentration,
299        relationship_reciprocity,
300        avg_relationship_age_days,
301        relationship_velocity,
302        total_relationships,
303        dominant_counterparty_share,
304    }
305}
306
307/// Computes counterparty risk features for a node.
308pub fn compute_counterparty_risk(
309    node_id: NodeId,
310    graph: &Graph,
311    config: &RelationshipFeatureConfig,
312) -> CounterpartyRisk {
313    let outgoing = graph.outgoing_edges(node_id);
314    let incoming = graph.incoming_edges(node_id);
315
316    if outgoing.is_empty() && incoming.is_empty() {
317        return CounterpartyRisk::default();
318    }
319
320    // Build counterparty info with risk scores
321    let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
322
323    // Process all edges
324    for edge in outgoing.iter().chain(incoming.iter()) {
325        let counterparty_id = if edge.source == node_id {
326            edge.target
327        } else {
328            edge.source
329        };
330
331        let info = counterparties.entry(counterparty_id).or_default();
332        info.transaction_count += 1;
333        info.total_volume += edge.weight;
334
335        // Inherit anomaly status from edge
336        if edge.is_anomaly {
337            info.is_anomalous = true;
338        }
339    }
340
341    // Calculate risk scores for counterparties
342    for (&cp_id, info) in counterparties.iter_mut() {
343        let cp_node = graph.get_node(cp_id);
344
345        // Base risk from counterparty's anomaly status
346        let mut risk = 0.0;
347
348        if let Some(node) = cp_node {
349            if node.is_anomaly {
350                risk += 0.5;
351                info.is_anomalous = true;
352            }
353        }
354
355        // Risk from edge anomalies with this counterparty
356        let cp_edges: Vec<_> = outgoing
357            .iter()
358            .chain(incoming.iter())
359            .filter(|e| e.source == cp_id || e.target == cp_id)
360            .collect();
361
362        let anomalous_edge_ratio =
363            cp_edges.iter().filter(|e| e.is_anomaly).count() as f64 / cp_edges.len().max(1) as f64;
364        risk += anomalous_edge_ratio * 0.3;
365
366        // Risk from having suspicious labels
367        if let Some(node) = cp_node {
368            let suspicious_labels = ["fraud", "suspicious", "high_risk", "flagged"];
369            for label in &node.labels {
370                if suspicious_labels
371                    .iter()
372                    .any(|s| label.to_lowercase().contains(s))
373                {
374                    risk += 0.2;
375                    break;
376                }
377            }
378        }
379
380        info.risk_score = risk.min(1.0);
381    }
382
383    let unique_counterparties = counterparties.len();
384    if unique_counterparties == 0 {
385        return CounterpartyRisk::default();
386    }
387
388    // Calculate high-risk counterparty ratio
389    let high_risk_count = counterparties
390        .values()
391        .filter(|info| info.risk_score >= config.high_risk_threshold)
392        .count();
393    let high_risk_counterparty_ratio = high_risk_count as f64 / unique_counterparties as f64;
394
395    // Calculate average risk score
396    let total_risk: f64 = counterparties.values().map(|i| i.risk_score).sum();
397    let avg_counterparty_risk_score = total_risk / unique_counterparties as f64;
398
399    // Calculate risk concentration (HHI of risk-weighted volume)
400    let total_risk_weighted: f64 = counterparties
401        .values()
402        .map(|i| i.total_volume * i.risk_score)
403        .sum();
404
405    let risk_concentration = if total_risk_weighted > 0.0 {
406        counterparties
407            .values()
408            .map(|info| {
409                let weighted = info.total_volume * info.risk_score;
410                let share = weighted / total_risk_weighted;
411                share * share
412            })
413            .sum()
414    } else {
415        0.0
416    };
417
418    // Count anomalous counterparties
419    let anomalous_counterparty_count = counterparties.values().filter(|i| i.is_anomalous).count();
420
421    // Calculate high-risk exposure
422    let high_risk_exposure: f64 = counterparties
423        .values()
424        .filter(|info| info.risk_score >= config.high_risk_threshold)
425        .map(|info| info.total_volume)
426        .sum();
427
428    CounterpartyRisk {
429        high_risk_counterparty_ratio,
430        avg_counterparty_risk_score,
431        risk_concentration,
432        anomalous_counterparty_count,
433        high_risk_exposure,
434    }
435}
436
437/// Computes relationship features for all nodes in a graph.
438pub fn compute_all_relationship_features(
439    graph: &Graph,
440    config: &RelationshipFeatureConfig,
441) -> HashMap<NodeId, RelationshipFeatures> {
442    let mut features = HashMap::new();
443
444    for &node_id in graph.nodes.keys() {
445        features.insert(
446            node_id,
447            compute_relationship_features(node_id, graph, config),
448        );
449    }
450
451    features
452}
453
454/// Computes counterparty risk for all nodes in a graph.
455pub fn compute_all_counterparty_risk(
456    graph: &Graph,
457    config: &RelationshipFeatureConfig,
458) -> HashMap<NodeId, CounterpartyRisk> {
459    let mut risks = HashMap::new();
460
461    for &node_id in graph.nodes.keys() {
462        risks.insert(node_id, compute_counterparty_risk(node_id, graph, config));
463    }
464
465    risks
466}
467
468/// Combined relationship and risk features.
469#[derive(Debug, Clone, Default)]
470pub struct CombinedRelationshipFeatures {
471    /// Base relationship features.
472    pub relationship: RelationshipFeatures,
473    /// Counterparty risk features.
474    pub risk: CounterpartyRisk,
475}
476
477impl CombinedRelationshipFeatures {
478    /// Converts to a combined feature vector.
479    pub fn to_features(&self) -> Vec<f64> {
480        let mut features = self.relationship.to_features();
481        features.extend(self.risk.to_features());
482        features
483    }
484
485    /// Returns total feature count.
486    pub fn feature_count() -> usize {
487        RelationshipFeatures::feature_count() + CounterpartyRisk::feature_count()
488    }
489
490    /// Returns all feature names.
491    pub fn feature_names() -> Vec<&'static str> {
492        let mut names = RelationshipFeatures::feature_names();
493        names.extend(CounterpartyRisk::feature_names());
494        names
495    }
496}
497
498/// Computes combined features for all nodes.
499pub fn compute_all_combined_features(
500    graph: &Graph,
501    config: &RelationshipFeatureConfig,
502) -> HashMap<NodeId, CombinedRelationshipFeatures> {
503    let mut features = HashMap::new();
504
505    for &node_id in graph.nodes.keys() {
506        features.insert(
507            node_id,
508            CombinedRelationshipFeatures {
509                relationship: compute_relationship_features(node_id, graph, config),
510                risk: compute_counterparty_risk(node_id, graph, config),
511            },
512        );
513    }
514
515    features
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
522    use crate::EdgeType;
523
524    fn create_test_graph() -> Graph {
525        let mut graph = Graph::new("test", GraphType::Transaction);
526
527        // Create nodes
528        let n1 = graph.add_node(GraphNode::new(
529            0,
530            NodeType::Account,
531            "A".to_string(),
532            "A".to_string(),
533        ));
534        let n2 = graph.add_node(GraphNode::new(
535            0,
536            NodeType::Account,
537            "B".to_string(),
538            "B".to_string(),
539        ));
540        let n3 = graph.add_node(GraphNode::new(
541            0,
542            NodeType::Account,
543            "C".to_string(),
544            "C".to_string(),
545        ));
546        let n4 = graph.add_node(GraphNode::new(
547            0,
548            NodeType::Account,
549            "D".to_string(),
550            "D".to_string(),
551        ));
552
553        // Create edges with timestamps
554        // N1 -> N2 (2 transactions)
555        graph.add_edge(
556            GraphEdge::new(0, n1, n2, EdgeType::Transaction)
557                .with_weight(1000.0)
558                .with_timestamp(NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()),
559        );
560        graph.add_edge(
561            GraphEdge::new(0, n1, n2, EdgeType::Transaction)
562                .with_weight(2000.0)
563                .with_timestamp(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()),
564        );
565
566        // N1 -> N3 (1 transaction)
567        graph.add_edge(
568            GraphEdge::new(0, n1, n3, EdgeType::Transaction)
569                .with_weight(500.0)
570                .with_timestamp(NaiveDate::from_ymd_opt(2024, 3, 1).unwrap()),
571        );
572
573        // N2 -> N1 (bidirectional relationship)
574        graph.add_edge(
575            GraphEdge::new(0, n2, n1, EdgeType::Transaction)
576                .with_weight(1500.0)
577                .with_timestamp(NaiveDate::from_ymd_opt(2024, 4, 1).unwrap()),
578        );
579
580        // N1 -> N4 (recent new relationship)
581        graph.add_edge(
582            GraphEdge::new(0, n1, n4, EdgeType::Transaction)
583                .with_weight(300.0)
584                .with_timestamp(NaiveDate::from_ymd_opt(2024, 12, 15).unwrap()),
585        );
586
587        graph
588    }
589
590    #[test]
591    fn test_relationship_features() {
592        let graph = create_test_graph();
593        let config = RelationshipFeatureConfig::default();
594
595        let features = compute_relationship_features(1, &graph, &config);
596
597        assert_eq!(features.unique_counterparties, 3); // N2, N3, N4
598        assert!(features.new_relationship_ratio > 0.0); // N4 is new
599        assert!(features.counterparty_concentration > 0.0);
600        assert!(features.relationship_reciprocity > 0.0); // N2 is bidirectional
601    }
602
603    #[test]
604    fn test_herfindahl_index() {
605        let graph = create_test_graph();
606        let config = RelationshipFeatureConfig::default();
607
608        let features = compute_relationship_features(1, &graph, &config);
609
610        // HHI should be between 0 and 1
611        assert!(features.counterparty_concentration > 0.0);
612        assert!(features.counterparty_concentration <= 1.0);
613
614        // With 3 counterparties and unequal volumes, should be > 1/3
615        assert!(features.counterparty_concentration > 0.33);
616    }
617
618    #[test]
619    fn test_reciprocity() {
620        let graph = create_test_graph();
621        let config = RelationshipFeatureConfig::default();
622
623        let features = compute_relationship_features(1, &graph, &config);
624
625        // N1 has 3 unique counterparties, 1 bidirectional (N2)
626        // Reciprocity = 1/3 = 0.333...
627        assert!((features.relationship_reciprocity - 0.333).abs() < 0.1);
628    }
629
630    #[test]
631    fn test_counterparty_risk_basic() {
632        let graph = create_test_graph();
633        let config = RelationshipFeatureConfig::default();
634
635        let risk = compute_counterparty_risk(1, &graph, &config);
636
637        // No anomalies in test graph
638        assert_eq!(risk.anomalous_counterparty_count, 0);
639        assert_eq!(risk.avg_counterparty_risk_score, 0.0);
640    }
641
642    #[test]
643    fn test_counterparty_risk_with_anomalies() {
644        let mut graph = create_test_graph();
645
646        // Mark an edge as anomalous
647        if let Some(edge) = graph.get_edge_mut(1) {
648            edge.is_anomaly = true;
649        }
650
651        let config = RelationshipFeatureConfig::default();
652        let risk = compute_counterparty_risk(1, &graph, &config);
653
654        // Should detect the anomalous relationship
655        assert!(risk.avg_counterparty_risk_score > 0.0);
656    }
657
658    #[test]
659    fn test_feature_vector_length() {
660        assert_eq!(RelationshipFeatures::feature_count(), 8);
661        assert_eq!(CounterpartyRisk::feature_count(), 5);
662        assert_eq!(CombinedRelationshipFeatures::feature_count(), 13);
663
664        let features = RelationshipFeatures::default();
665        assert_eq!(
666            features.to_features().len(),
667            RelationshipFeatures::feature_count()
668        );
669
670        let risk = CounterpartyRisk::default();
671        assert_eq!(risk.to_features().len(), CounterpartyRisk::feature_count());
672    }
673
674    #[test]
675    fn test_all_relationship_features() {
676        let graph = create_test_graph();
677        let config = RelationshipFeatureConfig::default();
678
679        let all_features = compute_all_relationship_features(&graph, &config);
680
681        assert_eq!(all_features.len(), 4); // 4 nodes
682    }
683
684    #[test]
685    fn test_combined_features() {
686        let graph = create_test_graph();
687        let config = RelationshipFeatureConfig::default();
688
689        let combined = compute_all_combined_features(&graph, &config);
690
691        for (_node_id, features) in combined {
692            assert_eq!(
693                features.to_features().len(),
694                CombinedRelationshipFeatures::feature_count()
695            );
696        }
697    }
698}