lace_data/
feature.rs

1use crate::{Category, Datum};
2use crate::{Container, SparseContainer};
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
6#[serde(rename_all = "snake_case")]
7pub enum SummaryStatistics {
8    #[serde(rename = "binary")]
9    Binary {
10        n: usize,
11        pos: usize,
12    },
13    #[serde(rename = "continuous")]
14    Continuous {
15        min: f64,
16        max: f64,
17        mean: f64,
18        median: f64,
19        variance: f64,
20    },
21    #[serde(rename = "categorical")]
22    Categorical {
23        min: u8,
24        max: u8,
25        mode: Vec<u8>,
26    },
27    #[serde(rename = "count")]
28    Count {
29        min: u32,
30        max: u32,
31        median: f64,
32        mean: f64,
33        mode: Vec<u32>,
34    },
35    None,
36}
37
38// NOTE: If you change the order of the variants, serialization into binary
39// formats will not work the same
40/// Used when pulling data from features for saving
41#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
42pub enum FeatureData {
43    /// Univariate continuous data
44    Continuous(SparseContainer<f64>),
45    /// Categorical data
46    Categorical(SparseContainer<u8>),
47    /// Count data
48    Count(SparseContainer<u32>),
49    /// Binary data
50    Binary(SparseContainer<bool>),
51}
52
53impl FeatureData {
54    pub fn len(&self) -> usize {
55        match self {
56            Self::Binary(xs) => xs.len(),
57            Self::Continuous(xs) => xs.len(),
58            Self::Categorical(xs) => xs.len(),
59            Self::Count(xs) => xs.len(),
60        }
61    }
62
63    pub fn is_empty(&self) -> bool {
64        self.len() == 0
65    }
66
67    /// Get the datum at [row_ix, col_ix] as a `Datum`
68    pub fn is_present(&self, ix: usize) -> bool {
69        match self {
70            Self::Binary(xs) => xs.is_present(ix),
71            Self::Continuous(xs) => xs.is_present(ix),
72            Self::Categorical(xs) => xs.is_present(ix),
73            Self::Count(xs) => xs.is_present(ix),
74        }
75    }
76
77    /// Get the datum at [row_ix, col_ix] as a `Datum`
78    pub fn get(&self, ix: usize) -> Datum {
79        // TODO: SparseContainer index get (xs[i]) should return an option
80        match self {
81            FeatureData::Binary(xs) => {
82                xs.get(ix).map(Datum::Binary).unwrap_or(Datum::Missing)
83            }
84            FeatureData::Continuous(xs) => {
85                xs.get(ix).map(Datum::Continuous).unwrap_or(Datum::Missing)
86            }
87            FeatureData::Categorical(xs) => xs
88                .get(ix)
89                .map(|x| Datum::Categorical(Category::U8(x)))
90                .unwrap_or(Datum::Missing),
91            FeatureData::Count(xs) => {
92                xs.get(ix).map(Datum::Count).unwrap_or(Datum::Missing)
93            }
94        }
95    }
96
97    /// Get the summary statistic for a column
98    pub fn summarize(&self) -> SummaryStatistics {
99        match self {
100            FeatureData::Binary(ref container) => SummaryStatistics::Binary {
101                n: container.n_present(),
102                pos: container
103                    .get_slices()
104                    .iter()
105                    .map(|(_, xs)| xs.len())
106                    .sum::<usize>(),
107            },
108            FeatureData::Continuous(ref container) => {
109                summarize_continuous(container)
110            }
111            FeatureData::Categorical(ref container) => {
112                summarize_categorical(container)
113            }
114            FeatureData::Count(ref container) => summarize_count(container),
115        }
116    }
117}
118
119pub fn summarize_continuous(
120    container: &SparseContainer<f64>,
121) -> SummaryStatistics {
122    use lace_utils::{mean, var};
123    let mut xs: Vec<f64> = container.present_cloned();
124
125    xs.sort_by(|a, b| a.partial_cmp(b).unwrap());
126
127    let n = xs.len();
128    SummaryStatistics::Continuous {
129        min: xs[0],
130        max: xs[n - 1],
131        mean: mean(&xs),
132        variance: var(&xs),
133        median: if n % 2 == 0 {
134            (xs[n / 2] + xs[n / 2 - 1]) / 2.0
135        } else {
136            xs[n / 2]
137        },
138    }
139}
140
141pub fn summarize_categorical(
142    container: &SparseContainer<u8>,
143) -> SummaryStatistics {
144    use lace_utils::{bincount, minmax};
145    let xs: Vec<u8> = container.present_cloned();
146
147    let (min, max) = minmax(&xs);
148    let counts = bincount(&xs, (max + 1) as usize);
149    let max_ct = counts
150        .iter()
151        .fold(0_usize, |acc, &ct| if ct > acc { ct } else { acc });
152    let mode = counts
153        .iter()
154        .enumerate()
155        .filter(|(_, &ct)| ct == max_ct)
156        .map(|(ix, _)| ix as u8)
157        .collect();
158
159    SummaryStatistics::Categorical { min, max, mode }
160}
161
162pub fn summarize_count(container: &SparseContainer<u32>) -> SummaryStatistics {
163    use lace_utils::{bincount, minmax};
164    let xs: Vec<usize> = {
165        let mut xs: Vec<usize> =
166            container.present_iter().map(|&x| x as usize).collect();
167        xs.sort_unstable();
168        xs
169    };
170
171    let n = xs.len();
172    let nf = n as f64;
173
174    let (min, max) = {
175        let (min, max) = minmax(&xs);
176        (min as u32, max as u32)
177    };
178
179    let counts = bincount(&xs, (max + 1) as usize);
180
181    let max_ct = counts
182        .iter()
183        .fold(0_usize, |acc, &ct| if ct > acc { ct } else { acc });
184
185    let mode = counts
186        .iter()
187        .enumerate()
188        .filter(|(_, &ct)| ct == max_ct)
189        .map(|(ix, _)| ix as u32)
190        .collect();
191
192    let mean = xs.iter().sum::<usize>() as f64 / nf;
193
194    let median = if n % 2 == 0 {
195        (xs[n / 2] + xs[n / 2 - 1]) as f64 / 2.0
196    } else {
197        xs[n / 2] as f64
198    };
199
200    SummaryStatistics::Count {
201        min,
202        max,
203        median,
204        mean,
205        mode,
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use approx::*;
213
214    fn get_continuous() -> FeatureData {
215        let dc1: SparseContainer<f64> = SparseContainer::from(vec![
216            (4.0, true),
217            (3.0, false),
218            (2.0, true),
219            (1.0, true),
220            (0.0, true),
221        ]);
222
223        FeatureData::Continuous(dc1)
224    }
225
226    fn get_categorical() -> FeatureData {
227        let dc2: SparseContainer<u8> = SparseContainer::from(vec![
228            (5, true),
229            (3, true),
230            (2, true),
231            (1, false),
232            (4, true),
233        ]);
234
235        FeatureData::Categorical(dc2)
236    }
237
238    #[test]
239    fn gets_present_continuous_data() {
240        let ds = get_continuous();
241        assert_eq!(ds.get(0), Datum::Continuous(4.0));
242        assert_eq!(ds.get(2), Datum::Continuous(2.0));
243    }
244
245    #[test]
246    fn gets_present_categorical_data() {
247        let ds = get_categorical();
248        assert_eq!(ds.get(0), Datum::Categorical(Category::U8(5)));
249        assert_eq!(ds.get(4), Datum::Categorical(Category::U8(4)));
250    }
251
252    #[test]
253    fn gets_missing_continuous_data() {
254        let ds = get_continuous();
255        assert_eq!(ds.get(1), Datum::Missing);
256    }
257
258    #[test]
259    fn gets_missing_categorical_data() {
260        let ds = get_categorical();
261        assert_eq!(ds.get(3), Datum::Missing);
262    }
263
264    #[test]
265    fn summarize_categorical_works_with_fixture() {
266        let summary = get_categorical().summarize();
267        match summary {
268            SummaryStatistics::Categorical { min, max, mode } => {
269                assert_eq!(min, 2);
270                assert_eq!(max, 5);
271                assert_eq!(mode, vec![2, 3, 4, 5]);
272            }
273            _ => panic!("Unexpected summary type"),
274        }
275    }
276
277    #[test]
278    fn summarize_categorical_works_one_mode() {
279        let container: SparseContainer<u8> = SparseContainer::from(vec![
280            (5, true),
281            (3, true),
282            (2, true),
283            (2, true),
284            (1, true),
285            (4, true),
286        ]);
287
288        let summary = summarize_categorical(&container);
289        match summary {
290            SummaryStatistics::Categorical { min, max, mode } => {
291                assert_eq!(min, 1);
292                assert_eq!(max, 5);
293                assert_eq!(mode, vec![2]);
294            }
295            _ => panic!("Unexpected summary type"),
296        }
297    }
298
299    #[test]
300    fn summarize_categorical_works_two_modes() {
301        let container: SparseContainer<u8> = SparseContainer::from(vec![
302            (5, true),
303            (3, true),
304            (2, true),
305            (2, true),
306            (3, true),
307            (4, true),
308        ]);
309
310        let summary = summarize_categorical(&container);
311        match summary {
312            SummaryStatistics::Categorical { min, max, mode } => {
313                assert_eq!(min, 2);
314                assert_eq!(max, 5);
315                assert_eq!(mode, vec![2, 3]);
316            }
317            _ => panic!("Unexpected summary type"),
318        }
319    }
320
321    #[test]
322    fn summarize_continuous_works_with_fixture() {
323        let summary = get_continuous().summarize();
324        match summary {
325            SummaryStatistics::Continuous {
326                min,
327                max,
328                mean,
329                median,
330                variance,
331            } => {
332                assert_relative_eq!(min, 0.0, epsilon = 1E-10);
333                assert_relative_eq!(max, 4.0, epsilon = 1E-10);
334                assert_relative_eq!(mean, 1.75, epsilon = 1E-10);
335                assert_relative_eq!(median, 1.5, epsilon = 1E-10);
336                assert_relative_eq!(variance, 2.1875, epsilon = 1E-10);
337            }
338            _ => panic!("Unexpected summary type"),
339        }
340    }
341
342    #[test]
343    fn summarize_continuous_works_with_odd_number_data() {
344        let container: SparseContainer<f64> = SparseContainer::from(vec![
345            (4.0, true),
346            (3.0, true),
347            (2.0, true),
348            (1.0, true),
349            (0.0, true),
350        ]);
351
352        let summary = summarize_continuous(&container);
353        match summary {
354            SummaryStatistics::Continuous { median, .. } => {
355                assert_relative_eq!(median, 2.0, epsilon = 1E-10);
356            }
357            _ => panic!("Unexpected summary type"),
358        }
359    }
360}