1use std::borrow::Borrow;
15use std::collections::{BTreeMap, BTreeSet};
16use std::fmt::Debug;
17
18use crate::metadata::{AnnotationValue, Field, ReadMetadata, WeightValue};
19
20fn increment_count<F: Clone + Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
21 match map.get_mut(field) {
22 Some(value) => {
23 *value = value.saturating_add(amount);
24 }
25 None => {
26 map.insert(field.to_owned(), amount);
27 }
28 }
29}
30
31fn decrement_count<F: Ord>(map: &mut BTreeMap<F, usize>, field: &F, amount: usize) {
32 if let Some(value) = map.get_mut(field) {
33 *value = value.saturating_sub(amount);
34 }
35}
36
37#[derive(Clone, Debug, Default, PartialEq)]
39#[cfg_attr(
40 feature = "serde",
41 derive(serde::Serialize, serde::Deserialize),
42 serde(rename_all = "camelCase")
43)]
44#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
45#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
46pub enum Aggregator<N, K, L, W, A>
47where
48 N: Field,
49 K: Field,
50 L: Field,
51 W: Field,
52 A: Field,
53{
54 #[default]
56 Binary,
57 Names(Option<BTreeSet<N>>),
59 Kinds(Option<BTreeSet<K>>),
61 Labels(Option<BTreeSet<L>>),
63 Weights {
65 fields: Option<BTreeSet<W>>,
66 absolute: bool,
67 },
68 Annotations(Option<BTreeSet<A>>),
70}
71impl<N, K, L, W, A> Aggregator<N, K, L, W, A>
72where
73 N: Field,
74 K: Field,
75 L: Field,
76 W: Field,
77 A: Field,
78{
79 pub fn all_names() -> Self {
81 Self::Names(None)
82 }
83
84 pub fn for_name(value: N) -> Self {
86 Self::Names(Some(BTreeSet::from([value])))
87 }
88
89 pub fn all_kinds() -> Self {
91 Self::Kinds(None)
92 }
93
94 pub fn for_kind(value: K) -> Self {
96 Self::Kinds(Some(BTreeSet::from([value])))
97 }
98
99 pub fn all_labels() -> Self {
101 Self::Labels(None)
102 }
103 pub fn for_label(value: L) -> Self {
105 Self::Labels(Some(BTreeSet::from([value])))
106 }
107
108 pub fn all_weights(absolute: bool) -> Self {
110 Self::Weights {
111 fields: None,
112 absolute,
113 }
114 }
115
116 pub fn for_weight(value: W, absolute: bool) -> Self {
118 Self::Weights {
119 fields: Some(BTreeSet::from([value])),
120 absolute,
121 }
122 }
123
124 pub fn all_annotations() -> Self {
126 Self::Annotations(None)
127 }
128
129 pub fn for_annotation(value: A) -> Self {
131 Self::Annotations(Some(BTreeSet::from([value])))
132 }
133
134 pub fn as_all(&self) -> Self {
136 match *self {
137 Self::Binary => Self::Binary,
138 Self::Names(_) => Self::Names(None),
139 Self::Kinds(_) => Self::Kinds(None),
140 Self::Labels(_) => Self::Labels(None),
141 Self::Weights {
142 fields: _,
143 absolute,
144 } => Self::Weights {
145 fields: None,
146 absolute,
147 },
148 Self::Annotations(_) => Self::Annotations(None),
149 }
150 }
151}
152
153#[derive(Clone, Debug, PartialEq, bon::Builder)]
155#[cfg_attr(
156 feature = "serde",
157 derive(serde::Serialize, serde::Deserialize),
158 serde(default, rename_all = "camelCase")
159)]
160#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
161#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
162pub struct Aggregate<N, K, L, W, WV, A>
163where
164 N: Field,
165 K: Field,
166 L: Field,
167 W: Field,
168 WV: WeightValue,
169 A: Field,
170{
171 #[builder(default)]
173 pub items: usize,
174
175 #[builder(default=BTreeMap::new())]
177 pub names: BTreeMap<N, usize>,
178
179 #[builder(default=BTreeMap::new())]
181 pub kinds: BTreeMap<K, usize>,
182
183 #[builder(default=BTreeMap::new())]
185 pub labels: BTreeMap<L, usize>,
186
187 #[builder(default=BTreeMap::new())]
189 pub weights: BTreeMap<W, WV>,
190
191 #[builder(default=BTreeMap::new())]
193 pub annotations: BTreeMap<A, usize>,
194}
195
196impl<N, K, L, W, WV, A> Default for Aggregate<N, K, L, W, WV, A>
197where
198 N: Field,
199 K: Field,
200 L: Field,
201 W: Field,
202 WV: WeightValue,
203 A: Field,
204{
205 fn default() -> Self {
206 Self {
207 items: 0,
208 names: BTreeMap::new(),
209 kinds: BTreeMap::new(),
210 labels: BTreeMap::new(),
211 weights: BTreeMap::new(),
212 annotations: BTreeMap::new(),
213 }
214 }
215}
216
217impl<N, K, L, W, WV, A> Aggregate<N, K, L, W, WV, A>
218where
219 N: Field,
220 K: Field,
221 L: Field,
222 W: Field,
223 WV: WeightValue,
224 A: Field,
225{
226 pub fn new() -> Self {
228 Self::default()
229 }
230
231 pub fn add<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
233 &'a mut self,
234 item: &'a M,
235 ) {
236 self.items += 1;
237
238 if let Some(name) = item.name() {
239 increment_count(&mut self.names, name, 1);
240 }
241
242 if let Some(kind) = item.kind() {
243 increment_count(&mut self.kinds, kind, 1);
244 }
245
246 item.labels()
247 .for_each(|field| increment_count(&mut self.labels, field, 1));
248
249 item.weights()
250 .for_each(|(field, &value)| match self.weights.get_mut(field) {
251 Some(v) => {
252 v.add_assign(value);
253 }
254 None => {
255 self.weights.insert(field.to_owned(), value);
256 }
257 });
258
259 item.annotations()
260 .for_each(|(field, _)| increment_count(&mut self.annotations, field, 1));
261 }
262
263 pub fn subtract<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>, AV: 'a + AnnotationValue>(
265 &'a mut self,
266 item: &'a M,
267 ) {
268 self.items = self.items.saturating_sub(1);
269
270 if let Some(field) = item.name() {
271 decrement_count(&mut self.names, field, 1);
272 }
273
274 if let Some(field) = item.kind() {
275 decrement_count(&mut self.kinds, field, 1);
276 }
277
278 item.labels().for_each(|field| {
279 decrement_count(&mut self.labels, field, 1);
280 });
281
282 item.weights()
283 .for_each(|(field, &value)| match self.weights.get_mut(field) {
284 Some(v) => {
285 v.sub_assign(value);
286 }
287 None => {
288 self.weights.insert(field.clone(), -value);
289 }
290 });
291
292 item.annotations()
293 .for_each(|(field, _)| decrement_count(&mut self.annotations, field, 1));
294 }
295
296 pub fn extend(&mut self, other: Self) {
298 let Self {
299 items,
300 names,
301 kinds,
302 labels,
303 weights,
304 annotations,
305 } = other;
306
307 self.items += items;
308
309 names
310 .into_iter()
311 .for_each(|(field, amount)| increment_count(&mut self.names, &field, amount));
312
313 kinds
314 .into_iter()
315 .for_each(|(field, amount)| increment_count(&mut self.kinds, &field, amount));
316
317 labels
318 .into_iter()
319 .for_each(|(field, amount)| increment_count(&mut self.labels, &field, amount));
320
321 weights.into_iter().for_each(|(field, value)| {
322 match self.weights.get_mut(&field) {
323 Some(v) => {
324 v.sub_assign(value);
325 }
326 None => {
327 self.weights.insert(field.to_owned(), value);
328 }
329 };
330 });
331
332 annotations
333 .into_iter()
334 .for_each(|(field, amount)| increment_count(&mut self.annotations, &field, amount));
335 }
336
337 pub fn aggregate(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
339 match aggregator {
340 Aggregator::Binary => {
341 if self.items > 0 {
342 1.0
343 } else {
344 0.0
345 }
346 }
347 Aggregator::Names(fields) => match fields {
348 None => self.names.values().sum::<usize>() as f64,
349 Some(fields) => fields
350 .iter()
351 .filter_map(|field| self.names.get(field))
352 .sum::<usize>() as f64,
353 },
354 Aggregator::Kinds(fields) => match fields {
355 None => self.kinds.values().sum::<usize>() as f64,
356 Some(fields) => fields
357 .iter()
358 .filter_map(|field| self.kinds.get(field))
359 .sum::<usize>() as f64,
360 },
361 Aggregator::Labels(fields) => match fields {
362 None => self.labels.values().sum::<usize>() as f64,
363 Some(fields) => fields
364 .iter()
365 .filter_map(|field| self.labels.get(field))
366 .sum::<usize>() as f64,
367 },
368 Aggregator::Weights { fields, absolute } => match fields {
369 Some(fields) => {
370 let values = fields.iter().filter_map(|field| self.weights.get(field));
371 if *absolute {
372 values.map(|v| v.as_().abs()).sum()
373 } else {
374 values.map(|v| v.as_()).sum()
375 }
376 }
377 None => {
378 let values = self.weights.values();
379 if *absolute {
380 values.map(|v| v.as_().abs()).sum()
381 } else {
382 values.map(|v| v.as_()).sum()
383 }
384 }
385 },
386 Aggregator::Annotations(fields) => match fields {
387 None => self.annotations.values().sum::<usize>() as f64,
388 Some(fields) => fields
389 .iter()
390 .filter_map(|field| self.annotations.get(field))
391 .sum::<usize>() as f64,
392 },
393 }
394 }
395
396 pub fn fraction(&self, aggregator: &Aggregator<N, K, L, W, A>) -> f64 {
398 let total = self.aggregate(&aggregator.as_all());
399 if total == 0.0 {
400 0.0
401 } else {
402 self.aggregate(aggregator) / total
403 }
404 }
405
406 pub fn fractions(&self, aggregator: &Aggregator<N, K, L, W, A>, factor: f64) -> Vec<f64> {
408 let sum = self.aggregate(&aggregator.as_all());
409 let factor = { if sum == 0.0 { 1.0 } else { factor / sum } };
410 match aggregator {
411 Aggregator::Binary => vec![factor],
412 Aggregator::Names(None) => self.names.values().map(|&v| factor * v as f64).collect(),
413 Aggregator::Names(Some(fields)) => fields
414 .iter()
415 .filter_map(|field| self.names.get(field))
416 .map(|&v| factor * v as f64)
417 .collect(),
418 Aggregator::Kinds(None) => self.kinds.values().map(|&v| factor * v as f64).collect(),
419 Aggregator::Kinds(Some(fields)) => fields
420 .iter()
421 .filter_map(|field| self.kinds.get(field))
422 .map(|&v| factor * v as f64)
423 .collect(),
424 Aggregator::Labels(None) => self.labels.values().map(|&v| factor * v as f64).collect(),
425 Aggregator::Labels(Some(fields)) => fields
426 .iter()
427 .filter_map(|field| self.labels.get(field))
428 .map(|&v| factor * v as f64)
429 .collect(),
430 Aggregator::Weights {
431 fields: None,
432 absolute,
433 } => self
434 .weights
435 .values()
436 .map(|&v| {
437 factor * {
438 let value = v.as_();
439 if *absolute { value.abs() } else { value }
440 }
441 })
442 .collect(),
443 Aggregator::Weights {
444 fields: Some(fields),
445 absolute,
446 } => fields
447 .iter()
448 .filter_map(|field| self.weights.get(field))
449 .map(|&v| {
450 factor * {
451 let value = v.as_();
452 if *absolute { value.abs() } else { value }
453 }
454 })
455 .collect(),
456 Aggregator::Annotations(None) => self
457 .annotations
458 .values()
459 .map(|&v| factor * v as f64)
460 .collect(),
461 Aggregator::Annotations(Some(fields)) => fields
462 .iter()
463 .filter_map(|field| self.annotations.get(field))
464 .map(|&v| factor * v as f64)
465 .collect(),
466 }
467 }
468}
469
470#[derive(Clone, Debug, PartialEq)]
472#[cfg_attr(
473 feature = "serde",
474 derive(serde::Serialize, serde::Deserialize),
475 serde(default)
476)]
477pub struct Domains<F: Field> {
478 pub bounds: BTreeMap<F, (f64, f64)>,
479}
480impl<F: Field> Default for Domains<F> {
481 fn default() -> Self {
482 Self {
483 bounds: BTreeMap::new(),
484 }
485 }
486}
487
488impl<F: Field> Domains<F> {
489 pub fn new() -> Self {
491 Self::default()
492 }
493
494 pub fn update_map(&mut self, values: &BTreeMap<F, f64>) {
496 self.update_iter(values.iter());
497 }
498
499 pub fn update_iter<'a, I: Iterator<Item = (&'a F, &'a f64)>>(&'a mut self, iter: I) {
501 iter.for_each(|(key, &value)| self.update_key(key, value))
502 }
503
504 pub fn update_key(&mut self, key: &F, value: f64) {
506 if let Some(entry) = self.bounds.get_mut(key) {
507 entry.0 = entry.0.min(value);
508 entry.1 = entry.1.max(value);
509 } else {
510 self.bounds.insert(key.to_owned(), (value, value));
511 }
512 }
513
514 pub fn get<Q: Ord>(&self, key: &Q) -> Option<&(f64, f64)>
516 where
517 F: Borrow<Q>,
518 {
519 self.bounds.get(key)
520 }
521
522 pub fn interpolate<Q: Ord>(&self, key: &Q, value: f64) -> Option<f64>
525 where
526 F: Borrow<Q>,
527 {
528 self.get(key).map(|&(lower, upper)| {
529 if lower == upper {
530 1.0
531 } else {
532 (value - lower) / (upper - lower)
533 }
534 })
535 }
536}
537
538#[cfg(test)]
539pub mod tests {
540 use super::*;
541 use crate::metadata::{Metadata, SimpleMetadata};
542
543 pub type SimpleAggregate = Aggregate<String, String, String, String, f64, String>;
545
546 #[test]
547 fn test_aggregate_binary() {
548 let mut aggregate: SimpleAggregate = Aggregate::new();
549 let metadata: SimpleMetadata = Metadata::builder().build();
550 aggregate.add(&metadata.as_ref());
551 assert_eq!(aggregate.aggregate(&Aggregator::Binary), 1.0);
552 aggregate.subtract(&metadata.as_ref());
553 assert_eq!(aggregate.aggregate(&Aggregator::Binary), 0.0);
554 }
555
556 #[test]
557 fn test_aggregate_names() {
558 let mut aggregate: SimpleAggregate = Aggregate::new();
559 let metadata1 = SimpleMetadata::builder().name("test1".to_string()).build();
560 let metadata2 = SimpleMetadata::builder().name("test2".to_string()).build();
561 aggregate.add(&metadata1.as_ref());
562 aggregate.add(&metadata2.as_ref());
563 assert_eq!(aggregate.aggregate(&Aggregator::all_names()), 2.0);
564 assert_eq!(
565 aggregate.aggregate(&Aggregator::for_name("test1".to_string())),
566 1.0
567 );
568 assert_eq!(
569 aggregate.aggregate(&Aggregator::for_name("test3".to_string())),
570 0.0
571 );
572 }
573
574 #[test]
575 fn test_aggregate_kinds() {
576 let mut aggregate: SimpleAggregate = Aggregate::new();
577 let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
578 let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
579 aggregate.add(&metadata1.as_ref());
580 aggregate.add(&metadata2.as_ref());
581 assert_eq!(aggregate.aggregate(&Aggregator::all_kinds()), 2.0);
582 assert_eq!(
583 aggregate.aggregate(&Aggregator::for_kind("kind1".to_string())),
584 1.0
585 );
586 assert_eq!(
587 aggregate.aggregate(&Aggregator::for_kind("kind3".to_string())),
588 0.0
589 );
590 }
591
592 #[test]
593 fn test_aggregate_labels() {
594 let mut aggregate = SimpleAggregate::new();
595 let metadata1 = SimpleMetadata::builder()
596 .labels(bon::set!["label1".to_string()])
597 .build();
598 let metadata2 = SimpleMetadata::builder()
599 .labels(bon::set!["label2".to_string()])
600 .build();
601 aggregate.add(&metadata1.as_ref());
602 aggregate.add(&metadata2.as_ref());
603 assert_eq!(aggregate.aggregate(&Aggregator::all_labels()), 2.0);
604 assert_eq!(
605 aggregate.aggregate(&Aggregator::for_label("label1".to_string())),
606 1.0
607 );
608 assert_eq!(
609 aggregate.aggregate(&Aggregator::for_label("label3".to_string())),
610 0.0
611 );
612 }
613
614 #[test]
615 fn test_aggregate_weights() {
616 let mut aggregate = SimpleAggregate::new();
617 let metadata1 = SimpleMetadata::builder()
618 .weights(bon::map! {"weight1": 10.0})
619 .build();
620 let metadata2 = SimpleMetadata::builder()
621 .weights(bon::map! {"weight2": 20.0})
622 .build();
623 aggregate.add(&metadata1.as_ref());
624 aggregate.add(&metadata2.as_ref());
625 assert_eq!(aggregate.aggregate(&Aggregator::all_weights(false)), 30.0);
626 assert_eq!(
627 aggregate.aggregate(&Aggregator::for_weight("weight1".to_string(), false)),
628 10.0
629 );
630 assert_eq!(
631 aggregate.aggregate(&Aggregator::for_weight("weight3".to_string(), false)),
632 0.0
633 );
634 }
635
636 #[test]
637 fn test_aggregate_annotations() {
638 let mut aggregate = SimpleAggregate::new();
639 let metadata1 = SimpleMetadata::builder()
640 .annotations(bon::map! {"key1": "value1"})
641 .build();
642 let metadata2 = SimpleMetadata::builder()
643 .annotations(bon::map! {"key2": "value2"})
644 .build();
645 aggregate.add(&metadata1.as_ref());
646 aggregate.add(&metadata2.as_ref());
647 assert_eq!(aggregate.aggregate(&Aggregator::all_annotations()), 2.0);
648 assert_eq!(
649 aggregate.aggregate(&Aggregator::for_annotation("key1".to_string())),
650 1.0
651 );
652 assert_eq!(
653 aggregate.aggregate(&Aggregator::for_annotation("key3".to_string())),
654 0.0
655 );
656 }
657
658 #[test]
659 fn test_fraction() {
660 let mut aggregate = SimpleAggregate::new();
661 let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
662 let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
663 aggregate.add(&metadata1.as_ref());
664 aggregate.add(&metadata2.as_ref());
665 assert_eq!(
666 aggregate.fraction(&Aggregator::for_kind("kind1".to_string())),
667 0.5
668 );
669 assert_eq!(
670 aggregate.fraction(&Aggregator::for_kind("kind3".to_string())),
671 0.0
672 );
673 }
674
675 #[test]
676 fn test_fractions() {
677 let mut aggregate = SimpleAggregate::new();
678 let metadata1 = SimpleMetadata::builder().kind("kind1".to_string()).build();
679 let metadata2 = SimpleMetadata::builder().kind("kind2".to_string()).build();
680 aggregate.add(&metadata1.as_ref());
681 aggregate.add(&metadata2.as_ref());
682 let fractions = aggregate.fractions(&Aggregator::all_kinds(), 1.0);
683 assert_eq!(fractions, vec![0.5, 0.5]);
684 }
685
686 #[test]
687 fn test_domains() {
688 let mut domains = Domains::new();
689 let mut map = BTreeMap::new();
690 map.insert("key1".to_string(), 10.0);
691 map.insert("key2".to_string(), 20.0);
692 domains.update_map(&map);
693 assert_eq!(domains.get(&"key1".to_string()), Some(&(10.0, 10.0)));
694 assert_eq!(domains.interpolate(&"key1".to_string(), 10.0), Some(1.0));
695 domains.update_key(&"key1".to_string(), 5.0);
696 assert_eq!(domains.get(&"key1".to_string()), Some(&(5.0, 10.0)));
697 assert_eq!(domains.interpolate(&"key1".to_string(), 7.5), Some(0.5));
698 }
699}