Skip to main content

datasynth_graph/ml/
relationship_features.rs

1//! Entity relationship feature computation for fraud detection.
2//!
3//! This module provides relationship-based features including:
4//! - Counterparty concentration (Herfindahl index)
5//! - Relationship age and velocity
6//! - Reciprocity (bidirectional transaction patterns)
7//! - Counterparty risk propagation
8
9use std::collections::{HashMap, HashSet};
10
11use chrono::NaiveDate;
12use serde::{Deserialize, Serialize};
13
14use crate::models::{Graph, NodeId};
15
16/// Configuration for relationship feature computation.
17#[derive(Debug, Clone)]
18pub struct RelationshipFeatureConfig {
19    /// Number of days to consider a relationship "new".
20    pub new_relationship_days: i64,
21    /// Reference date for age calculations.
22    pub reference_date: NaiveDate,
23    /// Threshold for high-risk counterparty classification.
24    pub high_risk_threshold: f64,
25    /// Whether to weight features by transaction amount.
26    pub weight_by_amount: bool,
27    /// Minimum number of transactions for meaningful features.
28    pub min_transactions: usize,
29}
30
31impl Default for RelationshipFeatureConfig {
32    fn default() -> Self {
33        Self {
34            new_relationship_days: 30,
35            reference_date: NaiveDate::from_ymd_opt(2024, 12, 31).unwrap(),
36            high_risk_threshold: 0.5,
37            weight_by_amount: true,
38            min_transactions: 1,
39        }
40    }
41}
42
43/// Relationship features for a node.
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct RelationshipFeatures {
46    /// Number of unique counterparties.
47    pub unique_counterparties: usize,
48    /// Ratio of new relationships (< new_relationship_days old).
49    pub new_relationship_ratio: f64,
50    /// Herfindahl-Hirschman Index for counterparty concentration.
51    pub counterparty_concentration: f64,
52    /// Ratio of bidirectional relationships.
53    pub relationship_reciprocity: f64,
54    /// Average relationship age in days.
55    pub avg_relationship_age_days: f64,
56    /// Rate of new relationships per month.
57    pub relationship_velocity: f64,
58    /// Total number of relationships (including multiple txns per counterparty).
59    pub total_relationships: usize,
60    /// Share of transactions with dominant counterparty.
61    pub dominant_counterparty_share: f64,
62}
63
64impl RelationshipFeatures {
65    /// Converts to a feature vector.
66    pub fn to_features(&self) -> Vec<f64> {
67        vec![
68            self.unique_counterparties as f64,
69            self.new_relationship_ratio,
70            self.counterparty_concentration,
71            self.relationship_reciprocity,
72            self.avg_relationship_age_days / 365.0, // Normalize to years
73            self.relationship_velocity,
74            self.total_relationships as f64,
75            self.dominant_counterparty_share,
76        ]
77    }
78
79    /// Returns the number of features.
80    pub fn feature_count() -> usize {
81        8
82    }
83
84    /// Returns feature names.
85    pub fn feature_names() -> Vec<&'static str> {
86        vec![
87            "unique_counterparties",
88            "new_relationship_ratio",
89            "counterparty_concentration_hhi",
90            "relationship_reciprocity",
91            "avg_relationship_age_years",
92            "relationship_velocity",
93            "total_relationships",
94            "dominant_counterparty_share",
95        ]
96    }
97}
98
99/// Counterparty risk features for a node.
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct CounterpartyRisk {
102    /// Ratio of high-risk counterparties.
103    pub high_risk_counterparty_ratio: f64,
104    /// Average risk score of counterparties.
105    pub avg_counterparty_risk_score: f64,
106    /// Concentration of risk in few counterparties.
107    pub risk_concentration: f64,
108    /// Number of anomalous counterparties.
109    pub anomalous_counterparty_count: usize,
110    /// Total exposure to high-risk counterparties (by amount).
111    pub high_risk_exposure: f64,
112}
113
114impl CounterpartyRisk {
115    /// Converts to a feature vector.
116    pub fn to_features(&self) -> Vec<f64> {
117        vec![
118            self.high_risk_counterparty_ratio,
119            self.avg_counterparty_risk_score,
120            self.risk_concentration,
121            self.anomalous_counterparty_count as f64,
122            (self.high_risk_exposure + 1.0).ln(),
123        ]
124    }
125
126    /// Returns the number of features.
127    pub fn feature_count() -> usize {
128        5
129    }
130
131    /// Returns feature names.
132    pub fn feature_names() -> Vec<&'static str> {
133        vec![
134            "high_risk_counterparty_ratio",
135            "avg_counterparty_risk_score",
136            "risk_concentration",
137            "anomalous_counterparty_count",
138            "high_risk_exposure_log",
139        ]
140    }
141}
142
143/// Internal structure for tracking counterparty relationships.
144#[derive(Debug, Clone, Default)]
145struct CounterpartyInfo {
146    /// First transaction date with this counterparty.
147    first_contact: Option<NaiveDate>,
148    /// Total transaction count.
149    transaction_count: usize,
150    /// Total transaction volume.
151    total_volume: f64,
152    /// Is this counterparty anomalous.
153    is_anomalous: bool,
154    /// Risk score for this counterparty.
155    risk_score: f64,
156}
157
158/// Computes relationship features for a single node.
159pub fn compute_relationship_features(
160    node_id: NodeId,
161    graph: &Graph,
162    config: &RelationshipFeatureConfig,
163) -> RelationshipFeatures {
164    let outgoing = graph.outgoing_edges(node_id);
165    let incoming = graph.incoming_edges(node_id);
166
167    if outgoing.is_empty() && incoming.is_empty() {
168        return RelationshipFeatures::default();
169    }
170
171    // Build counterparty info
172    let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
173    let mut outgoing_targets: HashSet<NodeId> = HashSet::new();
174    let mut incoming_sources: HashSet<NodeId> = HashSet::new();
175
176    // Process outgoing edges
177    for edge in &outgoing {
178        outgoing_targets.insert(edge.target);
179        let info = counterparties.entry(edge.target).or_default();
180        info.transaction_count += 1;
181        info.total_volume += edge.weight;
182
183        if let Some(date) = edge.timestamp {
184            match info.first_contact {
185                None => info.first_contact = Some(date),
186                Some(existing) if date < existing => info.first_contact = Some(date),
187                _ => {}
188            }
189        }
190    }
191
192    // Process incoming edges
193    for edge in &incoming {
194        incoming_sources.insert(edge.source);
195        let info = counterparties.entry(edge.source).or_default();
196        info.transaction_count += 1;
197        info.total_volume += edge.weight;
198
199        if let Some(date) = edge.timestamp {
200            match info.first_contact {
201                None => info.first_contact = Some(date),
202                Some(existing) if date < existing => info.first_contact = Some(date),
203                _ => {}
204            }
205        }
206    }
207
208    let unique_counterparties = counterparties.len();
209    let total_relationships = outgoing.len() + incoming.len();
210
211    if unique_counterparties == 0 {
212        return RelationshipFeatures::default();
213    }
214
215    // Calculate new relationship ratio
216    let new_threshold =
217        config.reference_date - chrono::Duration::days(config.new_relationship_days);
218    let new_count = counterparties
219        .values()
220        .filter(|info| {
221            info.first_contact
222                .map(|d| d >= new_threshold)
223                .unwrap_or(false)
224        })
225        .count();
226    let new_relationship_ratio = new_count as f64 / unique_counterparties as f64;
227
228    // Calculate HHI for concentration
229    let total_volume: f64 = counterparties.values().map(|i| i.total_volume).sum();
230    let counterparty_concentration = if total_volume > 0.0 {
231        counterparties
232            .values()
233            .map(|info| {
234                let share = info.total_volume / total_volume;
235                share * share
236            })
237            .sum()
238    } else {
239        1.0 / unique_counterparties as f64 // Equal distribution
240    };
241
242    // Calculate reciprocity (bidirectional relationships)
243    let bidirectional_count = outgoing_targets.intersection(&incoming_sources).count();
244    let relationship_reciprocity = if unique_counterparties > 0 {
245        bidirectional_count as f64 / unique_counterparties as f64
246    } else {
247        0.0
248    };
249
250    // Calculate average relationship age
251    let ages: Vec<i64> = counterparties
252        .values()
253        .filter_map(|info| info.first_contact)
254        .map(|date| (config.reference_date - date).num_days().max(0))
255        .collect();
256
257    let avg_relationship_age_days = if !ages.is_empty() {
258        ages.iter().sum::<i64>() as f64 / ages.len() as f64
259    } else {
260        0.0
261    };
262
263    // Calculate relationship velocity (new relationships per month)
264    let date_range = counterparties
265        .values()
266        .filter_map(|info| info.first_contact)
267        .fold((None, None), |(min, max), date| {
268            let new_min = min.map_or(date, |m: NaiveDate| m.min(date));
269            let new_max = max.map_or(date, |m: NaiveDate| m.max(date));
270            (Some(new_min), Some(new_max))
271        });
272
273    let relationship_velocity = if let (Some(min_date), Some(max_date)) = date_range {
274        let months = (max_date - min_date).num_days() as f64 / 30.0;
275        if months > 0.0 {
276            unique_counterparties as f64 / months
277        } else {
278            unique_counterparties as f64
279        }
280    } else {
281        0.0
282    };
283
284    // Calculate dominant counterparty share
285    let max_volume = counterparties
286        .values()
287        .map(|i| i.total_volume)
288        .fold(0.0, f64::max);
289    let dominant_counterparty_share = if total_volume > 0.0 {
290        max_volume / total_volume
291    } else {
292        0.0
293    };
294
295    RelationshipFeatures {
296        unique_counterparties,
297        new_relationship_ratio,
298        counterparty_concentration,
299        relationship_reciprocity,
300        avg_relationship_age_days,
301        relationship_velocity,
302        total_relationships,
303        dominant_counterparty_share,
304    }
305}
306
307/// Computes counterparty risk features for a node.
308pub fn compute_counterparty_risk(
309    node_id: NodeId,
310    graph: &Graph,
311    config: &RelationshipFeatureConfig,
312) -> CounterpartyRisk {
313    let outgoing = graph.outgoing_edges(node_id);
314    let incoming = graph.incoming_edges(node_id);
315
316    if outgoing.is_empty() && incoming.is_empty() {
317        return CounterpartyRisk::default();
318    }
319
320    // Build counterparty info with risk scores
321    let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
322
323    // Process all edges
324    for edge in outgoing.iter().chain(incoming.iter()) {
325        let counterparty_id = if edge.source == node_id {
326            edge.target
327        } else {
328            edge.source
329        };
330
331        let info = counterparties.entry(counterparty_id).or_default();
332        info.transaction_count += 1;
333        info.total_volume += edge.weight;
334
335        // Inherit anomaly status from edge
336        if edge.is_anomaly {
337            info.is_anomalous = true;
338        }
339    }
340
341    // Calculate risk scores for counterparties
342    for (&cp_id, info) in counterparties.iter_mut() {
343        let cp_node = graph.get_node(cp_id);
344
345        // Base risk from counterparty's anomaly status
346        let mut risk = 0.0;
347
348        if let Some(node) = cp_node {
349            if node.is_anomaly {
350                risk += 0.5;
351                info.is_anomalous = true;
352            }
353        }
354
355        // Risk from edge anomalies with this counterparty
356        let cp_edges: Vec<_> = outgoing
357            .iter()
358            .chain(incoming.iter())
359            .filter(|e| e.source == cp_id || e.target == cp_id)
360            .collect();
361
362        let anomalous_edge_ratio =
363            cp_edges.iter().filter(|e| e.is_anomaly).count() as f64 / cp_edges.len().max(1) as f64;
364        risk += anomalous_edge_ratio * 0.3;
365
366        // Risk from having suspicious labels
367        if let Some(node) = cp_node {
368            let suspicious_labels = ["fraud", "suspicious", "high_risk", "flagged"];
369            for label in &node.labels {
370                if suspicious_labels
371                    .iter()
372                    .any(|s| label.to_lowercase().contains(s))
373                {
374                    risk += 0.2;
375                    break;
376                }
377            }
378        }
379
380        info.risk_score = risk.min(1.0);
381    }
382
383    let unique_counterparties = counterparties.len();
384    if unique_counterparties == 0 {
385        return CounterpartyRisk::default();
386    }
387
388    // Calculate high-risk counterparty ratio
389    let high_risk_count = counterparties
390        .values()
391        .filter(|info| info.risk_score >= config.high_risk_threshold)
392        .count();
393    let high_risk_counterparty_ratio = high_risk_count as f64 / unique_counterparties as f64;
394
395    // Calculate average risk score
396    let total_risk: f64 = counterparties.values().map(|i| i.risk_score).sum();
397    let avg_counterparty_risk_score = total_risk / unique_counterparties as f64;
398
399    // Calculate risk concentration (HHI of risk-weighted volume)
400    let total_risk_weighted: f64 = counterparties
401        .values()
402        .map(|i| i.total_volume * i.risk_score)
403        .sum();
404
405    let risk_concentration = if total_risk_weighted > 0.0 {
406        counterparties
407            .values()
408            .map(|info| {
409                let weighted = info.total_volume * info.risk_score;
410                let share = weighted / total_risk_weighted;
411                share * share
412            })
413            .sum()
414    } else {
415        0.0
416    };
417
418    // Count anomalous counterparties
419    let anomalous_counterparty_count = counterparties.values().filter(|i| i.is_anomalous).count();
420
421    // Calculate high-risk exposure
422    let high_risk_exposure: f64 = counterparties
423        .values()
424        .filter(|info| info.risk_score >= config.high_risk_threshold)
425        .map(|info| info.total_volume)
426        .sum();
427
428    CounterpartyRisk {
429        high_risk_counterparty_ratio,
430        avg_counterparty_risk_score,
431        risk_concentration,
432        anomalous_counterparty_count,
433        high_risk_exposure,
434    }
435}
436
437/// Computes relationship features for all nodes in a graph.
438pub fn compute_all_relationship_features(
439    graph: &Graph,
440    config: &RelationshipFeatureConfig,
441) -> HashMap<NodeId, RelationshipFeatures> {
442    let mut features = HashMap::new();
443
444    for &node_id in graph.nodes.keys() {
445        features.insert(
446            node_id,
447            compute_relationship_features(node_id, graph, config),
448        );
449    }
450
451    features
452}
453
454/// Computes counterparty risk for all nodes in a graph.
455pub fn compute_all_counterparty_risk(
456    graph: &Graph,
457    config: &RelationshipFeatureConfig,
458) -> HashMap<NodeId, CounterpartyRisk> {
459    let mut risks = HashMap::new();
460
461    for &node_id in graph.nodes.keys() {
462        risks.insert(node_id, compute_counterparty_risk(node_id, graph, config));
463    }
464
465    risks
466}
467
468/// Combined relationship and risk features.
469#[derive(Debug, Clone, Default)]
470pub struct CombinedRelationshipFeatures {
471    /// Base relationship features.
472    pub relationship: RelationshipFeatures,
473    /// Counterparty risk features.
474    pub risk: CounterpartyRisk,
475}
476
477impl CombinedRelationshipFeatures {
478    /// Converts to a combined feature vector.
479    pub fn to_features(&self) -> Vec<f64> {
480        let mut features = self.relationship.to_features();
481        features.extend(self.risk.to_features());
482        features
483    }
484
485    /// Returns total feature count.
486    pub fn feature_count() -> usize {
487        RelationshipFeatures::feature_count() + CounterpartyRisk::feature_count()
488    }
489
490    /// Returns all feature names.
491    pub fn feature_names() -> Vec<&'static str> {
492        let mut names = RelationshipFeatures::feature_names();
493        names.extend(CounterpartyRisk::feature_names());
494        names
495    }
496}
497
498/// Computes combined features for all nodes.
499pub fn compute_all_combined_features(
500    graph: &Graph,
501    config: &RelationshipFeatureConfig,
502) -> HashMap<NodeId, CombinedRelationshipFeatures> {
503    let mut features = HashMap::new();
504
505    for &node_id in graph.nodes.keys() {
506        features.insert(
507            node_id,
508            CombinedRelationshipFeatures {
509                relationship: compute_relationship_features(node_id, graph, config),
510                risk: compute_counterparty_risk(node_id, graph, config),
511            },
512        );
513    }
514
515    features
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use crate::test_helpers::create_relationship_test_graph;
522
523    #[test]
524    fn test_relationship_features() {
525        let graph = create_relationship_test_graph();
526        let config = RelationshipFeatureConfig::default();
527
528        let features = compute_relationship_features(1, &graph, &config);
529
530        assert_eq!(features.unique_counterparties, 3); // N2, N3, N4
531        assert!(features.new_relationship_ratio > 0.0); // N4 is new
532        assert!(features.counterparty_concentration > 0.0);
533        assert!(features.relationship_reciprocity > 0.0); // N2 is bidirectional
534    }
535
536    #[test]
537    fn test_herfindahl_index() {
538        let graph = create_relationship_test_graph();
539        let config = RelationshipFeatureConfig::default();
540
541        let features = compute_relationship_features(1, &graph, &config);
542
543        // HHI should be between 0 and 1
544        assert!(features.counterparty_concentration > 0.0);
545        assert!(features.counterparty_concentration <= 1.0);
546
547        // With 3 counterparties and unequal volumes, should be > 1/3
548        assert!(features.counterparty_concentration > 0.33);
549    }
550
551    #[test]
552    fn test_reciprocity() {
553        let graph = create_relationship_test_graph();
554        let config = RelationshipFeatureConfig::default();
555
556        let features = compute_relationship_features(1, &graph, &config);
557
558        // N1 has 3 unique counterparties, 1 bidirectional (N2)
559        // Reciprocity = 1/3 = 0.333...
560        assert!((features.relationship_reciprocity - 0.333).abs() < 0.1);
561    }
562
563    #[test]
564    fn test_counterparty_risk_basic() {
565        let graph = create_relationship_test_graph();
566        let config = RelationshipFeatureConfig::default();
567
568        let risk = compute_counterparty_risk(1, &graph, &config);
569
570        // No anomalies in test graph
571        assert_eq!(risk.anomalous_counterparty_count, 0);
572        assert_eq!(risk.avg_counterparty_risk_score, 0.0);
573    }
574
575    #[test]
576    fn test_counterparty_risk_with_anomalies() {
577        let mut graph = create_relationship_test_graph();
578
579        // Mark an edge as anomalous
580        if let Some(edge) = graph.get_edge_mut(1) {
581            edge.is_anomaly = true;
582        }
583
584        let config = RelationshipFeatureConfig::default();
585        let risk = compute_counterparty_risk(1, &graph, &config);
586
587        // Should detect the anomalous relationship
588        assert!(risk.avg_counterparty_risk_score > 0.0);
589    }
590
591    #[test]
592    fn test_feature_vector_length() {
593        assert_eq!(RelationshipFeatures::feature_count(), 8);
594        assert_eq!(CounterpartyRisk::feature_count(), 5);
595        assert_eq!(CombinedRelationshipFeatures::feature_count(), 13);
596
597        let features = RelationshipFeatures::default();
598        assert_eq!(
599            features.to_features().len(),
600            RelationshipFeatures::feature_count()
601        );
602
603        let risk = CounterpartyRisk::default();
604        assert_eq!(risk.to_features().len(), CounterpartyRisk::feature_count());
605    }
606
607    #[test]
608    fn test_all_relationship_features() {
609        let graph = create_relationship_test_graph();
610        let config = RelationshipFeatureConfig::default();
611
612        let all_features = compute_all_relationship_features(&graph, &config);
613
614        assert_eq!(all_features.len(), 4); // 4 nodes
615    }
616
617    #[test]
618    fn test_combined_features() {
619        let graph = create_relationship_test_graph();
620        let config = RelationshipFeatureConfig::default();
621
622        let combined = compute_all_combined_features(&graph, &config);
623
624        for (_node_id, features) in combined {
625            assert_eq!(
626                features.to_features().len(),
627                CombinedRelationshipFeatures::feature_count()
628            );
629        }
630    }
631}