1use std::{
2 collections::{HashMap, HashSet},
3 fmt::{self, Debug},
4};
5
6use super::{Point, Scale};
7use crate::repr::Data;
8
9#[derive(Clone, Debug, PartialEq)]
10pub struct StackedBar {
11 pub point: Point,
13 pub fractions: HashMap<String, f64>,
16 pub is_negative: bool,
18 true_y: Data,
20 removed_sections: HashSet<String>,
22}
23
24impl StackedBar {
25 pub(crate) fn new(point: Point, fractions: HashMap<String, f64>, is_negative: bool) -> Self {
26 let true_y = point.y.clone();
27 Self {
28 point,
29 fractions,
30 is_negative,
31 true_y,
32 removed_sections: HashSet::new(),
33 }
34 }
35
36 pub fn from_point(point: impl Into<Point>, is_negative: bool) -> Self {
37 let point = point.into();
38 let true_y = point.y.clone();
39 Self {
40 point,
41 fractions: HashMap::default(),
42 is_negative,
43 true_y,
44 removed_sections: HashSet::new(),
45 }
46 }
47
48 pub fn restore(&mut self) {
49 self.point.y = self.true_y.clone();
50 }
51
52 pub fn get_fractions(&self) -> &HashMap<String, f64> {
53 &self.fractions
54 }
55
56 pub fn get_point(&self) -> &Point {
57 &self.point
58 }
59
60 pub(crate) fn is_empty(&self) -> bool {
63 match &self.point.y {
64 Data::Integer(i) => *i == 0,
65 Data::Number(n) => *n == 0,
66 Data::Float(f) => *f == 0.0,
67 _ => false,
68 }
69 }
70
71 pub fn remove_section(&mut self, section: impl Into<String>) {
74 let section = section.into();
75
76 if self.removed_sections.contains(§ion) {
77 return;
78 }
79
80 let fraction = self.fractions.get(§ion);
81
82 let Some(fraction) = fraction else { return };
83
84 let contribution = match self.true_y {
85 Data::Number(n) => (n as f64) * fraction,
86 Data::Integer(i) => (i as f64) * fraction,
87 Data::Float(f) => (f as f64) * fraction,
88 _ => 0.0,
89 };
90
91 match self.point.y {
92 Data::Number(n) => self.point.y = Data::Number(((n as f64) - contribution) as isize),
93 Data::Integer(i) => self.point.y = Data::Integer(((i as f64) - contribution) as i32),
94 Data::Float(f) => self.point.y = Data::Float(((f as f64) - contribution) as f32),
95 _ => {}
96 };
97
98 self.removed_sections.insert(section);
99 }
100
101 pub fn add_section(&mut self, section: impl Into<String>) {
104 let section = section.into();
105
106 if !self.removed_sections.contains(§ion) {
107 return;
108 }
109
110 let fraction = self.fractions.get(§ion);
111
112 let Some(fraction) = fraction else { return };
113
114 let contribution = match self.true_y {
115 Data::Number(n) => (n as f64) * fraction,
116 Data::Integer(i) => (i as f64) * fraction,
117 Data::Float(f) => (f as f64) * fraction,
118 _ => 0.0,
119 };
120
121 match self.point.y {
122 Data::Number(n) => self.point.y = Data::Number(((n as f64) + contribution) as isize),
123 Data::Integer(i) => self.point.y = Data::Integer(((i as f64) + contribution) as i32),
124 Data::Float(f) => self.point.y = Data::Float(((f as f64) + contribution) as f32),
125 _ => {}
126 }
127
128 self.removed_sections.remove(§ion);
129 }
130}
131
132#[derive(Clone, Debug, PartialEq)]
133pub struct StackedBarChart {
134 pub bars: Vec<StackedBar>,
135 pub x_axis: Option<String>,
136 pub y_axis: Option<String>,
137 pub labels: HashSet<String>,
138 pub x_scale: Scale,
139 pub y_scale: Scale,
140}
141
142#[allow(dead_code)]
143impl StackedBarChart {
144 pub(crate) fn new(
145 bars: Vec<StackedBar>,
146 x_scale: Scale,
147 y_scale: Scale,
148 labels: HashSet<String>,
149 ) -> Result<Self, StackedBarChartError> {
150 Self::assert_x_scale(&x_scale, &bars)?;
151 Self::assert_y_scale(&y_scale, &bars)?;
152
153 Ok(Self {
154 x_scale,
155 y_scale,
156 bars,
157 x_axis: None,
158 y_axis: None,
159 labels,
160 })
161 }
162
163 fn assert_x_scale(scale: &Scale, bars: &[StackedBar]) -> Result<(), StackedBarChartError> {
164 for x in bars.iter().map(|bar| &bar.point.x) {
165 if !scale.contains(x) {
166 return Err(StackedBarChartError::OutOfRange(
167 "X".to_string(),
168 x.to_string(),
169 ));
170 }
171 }
172
173 Ok(())
174 }
175
176 fn assert_y_scale(scale: &Scale, bars: &[StackedBar]) -> Result<(), StackedBarChartError> {
177 for y in bars.iter().map(|bar| &bar.point.y) {
178 if !scale.contains(y) {
179 return Err(StackedBarChartError::OutOfRange(
180 "Y".to_string(),
181 y.to_string(),
182 ));
183 }
184 }
185
186 Ok(())
187 }
188
189 pub fn x_axis(mut self, label: impl Into<String>) -> Self {
190 self.x_axis = Some(label.into());
191 self
192 }
193
194 pub fn y_axis(mut self, label: impl Into<String>) -> Self {
195 self.y_axis = Some(label.into());
196 self
197 }
198
199 pub fn has_true_negatives(&self) -> bool {
202 self.bars
203 .iter()
204 .any(|bar| bar.is_negative && !bar.is_empty())
205 }
206
207 pub fn has_true_positives(&self) -> bool {
210 self.bars
211 .iter()
212 .any(|bar| !bar.is_negative && !bar.is_empty())
213 }
214
215 pub fn remove_section(&mut self, bar: usize, section: impl Into<String>) {
216 if let Some(bar) = self.bars.get_mut(bar) {
217 bar.remove_section(section);
218 };
219 }
220
221 pub fn remove_section_all(&mut self, section: impl Into<String>) {
222 let section: String = section.into();
223 self.bars.iter_mut().for_each(|bar| {
224 bar.remove_section(section.clone());
225 });
226 }
227
228 pub fn add_section(&mut self, bar: usize, section: impl Into<String>) {
229 if let Some(bar) = self.bars.get_mut(bar) {
230 bar.add_section(section);
231 };
232 }
233
234 pub fn add_section_all(&mut self, section: impl Into<String>) {
235 let section: String = section.into();
236 self.bars.iter_mut().for_each(|bar| {
237 bar.add_section(section.clone());
238 });
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq)]
243pub enum StackedBarChartError {
244 OutOfRange(String, String),
245}
246
247impl fmt::Display for StackedBarChartError {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 match self {
250 StackedBarChartError::OutOfRange(sc, val) => {
251 write!(
252 f,
253 "The point with value {} on the {} axis is out of range",
254 val, sc
255 )
256 }
257 }
258 }
259}
260
261impl std::error::Error for StackedBarChartError {}
262
263#[cfg(test)]
264mod stacked_barchart_tests {
265 use crate::models::ScaleKind;
266
267 use super::*;
268
269 fn create_barchart<'a>() -> StackedBarChart {
270 let mut bars = Vec::with_capacity(5);
271
272 let pnt = Point::new(Data::Text("One".into()), Data::Integer(19));
273
274 let fractions = HashMap::from([
275 (String::from("Soda"), 3.0 / 19.0),
276 (String::from("Cream"), 3.0 / 19.0),
277 (String::from("Coffee"), 5.0 / 19.0),
278 (String::from("Choco"), 8.0 / 19.0),
279 ]);
280
281 let bar = StackedBar::new(pnt, fractions, false);
282
283 bars.push(bar);
284
285 let pnt = Point::new(Data::Text("Two".into()), Data::Integer(19));
286
287 let fractions = HashMap::from([
288 (String::from("Soda"), 3.0 / 19.0),
289 (String::from("Cream"), 6.0 / 19.0),
290 (String::from("Coffee"), 10.0 / 19.0),
291 (String::from("Choco"), 0.0 / 19.0),
292 ]);
293
294 let bar = StackedBar::new(pnt, fractions, false);
295 bars.push(bar);
296
297 let pnt = Point::new(Data::Text("Three".into()), Data::Integer(14));
298
299 let fractions = HashMap::from([
300 (String::from("Soda"), 6.0 / 14.0),
301 (String::from("Cream"), 0.0 / 14.0),
302 (String::from("Coffee"), 8.0 / 14.0),
303 (String::from("Choco"), 0.0 / 14.0),
304 ]);
305
306 let bar = StackedBar::new(pnt, fractions, false);
307 bars.push(bar);
308
309 let pnt = Point::new(Data::Text("Four".into()), Data::Integer(16));
310
311 let fractions = HashMap::from([
312 (String::from("Soda"), 3.0 / 16.0),
313 (String::from("Cream"), 0.0 / 16.0),
314 (String::from("Coffee"), 7.0 / 16.0),
315 (String::from("Choco"), 6.0 / 16.0),
316 ]);
317
318 let bar = StackedBar::new(pnt, fractions, false);
319 bars.push(bar);
320
321 let pnt = Point::new(Data::Text("Five".into()), Data::Integer(19));
322
323 let fractions = HashMap::from([
324 (String::from("Soda"), 9.0 / 19.0),
325 (String::from("Cream"), 0.0 / 19.0),
326 (String::from("Coffee"), 10.0 / 19.0),
327 (String::from("Choco"), 0.0 / 19.0),
328 ]);
329
330 let bar = StackedBar::new(pnt, fractions, false);
331 bars.push(bar);
332
333 let x_scale = {
334 let values = vec!["One", "Two", "Three", "Four", "Five"];
335
336 Scale::new(values, ScaleKind::Categorical)
337 };
338
339 let y_scale = vec![14, 16, 19].into();
340
341 let labels = HashSet::from([
342 (String::from("Soda")),
343 (String::from("Cream")),
344 (String::from("Coffee")),
345 (String::from("Choco")),
346 ]);
347
348 match StackedBarChart::new(bars, x_scale, y_scale, labels) {
349 Ok(bar) => bar.x_axis("Number").y_axis("Total"),
350 Err(e) => panic!("{}", e),
351 }
352 }
353
354 fn out_of_range() -> Result<StackedBarChart, StackedBarChartError> {
355 let xs = [1, 5, 6, 11, 15];
356 let ys = [4, 5, 6, 7, 8];
357
358 let bars = xs
359 .into_iter()
360 .zip(ys.into_iter())
361 .map(|point| {
362 StackedBar::from_point((Data::Integer(point.0), Data::Integer(point.1)), false)
363 })
364 .collect();
365
366 let x_scale = {
367 let rng = -5..11;
368
369 Scale::new(rng, ScaleKind::Integer)
370 };
371 let y_scale = {
372 let rng = 2..10;
373
374 Scale::new(rng, ScaleKind::Integer)
375 };
376
377 StackedBarChart::new(bars, x_scale, y_scale, HashSet::default())
378 }
379
380 #[test]
381 fn test_barchart() {
382 let barchart = create_barchart();
383
384 assert_eq!(barchart.x_axis.unwrap(), String::from("Number"));
385 assert_eq!(barchart.y_axis.unwrap(), String::from("Total"));
386
387 assert_eq!(
388 barchart.bars[0].fractions.get(&String::from("Soda")),
389 Some(&(3.0 / 19.0))
390 );
391
392 assert_eq!(
393 barchart.labels,
394 HashSet::from([
395 String::from("Soda"),
396 String::from("Cream"),
397 String::from("Coffee"),
398 String::from("Choco"),
399 ])
400 );
401
402 assert_eq!(barchart.bars.len(), 5)
403 }
404
405 #[test]
406 fn test_faulty_barchart() {
407 let expected = StackedBarChartError::OutOfRange(String::from("X"), String::from("11"));
408 match out_of_range() {
409 Ok(_) => panic!("Should not reach this test case"),
410 Err(e) => assert_eq!(e, expected),
411 }
412 }
413}