Skip to main content

datasynth_graph/ml/
aggregation.rs

1//! Feature aggregation for entity groups.
2//!
3//! This module provides aggregation functions for computing
4//! group-level features from individual node features.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::models::{Graph, NodeId};
11
12use super::entity_groups::EntityGroup;
13
14/// Aggregation method for combining features.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum AggregationType {
17    /// Sum of all values.
18    Sum,
19    /// Arithmetic mean.
20    Mean,
21    /// Weighted mean (by transaction volume).
22    WeightedMean,
23    /// Maximum value.
24    Max,
25    /// Minimum value.
26    Min,
27    /// Median value.
28    Median,
29}
30
31/// Aggregated features for an entity group.
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct AggregatedFeatures {
34    /// Total transaction volume within the group.
35    pub total_volume: f64,
36    /// Average transaction size.
37    pub avg_transaction_size: f64,
38    /// Combined risk score for the group.
39    pub combined_risk_score: f64,
40    /// Ratio of internal flows to total.
41    pub internal_flow_ratio: f64,
42    /// Ratio of external flows to total.
43    pub external_flow_ratio: f64,
44    /// Number of unique external counterparties.
45    pub external_counterparty_count: usize,
46    /// Variance in member activity.
47    pub activity_variance: f64,
48    /// Number of members.
49    pub member_count: usize,
50}
51
52impl AggregatedFeatures {
53    /// Converts to a feature vector.
54    pub fn to_features(&self) -> Vec<f64> {
55        vec![
56            (self.total_volume + 1.0).ln(),
57            (self.avg_transaction_size + 1.0).ln(),
58            self.combined_risk_score,
59            self.internal_flow_ratio,
60            self.external_flow_ratio,
61            self.external_counterparty_count as f64,
62            self.activity_variance,
63            self.member_count as f64,
64        ]
65    }
66
67    /// Returns feature count.
68    pub fn feature_count() -> usize {
69        8
70    }
71
72    /// Returns feature names.
73    pub fn feature_names() -> Vec<&'static str> {
74        vec![
75            "total_volume_log",
76            "avg_transaction_size_log",
77            "combined_risk_score",
78            "internal_flow_ratio",
79            "external_flow_ratio",
80            "external_counterparty_count",
81            "activity_variance",
82            "member_count",
83        ]
84    }
85}
86
87/// Aggregates features for a group of nodes.
88pub fn aggregate_features(
89    group: &EntityGroup,
90    graph: &Graph,
91    _agg_type: AggregationType,
92) -> AggregatedFeatures {
93    let member_set: std::collections::HashSet<NodeId> = group.members.iter().copied().collect();
94
95    let mut total_volume = 0.0;
96    let mut internal_volume = 0.0;
97    let mut external_volume = 0.0;
98    let mut transaction_count = 0;
99    let mut external_counterparties = std::collections::HashSet::new();
100    let mut member_activities = Vec::new();
101
102    // Calculate volumes and counterparties
103    for &member in &group.members {
104        let mut member_activity = 0.0;
105
106        for edge in graph.outgoing_edges(member) {
107            total_volume += edge.weight;
108            member_activity += edge.weight;
109            transaction_count += 1;
110
111            if member_set.contains(&edge.target) {
112                internal_volume += edge.weight;
113            } else {
114                external_volume += edge.weight;
115                external_counterparties.insert(edge.target);
116            }
117        }
118
119        for edge in graph.incoming_edges(member) {
120            if !member_set.contains(&edge.source) {
121                external_counterparties.insert(edge.source);
122            }
123        }
124
125        member_activities.push(member_activity);
126    }
127
128    // Calculate averages and ratios
129    let avg_transaction_size = if transaction_count > 0 {
130        total_volume / transaction_count as f64
131    } else {
132        0.0
133    };
134
135    let total_flow = internal_volume + external_volume;
136    let internal_flow_ratio = if total_flow > 0.0 {
137        internal_volume / total_flow
138    } else {
139        0.0
140    };
141    let external_flow_ratio = if total_flow > 0.0 {
142        external_volume / total_flow
143    } else {
144        0.0
145    };
146
147    // Calculate activity variance
148    let mean_activity = if !member_activities.is_empty() {
149        member_activities.iter().sum::<f64>() / member_activities.len() as f64
150    } else {
151        0.0
152    };
153
154    let activity_variance = if member_activities.len() > 1 {
155        let variance: f64 = member_activities
156            .iter()
157            .map(|&a| (a - mean_activity).powi(2))
158            .sum::<f64>()
159            / member_activities.len() as f64;
160        variance.sqrt() / (mean_activity + 1.0) // Coefficient of variation
161    } else {
162        0.0
163    };
164
165    // Calculate combined risk score
166    let anomalous_members = group
167        .members
168        .iter()
169        .filter(|&&n| {
170            graph
171                .get_node(n)
172                .map(|node| node.is_anomaly)
173                .unwrap_or(false)
174        })
175        .count();
176
177    let anomalous_edges = group
178        .members
179        .iter()
180        .flat_map(|&n| {
181            graph
182                .outgoing_edges(n)
183                .into_iter()
184                .chain(graph.incoming_edges(n))
185        })
186        .filter(|e| e.is_anomaly)
187        .count();
188
189    let total_edges = group
190        .members
191        .iter()
192        .map(|&n| graph.degree(n))
193        .sum::<usize>();
194
195    let member_risk = anomalous_members as f64 / group.members.len().max(1) as f64;
196    let edge_risk = anomalous_edges as f64 / total_edges.max(1) as f64;
197    let combined_risk_score = (member_risk * 0.6 + edge_risk * 0.4).min(1.0);
198
199    AggregatedFeatures {
200        total_volume,
201        avg_transaction_size,
202        combined_risk_score,
203        internal_flow_ratio,
204        external_flow_ratio,
205        external_counterparty_count: external_counterparties.len(),
206        activity_variance,
207        member_count: group.members.len(),
208    }
209}
210
211/// Aggregates a specific feature across multiple values.
212pub fn aggregate_values(values: &[f64], agg_type: AggregationType) -> f64 {
213    if values.is_empty() {
214        return 0.0;
215    }
216
217    match agg_type {
218        AggregationType::Sum => values.iter().sum(),
219        AggregationType::Mean => values.iter().sum::<f64>() / values.len() as f64,
220        AggregationType::WeightedMean => {
221            // Without weights, defaults to mean
222            values.iter().sum::<f64>() / values.len() as f64
223        }
224        AggregationType::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
225        AggregationType::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
226        AggregationType::Median => {
227            let mut sorted = values.to_vec();
228            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
229            let mid = sorted.len() / 2;
230            if sorted.len() % 2 == 0 {
231                (sorted[mid - 1] + sorted[mid]) / 2.0
232            } else {
233                sorted[mid]
234            }
235        }
236    }
237}
238
239/// Aggregates weighted values.
240pub fn aggregate_weighted(values: &[f64], weights: &[f64], agg_type: AggregationType) -> f64 {
241    if values.is_empty() || weights.is_empty() || values.len() != weights.len() {
242        return aggregate_values(values, agg_type);
243    }
244
245    match agg_type {
246        AggregationType::WeightedMean => {
247            let total_weight: f64 = weights.iter().sum();
248            if total_weight > 0.0 {
249                let weighted_sum: f64 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
250                weighted_sum / total_weight
251            } else {
252                aggregate_values(values, AggregationType::Mean)
253            }
254        }
255        _ => aggregate_values(values, agg_type),
256    }
257}
258
259/// Aggregates features for all groups.
260pub fn aggregate_all_groups(
261    groups: &[EntityGroup],
262    graph: &Graph,
263    agg_type: AggregationType,
264) -> HashMap<u64, AggregatedFeatures> {
265    let mut result = HashMap::new();
266
267    for group in groups {
268        let features = aggregate_features(group, graph, agg_type);
269        result.insert(group.group_id, features);
270    }
271
272    result
273}
274
275/// Multi-feature aggregation result.
276#[derive(Debug, Clone)]
277pub struct MultiFeatureAggregation {
278    /// Aggregated features per dimension.
279    pub features: Vec<f64>,
280    /// Feature names.
281    pub names: Vec<String>,
282}
283
284impl MultiFeatureAggregation {
285    /// Creates a new multi-feature aggregation.
286    pub fn new(features: Vec<f64>, names: Vec<String>) -> Self {
287        Self { features, names }
288    }
289
290    /// Returns the feature vector.
291    pub fn to_features(&self) -> &[f64] {
292        &self.features
293    }
294}
295
296/// Aggregates multiple node feature vectors into a single vector.
297pub fn aggregate_node_features(
298    node_ids: &[NodeId],
299    graph: &Graph,
300    agg_type: AggregationType,
301) -> MultiFeatureAggregation {
302    if node_ids.is_empty() {
303        return MultiFeatureAggregation::new(Vec::new(), Vec::new());
304    }
305
306    // Collect features from all nodes
307    let node_features: Vec<Vec<f64>> = node_ids
308        .iter()
309        .filter_map(|&id| graph.get_node(id))
310        .map(|n| n.features.clone())
311        .filter(|f| !f.is_empty())
312        .collect();
313
314    if node_features.is_empty() {
315        return MultiFeatureAggregation::new(Vec::new(), Vec::new());
316    }
317
318    // Find feature dimension
319    let dim = node_features[0].len();
320
321    // Aggregate each dimension
322    let aggregated: Vec<f64> = (0..dim)
323        .map(|d| {
324            let values: Vec<f64> = node_features
325                .iter()
326                .map(|f| f.get(d).copied().unwrap_or(0.0))
327                .collect();
328            aggregate_values(&values, agg_type)
329        })
330        .collect();
331
332    let names: Vec<String> = (0..dim).map(|d| format!("feature_{}", d)).collect();
333
334    MultiFeatureAggregation::new(aggregated, names)
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
341    use crate::EdgeType;
342
343    fn create_test_graph() -> Graph {
344        let mut graph = Graph::new("test", GraphType::Transaction);
345
346        let n1 = graph.add_node(
347            GraphNode::new(0, NodeType::Account, "A".to_string(), "A".to_string())
348                .with_features(vec![1.0, 2.0, 3.0]),
349        );
350        let n2 = graph.add_node(
351            GraphNode::new(0, NodeType::Account, "B".to_string(), "B".to_string())
352                .with_features(vec![4.0, 5.0, 6.0]),
353        );
354        let n3 = graph.add_node(
355            GraphNode::new(0, NodeType::Account, "C".to_string(), "C".to_string())
356                .with_features(vec![7.0, 8.0, 9.0]),
357        );
358
359        graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction).with_weight(100.0));
360        graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction).with_weight(200.0));
361        graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction).with_weight(150.0));
362
363        graph
364    }
365
366    #[test]
367    fn test_aggregate_values_sum() {
368        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
369        assert_eq!(aggregate_values(&values, AggregationType::Sum), 15.0);
370    }
371
372    #[test]
373    fn test_aggregate_values_mean() {
374        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
375        assert_eq!(aggregate_values(&values, AggregationType::Mean), 3.0);
376    }
377
378    #[test]
379    fn test_aggregate_values_max() {
380        let values = vec![1.0, 5.0, 3.0, 2.0, 4.0];
381        assert_eq!(aggregate_values(&values, AggregationType::Max), 5.0);
382    }
383
384    #[test]
385    fn test_aggregate_values_min() {
386        let values = vec![1.0, 5.0, 3.0, 2.0, 4.0];
387        assert_eq!(aggregate_values(&values, AggregationType::Min), 1.0);
388    }
389
390    #[test]
391    fn test_aggregate_values_median_odd() {
392        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
393        assert_eq!(aggregate_values(&values, AggregationType::Median), 3.0);
394    }
395
396    #[test]
397    fn test_aggregate_values_median_even() {
398        let values = vec![1.0, 2.0, 3.0, 4.0];
399        assert_eq!(aggregate_values(&values, AggregationType::Median), 2.5);
400    }
401
402    #[test]
403    fn test_aggregate_weighted() {
404        let values = vec![10.0, 20.0, 30.0];
405        let weights = vec![1.0, 2.0, 1.0];
406
407        let result = aggregate_weighted(&values, &weights, AggregationType::WeightedMean);
408        // (10*1 + 20*2 + 30*1) / 4 = 80/4 = 20
409        assert_eq!(result, 20.0);
410    }
411
412    #[test]
413    fn test_aggregate_features() {
414        let graph = create_test_graph();
415        let group = EntityGroup::new(
416            1,
417            vec![1, 2, 3],
418            super::super::entity_groups::GroupType::TransactionCluster,
419        );
420
421        let features = aggregate_features(&group, &graph, AggregationType::Sum);
422
423        assert!(features.total_volume > 0.0);
424        assert_eq!(features.member_count, 3);
425    }
426
427    #[test]
428    fn test_aggregate_node_features() {
429        let graph = create_test_graph();
430        let result = aggregate_node_features(&[1, 2, 3], &graph, AggregationType::Mean);
431
432        assert_eq!(result.features.len(), 3);
433        // Mean of [1,4,7], [2,5,8], [3,6,9] = [4, 5, 6]
434        assert_eq!(result.features[0], 4.0);
435        assert_eq!(result.features[1], 5.0);
436        assert_eq!(result.features[2], 6.0);
437    }
438
439    #[test]
440    fn test_aggregated_features_to_vector() {
441        let features = AggregatedFeatures {
442            total_volume: 1000.0,
443            avg_transaction_size: 100.0,
444            combined_risk_score: 0.5,
445            internal_flow_ratio: 0.6,
446            external_flow_ratio: 0.4,
447            external_counterparty_count: 5,
448            activity_variance: 0.3,
449            member_count: 3,
450        };
451
452        let vec = features.to_features();
453        assert_eq!(vec.len(), AggregatedFeatures::feature_count());
454    }
455}