use std::borrow::Borrow;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use crate::metadata::{AnnotationValue, Field, ReadMetadata, WeightValue};
fn increment_count<F: Clone + Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
match map.get_mut(field) {
Some(value) => {
*value = value.saturating_add(amount);
}
None => {
map.insert(field.to_owned(), amount);
}
}
}
fn decrement_count<F: Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
if let Some(value) = map.get_mut(field) {
*value = value.saturating_sub(amount);
}
}
#[derive(Clone, Debug, Default, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(rename_all = "camelCase")
)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
pub enum Aggregator<N, K, L, W, A>
where
N: Field,
K: Field,
L: Field,
W: Field,
A: Field,
{
#[default]
Binary,
Names(Option<BTreeSet<N>>),
Kinds(Option<BTreeSet<K>>),
Labels(Option<BTreeSet<L>>),
Weights {
fields: Option<BTreeSet<W>>,
absolute: bool,
},
Annotations(Option<BTreeSet<A>>),
}
impl<N, K, L, W, A> Aggregator<N, K, L, W, A>
where
N: Field,
K: Field,
L: Field,
W: Field,
A: Field,
{
pub fn all_names() -> Self {
Self::Names(None)
}
pub fn for_name(value: N) -> Self {
Self::Names(Some(BTreeSet::from([value])))
}
pub fn all_kinds() -> Self {
Self::Kinds(None)
}
pub fn for_kind(value: K) -> Self {
Self::Kinds(Some(BTreeSet::from([value])))
}
pub fn all_labels() -> Self {
Self::Labels(None)
}
pub fn for_label(value: L) -> Self {
Self::Labels(Some(BTreeSet::from([value])))
}
pub fn all_weights(absolute: bool) -> Self {
Self::Weights {
fields: None,
absolute,
}
}
pub fn for_weight(value: W, absolute: bool) -> Self {
Self::Weights {
fields: Some(BTreeSet::from([value])),
absolute,
}
}
pub fn all_annotations() -> Self {
Self::Annotations(None)
}
pub fn for_annotation(value: A) -> Self {
Self::Annotations(Some(BTreeSet::from([value])))
}
pub fn as_all(&self) -> Self {
match *self {
Self::Binary => Self::Binary,
Self::Names(_) => Self::Names(None),
Self::Kinds(_) => Self::Kinds(None),
Self::Labels(_) => Self::Labels(None),
Self::Weights {
fields: _,
absolute,
} => Self::Weights {
fields: None,
absolute,
},
Self::Annotations(_) => Self::Annotations(None),
}
}
}
#[derive(Clone, Debug, PartialEq, bon::Builder)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(default, rename_all = "camelCase")
)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
pub struct Aggregate<N, K, L, W, WV, A>
where
N: Field,
K: Field,
L: Field,
W: Field,
WV: WeightValue,
A: Field,
{
#[builder(default)]
pub items: usize,
#[builder(default=BTreeMap::new())]
pub names: BTreeMap<N, usize>,
#[builder(default=BTreeMap::new())]
pub kinds: BTreeMap<K, usize>,
#[builder(default=BTreeMap::new())]
pub labels: BTreeMap<L, usize>,
#[builder(default=BTreeMap::new())]
pub weights: BTreeMap<W, WV>,
#[builder(default=BTreeMap::new())]
pub annotations: BTreeMap<A, usize>,
}
impl<N, K, L, W, WV, A> Default for Aggregate<N, K, L, W, WV, A>
where
N: Field,
K: Field,
L: Field,
W: Field,
WV: WeightValue,
A: Field,
{
fn default() -> Self {
Self {
items: 0,
names: BTreeMap::new(),
kinds: BTreeMap::new(),
labels: BTreeMap::new(),
weights: BTreeMap::new(),
annotations: BTreeMap::new(),
}
}
}
impl<N, K, L, W, WV, A> Aggregate<N, K, L, W, WV, A>
where
N: Field,
K: Field,
L: Field,
W: Field,
WV: WeightValue,
A: Field,
{
pub fn new() -> Self {
Self::default()
}
pub fn add<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
&'a mut self,
item: &'a M,
) {
self.items += 1;
if let Some(name) = item.name() {
increment_count(&mut self.names, name, 1);
}
if let Some(kind) = item.kind() {
increment_count(&mut self.kinds, kind, 1);
}
item.labels()
.for_each(|field| increment_count(&mut self.labels, field, 1));
item.weights()
.for_each(|(field, &value)| match self.weights.get_mut(field) {
Some(v) => {
v.add_assign(value);
}
None => {
self.weights.insert(field.to_owned(), value);
}
});
item.annotations()
.for_each(|(field, _)| increment_count(&mut self.annotations, field, 1));
}
pub fn subtract<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
&'a mut self,
item: &'a M,
) {
self.items = self.items.saturating_sub(1);
if let Some(field) = item.name() {
decrement_count(&mut self.names, field, 1);
}
if let Some(field) = item.kind() {
decrement_count(&mut self.kinds, field, 1);
}
item.labels().for_each(|field| {
decrement_count(&mut self.labels, field, 1);
});
item.weights()
.for_each(|(field, &value)| match self.weights.get_mut(field) {
Some(v) => {
v.sub_assign(value);
}
None => {
self.weights.insert(field.clone(), -value);
}
});
item.annotations()
.for_each(|(field, _)| decrement_count(&mut self.annotations, field, 1));
}
pub fn extend(&mut self, other: Self) {
let Self {
items,
names,
kinds,
labels,
weights,
annotations,
} = other;
self.items += items;
names
.into_iter()
.for_each(|(field, amount)| increment_count(&mut self.names, &field, amount));
kinds
.into_iter()
.for_each(|(field, amount)| increment_count(&mut self.kinds, &field, amount));
labels
.into_iter()
.for_each(|(field, amount)| increment_count(&mut self.labels, &field, amount));
weights.into_iter().for_each(|(field, value)| {
match self.weights.get_mut(&field) {
Some(v) => {
v.sub_assign(value);
}
None => {
self.weights.insert(field.to_owned(), value);
}
};
});
annotations
.into_iter()
.for_each(|(field, amount)| increment_count(&mut self.annotations, &field, amount));
}
pub fn aggregate(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
match aggregator {
Aggregator::Binary => {
if self.items > 0 {
1.0
} else {
0.0
}
}
Aggregator::Names(fields) => match fields {
None => self.names.values().sum::<usize>() as f64,
Some(fields) => fields
.iter()
.filter_map(|field| self.names.get(field))
.sum::<usize>() as f64,
},
Aggregator::Kinds(fields) => match fields {
None => self.kinds.values().sum::<usize>() as f64,
Some(fields) => fields
.iter()
.filter_map(|field| self.kinds.get(field))
.sum::<usize>() as f64,
},
Aggregator::Labels(fields) => match fields {
None => self.labels.values().sum::<usize>() as f64,
Some(fields) => fields
.iter()
.filter_map(|field| self.labels.get(field))
.sum::<usize>() as f64,
},
Aggregator::Weights { fields, absolute } => match fields {
Some(fields) => {
let values = fields.iter().filter_map(|field| self.weights.get(field));
if *absolute {
values.map(|v| v.as_().abs()).sum()
} else {
values.map(|v| v.as_()).sum()
}
}
None => {
let values = self.weights.values();
if *absolute {
values.map(|v| v.as_().abs()).sum()
} else {
values.map(|v| v.as_()).sum()
}
}
},
Aggregator::Annotations(fields) => match fields {
None => self.annotations.values().sum::<usize>() as f64,
Some(fields) => fields
.iter()
.filter_map(|field| self.annotations.get(field))
.sum::<usize>() as f64,
},
}
}
pub fn fraction(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
let total = self.aggregate(&aggregator.as_all());
if total == 0.0 {
0.0
} else {
self.aggregate(aggregator) / total
}
}
pub fn fractions(&self, aggregator: &Aggregator<N, K, L, W, A>, factor: f64) -> Vec<f64> {
let sum = self.aggregate(&aggregator.as_all());
let factor = { if sum == 0.0 { 1.0 } else { factor / sum } };
match aggregator {
Aggregator::Binary => vec![factor],
Aggregator::Names(None) => self.names.values().map(|&v| factor * v as f64).collect(),
Aggregator::Names(Some(fields)) => fields
.iter()
.filter_map(|field| self.names.get(field))
.map(|&v| factor * v as f64)
.collect(),
Aggregator::Kinds(None) => self.kinds.values().map(|&v| factor * v as f64).collect(),
Aggregator::Kinds(Some(fields)) => fields
.iter()
.filter_map(|field| self.kinds.get(field))
.map(|&v| factor * v as f64)
.collect(),
Aggregator::Labels(None) => self.labels.values().map(|&v| factor * v as f64).collect(),
Aggregator::Labels(Some(fields)) => fields
.iter()
.filter_map(|field| self.labels.get(field))
.map(|&v| factor * v as f64)
.collect(),
Aggregator::Weights {
fields: None,
absolute,
} => self
.weights
.values()
.map(|&v| {
factor * {
let value = v.as_();
if *absolute { value.abs() } else { value }
}
})
.collect(),
Aggregator::Weights {
fields: Some(fields),
absolute,
} => fields
.iter()
.filter_map(|field| self.weights.get(field))
.map(|&v| {
factor * {
let value = v.as_();
if *absolute { value.abs() } else { value }
}
})
.collect(),
Aggregator::Annotations(None) => self
.annotations
.values()
.map(|&v| factor * v as f64)
.collect(),
Aggregator::Annotations(Some(fields)) => fields
.iter()
.filter_map(|field| self.annotations.get(field))
.map(|&v| factor * v as f64)
.collect(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(default)
)]
pub struct Domains<F: Field> {
pub bounds: BTreeMap<F, (f64, f64)>,
}
impl<F: Field> Default for Domains<F> {
fn default() -> Self {
Self {
bounds: BTreeMap::new(),
}
}
}
impl<F: Field> Domains<F> {
pub fn new() -> Self {
Self::default()
}
pub fn update_map(&mut self, values: &BTreeMap<F, f64>) {
self.update_iter(values.iter());
}
pub fn update_iter<'a, I: Iterator<Item = (&'a F, &'a f64)>>(&'a mut self, iter: I) {
iter.for_each(|(key, &value)| self.update_key(key, value))
}
pub fn update_key(&mut self, key: &F, value: f64) {
if let Some(entry) = self.bounds.get_mut(key) {
entry.0 = entry.0.min(value);
entry.1 = entry.1.max(value);
} else {
self.bounds.insert(key.to_owned(), (value, value));
}
}
pub fn get<Q: Ord>(&self, key: &Q) -> Option<&(f64, f64)>
where
F: Borrow<Q>,
{
self.bounds.get(key)
}
pub fn interpolate<Q: Ord>(&self, key: &Q, value: f64) -> Option<f64>
where
F: Borrow<Q>,
{
self.get(key).map(|&(lower, upper)| {
if lower == upper {
1.0
} else {
(value - lower) / (upper - lower)
}
})
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::metadata::{Metadata, SimpleMetadata};
pub type SimpleAggregate = Aggregate<String, String, String, String, f64, String>;
#[test]
fn test_aggregate_binary() {
let mut aggregate: SimpleAggregate = Aggregate::new();
let metadata: SimpleMetadata = Metadata::builder().build();
aggregate.add(&metadata.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::Binary), 1.0);
aggregate.subtract(&metadata.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::Binary), 0.0);
}
#[test]
fn test_aggregate_names() {
let mut aggregate: SimpleAggregate = Aggregate::new();
let metadata1 = SimpleMetadata::builder().name("test1".to_string()).build();
let metadata2 = SimpleMetadata::builder().name("test2".to_string()).build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::all_names()), 2.0);
assert_eq!(
aggregate.aggregate(&Aggregator::for_name("test1".to_string())),
1.0
);
assert_eq!(
aggregate.aggregate(&Aggregator::for_name("test3".to_string())),
0.0
);
}
#[test]
fn test_aggregate_kinds() {
let mut aggregate: SimpleAggregate = Aggregate::new();
let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::all_kinds()), 2.0);
assert_eq!(
aggregate.aggregate(&Aggregator::for_kind("kind1".to_string())),
1.0
);
assert_eq!(
aggregate.aggregate(&Aggregator::for_kind("kind3".to_string())),
0.0
);
}
#[test]
fn test_aggregate_labels() {
let mut aggregate = SimpleAggregate::new();
let metadata1 = SimpleMetadata::builder()
.labels(bon::set!["label1".to_string()])
.build();
let metadata2 = SimpleMetadata::builder()
.labels(bon::set!["label2".to_string()])
.build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::all_labels()), 2.0);
assert_eq!(
aggregate.aggregate(&Aggregator::for_label("label1".to_string())),
1.0
);
assert_eq!(
aggregate.aggregate(&Aggregator::for_label("label3".to_string())),
0.0
);
}
#[test]
fn test_aggregate_weights() {
let mut aggregate = SimpleAggregate::new();
let metadata1 = SimpleMetadata::builder()
.weights(bon::map! {"weight1": 10.0})
.build();
let metadata2 = SimpleMetadata::builder()
.weights(bon::map! {"weight2": 20.0})
.build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::all_weights(false)), 30.0);
assert_eq!(
aggregate.aggregate(&Aggregator::for_weight("weight1".to_string(), false)),
10.0
);
assert_eq!(
aggregate.aggregate(&Aggregator::for_weight("weight3".to_string(), false)),
0.0
);
}
#[test]
fn test_aggregate_annotations() {
let mut aggregate = SimpleAggregate::new();
let metadata1 = SimpleMetadata::builder()
.annotations(bon::map! {"key1": "value1"})
.build();
let metadata2 = SimpleMetadata::builder()
.annotations(bon::map! {"key2": "value2"})
.build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(aggregate.aggregate(&Aggregator::all_annotations()), 2.0);
assert_eq!(
aggregate.aggregate(&Aggregator::for_annotation("key1".to_string())),
1.0
);
assert_eq!(
aggregate.aggregate(&Aggregator::for_annotation("key3".to_string())),
0.0
);
}
#[test]
fn test_fraction() {
let mut aggregate = SimpleAggregate::new();
let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
assert_eq!(
aggregate.fraction(&Aggregator::for_kind("kind1".to_string())),
0.5
);
assert_eq!(
aggregate.fraction(&Aggregator::for_kind("kind3".to_string())),
0.0
);
}
#[test]
fn test_fractions() {
let mut aggregate = SimpleAggregate::new();
let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
aggregate.add(&metadata1.as_ref());
aggregate.add(&metadata2.as_ref());
let fractions = aggregate.fractions(&Aggregator::all_kinds(), 1.0);
assert_eq!(fractions, vec![0.5, 0.5]);
}
#[test]
fn test_domains() {
let mut domains = Domains::new();
let mut map = BTreeMap::new();
map.insert("key1".to_string(), 10.0);
map.insert("key2".to_string(), 20.0);
domains.update_map(&map);
assert_eq!(domains.get(&"key1".to_string()), Some(&(10.0, 10.0)));
assert_eq!(domains.interpolate(&"key1".to_string(), 10.0), Some(1.0));
domains.update_key(&"key1".to_string(), 5.0);
assert_eq!(domains.get(&"key1".to_string()), Some(&(5.0, 10.0)));
assert_eq!(domains.interpolate(&"key1".to_string(), 7.5), Some(0.5));
}
}