composable_indexes/index/
btree.rs

1use alloc::collections::BTreeMap;
2use alloc::string::String;
3use alloc::vec::Vec;
4
5use crate::core::{Index, Insert, Key, Remove, Seal};
6use crate::index::generic::DefaultKeySet;
7use crate::index::generic::KeySet;
8
9/// An index for equivalence and range lookups backed by a B-Tree.
10#[derive(Clone)]
11pub struct BTree<T, KeySet = DefaultKeySet> {
12    data: BTreeMap<T, KeySet>,
13}
14
15impl<T, KeySet_> Default for BTree<T, KeySet_>
16where
17    T: Ord + Clone,
18    KeySet_: KeySet + Default,
19{
20    fn default() -> Self {
21        Self {
22            data: BTreeMap::new(),
23        }
24    }
25}
26
27impl<T, KeySet_> BTree<T, KeySet_>
28where
29    T: Ord + Clone,
30    KeySet_: KeySet + Default,
31{
32    pub fn new() -> Self {
33        Self::default()
34    }
35}
36
37impl<In: Ord + Clone, KeySet_: KeySet> Index<In> for BTree<In, KeySet_> {
38    #[inline]
39    fn insert(&mut self, _seal: Seal, op: &Insert<In>) {
40        self.data.entry(op.new.clone()).or_default().insert(op.key);
41    }
42
43    #[inline]
44    fn remove(&mut self, _seal: Seal, op: &Remove<In>) {
45        let existing = self.data.get_mut(op.existing).unwrap();
46        existing.remove(&op.key);
47        if existing.is_empty() {
48            self.data.remove(op.existing);
49        }
50    }
51}
52
53impl<T, KeySet_: KeySet> BTree<T, KeySet_> {
54    #[inline]
55    pub fn contains(&self, key: &T) -> bool
56    where
57        T: Ord + Eq,
58    {
59        self.data.contains_key(key)
60    }
61
62    #[inline]
63    pub fn count_distinct(&self) -> usize
64    where
65        T: Ord + Eq,
66    {
67        self.data.len()
68    }
69
70    #[inline]
71    pub fn get_one(&self, key: &T) -> Option<Key>
72    where
73        T: Ord + Eq,
74    {
75        self.data.get(key).and_then(|v| v.iter().next())
76    }
77
78    #[inline]
79    pub fn get_all(&self, key: &T) -> Vec<Key>
80    where
81        T: Ord + Eq,
82    {
83        let keys = self.data.get(key);
84        keys.map(|v| v.iter().collect()).unwrap_or_default()
85    }
86
87    #[inline]
88    pub fn range<R>(&self, range: R) -> Vec<Key>
89    where
90        T: Ord + Eq,
91        R: core::ops::RangeBounds<T>,
92    {
93        self.data.range(range).flat_map(|(_, v)| v.iter()).collect()
94    }
95
96    #[inline]
97    pub fn min_one(&self) -> Option<Key>
98    where
99        T: Ord + Eq,
100    {
101        self.data
102            .iter()
103            .next()
104            .map(|(_, v)| v.iter().next().unwrap())
105    }
106
107    #[inline]
108    pub fn max_one(&self) -> Option<Key>
109    where
110        T: Ord + Eq,
111    {
112        self.data
113            .iter()
114            .next_back()
115            .map(|(_, v)| v.iter().next().unwrap())
116    }
117}
118
119impl BTree<String> {
120    pub fn starts_with(&self, prefix: &str) -> Vec<Key> {
121        let start = alloc::string::ToString::to_string(prefix);
122        // Increment the last character to get the exclusive upper bound
123        let mut end = start.clone();
124        if let Some(last_char) = end.pop() {
125            let next_char = (last_char as u8 + 1) as char;
126            end.push(next_char);
127        } else {
128            end.push('\u{10FFFF}'); // Push the maximum valid Unicode character
129        }
130
131        self.data
132            .range(start..end)
133            .flat_map(|(_, v)| v.iter().cloned())
134            .collect()
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use std::collections::BTreeSet;
141
142    use super::*;
143    use crate::index::premap::PremapOwned;
144    use crate::testutils::{SortedVec, prop_assert_reference};
145    use proptest_derive::Arbitrary;
146
147    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Arbitrary)]
148    enum Month {
149        Jan,
150        Feb,
151        Mar,
152        Apr,
153    }
154
155    #[test]
156    fn test_aggrs() {
157        prop_assert_reference(
158            BTree::<Month>::new,
159            |db| {
160                let (mi, ma) = db.query(|ix| (ix.max_one(), ix.min_one()));
161                (mi.cloned(), ma.cloned())
162            },
163            |xs| {
164                let max = xs.iter().max().cloned();
165                let min = xs.iter().min().cloned();
166                (max, min)
167            },
168            None,
169        );
170    }
171
172    #[test]
173    fn test_lookup() {
174        prop_assert_reference(
175            || PremapOwned::new(|i: &(Month, u32)| i.1, BTree::<u32>::new()),
176            |db| {
177                db.query(|ix| ix.get_all(&1))
178                    .into_iter()
179                    .cloned()
180                    .collect::<SortedVec<_>>()
181            },
182            |xs| {
183                xs.iter()
184                    .filter(|i| i.1 == 1)
185                    .cloned()
186                    .collect::<SortedVec<_>>()
187            },
188            None,
189        );
190    }
191
192    #[test]
193    fn test_range() {
194        prop_assert_reference(
195            || PremapOwned::new(|i: &(Month, u8)| i.0, BTree::<Month>::new()),
196            |db| {
197                db.query(|ix| ix.range(Month::Jan..=Month::Feb))
198                    .into_iter()
199                    .cloned()
200                    .collect::<SortedVec<_>>()
201            },
202            |xs| {
203                xs.iter()
204                    .filter(|i| i.0 >= Month::Jan && i.0 <= Month::Feb)
205                    .cloned()
206                    .collect::<SortedVec<_>>()
207            },
208            None,
209        );
210    }
211
212    #[test]
213    fn test_count_distinct() {
214        prop_assert_reference(
215            BTree::<u8>::new,
216            |db| db.query(|ix| ix.count_distinct()),
217            |xs| xs.iter().collect::<BTreeSet<_>>().len(),
218            None,
219        );
220    }
221
222    #[test]
223    fn test_starts_with() {
224        prop_assert_reference(
225            || BTree::<String>::new(),
226            |db| {
227                db.query(|ix| ix.starts_with("ab"))
228                    .into_iter()
229                    .cloned()
230                    .collect::<SortedVec<_>>()
231            },
232            |xs| {
233                xs.iter()
234                    .filter(|s| s.starts_with("ab"))
235                    .cloned()
236                    .collect::<SortedVec<_>>()
237            },
238            None,
239        );
240    }
241}