1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::models::{Graph, NodeId};
11
12use super::entity_groups::EntityGroup;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum AggregationType {
17 Sum,
19 Mean,
21 WeightedMean,
23 Max,
25 Min,
27 Median,
29}
30
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct AggregatedFeatures {
34 pub total_volume: f64,
36 pub avg_transaction_size: f64,
38 pub combined_risk_score: f64,
40 pub internal_flow_ratio: f64,
42 pub external_flow_ratio: f64,
44 pub external_counterparty_count: usize,
46 pub activity_variance: f64,
48 pub member_count: usize,
50}
51
52impl AggregatedFeatures {
53 pub fn to_features(&self) -> Vec<f64> {
55 vec![
56 (self.total_volume + 1.0).ln(),
57 (self.avg_transaction_size + 1.0).ln(),
58 self.combined_risk_score,
59 self.internal_flow_ratio,
60 self.external_flow_ratio,
61 self.external_counterparty_count as f64,
62 self.activity_variance,
63 self.member_count as f64,
64 ]
65 }
66
67 pub fn feature_count() -> usize {
69 8
70 }
71
72 pub fn feature_names() -> Vec<&'static str> {
74 vec![
75 "total_volume_log",
76 "avg_transaction_size_log",
77 "combined_risk_score",
78 "internal_flow_ratio",
79 "external_flow_ratio",
80 "external_counterparty_count",
81 "activity_variance",
82 "member_count",
83 ]
84 }
85}
86
87pub fn aggregate_features(
89 group: &EntityGroup,
90 graph: &Graph,
91 _agg_type: AggregationType,
92) -> AggregatedFeatures {
93 let member_set: std::collections::HashSet<NodeId> = group.members.iter().copied().collect();
94
95 let mut total_volume = 0.0;
96 let mut internal_volume = 0.0;
97 let mut external_volume = 0.0;
98 let mut transaction_count = 0;
99 let mut external_counterparties = std::collections::HashSet::new();
100 let mut member_activities = Vec::new();
101
102 for &member in &group.members {
104 let mut member_activity = 0.0;
105
106 for edge in graph.outgoing_edges(member) {
107 total_volume += edge.weight;
108 member_activity += edge.weight;
109 transaction_count += 1;
110
111 if member_set.contains(&edge.target) {
112 internal_volume += edge.weight;
113 } else {
114 external_volume += edge.weight;
115 external_counterparties.insert(edge.target);
116 }
117 }
118
119 for edge in graph.incoming_edges(member) {
120 if !member_set.contains(&edge.source) {
121 external_counterparties.insert(edge.source);
122 }
123 }
124
125 member_activities.push(member_activity);
126 }
127
128 let avg_transaction_size = if transaction_count > 0 {
130 total_volume / transaction_count as f64
131 } else {
132 0.0
133 };
134
135 let total_flow = internal_volume + external_volume;
136 let internal_flow_ratio = if total_flow > 0.0 {
137 internal_volume / total_flow
138 } else {
139 0.0
140 };
141 let external_flow_ratio = if total_flow > 0.0 {
142 external_volume / total_flow
143 } else {
144 0.0
145 };
146
147 let mean_activity = if !member_activities.is_empty() {
149 member_activities.iter().sum::<f64>() / member_activities.len() as f64
150 } else {
151 0.0
152 };
153
154 let activity_variance = if member_activities.len() > 1 {
155 let variance: f64 = member_activities
156 .iter()
157 .map(|&a| (a - mean_activity).powi(2))
158 .sum::<f64>()
159 / member_activities.len() as f64;
160 variance.sqrt() / (mean_activity + 1.0) } else {
162 0.0
163 };
164
165 let anomalous_members = group
167 .members
168 .iter()
169 .filter(|&&n| {
170 graph
171 .get_node(n)
172 .map(|node| node.is_anomaly)
173 .unwrap_or(false)
174 })
175 .count();
176
177 let anomalous_edges = group
178 .members
179 .iter()
180 .flat_map(|&n| {
181 graph
182 .outgoing_edges(n)
183 .into_iter()
184 .chain(graph.incoming_edges(n))
185 })
186 .filter(|e| e.is_anomaly)
187 .count();
188
189 let total_edges = group
190 .members
191 .iter()
192 .map(|&n| graph.degree(n))
193 .sum::<usize>();
194
195 let member_risk = anomalous_members as f64 / group.members.len().max(1) as f64;
196 let edge_risk = anomalous_edges as f64 / total_edges.max(1) as f64;
197 let combined_risk_score = (member_risk * 0.6 + edge_risk * 0.4).min(1.0);
198
199 AggregatedFeatures {
200 total_volume,
201 avg_transaction_size,
202 combined_risk_score,
203 internal_flow_ratio,
204 external_flow_ratio,
205 external_counterparty_count: external_counterparties.len(),
206 activity_variance,
207 member_count: group.members.len(),
208 }
209}
210
211pub fn aggregate_values(values: &[f64], agg_type: AggregationType) -> f64 {
213 if values.is_empty() {
214 return 0.0;
215 }
216
217 match agg_type {
218 AggregationType::Sum => values.iter().sum(),
219 AggregationType::Mean => values.iter().sum::<f64>() / values.len() as f64,
220 AggregationType::WeightedMean => {
221 values.iter().sum::<f64>() / values.len() as f64
223 }
224 AggregationType::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
225 AggregationType::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
226 AggregationType::Median => {
227 let mut sorted = values.to_vec();
228 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
229 let mid = sorted.len() / 2;
230 if sorted.len() % 2 == 0 {
231 (sorted[mid - 1] + sorted[mid]) / 2.0
232 } else {
233 sorted[mid]
234 }
235 }
236 }
237}
238
239pub fn aggregate_weighted(values: &[f64], weights: &[f64], agg_type: AggregationType) -> f64 {
241 if values.is_empty() || weights.is_empty() || values.len() != weights.len() {
242 return aggregate_values(values, agg_type);
243 }
244
245 match agg_type {
246 AggregationType::WeightedMean => {
247 let total_weight: f64 = weights.iter().sum();
248 if total_weight > 0.0 {
249 let weighted_sum: f64 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
250 weighted_sum / total_weight
251 } else {
252 aggregate_values(values, AggregationType::Mean)
253 }
254 }
255 _ => aggregate_values(values, agg_type),
256 }
257}
258
259pub fn aggregate_all_groups(
261 groups: &[EntityGroup],
262 graph: &Graph,
263 agg_type: AggregationType,
264) -> HashMap<u64, AggregatedFeatures> {
265 let mut result = HashMap::new();
266
267 for group in groups {
268 let features = aggregate_features(group, graph, agg_type);
269 result.insert(group.group_id, features);
270 }
271
272 result
273}
274
275#[derive(Debug, Clone)]
277pub struct MultiFeatureAggregation {
278 pub features: Vec<f64>,
280 pub names: Vec<String>,
282}
283
284impl MultiFeatureAggregation {
285 pub fn new(features: Vec<f64>, names: Vec<String>) -> Self {
287 Self { features, names }
288 }
289
290 pub fn to_features(&self) -> &[f64] {
292 &self.features
293 }
294}
295
296pub fn aggregate_node_features(
298 node_ids: &[NodeId],
299 graph: &Graph,
300 agg_type: AggregationType,
301) -> MultiFeatureAggregation {
302 if node_ids.is_empty() {
303 return MultiFeatureAggregation::new(Vec::new(), Vec::new());
304 }
305
306 let node_features: Vec<Vec<f64>> = node_ids
308 .iter()
309 .filter_map(|&id| graph.get_node(id))
310 .map(|n| n.features.clone())
311 .filter(|f| !f.is_empty())
312 .collect();
313
314 if node_features.is_empty() {
315 return MultiFeatureAggregation::new(Vec::new(), Vec::new());
316 }
317
318 let dim = node_features[0].len();
320
321 let aggregated: Vec<f64> = (0..dim)
323 .map(|d| {
324 let values: Vec<f64> = node_features
325 .iter()
326 .map(|f| f.get(d).copied().unwrap_or(0.0))
327 .collect();
328 aggregate_values(&values, agg_type)
329 })
330 .collect();
331
332 let names: Vec<String> = (0..dim).map(|d| format!("feature_{}", d)).collect();
333
334 MultiFeatureAggregation::new(aggregated, names)
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
341 use crate::EdgeType;
342
343 fn create_test_graph() -> Graph {
344 let mut graph = Graph::new("test", GraphType::Transaction);
345
346 let n1 = graph.add_node(
347 GraphNode::new(0, NodeType::Account, "A".to_string(), "A".to_string())
348 .with_features(vec![1.0, 2.0, 3.0]),
349 );
350 let n2 = graph.add_node(
351 GraphNode::new(0, NodeType::Account, "B".to_string(), "B".to_string())
352 .with_features(vec![4.0, 5.0, 6.0]),
353 );
354 let n3 = graph.add_node(
355 GraphNode::new(0, NodeType::Account, "C".to_string(), "C".to_string())
356 .with_features(vec![7.0, 8.0, 9.0]),
357 );
358
359 graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction).with_weight(100.0));
360 graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction).with_weight(200.0));
361 graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction).with_weight(150.0));
362
363 graph
364 }
365
366 #[test]
367 fn test_aggregate_values_sum() {
368 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
369 assert_eq!(aggregate_values(&values, AggregationType::Sum), 15.0);
370 }
371
372 #[test]
373 fn test_aggregate_values_mean() {
374 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
375 assert_eq!(aggregate_values(&values, AggregationType::Mean), 3.0);
376 }
377
378 #[test]
379 fn test_aggregate_values_max() {
380 let values = vec![1.0, 5.0, 3.0, 2.0, 4.0];
381 assert_eq!(aggregate_values(&values, AggregationType::Max), 5.0);
382 }
383
384 #[test]
385 fn test_aggregate_values_min() {
386 let values = vec![1.0, 5.0, 3.0, 2.0, 4.0];
387 assert_eq!(aggregate_values(&values, AggregationType::Min), 1.0);
388 }
389
390 #[test]
391 fn test_aggregate_values_median_odd() {
392 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
393 assert_eq!(aggregate_values(&values, AggregationType::Median), 3.0);
394 }
395
396 #[test]
397 fn test_aggregate_values_median_even() {
398 let values = vec![1.0, 2.0, 3.0, 4.0];
399 assert_eq!(aggregate_values(&values, AggregationType::Median), 2.5);
400 }
401
402 #[test]
403 fn test_aggregate_weighted() {
404 let values = vec![10.0, 20.0, 30.0];
405 let weights = vec![1.0, 2.0, 1.0];
406
407 let result = aggregate_weighted(&values, &weights, AggregationType::WeightedMean);
408 assert_eq!(result, 20.0);
410 }
411
412 #[test]
413 fn test_aggregate_features() {
414 let graph = create_test_graph();
415 let group = EntityGroup::new(
416 1,
417 vec![1, 2, 3],
418 super::super::entity_groups::GroupType::TransactionCluster,
419 );
420
421 let features = aggregate_features(&group, &graph, AggregationType::Sum);
422
423 assert!(features.total_volume > 0.0);
424 assert_eq!(features.member_count, 3);
425 }
426
427 #[test]
428 fn test_aggregate_node_features() {
429 let graph = create_test_graph();
430 let result = aggregate_node_features(&[1, 2, 3], &graph, AggregationType::Mean);
431
432 assert_eq!(result.features.len(), 3);
433 assert_eq!(result.features[0], 4.0);
435 assert_eq!(result.features[1], 5.0);
436 assert_eq!(result.features[2], 6.0);
437 }
438
439 #[test]
440 fn test_aggregated_features_to_vector() {
441 let features = AggregatedFeatures {
442 total_volume: 1000.0,
443 avg_transaction_size: 100.0,
444 combined_risk_score: 0.5,
445 internal_flow_ratio: 0.6,
446 external_flow_ratio: 0.4,
447 external_counterparty_count: 5,
448 activity_variance: 0.3,
449 member_count: 3,
450 };
451
452 let vec = features.to_features();
453 assert_eq!(vec.len(), AggregatedFeatures::feature_count());
454 }
455}