Skip to main content

ratio_metadata/
aggregate.rs

1//! # Aggregate module
2//!
3//! Module with aggregate functions for different types of data, such as summing up values or
4//! counting occurrences.
5//!
6//! ## License
7//!
8//! This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
9//! If a copy of the MPL was not distributed with this file,
10//! You can obtain one at <https://mozilla.org/MPL/2.0/>.
11//!
12//! **Code examples both in the docstrings and rendered documentation are free to use.**
13
14use std::borrow::Borrow;
15use std::collections::{BTreeMap, BTreeSet};
16use std::fmt::Debug;
17
18use crate::metadata::{AnnotationValue, Field, ReadMetadata, WeightValue};
19
20fn increment_count<F: Clone + Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
21    match map.get_mut(field) {
22        Some(value) => {
23            *value = value.saturating_add(amount);
24        }
25        None => {
26            map.insert(field.to_owned(), amount);
27        }
28    }
29}
30
31fn decrement_count<F: Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
32    if let Some(value) = map.get_mut(field) {
33        *value = value.saturating_sub(amount);
34    }
35}
36
37/// What to aggregate by.
38#[derive(Clone, Debug, Default, PartialEq)]
39#[cfg_attr(
40    feature = "serde",
41    derive(serde::Serialize, serde::Deserialize),
42    serde(rename_all = "camelCase")
43)]
44#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
45#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
46pub enum Aggregator<N, K, L, W, A>
47where
48    N: Field,
49    K: Field,
50    L: Field,
51    W: Field,
52    A: Field,
53{
54    /// Whether any objects have been counted.
55    #[default]
56    Binary,
57    /// Aggregated occurrences of these kinds. None means all.
58    Names(Option<BTreeSet<N>>),
59    /// Aggregated occurrences of these kinds. None means all.
60    Kinds(Option<BTreeSet<K>>),
61    /// Aggregated occurrences of these labels. None means all.
62    Labels(Option<BTreeSet<L>>),
63    /// Aggregate weight value.
64    Weights {
65        fields: Option<BTreeSet<W>>,
66        absolute: bool,
67    },
68    /// Aggregate annotation key occurrence.
69    Annotations(Option<BTreeSet<A>>),
70}
71impl<N, K, L, W, A> Aggregator<N, K, L, W, A>
72where
73    N: Field,
74    K: Field,
75    L: Field,
76    W: Field,
77    A: Field,
78{
79    /// Aggregator for all names.
80    pub fn all_names() -> Self {
81        Self::Names(None)
82    }
83
84    /// Aggregator for a single name.
85    pub fn for_name(value: N) -> Self {
86        Self::Names(Some(BTreeSet::from([value])))
87    }
88
89    /// Aggregator for all kinds.
90    pub fn all_kinds() -> Self {
91        Self::Kinds(None)
92    }
93
94    /// Aggregator for a single kind.
95    pub fn for_kind(value: K) -> Self {
96        Self::Kinds(Some(BTreeSet::from([value])))
97    }
98
99    /// Aggregator for all labels.
100    pub fn all_labels() -> Self {
101        Self::Labels(None)
102    }
103    /// Aggregator for a single kind.
104    pub fn for_label(value: L) -> Self {
105        Self::Labels(Some(BTreeSet::from([value])))
106    }
107
108    /// Aggregator for all weights.
109    pub fn all_weights(absolute: bool) -> Self {
110        Self::Weights {
111            fields: None,
112            absolute,
113        }
114    }
115
116    /// Aggregator for a single weight.
117    pub fn for_weight(value: W, absolute: bool) -> Self {
118        Self::Weights {
119            fields: Some(BTreeSet::from([value])),
120            absolute,
121        }
122    }
123
124    /// Aggregator for all annotations.
125    pub fn all_annotations() -> Self {
126        Self::Annotations(None)
127    }
128
129    /// Aggregator for a single annotation.
130    pub fn for_annotation(value: A) -> Self {
131        Self::Annotations(Some(BTreeSet::from([value])))
132    }
133
134    /// Create an instance that selects all fields.
135    pub fn as_all(&self) -> Self {
136        match *self {
137            Self::Binary => Self::Binary,
138            Self::Names(_) => Self::Names(None),
139            Self::Kinds(_) => Self::Kinds(None),
140            Self::Labels(_) => Self::Labels(None),
141            Self::Weights {
142                fields: _,
143                absolute,
144            } => Self::Weights {
145                fields: None,
146                absolute,
147            },
148            Self::Annotations(_) => Self::Annotations(None),
149        }
150    }
151}
152
153/// Metadata aggregate values.
154#[derive(Clone, Debug, PartialEq, bon::Builder)]
155#[cfg_attr(
156    feature = "serde",
157    derive(serde::Serialize, serde::Deserialize),
158    serde(default, rename_all = "camelCase")
159)]
160#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
161#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
162pub struct Aggregate<N, K, L, W, WV, A>
163where
164    N: Field,
165    K: Field,
166    L: Field,
167    W: Field,
168    WV: WeightValue,
169    A: Field,
170{
171    /// Over how many items the aggregate has been taken.
172    #[builder(default)]
173    pub items: usize,
174
175    /// Name occurrence.
176    #[builder(default=BTreeMap::new())]
177    pub names: BTreeMap<N, usize>,
178
179    /// Kind occurrence.
180    #[builder(default=BTreeMap::new())]
181    pub kinds: BTreeMap<K, usize>,
182
183    /// Label occurrence.
184    #[builder(default=BTreeMap::new())]
185    pub labels: BTreeMap<L, usize>,
186
187    /// Weight sum value.
188    #[builder(default=BTreeMap::new())]
189    pub weights: BTreeMap<W, WV>,
190
191    /// Annotation key occurrence.
192    #[builder(default=BTreeMap::new())]
193    pub annotations: BTreeMap<A, usize>,
194}
195
196impl<N, K, L, W, WV, A> Default for Aggregate<N, K, L, W, WV, A>
197where
198    N: Field,
199    K: Field,
200    L: Field,
201    W: Field,
202    WV: WeightValue,
203    A: Field,
204{
205    fn default() -> Self {
206        Self {
207            items: 0,
208            names: BTreeMap::new(),
209            kinds: BTreeMap::new(),
210            labels: BTreeMap::new(),
211            weights: BTreeMap::new(),
212            annotations: BTreeMap::new(),
213        }
214    }
215}
216
217impl<N, K, L, W, WV, A> Aggregate<N, K, L, W, WV, A>
218where
219    N: Field,
220    K: Field,
221    L: Field,
222    W: Field,
223    WV: WeightValue,
224    A: Field,
225{
226    /// Create a new aggregate instance.
227    pub fn new() -> Self {
228        Self::default()
229    }
230
231    /// Add metadata to the aggregate.
232    pub fn add<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
233        &'a mut self,
234        item: &'a M,
235    ) {
236        self.items += 1;
237
238        if let Some(name) = item.name() {
239            increment_count(&mut self.names, name, 1);
240        }
241
242        if let Some(kind) = item.kind() {
243            increment_count(&mut self.kinds, kind, 1);
244        }
245
246        item.labels()
247            .for_each(|field| increment_count(&mut self.labels, field, 1));
248
249        item.weights()
250            .for_each(|(field, &value)| match self.weights.get_mut(field) {
251                Some(v) => {
252                    v.add_assign(value);
253                }
254                None => {
255                    self.weights.insert(field.to_owned(), value);
256                }
257            });
258
259        item.annotations()
260            .for_each(|(field, _)| increment_count(&mut self.annotations, field, 1));
261    }
262
263    /// Subtract metadata from the aggregate.
264    pub fn subtract<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
265        &'a mut self,
266        item: &'a M,
267    ) {
268        self.items = self.items.saturating_sub(1);
269
270        if let Some(field) = item.name() {
271            decrement_count(&mut self.names, field, 1);
272        }
273
274        if let Some(field) = item.kind() {
275            decrement_count(&mut self.kinds, field, 1);
276        }
277
278        item.labels().for_each(|field| {
279            decrement_count(&mut self.labels, field, 1);
280        });
281
282        item.weights()
283            .for_each(|(field, &value)| match self.weights.get_mut(field) {
284                Some(v) => {
285                    v.sub_assign(value);
286                }
287                None => {
288                    self.weights.insert(field.clone(), -value);
289                }
290            });
291
292        item.annotations()
293            .for_each(|(field, _)| decrement_count(&mut self.annotations, field, 1));
294    }
295
296    /// Add another Aggregate's values to this.
297    pub fn extend(&mut self, other: Self) {
298        let Self {
299            items,
300            names,
301            kinds,
302            labels,
303            weights,
304            annotations,
305        } = other;
306
307        self.items += items;
308
309        names
310            .into_iter()
311            .for_each(|(field, amount)| increment_count(&mut self.names, &field, amount));
312
313        kinds
314            .into_iter()
315            .for_each(|(field, amount)| increment_count(&mut self.kinds, &field, amount));
316
317        labels
318            .into_iter()
319            .for_each(|(field, amount)| increment_count(&mut self.labels, &field, amount));
320
321        weights.into_iter().for_each(|(field, value)| {
322            match self.weights.get_mut(&field) {
323                Some(v) => {
324                    v.sub_assign(value);
325                }
326                None => {
327                    self.weights.insert(field.to_owned(), value);
328                }
329            };
330        });
331
332        annotations
333            .into_iter()
334            .for_each(|(field, amount)| increment_count(&mut self.annotations, &field, amount));
335    }
336
337    /// Get the sum for all given fields' values for a given aggregator.
338    pub fn aggregate(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
339        match aggregator {
340            Aggregator::Binary => {
341                if self.items > 0 {
342                    1.0
343                } else {
344                    0.0
345                }
346            }
347            Aggregator::Names(fields) => match fields {
348                None => self.names.values().sum::<usize>() as f64,
349                Some(fields) => fields
350                    .iter()
351                    .filter_map(|field| self.names.get(field))
352                    .sum::<usize>() as f64,
353            },
354            Aggregator::Kinds(fields) => match fields {
355                None => self.kinds.values().sum::<usize>() as f64,
356                Some(fields) => fields
357                    .iter()
358                    .filter_map(|field| self.kinds.get(field))
359                    .sum::<usize>() as f64,
360            },
361            Aggregator::Labels(fields) => match fields {
362                None => self.labels.values().sum::<usize>() as f64,
363                Some(fields) => fields
364                    .iter()
365                    .filter_map(|field| self.labels.get(field))
366                    .sum::<usize>() as f64,
367            },
368            Aggregator::Weights { fields, absolute } => match fields {
369                Some(fields) => {
370                    let values = fields.iter().filter_map(|field| self.weights.get(field));
371                    if *absolute {
372                        values.map(|v| v.as_().abs()).sum()
373                    } else {
374                        values.map(|v| v.as_()).sum()
375                    }
376                }
377                None => {
378                    let values = self.weights.values();
379                    if *absolute {
380                        values.map(|v| v.as_().abs()).sum()
381                    } else {
382                        values.map(|v| v.as_()).sum()
383                    }
384                }
385            },
386            Aggregator::Annotations(fields) => match fields {
387                None => self.annotations.values().sum::<usize>() as f64,
388                Some(fields) => fields
389                    .iter()
390                    .filter_map(|field| self.annotations.get(field))
391                    .sum::<usize>() as f64,
392            },
393        }
394    }
395
396    /// Get a fraction that a field represents with respect to the sum of all fields for a certain aggregator.
397    pub fn fraction(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
398        let total = self.aggregate(&aggregator.as_all());
399        if total == 0.0 {
400            0.0
401        } else {
402            self.aggregate(aggregator) / total
403        }
404    }
405
406    /// Get all fractions for the given fields with respect to the sum of all fields for a certain aggregator. It can be scaled with a constant factor. Set the factor to 1.0 to ignore scaling.
407    pub fn fractions(&self, aggregator: &Aggregator<N, K, L, W, A>, factor: f64) -> Vec<f64> {
408        let sum = self.aggregate(&aggregator.as_all());
409        let factor = { if sum == 0.0 { 1.0 } else { factor / sum } };
410        match aggregator {
411            Aggregator::Binary => vec![factor],
412            Aggregator::Names(None) => self.names.values().map(|&v| factor * v as f64).collect(),
413            Aggregator::Names(Some(fields)) => fields
414                .iter()
415                .filter_map(|field| self.names.get(field))
416                .map(|&v| factor * v as f64)
417                .collect(),
418            Aggregator::Kinds(None) => self.kinds.values().map(|&v| factor * v as f64).collect(),
419            Aggregator::Kinds(Some(fields)) => fields
420                .iter()
421                .filter_map(|field| self.kinds.get(field))
422                .map(|&v| factor * v as f64)
423                .collect(),
424            Aggregator::Labels(None) => self.labels.values().map(|&v| factor * v as f64).collect(),
425            Aggregator::Labels(Some(fields)) => fields
426                .iter()
427                .filter_map(|field| self.labels.get(field))
428                .map(|&v| factor * v as f64)
429                .collect(),
430            Aggregator::Weights {
431                fields: None,
432                absolute,
433            } => self
434                .weights
435                .values()
436                .map(|&v| {
437                    factor * {
438                        let value = v.as_();
439                        if *absolute { value.abs() } else { value }
440                    }
441                })
442                .collect(),
443            Aggregator::Weights {
444                fields: Some(fields),
445                absolute,
446            } => fields
447                .iter()
448                .filter_map(|field| self.weights.get(field))
449                .map(|&v| {
450                    factor * {
451                        let value = v.as_();
452                        if *absolute { value.abs() } else { value }
453                    }
454                })
455                .collect(),
456            Aggregator::Annotations(None) => self
457                .annotations
458                .values()
459                .map(|&v| factor * v as f64)
460                .collect(),
461            Aggregator::Annotations(Some(fields)) => fields
462                .iter()
463                .filter_map(|field| self.annotations.get(field))
464                .map(|&v| factor * v as f64)
465                .collect(),
466        }
467    }
468}
469
470/// Track lower and upper bounds for float BTreeMaps.
471#[derive(Clone, Debug, PartialEq)]
472#[cfg_attr(
473    feature = "serde",
474    derive(serde::Serialize, serde::Deserialize),
475    serde(default)
476)]
477pub struct Domains<F: Field> {
478    pub bounds: BTreeMap<F, (f64, f64)>,
479}
480impl<F: Field> Default for Domains<F> {
481    fn default() -> Self {
482        Self {
483            bounds: BTreeMap::new(),
484        }
485    }
486}
487
488impl<F: Field> Domains<F> {
489    /// Create a new domains instance.
490    pub fn new() -> Self {
491        Self::default()
492    }
493
494    /// Update the lower and upper bounds according to the values from this map.
495    pub fn update_map(&mut self, values: &BTreeMap<F, f64>) {
496        self.update_iter(values.iter());
497    }
498
499    /// Update from an iterator of (key, value).
500    pub fn update_iter<'a, I: Iterator<Item = (&'a F, &'a f64)>>(&'a mut self, iter: I) {
501        iter.for_each(|(key, &value)| self.update_key(key, value))
502    }
503
504    /// Update the lower and upper bounds for this key using the given value.
505    pub fn update_key(&mut self, key: &F, value: f64) {
506        if let Some(entry) = self.bounds.get_mut(key) {
507            entry.0 = entry.0.min(value);
508            entry.1 = entry.1.max(value);
509        } else {
510            self.bounds.insert(key.to_owned(), (value, value));
511        }
512    }
513
514    /// Get the domains for a specified key.
515    pub fn get<Q: Ord>(&self, key: &Q) -> Option<&(f64, f64)>
516    where
517        F: Borrow<Q>,
518    {
519        self.bounds.get(key)
520    }
521
522    /// Interpolate a value for a specified key's domain. Returns 1.0 when the domain has no size
523    /// (lower and upper bound are equal).
524    pub fn interpolate<Q: Ord>(&self, key: &Q, value: f64) -> Option<f64>
525    where
526        F: Borrow<Q>,
527    {
528        self.get(key).map(|&(lower, upper)| {
529            if lower == upper {
530                1.0
531            } else {
532                (value - lower) / (upper - lower)
533            }
534        })
535    }
536}
537
538#[cfg(test)]
539pub mod tests {
540    use super::*;
541    use crate::metadata::{Metadata, SimpleMetadata};
542
543    /// Aggregate where all fields are given by a String and weights as floats.
544    pub type SimpleAggregate = Aggregate<String, String, String, String, f64, String>;
545
546    #[test]
547    fn test_aggregate_binary() {
548        let mut aggregate: SimpleAggregate = Aggregate::new();
549        let metadata: SimpleMetadata = Metadata::builder().build();
550        aggregate.add(&metadata.as_ref());
551        assert_eq!(aggregate.aggregate(&Aggregator::Binary), 1.0);
552        aggregate.subtract(&metadata.as_ref());
553        assert_eq!(aggregate.aggregate(&Aggregator::Binary), 0.0);
554    }
555
556    #[test]
557    fn test_aggregate_names() {
558        let mut aggregate: SimpleAggregate = Aggregate::new();
559        let metadata1 = SimpleMetadata::builder().name("test1".to_string()).build();
560        let metadata2 = SimpleMetadata::builder().name("test2".to_string()).build();
561        aggregate.add(&metadata1.as_ref());
562        aggregate.add(&metadata2.as_ref());
563        assert_eq!(aggregate.aggregate(&Aggregator::all_names()), 2.0);
564        assert_eq!(
565            aggregate.aggregate(&Aggregator::for_name("test1".to_string())),
566            1.0
567        );
568        assert_eq!(
569            aggregate.aggregate(&Aggregator::for_name("test3".to_string())),
570            0.0
571        );
572    }
573
574    #[test]
575    fn test_aggregate_kinds() {
576        let mut aggregate: SimpleAggregate = Aggregate::new();
577        let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
578        let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
579        aggregate.add(&metadata1.as_ref());
580        aggregate.add(&metadata2.as_ref());
581        assert_eq!(aggregate.aggregate(&Aggregator::all_kinds()), 2.0);
582        assert_eq!(
583            aggregate.aggregate(&Aggregator::for_kind("kind1".to_string())),
584            1.0
585        );
586        assert_eq!(
587            aggregate.aggregate(&Aggregator::for_kind("kind3".to_string())),
588            0.0
589        );
590    }
591
592    #[test]
593    fn test_aggregate_labels() {
594        let mut aggregate = SimpleAggregate::new();
595        let metadata1 = SimpleMetadata::builder()
596            .labels(bon::set!["label1".to_string()])
597            .build();
598        let metadata2 = SimpleMetadata::builder()
599            .labels(bon::set!["label2".to_string()])
600            .build();
601        aggregate.add(&metadata1.as_ref());
602        aggregate.add(&metadata2.as_ref());
603        assert_eq!(aggregate.aggregate(&Aggregator::all_labels()), 2.0);
604        assert_eq!(
605            aggregate.aggregate(&Aggregator::for_label("label1".to_string())),
606            1.0
607        );
608        assert_eq!(
609            aggregate.aggregate(&Aggregator::for_label("label3".to_string())),
610            0.0
611        );
612    }
613
614    #[test]
615    fn test_aggregate_weights() {
616        let mut aggregate = SimpleAggregate::new();
617        let metadata1 = SimpleMetadata::builder()
618            .weights(bon::map! {"weight1": 10.0})
619            .build();
620        let metadata2 = SimpleMetadata::builder()
621            .weights(bon::map! {"weight2": 20.0})
622            .build();
623        aggregate.add(&metadata1.as_ref());
624        aggregate.add(&metadata2.as_ref());
625        assert_eq!(aggregate.aggregate(&Aggregator::all_weights(false)), 30.0);
626        assert_eq!(
627            aggregate.aggregate(&Aggregator::for_weight("weight1".to_string(), false)),
628            10.0
629        );
630        assert_eq!(
631            aggregate.aggregate(&Aggregator::for_weight("weight3".to_string(), false)),
632            0.0
633        );
634    }
635
636    #[test]
637    fn test_aggregate_annotations() {
638        let mut aggregate = SimpleAggregate::new();
639        let metadata1 = SimpleMetadata::builder()
640            .annotations(bon::map! {"key1": "value1"})
641            .build();
642        let metadata2 = SimpleMetadata::builder()
643            .annotations(bon::map! {"key2": "value2"})
644            .build();
645        aggregate.add(&metadata1.as_ref());
646        aggregate.add(&metadata2.as_ref());
647        assert_eq!(aggregate.aggregate(&Aggregator::all_annotations()), 2.0);
648        assert_eq!(
649            aggregate.aggregate(&Aggregator::for_annotation("key1".to_string())),
650            1.0
651        );
652        assert_eq!(
653            aggregate.aggregate(&Aggregator::for_annotation("key3".to_string())),
654            0.0
655        );
656    }
657
658    #[test]
659    fn test_fraction() {
660        let mut aggregate = SimpleAggregate::new();
661        let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
662        let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
663        aggregate.add(&metadata1.as_ref());
664        aggregate.add(&metadata2.as_ref());
665        assert_eq!(
666            aggregate.fraction(&Aggregator::for_kind("kind1".to_string())),
667            0.5
668        );
669        assert_eq!(
670            aggregate.fraction(&Aggregator::for_kind("kind3".to_string())),
671            0.0
672        );
673    }
674
675    #[test]
676    fn test_fractions() {
677        let mut aggregate = SimpleAggregate::new();
678        let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
679        let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
680        aggregate.add(&metadata1.as_ref());
681        aggregate.add(&metadata2.as_ref());
682        let fractions = aggregate.fractions(&Aggregator::all_kinds(), 1.0);
683        assert_eq!(fractions, vec![0.5, 0.5]);
684    }
685
686    #[test]
687    fn test_domains() {
688        let mut domains = Domains::new();
689        let mut map = BTreeMap::new();
690        map.insert("key1".to_string(), 10.0);
691        map.insert("key2".to_string(), 20.0);
692        domains.update_map(&map);
693        assert_eq!(domains.get(&"key1".to_string()), Some(&(10.0, 10.0)));
694        assert_eq!(domains.interpolate(&"key1".to_string(), 10.0), Some(1.0));
695        domains.update_key(&"key1".to_string(), 5.0);
696        assert_eq!(domains.get(&"key1".to_string()), Some(&(5.0, 10.0)));
697        assert_eq!(domains.interpolate(&"key1".to_string(), 7.5), Some(0.5));
698    }
699}