composable_indexes/index/im/
grouped.rs

1//! A combinator that groups entries by a key and maintains separate indexes for each group.
2//! This enables functionality similar to the "group by" expression.
3
4use core::hash::Hash;
5
6use crate::{
7    ShallowClone, aggregation,
8    core::{DefaultHasher, Index, Insert, Remove, Seal, Update},
9};
10
11pub struct Grouped<T, GroupKey, InnerIndex, S = DefaultHasher> {
12    group_key: fn(&T) -> GroupKey,
13    mk_index: fn() -> InnerIndex,
14    groups: imbl::GenericHashMap<
15        GroupKey,
16        (InnerIndex, aggregation::Count),
17        S,
18        imbl::shared_ptr::DefaultSharedPtr,
19    >,
20    empty: InnerIndex,
21    _marker: core::marker::PhantomData<fn() -> T>,
22}
23
24impl<In, GroupKey, InnerIndex, S> Clone for Grouped<In, GroupKey, InnerIndex, S>
25where
26    InnerIndex: Clone,
27    GroupKey: Clone,
28    S: Clone,
29{
30    fn clone(&self) -> Self {
31        Self {
32            group_key: self.group_key,
33            mk_index: self.mk_index,
34            groups: self.groups.clone(),
35            empty: self.empty.clone(),
36            _marker: core::marker::PhantomData,
37        }
38    }
39}
40
41impl<In, GroupKey: Clone, InnerIndex: ShallowClone, S: Clone> ShallowClone
42    for Grouped<In, GroupKey, InnerIndex, S>
43{
44}
45
46impl<In, GroupKey, InnerIndex> Grouped<In, GroupKey, InnerIndex> {
47    pub fn new(group_key: fn(&In) -> GroupKey, mk_index: fn() -> InnerIndex) -> Self {
48        Grouped {
49            group_key,
50            mk_index,
51            empty: mk_index(),
52            groups: imbl::GenericHashMap::with_hasher(DefaultHasher::default()),
53            _marker: core::marker::PhantomData,
54        }
55    }
56
57    pub fn with_hasher<S: core::hash::BuildHasher>(
58        group_key: fn(&In) -> GroupKey,
59        mk_index: fn() -> InnerIndex,
60        hasher: S,
61    ) -> Grouped<In, GroupKey, InnerIndex, S> {
62        Grouped {
63            group_key,
64            mk_index,
65            empty: mk_index(),
66            groups: imbl::GenericHashMap::with_hasher(hasher),
67            _marker: core::marker::PhantomData,
68        }
69    }
70}
71
72impl<T, GroupKey, InnerIndex, S> Grouped<T, GroupKey, InnerIndex, S>
73where
74    GroupKey: Eq + Hash + Clone,
75    InnerIndex: Clone,
76    S: core::hash::BuildHasher + Clone,
77{
78    fn get_ix(&mut self, elem: &T) -> &mut (InnerIndex, aggregation::Count) {
79        let key = (self.group_key)(elem);
80        self.groups.entry(key).or_insert_with(|| {
81            let ix = (self.mk_index)();
82            (ix, aggregation::Count::new())
83        })
84    }
85}
86
87impl<In, GroupKey, InnerIndex, S> Index<In> for Grouped<In, GroupKey, InnerIndex, S>
88where
89    GroupKey: Eq + Hash + Clone,
90    InnerIndex: Index<In> + Clone,
91    S: core::hash::BuildHasher + Clone,
92{
93    fn insert(&mut self, seal: Seal, op: &Insert<In>) {
94        self.get_ix(op.new).insert(seal, op);
95    }
96
97    fn update(&mut self, seal: Seal, op: &Update<In>) {
98        let existing_key = (self.group_key)(op.existing);
99        let new_key = (self.group_key)(op.new);
100
101        if existing_key == new_key {
102            self.get_ix(op.new).update(seal, op);
103        } else {
104            self.get_ix(op.existing).remove(
105                seal,
106                &Remove {
107                    key: op.key,
108                    existing: op.existing,
109                },
110            );
111            self.get_ix(op.new).insert(
112                seal,
113                &Insert {
114                    key: op.key,
115                    new: op.new,
116                },
117            );
118        }
119    }
120
121    fn remove(&mut self, seal: Seal, op: &Remove<In>) {
122        let key = (self.group_key)(op.existing);
123        let ix = self.groups.get_mut(&key).unwrap();
124        ix.remove(seal, op);
125        if ix.1.count() == 0 {
126            self.groups.remove(&key);
127        }
128    }
129}
130
131impl<In, GroupKey, InnerIndex, S> Grouped<In, GroupKey, InnerIndex, S>
132where
133    GroupKey: Eq + Hash,
134    S: core::hash::BuildHasher + Clone,
135{
136    pub fn get(&self, key: &GroupKey) -> &InnerIndex {
137        self.groups.get(key).map(|i| &i.0).unwrap_or(&self.empty)
138    }
139
140    pub fn groups(&self) -> impl Iterator<Item = (&GroupKey, &InnerIndex)> {
141        self.groups.iter().map(|(k, v)| (k, &v.0))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::aggregation::Sum;
149    use crate::core::Collection;
150    use crate::index::im::btree::BTree;
151    use crate::index::premap::PremapOwned;
152    use crate::testutils::{SortedVec, prop_assert_reference};
153
154    #[derive(Debug, Clone, PartialEq, Eq)]
155    struct Payload {
156        ty: String,
157        value: u32,
158    }
159
160    fn sample_data() -> Vec<Payload> {
161        vec![
162            Payload {
163                ty: "a".to_string(),
164                value: 1,
165            },
166            Payload {
167                ty: "b".to_string(),
168                value: 2,
169            },
170            Payload {
171                ty: "a".to_string(),
172                value: 3,
173            },
174        ]
175    }
176
177    #[test]
178    fn group_ix() {
179        let mut db = Collection::<Payload, _>::new(Grouped::new(
180            |p: &Payload| p.ty.clone(),
181            || PremapOwned::new(|p: &Payload| p.value, BTree::<u32>::new()),
182        ));
183
184        sample_data().into_iter().for_each(|p| {
185            db.insert(p);
186        });
187
188        let a_max = db.query(|ix| ix.get(&"a".to_string()).inner().max_one());
189        assert_eq!(a_max.as_ref().map(|p| p.value), Some(3));
190
191        let b_max = db.query(|ix| ix.get(&"b".to_string()).inner().max_one());
192        assert_eq!(b_max.as_ref().map(|p| p.value), Some(2));
193
194        let c_max = db.query(|ix| ix.get(&"c".to_string()).inner().max_one());
195        assert_eq!(c_max, None);
196    }
197
198    #[test]
199    fn test_reference() {
200        prop_assert_reference(
201            || {
202                Grouped::new(
203                    |p: &u8| p % 4,
204                    || PremapOwned::new(|x| *x as u64, Sum::new()),
205                )
206            },
207            |db| {
208                db.query(|ix| {
209                    ix.groups()
210                        .map(|(k, v)| (*k, v.inner().get()))
211                        .filter(|(_, v)| *v > 0)
212                        .collect::<Vec<_>>()
213                })
214                .into()
215            },
216            |xs| {
217                let mut groups = std::collections::HashMap::new();
218                for x in xs {
219                    let key = x % 4;
220                    *groups.entry(key).or_insert(0) += x as u64;
221                }
222                groups
223                    .into_iter()
224                    .filter(|(_, v)| *v > 0)
225                    .collect::<SortedVec<_>>()
226            },
227            None,
228        );
229    }
230}