1use std::collections::{HashMap, HashSet};
10
11use chrono::NaiveDate;
12use serde::{Deserialize, Serialize};
13
14use crate::models::{Graph, NodeId};
15
16#[derive(Debug, Clone)]
18pub struct RelationshipFeatureConfig {
19 pub new_relationship_days: i64,
21 pub reference_date: NaiveDate,
23 pub high_risk_threshold: f64,
25 pub weight_by_amount: bool,
27 pub min_transactions: usize,
29}
30
31impl Default for RelationshipFeatureConfig {
32 fn default() -> Self {
33 Self {
34 new_relationship_days: 30,
35 reference_date: NaiveDate::from_ymd_opt(2024, 12, 31).unwrap(),
36 high_risk_threshold: 0.5,
37 weight_by_amount: true,
38 min_transactions: 1,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct RelationshipFeatures {
46 pub unique_counterparties: usize,
48 pub new_relationship_ratio: f64,
50 pub counterparty_concentration: f64,
52 pub relationship_reciprocity: f64,
54 pub avg_relationship_age_days: f64,
56 pub relationship_velocity: f64,
58 pub total_relationships: usize,
60 pub dominant_counterparty_share: f64,
62}
63
64impl RelationshipFeatures {
65 pub fn to_features(&self) -> Vec<f64> {
67 vec![
68 self.unique_counterparties as f64,
69 self.new_relationship_ratio,
70 self.counterparty_concentration,
71 self.relationship_reciprocity,
72 self.avg_relationship_age_days / 365.0, self.relationship_velocity,
74 self.total_relationships as f64,
75 self.dominant_counterparty_share,
76 ]
77 }
78
79 pub fn feature_count() -> usize {
81 8
82 }
83
84 pub fn feature_names() -> Vec<&'static str> {
86 vec![
87 "unique_counterparties",
88 "new_relationship_ratio",
89 "counterparty_concentration_hhi",
90 "relationship_reciprocity",
91 "avg_relationship_age_years",
92 "relationship_velocity",
93 "total_relationships",
94 "dominant_counterparty_share",
95 ]
96 }
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct CounterpartyRisk {
102 pub high_risk_counterparty_ratio: f64,
104 pub avg_counterparty_risk_score: f64,
106 pub risk_concentration: f64,
108 pub anomalous_counterparty_count: usize,
110 pub high_risk_exposure: f64,
112}
113
114impl CounterpartyRisk {
115 pub fn to_features(&self) -> Vec<f64> {
117 vec![
118 self.high_risk_counterparty_ratio,
119 self.avg_counterparty_risk_score,
120 self.risk_concentration,
121 self.anomalous_counterparty_count as f64,
122 (self.high_risk_exposure + 1.0).ln(),
123 ]
124 }
125
126 pub fn feature_count() -> usize {
128 5
129 }
130
131 pub fn feature_names() -> Vec<&'static str> {
133 vec![
134 "high_risk_counterparty_ratio",
135 "avg_counterparty_risk_score",
136 "risk_concentration",
137 "anomalous_counterparty_count",
138 "high_risk_exposure_log",
139 ]
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145struct CounterpartyInfo {
146 first_contact: Option<NaiveDate>,
148 transaction_count: usize,
150 total_volume: f64,
152 is_anomalous: bool,
154 risk_score: f64,
156}
157
158pub fn compute_relationship_features(
160 node_id: NodeId,
161 graph: &Graph,
162 config: &RelationshipFeatureConfig,
163) -> RelationshipFeatures {
164 let outgoing = graph.outgoing_edges(node_id);
165 let incoming = graph.incoming_edges(node_id);
166
167 if outgoing.is_empty() && incoming.is_empty() {
168 return RelationshipFeatures::default();
169 }
170
171 let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
173 let mut outgoing_targets: HashSet<NodeId> = HashSet::new();
174 let mut incoming_sources: HashSet<NodeId> = HashSet::new();
175
176 for edge in &outgoing {
178 outgoing_targets.insert(edge.target);
179 let info = counterparties.entry(edge.target).or_default();
180 info.transaction_count += 1;
181 info.total_volume += edge.weight;
182
183 if let Some(date) = edge.timestamp {
184 match info.first_contact {
185 None => info.first_contact = Some(date),
186 Some(existing) if date < existing => info.first_contact = Some(date),
187 _ => {}
188 }
189 }
190 }
191
192 for edge in &incoming {
194 incoming_sources.insert(edge.source);
195 let info = counterparties.entry(edge.source).or_default();
196 info.transaction_count += 1;
197 info.total_volume += edge.weight;
198
199 if let Some(date) = edge.timestamp {
200 match info.first_contact {
201 None => info.first_contact = Some(date),
202 Some(existing) if date < existing => info.first_contact = Some(date),
203 _ => {}
204 }
205 }
206 }
207
208 let unique_counterparties = counterparties.len();
209 let total_relationships = outgoing.len() + incoming.len();
210
211 if unique_counterparties == 0 {
212 return RelationshipFeatures::default();
213 }
214
215 let new_threshold =
217 config.reference_date - chrono::Duration::days(config.new_relationship_days);
218 let new_count = counterparties
219 .values()
220 .filter(|info| {
221 info.first_contact
222 .map(|d| d >= new_threshold)
223 .unwrap_or(false)
224 })
225 .count();
226 let new_relationship_ratio = new_count as f64 / unique_counterparties as f64;
227
228 let total_volume: f64 = counterparties.values().map(|i| i.total_volume).sum();
230 let counterparty_concentration = if total_volume > 0.0 {
231 counterparties
232 .values()
233 .map(|info| {
234 let share = info.total_volume / total_volume;
235 share * share
236 })
237 .sum()
238 } else {
239 1.0 / unique_counterparties as f64 };
241
242 let bidirectional_count = outgoing_targets.intersection(&incoming_sources).count();
244 let relationship_reciprocity = if unique_counterparties > 0 {
245 bidirectional_count as f64 / unique_counterparties as f64
246 } else {
247 0.0
248 };
249
250 let ages: Vec<i64> = counterparties
252 .values()
253 .filter_map(|info| info.first_contact)
254 .map(|date| (config.reference_date - date).num_days().max(0))
255 .collect();
256
257 let avg_relationship_age_days = if !ages.is_empty() {
258 ages.iter().sum::<i64>() as f64 / ages.len() as f64
259 } else {
260 0.0
261 };
262
263 let date_range = counterparties
265 .values()
266 .filter_map(|info| info.first_contact)
267 .fold((None, None), |(min, max), date| {
268 let new_min = min.map_or(date, |m: NaiveDate| m.min(date));
269 let new_max = max.map_or(date, |m: NaiveDate| m.max(date));
270 (Some(new_min), Some(new_max))
271 });
272
273 let relationship_velocity = if let (Some(min_date), Some(max_date)) = date_range {
274 let months = (max_date - min_date).num_days() as f64 / 30.0;
275 if months > 0.0 {
276 unique_counterparties as f64 / months
277 } else {
278 unique_counterparties as f64
279 }
280 } else {
281 0.0
282 };
283
284 let max_volume = counterparties
286 .values()
287 .map(|i| i.total_volume)
288 .fold(0.0, f64::max);
289 let dominant_counterparty_share = if total_volume > 0.0 {
290 max_volume / total_volume
291 } else {
292 0.0
293 };
294
295 RelationshipFeatures {
296 unique_counterparties,
297 new_relationship_ratio,
298 counterparty_concentration,
299 relationship_reciprocity,
300 avg_relationship_age_days,
301 relationship_velocity,
302 total_relationships,
303 dominant_counterparty_share,
304 }
305}
306
307pub fn compute_counterparty_risk(
309 node_id: NodeId,
310 graph: &Graph,
311 config: &RelationshipFeatureConfig,
312) -> CounterpartyRisk {
313 let outgoing = graph.outgoing_edges(node_id);
314 let incoming = graph.incoming_edges(node_id);
315
316 if outgoing.is_empty() && incoming.is_empty() {
317 return CounterpartyRisk::default();
318 }
319
320 let mut counterparties: HashMap<NodeId, CounterpartyInfo> = HashMap::new();
322
323 for edge in outgoing.iter().chain(incoming.iter()) {
325 let counterparty_id = if edge.source == node_id {
326 edge.target
327 } else {
328 edge.source
329 };
330
331 let info = counterparties.entry(counterparty_id).or_default();
332 info.transaction_count += 1;
333 info.total_volume += edge.weight;
334
335 if edge.is_anomaly {
337 info.is_anomalous = true;
338 }
339 }
340
341 for (&cp_id, info) in counterparties.iter_mut() {
343 let cp_node = graph.get_node(cp_id);
344
345 let mut risk = 0.0;
347
348 if let Some(node) = cp_node {
349 if node.is_anomaly {
350 risk += 0.5;
351 info.is_anomalous = true;
352 }
353 }
354
355 let cp_edges: Vec<_> = outgoing
357 .iter()
358 .chain(incoming.iter())
359 .filter(|e| e.source == cp_id || e.target == cp_id)
360 .collect();
361
362 let anomalous_edge_ratio =
363 cp_edges.iter().filter(|e| e.is_anomaly).count() as f64 / cp_edges.len().max(1) as f64;
364 risk += anomalous_edge_ratio * 0.3;
365
366 if let Some(node) = cp_node {
368 let suspicious_labels = ["fraud", "suspicious", "high_risk", "flagged"];
369 for label in &node.labels {
370 if suspicious_labels
371 .iter()
372 .any(|s| label.to_lowercase().contains(s))
373 {
374 risk += 0.2;
375 break;
376 }
377 }
378 }
379
380 info.risk_score = risk.min(1.0);
381 }
382
383 let unique_counterparties = counterparties.len();
384 if unique_counterparties == 0 {
385 return CounterpartyRisk::default();
386 }
387
388 let high_risk_count = counterparties
390 .values()
391 .filter(|info| info.risk_score >= config.high_risk_threshold)
392 .count();
393 let high_risk_counterparty_ratio = high_risk_count as f64 / unique_counterparties as f64;
394
395 let total_risk: f64 = counterparties.values().map(|i| i.risk_score).sum();
397 let avg_counterparty_risk_score = total_risk / unique_counterparties as f64;
398
399 let total_risk_weighted: f64 = counterparties
401 .values()
402 .map(|i| i.total_volume * i.risk_score)
403 .sum();
404
405 let risk_concentration = if total_risk_weighted > 0.0 {
406 counterparties
407 .values()
408 .map(|info| {
409 let weighted = info.total_volume * info.risk_score;
410 let share = weighted / total_risk_weighted;
411 share * share
412 })
413 .sum()
414 } else {
415 0.0
416 };
417
418 let anomalous_counterparty_count = counterparties.values().filter(|i| i.is_anomalous).count();
420
421 let high_risk_exposure: f64 = counterparties
423 .values()
424 .filter(|info| info.risk_score >= config.high_risk_threshold)
425 .map(|info| info.total_volume)
426 .sum();
427
428 CounterpartyRisk {
429 high_risk_counterparty_ratio,
430 avg_counterparty_risk_score,
431 risk_concentration,
432 anomalous_counterparty_count,
433 high_risk_exposure,
434 }
435}
436
437pub fn compute_all_relationship_features(
439 graph: &Graph,
440 config: &RelationshipFeatureConfig,
441) -> HashMap<NodeId, RelationshipFeatures> {
442 let mut features = HashMap::new();
443
444 for &node_id in graph.nodes.keys() {
445 features.insert(
446 node_id,
447 compute_relationship_features(node_id, graph, config),
448 );
449 }
450
451 features
452}
453
454pub fn compute_all_counterparty_risk(
456 graph: &Graph,
457 config: &RelationshipFeatureConfig,
458) -> HashMap<NodeId, CounterpartyRisk> {
459 let mut risks = HashMap::new();
460
461 for &node_id in graph.nodes.keys() {
462 risks.insert(node_id, compute_counterparty_risk(node_id, graph, config));
463 }
464
465 risks
466}
467
468#[derive(Debug, Clone, Default)]
470pub struct CombinedRelationshipFeatures {
471 pub relationship: RelationshipFeatures,
473 pub risk: CounterpartyRisk,
475}
476
477impl CombinedRelationshipFeatures {
478 pub fn to_features(&self) -> Vec<f64> {
480 let mut features = self.relationship.to_features();
481 features.extend(self.risk.to_features());
482 features
483 }
484
485 pub fn feature_count() -> usize {
487 RelationshipFeatures::feature_count() + CounterpartyRisk::feature_count()
488 }
489
490 pub fn feature_names() -> Vec<&'static str> {
492 let mut names = RelationshipFeatures::feature_names();
493 names.extend(CounterpartyRisk::feature_names());
494 names
495 }
496}
497
498pub fn compute_all_combined_features(
500 graph: &Graph,
501 config: &RelationshipFeatureConfig,
502) -> HashMap<NodeId, CombinedRelationshipFeatures> {
503 let mut features = HashMap::new();
504
505 for &node_id in graph.nodes.keys() {
506 features.insert(
507 node_id,
508 CombinedRelationshipFeatures {
509 relationship: compute_relationship_features(node_id, graph, config),
510 risk: compute_counterparty_risk(node_id, graph, config),
511 },
512 );
513 }
514
515 features
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
522 use crate::EdgeType;
523
524 fn create_test_graph() -> Graph {
525 let mut graph = Graph::new("test", GraphType::Transaction);
526
527 let n1 = graph.add_node(GraphNode::new(
529 0,
530 NodeType::Account,
531 "A".to_string(),
532 "A".to_string(),
533 ));
534 let n2 = graph.add_node(GraphNode::new(
535 0,
536 NodeType::Account,
537 "B".to_string(),
538 "B".to_string(),
539 ));
540 let n3 = graph.add_node(GraphNode::new(
541 0,
542 NodeType::Account,
543 "C".to_string(),
544 "C".to_string(),
545 ));
546 let n4 = graph.add_node(GraphNode::new(
547 0,
548 NodeType::Account,
549 "D".to_string(),
550 "D".to_string(),
551 ));
552
553 graph.add_edge(
556 GraphEdge::new(0, n1, n2, EdgeType::Transaction)
557 .with_weight(1000.0)
558 .with_timestamp(NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()),
559 );
560 graph.add_edge(
561 GraphEdge::new(0, n1, n2, EdgeType::Transaction)
562 .with_weight(2000.0)
563 .with_timestamp(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()),
564 );
565
566 graph.add_edge(
568 GraphEdge::new(0, n1, n3, EdgeType::Transaction)
569 .with_weight(500.0)
570 .with_timestamp(NaiveDate::from_ymd_opt(2024, 3, 1).unwrap()),
571 );
572
573 graph.add_edge(
575 GraphEdge::new(0, n2, n1, EdgeType::Transaction)
576 .with_weight(1500.0)
577 .with_timestamp(NaiveDate::from_ymd_opt(2024, 4, 1).unwrap()),
578 );
579
580 graph.add_edge(
582 GraphEdge::new(0, n1, n4, EdgeType::Transaction)
583 .with_weight(300.0)
584 .with_timestamp(NaiveDate::from_ymd_opt(2024, 12, 15).unwrap()),
585 );
586
587 graph
588 }
589
590 #[test]
591 fn test_relationship_features() {
592 let graph = create_test_graph();
593 let config = RelationshipFeatureConfig::default();
594
595 let features = compute_relationship_features(1, &graph, &config);
596
597 assert_eq!(features.unique_counterparties, 3); assert!(features.new_relationship_ratio > 0.0); assert!(features.counterparty_concentration > 0.0);
600 assert!(features.relationship_reciprocity > 0.0); }
602
603 #[test]
604 fn test_herfindahl_index() {
605 let graph = create_test_graph();
606 let config = RelationshipFeatureConfig::default();
607
608 let features = compute_relationship_features(1, &graph, &config);
609
610 assert!(features.counterparty_concentration > 0.0);
612 assert!(features.counterparty_concentration <= 1.0);
613
614 assert!(features.counterparty_concentration > 0.33);
616 }
617
618 #[test]
619 fn test_reciprocity() {
620 let graph = create_test_graph();
621 let config = RelationshipFeatureConfig::default();
622
623 let features = compute_relationship_features(1, &graph, &config);
624
625 assert!((features.relationship_reciprocity - 0.333).abs() < 0.1);
628 }
629
630 #[test]
631 fn test_counterparty_risk_basic() {
632 let graph = create_test_graph();
633 let config = RelationshipFeatureConfig::default();
634
635 let risk = compute_counterparty_risk(1, &graph, &config);
636
637 assert_eq!(risk.anomalous_counterparty_count, 0);
639 assert_eq!(risk.avg_counterparty_risk_score, 0.0);
640 }
641
642 #[test]
643 fn test_counterparty_risk_with_anomalies() {
644 let mut graph = create_test_graph();
645
646 if let Some(edge) = graph.get_edge_mut(1) {
648 edge.is_anomaly = true;
649 }
650
651 let config = RelationshipFeatureConfig::default();
652 let risk = compute_counterparty_risk(1, &graph, &config);
653
654 assert!(risk.avg_counterparty_risk_score > 0.0);
656 }
657
658 #[test]
659 fn test_feature_vector_length() {
660 assert_eq!(RelationshipFeatures::feature_count(), 8);
661 assert_eq!(CounterpartyRisk::feature_count(), 5);
662 assert_eq!(CombinedRelationshipFeatures::feature_count(), 13);
663
664 let features = RelationshipFeatures::default();
665 assert_eq!(
666 features.to_features().len(),
667 RelationshipFeatures::feature_count()
668 );
669
670 let risk = CounterpartyRisk::default();
671 assert_eq!(risk.to_features().len(), CounterpartyRisk::feature_count());
672 }
673
674 #[test]
675 fn test_all_relationship_features() {
676 let graph = create_test_graph();
677 let config = RelationshipFeatureConfig::default();
678
679 let all_features = compute_all_relationship_features(&graph, &config);
680
681 assert_eq!(all_features.len(), 4); }
683
684 #[test]
685 fn test_combined_features() {
686 let graph = create_test_graph();
687 let config = RelationshipFeatureConfig::default();
688
689 let combined = compute_all_combined_features(&graph, &config);
690
691 for (_node_id, features) in combined {
692 assert_eq!(
693 features.to_features().len(),
694 CombinedRelationshipFeatures::feature_count()
695 );
696 }
697 }
698}