1use std::collections::{BTreeMap, BTreeSet};
14use std::fmt::Debug;
15use std::marker::PhantomData;
16
17use num_traits::AsPrimitive;
18use uuid::Uuid;
19
20use crate::metadata::{AnnotationValue, Field, ReadMetadata, WeightValue};
21
22pub trait MetadataFilter<IF, NF, KF, LF, WF, AF, N, K, L, W, WV, A, AV>
24where
25 IF: FieldsFilter<Uuid>,
26 NF: FieldsFilter<N>,
27 KF: FieldsFilter<K>,
28 LF: FieldsFilter<L>,
29 WF: FieldValuesFilter<W, WV>,
30 AF: FieldValuesFilter<A, AV>,
31 N: Field,
32 K: Field,
33 L: Field,
34 W: Field,
35 WV: WeightValue,
36 A: Field,
37 AV: AnnotationValue,
38{
39 fn matches_metadata<'a, M: ReadMetadata<'a, N, K, L, W, WV, A, AV>>(
41 &'a self,
42 meta: &'a M,
43 ) -> bool
44 where
45 N: 'a,
46 K: 'a,
47 L: 'a,
48 W: 'a,
49 WV: 'a,
50 A: 'a,
51 AV: 'a,
52 LF: 'a,
53 WF: 'a,
54 AF: 'a,
55 {
56 self.id_filters().all(|f| f.match_one(meta.id()))
57 && self
58 .name_filters()
59 .all(|f| f.matches(meta.name().into_iter()))
60 && self
61 .kind_filters()
62 .all(|f| f.matches(meta.kind().into_iter()))
63 && self.label_filters().all(|f| f.matches(meta.labels()))
64 && self.weight_filters().all(|f| f.matches(meta.weights()))
65 && self
66 .annotation_filters()
67 .all(|f| f.matches(meta.annotations()))
68 }
69
70 fn id_filters<'a>(&'a self) -> impl Iterator<Item = &'a IF>
72 where
73 IF: 'a;
74
75 fn name_filters<'a>(&'a self) -> impl Iterator<Item = &'a NF>
77 where
78 NF: 'a;
79
80 fn kind_filters<'a>(&'a self) -> impl Iterator<Item = &'a KF>
82 where
83 KF: 'a;
84
85 fn label_filters<'a>(&'a self) -> impl Iterator<Item = &'a LF>
87 where
88 LF: 'a;
89
90 fn weight_filters<'a>(&'a self) -> impl Iterator<Item = &'a WF>
92 where
93 WF: 'a;
94
95 fn annotation_filters<'a>(&'a self) -> impl Iterator<Item = &'a AF>
97 where
98 AF: 'a;
99}
100
101pub trait FieldsFilter<F: Field> {
103 fn matches<'a, Fs: Iterator<Item = &'a F>>(&'a self, fields: Fs) -> bool
105 where
106 F: 'a,
107 {
108 match self.matching_mode() {
109 MatchingMode::All => self.match_all(fields),
110 MatchingMode::Any => self.match_any(fields),
111 MatchingMode::None => self.match_none(fields),
112 }
113 }
114
115 fn matching_mode(&self) -> &MatchingMode;
117
118 fn match_one(&self, field: &F) -> bool;
120
121 fn match_all<'a, Fs: Iterator<Item = &'a F>>(&'a self, fields: Fs) -> bool
123 where
124 F: 'a,
125 {
126 fields
127 .fold(BTreeSet::new(), |mut matched, f| {
128 if self.match_one(f) {
129 matched.insert(f);
130 }
131 matched
132 })
133 .len()
134 == self.max_matches()
135 }
136
137 fn max_matches(&self) -> usize;
139
140 fn match_any<'a, Fs: Iterator<Item = &'a F>>(&'a self, mut fields: Fs) -> bool
142 where
143 F: 'a,
144 {
145 fields.any(|f| self.match_one(f))
146 }
147
148 fn match_none<'a, Fs: Iterator<Item = &'a F>>(&'a self, fields: Fs) -> bool
150 where
151 F: 'a,
152 {
153 !self.match_any(fields)
154 }
155}
156
157pub trait FieldValuesFilter<F: Field, V> {
161 fn matches<'a, I: Iterator<Item = (&'a F, &'a V)>>(&'a self, items: I) -> bool
163 where
164 F: 'a,
165 V: 'a,
166 {
167 match self.matching_mode() {
168 MatchingMode::All => self.match_all(items),
169 MatchingMode::Any => self.match_any(items),
170 MatchingMode::None => self.match_none(items),
171 }
172 }
173
174 fn matching_mode(&self) -> &MatchingMode;
176
177 fn match_one(&self, field: &F, value: &V) -> bool;
179
180 fn match_all<'a, I: Iterator<Item = (&'a F, &'a V)>>(&'a self, items: I) -> bool
182 where
183 F: 'a,
184 V: 'a,
185 {
186 items
187 .fold(BTreeSet::new(), |mut matched, (f, v)| {
188 if self.match_one(f, v) {
189 matched.insert(f);
190 }
191 matched
192 })
193 .len()
194 == self.max_matches()
195 }
196
197 fn max_matches(&self) -> usize;
199
200 fn match_any<'a, I: Iterator<Item = (&'a F, &'a V)>>(&'a self, mut items: I) -> bool
202 where
203 F: 'a,
204 V: 'a,
205 {
206 items.any(|(f, v)| self.match_one(f, v))
207 }
208
209 fn match_none<'a, I: Iterator<Item = (&'a F, &'a V)>>(&'a self, items: I) -> bool
211 where
212 F: 'a,
213 V: 'a,
214 {
215 !self.match_any(items)
216 }
217}
218
219pub trait ValueFilter<V> {
221 fn matches<'a>(&'a self, value: &'a V) -> bool;
223}
224impl<T> ValueFilter<T> for () {
225 fn matches<'a>(&'a self, _: &'a T) -> bool {
226 true
227 }
228}
229
230#[derive(Clone, Debug, PartialEq, bon::Builder)]
233#[cfg_attr(
234 feature = "serde",
235 derive(serde::Serialize, serde::Deserialize),
236 serde(default, rename_all = "camelCase")
237)]
238#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
239#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
240pub struct BTreeMetadataFilter<N, K, L, W, WV, A, AV>
241where
242 N: Field,
243 K: Field,
244 L: Field,
245 W: Field,
246 WV: WeightValue,
247 A: Field,
248 AV: AnnotationValue,
249{
250 #[builder(default)]
251 pub id_filters: Vec<BTreeFieldsFilter<Uuid>>,
252
253 #[builder(default)]
254 pub name_filters: Vec<BTreeFieldsFilter<N>>,
255
256 #[builder(default)]
257 pub kind_filters: Vec<BTreeFieldsFilter<K>>,
258
259 #[builder(default)]
260 pub label_filters: Vec<BTreeFieldsFilter<L>>,
261
262 #[builder(default)]
263 pub weight_filters: Vec<BTreeFieldValuesFilter<W, Domain1D<WV>, WV>>,
264
265 #[builder(default)]
266 pub annotation_filters: Vec<BTreeFieldValuesFilter<A, (), AV>>, }
268impl<N, K, L, W, WV, A, AV> Default for BTreeMetadataFilter<N, K, L, W, WV, A, AV>
269where
270 N: Field,
271 K: Field,
272 L: Field,
273 W: Field,
274 WV: WeightValue,
275 A: Field,
276 AV: AnnotationValue,
277{
278 fn default() -> Self {
279 Self {
280 id_filters: Vec::new(),
281 name_filters: Vec::new(),
282 kind_filters: Vec::new(),
283 label_filters: Vec::new(),
284 weight_filters: Vec::new(),
285 annotation_filters: Vec::new(),
286 }
287 }
288}
289impl<N, K, L, W, WV, A, AV>
290 MetadataFilter<
291 BTreeFieldsFilter<Uuid>,
292 BTreeFieldsFilter<N>,
293 BTreeFieldsFilter<K>,
294 BTreeFieldsFilter<L>,
295 BTreeFieldValuesFilter<W, Domain1D<WV>, WV>,
296 BTreeFieldValuesFilter<A, (), AV>,
297 N,
298 K,
299 L,
300 W,
301 WV,
302 A,
303 AV,
304 > for BTreeMetadataFilter<N, K, L, W, WV, A, AV>
305where
306 N: Field,
307 K: Field,
308 L: Field,
309 W: Field,
310 WV: WeightValue,
311 A: Field,
312 AV: AnnotationValue,
313{
314 fn id_filters<'a>(&'a self) -> impl Iterator<Item = &'a BTreeFieldsFilter<Uuid>>
315 where
316 BTreeFieldsFilter<Uuid>: 'a,
317 {
318 self.id_filters.iter()
319 }
320 fn name_filters<'a>(&'a self) -> impl Iterator<Item = &'a BTreeFieldsFilter<N>>
321 where
322 BTreeFieldsFilter<N>: 'a,
323 {
324 self.name_filters.iter()
325 }
326 fn kind_filters<'a>(&'a self) -> impl Iterator<Item = &'a BTreeFieldsFilter<K>>
327 where
328 BTreeFieldsFilter<K>: 'a,
329 {
330 self.kind_filters.iter()
331 }
332 fn label_filters<'a>(&'a self) -> impl Iterator<Item = &'a BTreeFieldsFilter<L>>
333 where
334 BTreeFieldsFilter<L>: 'a,
335 {
336 self.label_filters.iter()
337 }
338 fn weight_filters<'a>(
339 &'a self,
340 ) -> impl Iterator<Item = &'a BTreeFieldValuesFilter<W, Domain1D<WV>, WV>>
341 where
342 BTreeFieldValuesFilter<W, Domain1D<WV>, WV>: 'a,
343 {
344 self.weight_filters.iter()
345 }
346 fn annotation_filters<'a>(
347 &'a self,
348 ) -> impl Iterator<Item = &'a BTreeFieldValuesFilter<A, (), AV>>
349 where
350 BTreeFieldsFilter<A>: 'a,
351 {
352 self.annotation_filters.iter()
353 }
354}
355
356#[derive(Clone, Debug, PartialEq, bon::Builder)]
358#[cfg_attr(
359 feature = "serde",
360 derive(serde::Serialize, serde::Deserialize),
361 serde(default, rename_all = "camelCase")
362)]
363#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
364#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
365pub struct BTreeFieldsFilter<F>
366where
367 F: Field,
368{
369 #[builder(default)]
371 fields: BTreeSet<F>,
372
373 #[builder(default)]
375 mode: MatchingMode,
376}
377impl<F: Field> Default for BTreeFieldsFilter<F> {
378 fn default() -> Self {
379 Self {
380 fields: BTreeSet::<F>::new(),
381 mode: MatchingMode::Any,
382 }
383 }
384}
385impl<F: Field> FieldsFilter<F> for BTreeFieldsFilter<F> {
386 fn max_matches(&self) -> usize {
387 self.fields.len()
388 }
389
390 fn match_one(&self, field: &F) -> bool {
391 self.fields.contains(field)
392 }
393
394 fn matching_mode(&self) -> &MatchingMode {
395 &self.mode
396 }
397}
398
399#[derive(Clone, Debug, PartialEq, bon::Builder)]
400#[cfg_attr(
401 feature = "serde",
402 derive(serde::Serialize, serde::Deserialize),
403 serde(default, rename_all = "camelCase")
404)]
405#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
406#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
407pub struct BTreeFieldValuesFilter<F, VF, V>
408where
409 F: Field,
410 VF: ValueFilter<V>,
411{
412 #[builder(default)]
414 pub fields: BTreeMap<F, VF>,
415
416 #[builder(default)]
418 pub mode: MatchingMode,
419
420 #[builder(default)]
422 _value: PhantomData<V>,
423}
424impl<F: Field, VF: ValueFilter<V>, V> Default for BTreeFieldValuesFilter<F, VF, V> {
425 fn default() -> Self {
426 Self {
427 fields: BTreeMap::new(),
428 _value: PhantomData,
429 mode: MatchingMode::Any,
430 }
431 }
432}
433impl<F: Field, V> From<BTreeFieldsFilter<F>> for BTreeFieldValuesFilter<F, (), V> {
434 fn from(value: BTreeFieldsFilter<F>) -> Self {
435 Self {
436 fields: value.fields.into_iter().map(|f| (f, ())).collect(),
437 _value: PhantomData,
438 mode: value.mode,
439 }
440 }
441}
442impl<F: Field, VF: ValueFilter<V>, V> FieldValuesFilter<F, V> for BTreeFieldValuesFilter<F, VF, V> {
443 fn matching_mode(&self) -> &MatchingMode {
444 &self.mode
445 }
446 fn match_one(&self, field: &F, value: &V) -> bool {
447 self.fields
448 .get(field)
449 .map(|vf| vf.matches(value))
450 .unwrap_or_default()
451 }
452 fn max_matches(&self) -> usize {
453 self.fields.len()
454 }
455}
456
457#[derive(Clone, Debug, PartialEq, bon::Builder)]
459#[cfg_attr(
460 feature = "serde",
461 derive(serde::Serialize, serde::Deserialize),
462 serde(default, rename_all = "camelCase")
463)]
464#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
465#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
466pub struct Domain1D<WV>
467where
468 WV: PartialEq + PartialOrd,
469{
470 lower: Option<WV>,
472
473 upper: Option<WV>,
475}
476impl<WV: PartialEq + PartialOrd> Default for Domain1D<WV> {
477 fn default() -> Self {
478 Domain1D {
479 lower: None,
480 upper: None,
481 }
482 }
483}
484impl<T: AsPrimitive<Q>, Q: 'static + Copy + PartialEq + PartialOrd> From<(T, T)> for Domain1D<Q> {
485 fn from(value: (T, T)) -> Self {
486 Domain1D::builder()
487 .lower(value.0.as_())
488 .upper(value.1.as_())
489 .build()
490 }
491}
492impl<WV: PartialEq + PartialOrd> ValueFilter<WV> for Domain1D<WV> {
493 fn matches<'a>(&'a self, value: &'a WV) -> bool {
494 match self {
495 Self {
496 lower: Some(lb),
497 upper: Some(ub),
498 } => value >= lb && value <= ub,
499 Self {
500 lower: Some(lb),
501 upper: None,
502 } => value >= lb,
503 Self {
504 lower: None,
505 upper: Some(ub),
506 } => value <= ub,
507 _ => true,
508 }
509 }
510}
511
512#[derive(Clone, Copy, Debug, Default, PartialEq)]
514#[cfg_attr(
515 feature = "serde",
516 derive(serde::Serialize, serde::Deserialize),
517 serde(rename_all = "camelCase")
518)]
519#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
520#[cfg_attr(feature = "reactive", derive(reactive_stores::Store))]
521pub enum MatchingMode {
522 #[default]
523 Any,
524 All,
525 None,
526}
527
528#[cfg(test)]
529pub mod tests {
530
531 use super::*;
532 use crate::metadata::tests::{A, K, L, N, TestMeta, W};
533
534 #[test]
535 fn test_domain1d_from() {
536 let dom = Domain1D::<f64>::from((1234, 4678));
537 assert_eq!(dom.lower, Some(1234.0));
538 assert_eq!(dom.upper, Some(4678.0));
539 }
540
541 #[test]
542 fn test_btree_fields_filter() {
543 let filter = BTreeFieldsFilter {
544 fields: BTreeSet::from_iter([N::Foo, N::Bar]),
545 mode: MatchingMode::Any,
546 };
547
548 assert!(filter.matches(vec![&N::Foo].into_iter()));
549 assert!(filter.matches(vec![&N::Bar].into_iter()));
550 assert!(filter.matches(vec![&N::Foo, &N::Baz].into_iter()));
551 assert!(!filter.matches(vec![&N::Baz].into_iter()));
552 }
553
554 #[test]
555 fn test_btree_fields_filter_all_mode() {
556 let filter = BTreeFieldsFilter {
557 fields: BTreeSet::from_iter([N::Foo, N::Bar]),
558 mode: MatchingMode::All,
559 };
560
561 assert!(filter.matches(vec![&N::Foo, &N::Bar].into_iter()));
562 assert!(!filter.matches(vec![&N::Foo].into_iter()));
563 assert!(!filter.matches(vec![&N::Bar].into_iter()));
564 }
565
566 #[test]
567 fn test_btree_fields_filter_none_mode() {
568 let filter = BTreeFieldsFilter {
569 fields: BTreeSet::from_iter([N::Foo, N::Bar]),
570 mode: MatchingMode::None,
571 };
572
573 assert!(filter.matches(vec![&N::Baz].into_iter()));
574 assert!(!filter.matches(vec![&N::Foo].into_iter()));
575 assert!(!filter.matches(vec![&N::Bar].into_iter()));
576 assert!(!filter.matches(vec![&N::Foo, &N::Bar].into_iter()));
577 }
578
579 #[test]
580 fn test_btree_field_values_filter() {
581 let domain = Domain1D {
582 lower: Some(1.0),
583 upper: Some(3.0),
584 };
585
586 let filter = BTreeFieldValuesFilter {
587 fields: BTreeMap::from_iter([(W::A, domain)]),
588 _value: PhantomData,
589 mode: MatchingMode::Any,
590 };
591
592 assert!(filter.matches(vec![(&W::A, &2.0)].into_iter()));
593 assert!(!filter.matches(vec![(&W::A, &0.5)].into_iter()));
594 assert!(!filter.matches(vec![(&W::A, &3.5)].into_iter()));
595 }
596
597 #[test]
598 fn test_btree_field_values_filter_all_mode() {
599 let domain1 = Domain1D {
600 lower: Some(1.0),
601 upper: Some(3.0),
602 };
603
604 let domain2 = Domain1D {
605 lower: Some(2.0),
606 upper: Some(4.0),
607 };
608
609 let filter = BTreeFieldValuesFilter {
610 fields: BTreeMap::from_iter([(W::A, domain1), (W::B, domain2)]),
611 _value: PhantomData,
612 mode: MatchingMode::All,
613 };
614
615 assert!(filter.matches(vec![(&W::A, &2.0), (&W::B, &3.0)].into_iter()));
616 assert!(!filter.matches(vec![(&W::A, &2.0)].into_iter()));
617 assert!(!filter.matches(vec![(&W::B, &3.0)].into_iter()));
618 }
619
620 #[test]
621 fn test_btree_field_values_filter_none_mode() {
622 let domain = Domain1D {
623 lower: Some(1.0),
624 upper: Some(3.0),
625 };
626
627 let filter = BTreeFieldValuesFilter {
628 fields: BTreeMap::from_iter([(W::A, domain)]),
629 _value: PhantomData,
630 mode: MatchingMode::None,
631 };
632
633 assert!(filter.matches(vec![(&W::B, &2.0)].into_iter()));
634 assert!(!filter.matches(vec![(&W::A, &2.0)].into_iter()));
635 }
636
637 #[test]
638 fn test_btree_metadata_filter() {
639 let name_filter = BTreeFieldsFilter {
640 fields: BTreeSet::from_iter([N::Foo, N::Bar]),
641 mode: MatchingMode::Any,
642 };
643
644 let kind_filter = BTreeFieldsFilter {
645 fields: BTreeSet::from_iter([K::A, K::B]),
646 mode: MatchingMode::Any,
647 };
648
649 let label_filter = BTreeFieldsFilter {
650 fields: BTreeSet::from_iter([L::A, L::B]),
651 mode: MatchingMode::Any,
652 };
653
654 let weight_domain = Domain1D {
655 lower: Some(1.0),
656 upper: Some(3.0),
657 };
658
659 let weight_filter = BTreeFieldValuesFilter {
660 fields: BTreeMap::from_iter([(W::A, weight_domain)]),
661 _value: PhantomData,
662 mode: MatchingMode::Any,
663 };
664
665 let annotation_filter = BTreeFieldValuesFilter::builder()
666 .fields(bon::map! { A::A: ()})
667 .mode(MatchingMode::Any)
668 .build();
669
670 let metadata_filter = BTreeMetadataFilter {
671 id_filters: vec![],
672 name_filters: vec![name_filter],
673 kind_filters: vec![kind_filter],
674 label_filters: vec![label_filter],
675 weight_filters: vec![weight_filter],
676 annotation_filters: vec![annotation_filter],
677 };
678
679 let foo = TestMeta::foo();
680 let bar = TestMeta::bar();
681 let baz = TestMeta::baz();
682 let quux = TestMeta::quux();
683
684 assert!(metadata_filter.matches_metadata(&foo));
685 assert!(!metadata_filter.matches_metadata(&bar));
686 assert!(!metadata_filter.matches_metadata(&baz));
687 assert!(!metadata_filter.matches_metadata(&quux));
688
689 assert!(
691 metadata_filter
692 .id_filters()
693 .all(|f| f.matches(vec![&foo.id].into_iter()))
694 );
695 assert!(
696 metadata_filter
697 .name_filters()
698 .all(|f| f.matches(foo.name().into_iter()))
699 );
700 assert!(
701 metadata_filter
702 .kind_filters()
703 .all(|f| f.matches(foo.kind().into_iter()))
704 );
705 assert!(
706 metadata_filter
707 .label_filters()
708 .all(|f| f.matches(foo.labels()))
709 );
710 assert!(
711 metadata_filter
712 .weight_filters()
713 .all(|f| f.matches(foo.weights()))
714 );
715 assert!(
716 metadata_filter
717 .annotation_filters()
718 .all(|f| f.matches(foo.annotations()))
719 );
720
721 let none_filter = BTreeFieldsFilter {
723 fields: BTreeSet::from_iter([N::Baz]),
724 mode: MatchingMode::None,
725 };
726 assert!(none_filter.matches(foo.name().into_iter()));
727 assert!(!none_filter.matches(baz.name().into_iter()));
728 }
729
730 #[test]
731 fn test_domain1d_value_filter() {
732 let domain = Domain1D {
733 lower: Some(1.0),
734 upper: Some(3.0),
735 };
736
737 assert!(domain.matches(&2.0));
738 assert!(!domain.matches(&0.5));
739 assert!(!domain.matches(&3.5));
740
741 let lower_only = Domain1D {
742 lower: Some(1.0),
743 upper: None,
744 };
745
746 assert!(lower_only.matches(&2.0));
747 assert!(!lower_only.matches(&0.5));
748
749 let upper_only = Domain1D {
750 lower: None,
751 upper: Some(3.0),
752 };
753
754 assert!(upper_only.matches(&2.0));
755 assert!(!upper_only.matches(&3.5));
756
757 let no_bounds = Domain1D {
758 lower: None,
759 upper: None,
760 };
761
762 assert!(no_bounds.matches(&2.0));
763 assert!(no_bounds.matches(&0.5));
764 assert!(no_bounds.matches(&3.5));
765 }
766}