diff_priv/data_manipulation/
aggregation.rs

1use crate::data_manipulation::anonymizable::{
2    IntervalType, NominalType, OrdinalType, QuasiIdentifierType, QuasiIdentifierTypes,
3};
4use itertools::Itertools;
5
6/// TODO: check adding MEDIAN
7pub enum AggregateType {
8    Mean(Vec<QuasiIdentifierTypes>),
9    Mode(Vec<QuasiIdentifierTypes>),
10}
11
12impl AggregateType {
13    /// the vector should contain only the same type of QI type
14    pub fn aggregate(self) -> QuasiIdentifierTypes {
15        match self {
16            AggregateType::Mean(mut list) => match list.pop().unwrap() {
17                QuasiIdentifierTypes::Interval(interval) => {
18                    Self::aggregate_interval(interval, list)
19                }
20                _ => panic!("Wrong QI type found during aggregation for Mean"),
21            },
22            AggregateType::Mode(mut list) => {
23                // we need to pop the first element one to
24                // get the shared attributes between different elements in the QI list
25                match list.pop().unwrap() {
26                    QuasiIdentifierTypes::Ordinal(ordinal) => {
27                        Self::aggregate_ordinal(ordinal, list)
28                    }
29                    QuasiIdentifierTypes::Nominal(nominal) => {
30                        Self::aggregate_nominal(nominal, list)
31                    }
32                    _ => panic!("Wrong QI type for calculating mode"),
33                }
34            }
35        }
36    }
37
38    /// aggregate interval QI type
39    fn aggregate_interval(
40        interval: IntervalType,
41        list: Vec<QuasiIdentifierTypes>,
42    ) -> QuasiIdentifierTypes {
43        // we need to increment is by one as the first interval has already been popped
44        let size = list.len() + 1;
45        match interval {
46            (
47                QuasiIdentifierType::Float(value),
48                QuasiIdentifierType::Float(min),
49                QuasiIdentifierType::Float(max),
50                weight,
51            ) => {
52                let sum: f64 = list
53                    .into_iter()
54                    .map(|x| match x.extract_value() {
55                        QuasiIdentifierType::Float(temp) => temp,
56                        _ => panic!("Wrong type found for Mean aggregation"),
57                    })
58                    .sum();
59
60                QuasiIdentifierTypes::Interval((
61                    QuasiIdentifierType::Float((value + sum) / size as f64),
62                    QuasiIdentifierType::Float(min),
63                    QuasiIdentifierType::Float(max),
64                    weight,
65                ))
66            }
67            (
68                QuasiIdentifierType::Integer(value),
69                QuasiIdentifierType::Integer(min),
70                QuasiIdentifierType::Integer(max),
71                weight,
72            ) => {
73                let sum: i32 = list
74                    .into_iter()
75                    .map(|x| match x.extract_value() {
76                        QuasiIdentifierType::Integer(temp) => temp,
77                        _ => panic!("Wrong type found for Mean aggregation"),
78                    })
79                    .sum();
80
81                QuasiIdentifierTypes::Interval((
82                    QuasiIdentifierType::Integer((value + sum) / size as i32),
83                    QuasiIdentifierType::Integer(min),
84                    QuasiIdentifierType::Integer(max),
85                    weight,
86                ))
87            }
88            _ => panic!("Wrong interval type set found during aggregation"),
89        }
90    }
91
92    /// aggregate ordinal QI type
93    fn aggregate_ordinal(
94        ordinal: OrdinalType,
95        list: Vec<QuasiIdentifierTypes>,
96    ) -> QuasiIdentifierTypes {
97        let (rank, max_rank, weight) = ordinal;
98        let mut mode_list = Vec::new();
99        list.into_iter().for_each(|x| match x {
100            QuasiIdentifierTypes::Ordinal((temp, _, _)) => mode_list.push(temp),
101            _ => panic!("Wrong QI type"),
102        });
103
104        mode_list.push(rank);
105
106        let mode = Self::get_mode(mode_list);
107
108        QuasiIdentifierTypes::Ordinal((mode, max_rank, weight))
109    }
110
111    /// aggregate nominal QI type
112    fn aggregate_nominal(
113        nominal: NominalType,
114        list: Vec<QuasiIdentifierTypes>,
115    ) -> QuasiIdentifierTypes {
116        let (value, max_value, weight) = nominal;
117        let mut mode_list = Vec::new();
118        list.into_iter().for_each(|x| match x {
119            QuasiIdentifierTypes::Nominal((temp, _, _)) => mode_list.push(temp),
120            _ => panic!("Wrong QI type"),
121        });
122
123        mode_list.push(value);
124
125        let mode = Self::get_mode(mode_list);
126
127        QuasiIdentifierTypes::Nominal((mode, max_value, weight))
128    }
129
130    /// retrieve mode from list of i32
131    fn get_mode(mode_list: Vec<i32>) -> i32 {
132        let mut mode_grouped: Vec<(i32, Vec<i32>)> = Vec::new();
133        for (key, group) in &mode_list.into_iter().sorted().group_by(|&x| x) {
134            mode_grouped.push((key, group.collect()))
135        }
136
137        let (mode, _) = mode_grouped
138            .into_iter()
139            .map(|(key, group)| (key, group.len()))
140            .max_by_key(|(_, group)| *group)
141            .unwrap();
142
143        mode
144    }
145}
146
147/// check if the value that was randomly generated is contained within its domain
148pub fn truncate_to_domain<T: PartialOrd>(value: T, min: T, max: T) -> T {
149    match value {
150        x if x <= min => min,
151        x if x >= max => max,
152        _ => value,
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn truncating_to_domain_integer() {
162        let min = 0;
163        let value = 5;
164        let max = 10;
165
166        assert_eq!(truncate_to_domain(value, min, max), 5)
167    }
168
169    #[test]
170    fn truncate_to_domain_integer_max() {
171        let min = 0;
172        let value = 11;
173        let max = 10;
174
175        assert_eq!(truncate_to_domain(value, min, max), 10)
176    }
177
178    #[test]
179    fn truncate_to_domain_float() {
180        let min = 0.0;
181        let value = 5.0;
182        let max = 10.0;
183
184        assert!(truncate_to_domain(value, min, max) - 5.0 <= f64::EPSILON)
185    }
186
187    #[test]
188    fn truncate_to_domain_float_max() {
189        let min = 0.0;
190        let value = 10.0;
191        let max = 5.0;
192
193        assert!(truncate_to_domain(value, min, max) - 5.0 <= f64::EPSILON)
194    }
195}