1use crate::{Category, Datum};
2use crate::{Container, SparseContainer};
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
6#[serde(rename_all = "snake_case")]
7pub enum SummaryStatistics {
8 #[serde(rename = "binary")]
9 Binary {
10 n: usize,
11 pos: usize,
12 },
13 #[serde(rename = "continuous")]
14 Continuous {
15 min: f64,
16 max: f64,
17 mean: f64,
18 median: f64,
19 variance: f64,
20 },
21 #[serde(rename = "categorical")]
22 Categorical {
23 min: u8,
24 max: u8,
25 mode: Vec<u8>,
26 },
27 #[serde(rename = "count")]
28 Count {
29 min: u32,
30 max: u32,
31 median: f64,
32 mean: f64,
33 mode: Vec<u32>,
34 },
35 None,
36}
37
38#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
42pub enum FeatureData {
43 Continuous(SparseContainer<f64>),
45 Categorical(SparseContainer<u8>),
47 Count(SparseContainer<u32>),
49 Binary(SparseContainer<bool>),
51}
52
53impl FeatureData {
54 pub fn len(&self) -> usize {
55 match self {
56 Self::Binary(xs) => xs.len(),
57 Self::Continuous(xs) => xs.len(),
58 Self::Categorical(xs) => xs.len(),
59 Self::Count(xs) => xs.len(),
60 }
61 }
62
63 pub fn is_empty(&self) -> bool {
64 self.len() == 0
65 }
66
67 pub fn is_present(&self, ix: usize) -> bool {
69 match self {
70 Self::Binary(xs) => xs.is_present(ix),
71 Self::Continuous(xs) => xs.is_present(ix),
72 Self::Categorical(xs) => xs.is_present(ix),
73 Self::Count(xs) => xs.is_present(ix),
74 }
75 }
76
77 pub fn get(&self, ix: usize) -> Datum {
79 match self {
81 FeatureData::Binary(xs) => {
82 xs.get(ix).map(Datum::Binary).unwrap_or(Datum::Missing)
83 }
84 FeatureData::Continuous(xs) => {
85 xs.get(ix).map(Datum::Continuous).unwrap_or(Datum::Missing)
86 }
87 FeatureData::Categorical(xs) => xs
88 .get(ix)
89 .map(|x| Datum::Categorical(Category::U8(x)))
90 .unwrap_or(Datum::Missing),
91 FeatureData::Count(xs) => {
92 xs.get(ix).map(Datum::Count).unwrap_or(Datum::Missing)
93 }
94 }
95 }
96
97 pub fn summarize(&self) -> SummaryStatistics {
99 match self {
100 FeatureData::Binary(ref container) => SummaryStatistics::Binary {
101 n: container.n_present(),
102 pos: container
103 .get_slices()
104 .iter()
105 .map(|(_, xs)| xs.len())
106 .sum::<usize>(),
107 },
108 FeatureData::Continuous(ref container) => {
109 summarize_continuous(container)
110 }
111 FeatureData::Categorical(ref container) => {
112 summarize_categorical(container)
113 }
114 FeatureData::Count(ref container) => summarize_count(container),
115 }
116 }
117}
118
119pub fn summarize_continuous(
120 container: &SparseContainer<f64>,
121) -> SummaryStatistics {
122 use lace_utils::{mean, var};
123 let mut xs: Vec<f64> = container.present_cloned();
124
125 xs.sort_by(|a, b| a.partial_cmp(b).unwrap());
126
127 let n = xs.len();
128 SummaryStatistics::Continuous {
129 min: xs[0],
130 max: xs[n - 1],
131 mean: mean(&xs),
132 variance: var(&xs),
133 median: if n % 2 == 0 {
134 (xs[n / 2] + xs[n / 2 - 1]) / 2.0
135 } else {
136 xs[n / 2]
137 },
138 }
139}
140
141pub fn summarize_categorical(
142 container: &SparseContainer<u8>,
143) -> SummaryStatistics {
144 use lace_utils::{bincount, minmax};
145 let xs: Vec<u8> = container.present_cloned();
146
147 let (min, max) = minmax(&xs);
148 let counts = bincount(&xs, (max + 1) as usize);
149 let max_ct = counts
150 .iter()
151 .fold(0_usize, |acc, &ct| if ct > acc { ct } else { acc });
152 let mode = counts
153 .iter()
154 .enumerate()
155 .filter(|(_, &ct)| ct == max_ct)
156 .map(|(ix, _)| ix as u8)
157 .collect();
158
159 SummaryStatistics::Categorical { min, max, mode }
160}
161
162pub fn summarize_count(container: &SparseContainer<u32>) -> SummaryStatistics {
163 use lace_utils::{bincount, minmax};
164 let xs: Vec<usize> = {
165 let mut xs: Vec<usize> =
166 container.present_iter().map(|&x| x as usize).collect();
167 xs.sort_unstable();
168 xs
169 };
170
171 let n = xs.len();
172 let nf = n as f64;
173
174 let (min, max) = {
175 let (min, max) = minmax(&xs);
176 (min as u32, max as u32)
177 };
178
179 let counts = bincount(&xs, (max + 1) as usize);
180
181 let max_ct = counts
182 .iter()
183 .fold(0_usize, |acc, &ct| if ct > acc { ct } else { acc });
184
185 let mode = counts
186 .iter()
187 .enumerate()
188 .filter(|(_, &ct)| ct == max_ct)
189 .map(|(ix, _)| ix as u32)
190 .collect();
191
192 let mean = xs.iter().sum::<usize>() as f64 / nf;
193
194 let median = if n % 2 == 0 {
195 (xs[n / 2] + xs[n / 2 - 1]) as f64 / 2.0
196 } else {
197 xs[n / 2] as f64
198 };
199
200 SummaryStatistics::Count {
201 min,
202 max,
203 median,
204 mean,
205 mode,
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use approx::*;
213
214 fn get_continuous() -> FeatureData {
215 let dc1: SparseContainer<f64> = SparseContainer::from(vec![
216 (4.0, true),
217 (3.0, false),
218 (2.0, true),
219 (1.0, true),
220 (0.0, true),
221 ]);
222
223 FeatureData::Continuous(dc1)
224 }
225
226 fn get_categorical() -> FeatureData {
227 let dc2: SparseContainer<u8> = SparseContainer::from(vec![
228 (5, true),
229 (3, true),
230 (2, true),
231 (1, false),
232 (4, true),
233 ]);
234
235 FeatureData::Categorical(dc2)
236 }
237
238 #[test]
239 fn gets_present_continuous_data() {
240 let ds = get_continuous();
241 assert_eq!(ds.get(0), Datum::Continuous(4.0));
242 assert_eq!(ds.get(2), Datum::Continuous(2.0));
243 }
244
245 #[test]
246 fn gets_present_categorical_data() {
247 let ds = get_categorical();
248 assert_eq!(ds.get(0), Datum::Categorical(Category::U8(5)));
249 assert_eq!(ds.get(4), Datum::Categorical(Category::U8(4)));
250 }
251
252 #[test]
253 fn gets_missing_continuous_data() {
254 let ds = get_continuous();
255 assert_eq!(ds.get(1), Datum::Missing);
256 }
257
258 #[test]
259 fn gets_missing_categorical_data() {
260 let ds = get_categorical();
261 assert_eq!(ds.get(3), Datum::Missing);
262 }
263
264 #[test]
265 fn summarize_categorical_works_with_fixture() {
266 let summary = get_categorical().summarize();
267 match summary {
268 SummaryStatistics::Categorical { min, max, mode } => {
269 assert_eq!(min, 2);
270 assert_eq!(max, 5);
271 assert_eq!(mode, vec![2, 3, 4, 5]);
272 }
273 _ => panic!("Unexpected summary type"),
274 }
275 }
276
277 #[test]
278 fn summarize_categorical_works_one_mode() {
279 let container: SparseContainer<u8> = SparseContainer::from(vec![
280 (5, true),
281 (3, true),
282 (2, true),
283 (2, true),
284 (1, true),
285 (4, true),
286 ]);
287
288 let summary = summarize_categorical(&container);
289 match summary {
290 SummaryStatistics::Categorical { min, max, mode } => {
291 assert_eq!(min, 1);
292 assert_eq!(max, 5);
293 assert_eq!(mode, vec![2]);
294 }
295 _ => panic!("Unexpected summary type"),
296 }
297 }
298
299 #[test]
300 fn summarize_categorical_works_two_modes() {
301 let container: SparseContainer<u8> = SparseContainer::from(vec![
302 (5, true),
303 (3, true),
304 (2, true),
305 (2, true),
306 (3, true),
307 (4, true),
308 ]);
309
310 let summary = summarize_categorical(&container);
311 match summary {
312 SummaryStatistics::Categorical { min, max, mode } => {
313 assert_eq!(min, 2);
314 assert_eq!(max, 5);
315 assert_eq!(mode, vec![2, 3]);
316 }
317 _ => panic!("Unexpected summary type"),
318 }
319 }
320
321 #[test]
322 fn summarize_continuous_works_with_fixture() {
323 let summary = get_continuous().summarize();
324 match summary {
325 SummaryStatistics::Continuous {
326 min,
327 max,
328 mean,
329 median,
330 variance,
331 } => {
332 assert_relative_eq!(min, 0.0, epsilon = 1E-10);
333 assert_relative_eq!(max, 4.0, epsilon = 1E-10);
334 assert_relative_eq!(mean, 1.75, epsilon = 1E-10);
335 assert_relative_eq!(median, 1.5, epsilon = 1E-10);
336 assert_relative_eq!(variance, 2.1875, epsilon = 1E-10);
337 }
338 _ => panic!("Unexpected summary type"),
339 }
340 }
341
342 #[test]
343 fn summarize_continuous_works_with_odd_number_data() {
344 let container: SparseContainer<f64> = SparseContainer::from(vec![
345 (4.0, true),
346 (3.0, true),
347 (2.0, true),
348 (1.0, true),
349 (0.0, true),
350 ]);
351
352 let summary = summarize_continuous(&container);
353 match summary {
354 SummaryStatistics::Continuous { median, .. } => {
355 assert_relative_eq!(median, 2.0, epsilon = 1E-10);
356 }
357 _ => panic!("Unexpected summary type"),
358 }
359 }
360}