modav_core/models/
stacked_bar.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::{self, Debug},
4};
5
6use super::{Point, Scale};
7use crate::repr::Data;
8
9#[derive(Clone, Debug, PartialEq)]
10pub struct StackedBar {
11    /// The (x, y) points for the bar
12    pub point: Point,
13    /// The percentage makeup  of the bar. For all
14    /// k, v in `fractions` v1 + v2 + v3 + .. = 1.0
15    pub fractions: HashMap<String, f64>,
16    /// Is true of all points within the bar are negative
17    pub is_negative: bool,
18    /// The full value of the stacked bar
19    true_y: Data,
20    /// Keeps track of sections removed from the bar
21    removed_sections: HashSet<String>,
22}
23
24impl StackedBar {
25    pub(crate) fn new(point: Point, fractions: HashMap<String, f64>, is_negative: bool) -> Self {
26        let true_y = point.y.clone();
27        Self {
28            point,
29            fractions,
30            is_negative,
31            true_y,
32            removed_sections: HashSet::new(),
33        }
34    }
35
36    pub fn from_point(point: impl Into<Point>, is_negative: bool) -> Self {
37        let point = point.into();
38        let true_y = point.y.clone();
39        Self {
40            point,
41            fractions: HashMap::default(),
42            is_negative,
43            true_y,
44            removed_sections: HashSet::new(),
45        }
46    }
47
48    pub fn restore(&mut self) {
49        self.point.y = self.true_y.clone();
50    }
51
52    pub fn get_fractions(&self) -> &HashMap<String, f64> {
53        &self.fractions
54    }
55
56    pub fn get_point(&self) -> &Point {
57        &self.point
58    }
59
60    /// Returns true if the point is empty. For a Stacked bar chart, an empty point
61    /// is defined as one which has a y data value of 0 or 0.0
62    pub(crate) fn is_empty(&self) -> bool {
63        match &self.point.y {
64            Data::Integer(i) => *i == 0,
65            Data::Number(n) => *n == 0,
66            Data::Float(f) => *f == 0.0,
67            _ => false,
68        }
69    }
70
71    /// Effectively removes the contribution of specified section from the
72    /// stacked bar if it exists
73    pub fn remove_section(&mut self, section: impl Into<String>) {
74        let section = section.into();
75
76        if self.removed_sections.contains(&section) {
77            return;
78        }
79
80        let fraction = self.fractions.get(&section);
81
82        let Some(fraction) = fraction else { return };
83
84        let contribution = match self.true_y {
85            Data::Number(n) => (n as f64) * fraction,
86            Data::Integer(i) => (i as f64) * fraction,
87            Data::Float(f) => (f as f64) * fraction,
88            _ => 0.0,
89        };
90
91        match self.point.y {
92            Data::Number(n) => self.point.y = Data::Number(((n as f64) - contribution) as isize),
93            Data::Integer(i) => self.point.y = Data::Integer(((i as f64) - contribution) as i32),
94            Data::Float(f) => self.point.y = Data::Float(((f as f64) - contribution) as f32),
95            _ => {}
96        };
97
98        self.removed_sections.insert(section);
99    }
100
101    /// Effectively re-adds the contribution of specified section to the
102    /// stacked bar if it exists
103    pub fn add_section(&mut self, section: impl Into<String>) {
104        let section = section.into();
105
106        if !self.removed_sections.contains(&section) {
107            return;
108        }
109
110        let fraction = self.fractions.get(&section);
111
112        let Some(fraction) = fraction else { return };
113
114        let contribution = match self.true_y {
115            Data::Number(n) => (n as f64) * fraction,
116            Data::Integer(i) => (i as f64) * fraction,
117            Data::Float(f) => (f as f64) * fraction,
118            _ => 0.0,
119        };
120
121        match self.point.y {
122            Data::Number(n) => self.point.y = Data::Number(((n as f64) + contribution) as isize),
123            Data::Integer(i) => self.point.y = Data::Integer(((i as f64) + contribution) as i32),
124            Data::Float(f) => self.point.y = Data::Float(((f as f64) + contribution) as f32),
125            _ => {}
126        }
127
128        self.removed_sections.remove(&section);
129    }
130}
131
132#[derive(Clone, Debug, PartialEq)]
133pub struct StackedBarChart {
134    pub bars: Vec<StackedBar>,
135    pub x_axis: Option<String>,
136    pub y_axis: Option<String>,
137    pub labels: HashSet<String>,
138    pub x_scale: Scale,
139    pub y_scale: Scale,
140}
141
142#[allow(dead_code)]
143impl StackedBarChart {
144    pub(crate) fn new(
145        bars: Vec<StackedBar>,
146        x_scale: Scale,
147        y_scale: Scale,
148        labels: HashSet<String>,
149    ) -> Result<Self, StackedBarChartError> {
150        Self::assert_x_scale(&x_scale, &bars)?;
151        Self::assert_y_scale(&y_scale, &bars)?;
152
153        Ok(Self {
154            x_scale,
155            y_scale,
156            bars,
157            x_axis: None,
158            y_axis: None,
159            labels,
160        })
161    }
162
163    fn assert_x_scale(scale: &Scale, bars: &[StackedBar]) -> Result<(), StackedBarChartError> {
164        for x in bars.iter().map(|bar| &bar.point.x) {
165            if !scale.contains(x) {
166                return Err(StackedBarChartError::OutOfRange(
167                    "X".to_string(),
168                    x.to_string(),
169                ));
170            }
171        }
172
173        Ok(())
174    }
175
176    fn assert_y_scale(scale: &Scale, bars: &[StackedBar]) -> Result<(), StackedBarChartError> {
177        for y in bars.iter().map(|bar| &bar.point.y) {
178            if !scale.contains(y) {
179                return Err(StackedBarChartError::OutOfRange(
180                    "Y".to_string(),
181                    y.to_string(),
182                ));
183            }
184        }
185
186        Ok(())
187    }
188
189    pub fn x_axis(mut self, label: impl Into<String>) -> Self {
190        self.x_axis = Some(label.into());
191        self
192    }
193
194    pub fn y_axis(mut self, label: impl Into<String>) -> Self {
195        self.y_axis = Some(label.into());
196        self
197    }
198
199    /// Returns true any negative bar is not completely empty. For a Stacked bar chart, an empty point
200    /// is defined as one which has a y data value of 0 or 0.0
201    pub fn has_true_negatives(&self) -> bool {
202        self.bars
203            .iter()
204            .any(|bar| bar.is_negative && !bar.is_empty())
205    }
206
207    /// Returns true any positive bar is not completely empty. For a Stacked bar chart, an empty point
208    /// is defined as one which has a y data value of 0 or 0.0
209    pub fn has_true_positives(&self) -> bool {
210        self.bars
211            .iter()
212            .any(|bar| !bar.is_negative && !bar.is_empty())
213    }
214
215    pub fn remove_section(&mut self, bar: usize, section: impl Into<String>) {
216        if let Some(bar) = self.bars.get_mut(bar) {
217            bar.remove_section(section);
218        };
219    }
220
221    pub fn remove_section_all(&mut self, section: impl Into<String>) {
222        let section: String = section.into();
223        self.bars.iter_mut().for_each(|bar| {
224            bar.remove_section(section.clone());
225        });
226    }
227
228    pub fn add_section(&mut self, bar: usize, section: impl Into<String>) {
229        if let Some(bar) = self.bars.get_mut(bar) {
230            bar.add_section(section);
231        };
232    }
233
234    pub fn add_section_all(&mut self, section: impl Into<String>) {
235        let section: String = section.into();
236        self.bars.iter_mut().for_each(|bar| {
237            bar.add_section(section.clone());
238        });
239    }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq)]
243pub enum StackedBarChartError {
244    OutOfRange(String, String),
245}
246
247impl fmt::Display for StackedBarChartError {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        match self {
250            StackedBarChartError::OutOfRange(sc, val) => {
251                write!(
252                    f,
253                    "The point with value {} on the {} axis is out of range",
254                    val, sc
255                )
256            }
257        }
258    }
259}
260
261impl std::error::Error for StackedBarChartError {}
262
263#[cfg(test)]
264mod stacked_barchart_tests {
265    use crate::models::ScaleKind;
266
267    use super::*;
268
269    fn create_barchart<'a>() -> StackedBarChart {
270        let mut bars = Vec::with_capacity(5);
271
272        let pnt = Point::new(Data::Text("One".into()), Data::Integer(19));
273
274        let fractions = HashMap::from([
275            (String::from("Soda"), 3.0 / 19.0),
276            (String::from("Cream"), 3.0 / 19.0),
277            (String::from("Coffee"), 5.0 / 19.0),
278            (String::from("Choco"), 8.0 / 19.0),
279        ]);
280
281        let bar = StackedBar::new(pnt, fractions, false);
282
283        bars.push(bar);
284
285        let pnt = Point::new(Data::Text("Two".into()), Data::Integer(19));
286
287        let fractions = HashMap::from([
288            (String::from("Soda"), 3.0 / 19.0),
289            (String::from("Cream"), 6.0 / 19.0),
290            (String::from("Coffee"), 10.0 / 19.0),
291            (String::from("Choco"), 0.0 / 19.0),
292        ]);
293
294        let bar = StackedBar::new(pnt, fractions, false);
295        bars.push(bar);
296
297        let pnt = Point::new(Data::Text("Three".into()), Data::Integer(14));
298
299        let fractions = HashMap::from([
300            (String::from("Soda"), 6.0 / 14.0),
301            (String::from("Cream"), 0.0 / 14.0),
302            (String::from("Coffee"), 8.0 / 14.0),
303            (String::from("Choco"), 0.0 / 14.0),
304        ]);
305
306        let bar = StackedBar::new(pnt, fractions, false);
307        bars.push(bar);
308
309        let pnt = Point::new(Data::Text("Four".into()), Data::Integer(16));
310
311        let fractions = HashMap::from([
312            (String::from("Soda"), 3.0 / 16.0),
313            (String::from("Cream"), 0.0 / 16.0),
314            (String::from("Coffee"), 7.0 / 16.0),
315            (String::from("Choco"), 6.0 / 16.0),
316        ]);
317
318        let bar = StackedBar::new(pnt, fractions, false);
319        bars.push(bar);
320
321        let pnt = Point::new(Data::Text("Five".into()), Data::Integer(19));
322
323        let fractions = HashMap::from([
324            (String::from("Soda"), 9.0 / 19.0),
325            (String::from("Cream"), 0.0 / 19.0),
326            (String::from("Coffee"), 10.0 / 19.0),
327            (String::from("Choco"), 0.0 / 19.0),
328        ]);
329
330        let bar = StackedBar::new(pnt, fractions, false);
331        bars.push(bar);
332
333        let x_scale = {
334            let values = vec!["One", "Two", "Three", "Four", "Five"];
335
336            Scale::new(values, ScaleKind::Categorical)
337        };
338
339        let y_scale = vec![14, 16, 19].into();
340
341        let labels = HashSet::from([
342            (String::from("Soda")),
343            (String::from("Cream")),
344            (String::from("Coffee")),
345            (String::from("Choco")),
346        ]);
347
348        match StackedBarChart::new(bars, x_scale, y_scale, labels) {
349            Ok(bar) => bar.x_axis("Number").y_axis("Total"),
350            Err(e) => panic!("{}", e),
351        }
352    }
353
354    fn out_of_range() -> Result<StackedBarChart, StackedBarChartError> {
355        let xs = [1, 5, 6, 11, 15];
356        let ys = [4, 5, 6, 7, 8];
357
358        let bars = xs
359            .into_iter()
360            .zip(ys.into_iter())
361            .map(|point| {
362                StackedBar::from_point((Data::Integer(point.0), Data::Integer(point.1)), false)
363            })
364            .collect();
365
366        let x_scale = {
367            let rng = -5..11;
368
369            Scale::new(rng, ScaleKind::Integer)
370        };
371        let y_scale = {
372            let rng = 2..10;
373
374            Scale::new(rng, ScaleKind::Integer)
375        };
376
377        StackedBarChart::new(bars, x_scale, y_scale, HashSet::default())
378    }
379
380    #[test]
381    fn test_barchart() {
382        let barchart = create_barchart();
383
384        assert_eq!(barchart.x_axis.unwrap(), String::from("Number"));
385        assert_eq!(barchart.y_axis.unwrap(), String::from("Total"));
386
387        assert_eq!(
388            barchart.bars[0].fractions.get(&String::from("Soda")),
389            Some(&(3.0 / 19.0))
390        );
391
392        assert_eq!(
393            barchart.labels,
394            HashSet::from([
395                String::from("Soda"),
396                String::from("Cream"),
397                String::from("Coffee"),
398                String::from("Choco"),
399            ])
400        );
401
402        assert_eq!(barchart.bars.len(), 5)
403    }
404
405    #[test]
406    fn test_faulty_barchart() {
407        let expected = StackedBarChartError::OutOfRange(String::from("X"), String::from("11"));
408        match out_of_range() {
409            Ok(_) => panic!("Should not reach this test case"),
410            Err(e) => assert_eq!(e, expected),
411        }
412    }
413}