composable_indexes/index/im/
btree.rs

1//! An index backed by [`imbl::OrdMap`]. Provides efficient
2//! queries for the minimum/maximum keys and range queries using
3//! persistent immutable data structures.
4
5use alloc::string::String;
6use alloc::vec::Vec;
7
8use imbl::OrdMap;
9
10use crate::{
11    ShallowClone,
12    core::{Index, Insert, Key, Remove, Seal},
13    index::generic::{DefaultImmutableKeySet, KeySet},
14};
15
16#[derive(Clone)]
17pub struct BTree<T, KeySet = DefaultImmutableKeySet> {
18    data: OrdMap<T, KeySet>,
19}
20
21impl<T: Ord + Clone, KeySet_: KeySet + Default> Default for BTree<T, KeySet_> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<T: Ord + Clone, KeySet_: KeySet + Default> BTree<T, KeySet_> {
28    pub fn new() -> Self {
29        BTree {
30            data: OrdMap::new(),
31        }
32    }
33}
34
35impl<T: Clone, KeySet_: Clone> ShallowClone for BTree<T, KeySet_> {}
36
37impl<In, KeySet_> Index<In> for BTree<In, KeySet_>
38where
39    In: Ord + Clone,
40    KeySet_: KeySet + Clone,
41{
42    fn insert(&mut self, _seal: Seal, op: &Insert<In>) {
43        self.data.entry(op.new.clone()).or_default().insert(op.key);
44    }
45
46    fn remove(&mut self, _seal: Seal, op: &Remove<In>) {
47        let existing = self.data.get_mut(op.existing).unwrap();
48        existing.remove(&op.key);
49        if existing.is_empty() {
50            self.data.remove(op.existing);
51        }
52    }
53}
54
55impl<T, KeySet_: KeySet> BTree<T, KeySet_> {
56    pub fn contains(&self, key: &T) -> bool
57    where
58        T: Ord + Clone,
59    {
60        self.data.contains_key(key)
61    }
62
63    pub fn count_distinct(&self) -> usize
64    where
65        T: Ord + Clone,
66    {
67        self.data.len()
68    }
69
70    pub fn get_one(&self, key: &T) -> Option<Key>
71    where
72        T: Ord + Clone,
73    {
74        self.data.get(key).and_then(|v| v.iter().next())
75    }
76
77    pub fn get_all(&self, key: &T) -> Vec<Key>
78    where
79        T: Ord + Clone,
80    {
81        self.data
82            .get(key)
83            .map(|v| v.iter().collect())
84            .unwrap_or_default()
85    }
86
87    pub fn range<R>(&self, range: R) -> Vec<Key>
88    where
89        T: Ord + Clone,
90        R: core::ops::RangeBounds<T>,
91    {
92        self.data.range(range).flat_map(|(_, v)| v.iter()).collect()
93    }
94
95    pub fn min_one(&self) -> Option<Key>
96    where
97        T: Ord + Clone,
98    {
99        self.data.iter().next().and_then(|(_, v)| v.iter().next())
100    }
101
102    pub fn max_one(&self) -> Option<Key>
103    where
104        T: Ord + Clone,
105    {
106        self.data
107            .iter()
108            .next_back()
109            .and_then(|(_, v)| v.iter().next())
110    }
111}
112
113impl BTree<String> {
114    pub fn starts_with(&self, prefix: &str) -> Vec<Key> {
115        let start = alloc::string::ToString::to_string(prefix);
116        // Increment the last character to get the exclusive upper bound
117        let mut end = start.clone();
118        if let Some(last_char) = end.pop() {
119            let next_char = (last_char as u8 + 1) as char;
120            end.push(next_char);
121        } else {
122            end.push('\u{10FFFF}'); // Push the maximum valid Unicode character
123        }
124
125        self.data
126            .range(start..end)
127            .flat_map(|(_, v)| v.iter().cloned())
128            .collect()
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::index::premap::PremapOwned;
136    use crate::testutils::{SortedVec, prop_assert_reference};
137    use proptest_derive::Arbitrary;
138
139    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Arbitrary)]
140    enum Month {
141        Jan,
142        Feb,
143        Mar,
144        Apr,
145    }
146
147    #[test]
148    fn test_aggrs() {
149        prop_assert_reference(
150            || BTree::<Month>::new(),
151            |db| {
152                let (mi, ma) = db.query(|ix| (ix.max_one(), ix.min_one()));
153                (mi.cloned(), ma.cloned())
154            },
155            |xs| {
156                let max = xs.iter().max().cloned();
157                let min = xs.iter().min().cloned();
158                (max, min)
159            },
160            None,
161        );
162    }
163
164    #[test]
165    fn test_lookup() {
166        prop_assert_reference(
167            || PremapOwned::new(|i: &(Month, u32)| i.1, BTree::<u32>::new()),
168            |db| {
169                db.query(|ix| ix.get_all(&1))
170                    .into_iter()
171                    .cloned()
172                    .collect::<SortedVec<_>>()
173            },
174            |xs| {
175                xs.iter()
176                    .filter(|i| i.1 == 1)
177                    .cloned()
178                    .collect::<SortedVec<_>>()
179            },
180            None,
181        );
182    }
183
184    #[test]
185    fn test_range() {
186        prop_assert_reference(
187            || PremapOwned::new(|i: &(Month, u8)| i.0, BTree::<Month>::new()),
188            |db| {
189                db.query(|ix| ix.range(Month::Jan..=Month::Feb))
190                    .into_iter()
191                    .cloned()
192                    .collect::<SortedVec<_>>()
193            },
194            |xs| {
195                xs.iter()
196                    .filter(|i| i.0 >= Month::Jan && i.0 <= Month::Feb)
197                    .cloned()
198                    .collect::<SortedVec<_>>()
199            },
200            None,
201        );
202    }
203
204    #[test]
205    fn test_count_distinct() {
206        use alloc::collections::BTreeSet;
207        prop_assert_reference(
208            BTree::<u8>::new,
209            |db| db.query(|ix| ix.count_distinct()),
210            |xs| xs.iter().collect::<BTreeSet<_>>().len(),
211            None,
212        );
213    }
214
215    #[test]
216    fn test_starts_with() {
217        prop_assert_reference(
218            BTree::<String>::new,
219            |db| {
220                db.query(|ix| ix.starts_with("ab"))
221                    .into_iter()
222                    .cloned()
223                    .collect::<SortedVec<_>>()
224            },
225            |xs| {
226                xs.iter()
227                    .filter(|s| s.starts_with("ab"))
228                    .cloned()
229                    .collect::<SortedVec<_>>()
230            },
231            None,
232        );
233    }
234}