kermit_ds/ds/column_trie/
implementation.rs

1use {
2    crate::relation::{Relation, RelationHeader},
3    kermit_iters::{JoinIterable, TrieIterable},
4    std::fmt,
5};
6
7pub struct ColumnTrieLayer {
8    pub data: Vec<usize>,
9    pub interval: Vec<usize>,
10}
11
12pub struct ColumnTrie {
13    header: RelationHeader,
14    pub layers: Vec<ColumnTrieLayer>,
15}
16
17impl ColumnTrie {
18    pub fn layer(&self, layer_i: usize) -> &ColumnTrieLayer { &self.layers[layer_i] }
19
20    fn internal_insert(&mut self, tuple: &[usize]) -> bool {
21        /// Adds an interval to a layer at some index.
22        fn add_interval(layer: &mut ColumnTrieLayer, i: usize) {
23            if i == layer.interval.len() {
24                // If the index is greater than the length of the layer, we push a new interval
25                layer.interval.push(layer.data.len());
26            } else {
27                // Otherwise, we insert the interval at the specified index
28                layer.interval.insert(i, layer.interval[i]);
29            }
30        }
31
32        let mut interval_index = 0;
33        'layer_loop: for (layer_i, &k) in tuple.iter().enumerate() {
34            // There are still keys to insert
35
36            if self.layers[layer_i].data.is_empty() {
37                // layer is empty, so we can just add the key and continue
38                self.layers[layer_i].data.push(k);
39                self.layers[layer_i].interval.push(0);
40                interval_index = 0;
41            } else {
42                // layer is not empty, so we must find the place to insert it
43                let start_index = self.layers[layer_i].interval[interval_index];
44                let end_index = if interval_index == self.layers[layer_i].interval.len() - 1 {
45                    self.layers[layer_i].data.len()
46                } else {
47                    self.layers[layer_i].interval[interval_index + 1]
48                };
49
50                for i in start_index..end_index {
51                    if self.layers[layer_i].data[i] == k {
52                        // key exists in data, so we can just continue
53                        continue 'layer_loop;
54                    } else if k < self.layers[layer_i].data[i] {
55                        // we need to insert at position i
56                        self.layers[layer_i].data.insert(i, k);
57                        // now we increment all intervals after this index
58                        for j in (interval_index + 1)..self.layers[layer_i].interval.len() {
59                            self.layers[layer_i].interval[j] += 1;
60                        }
61                        // if this is the last layer, we're finished
62                        if layer_i == self.header().arity() - 1 {
63                            return true;
64                        }
65                        add_interval(&mut self.layers[layer_i + 1], i);
66                        interval_index = i;
67                        continue 'layer_loop;
68                    }
69                }
70
71                // key is greater than all existing keys, so we add it to the end (at end index)
72                if end_index == self.layers[layer_i].data.len() {
73                    // if we're at the end, we have to push
74                    self.layers[layer_i].data.push(k);
75                } else {
76                    // otherwise insert
77                    self.layers[layer_i].data.insert(end_index, k);
78                    // increment all intervals after this index
79                    for j in interval_index + 1..self.layers[layer_i].interval.len() {
80                        self.layers[layer_i].interval[j] += 1;
81                    }
82                }
83                if layer_i == self.header().arity() - 1 {
84                    // if there are no more layers, we are done
85                    return true;
86                }
87                add_interval(&mut self.layers[layer_i + 1], end_index);
88                interval_index = end_index;
89            }
90        }
91        true
92    }
93}
94
95impl fmt::Display for ColumnTrie {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        for (layer_i, layer) in self.layers.iter().enumerate() {
98            writeln!(f, "LAYER {layer_i}")?;
99            write!(f, "Data: [")?;
100            for (i, data) in layer.data.iter().enumerate() {
101                if i > 0 {
102                    write!(f, ", ")?;
103                }
104                write!(f, "{data}")?;
105            }
106            writeln!(f, "]")?;
107            write!(f, "Interval: [")?;
108            for (i, interval) in layer.interval.iter().enumerate() {
109                if i > 0 {
110                    write!(f, ", ")?;
111                }
112                write!(f, "{interval}")?;
113            }
114            writeln!(f, "]")?;
115        }
116        Ok(())
117    }
118}
119
120impl JoinIterable for ColumnTrie {}
121
122impl crate::relation::Projectable for ColumnTrie {
123    fn project(&self, columns: Vec<usize>) -> Self {
124        // Create a new header based on the current header but with projected attributes
125        let current_header = self.header();
126        let projected_attrs: Vec<String> = columns
127            .iter()
128            .filter_map(|&col_idx| current_header.attrs().get(col_idx).cloned())
129            .collect();
130
131        let new_header = if projected_attrs.is_empty() {
132            // If no named attributes, create a positional header
133            crate::relation::RelationHeader::new_nameless_positional(columns.len())
134        } else {
135            // Create a header with the projected attributes
136            crate::relation::RelationHeader::new_nameless(projected_attrs)
137        };
138
139        // Collect all tuples from the current relation using the iterator
140        let all_tuples: Vec<Vec<usize>> = self.trie_iter().into_iter().collect();
141
142        // Project each tuple to the specified columns
143        let projected_tuples: Vec<Vec<usize>> = all_tuples
144            .into_iter()
145            .map(|tuple| columns.iter().map(|&col_idx| tuple[col_idx]).collect())
146            .collect();
147
148        // Create new relation from projected tuples
149        Self::from_tuples(new_header, projected_tuples)
150    }
151}
152
153impl Relation for ColumnTrie {
154    fn header(&self) -> &RelationHeader { &self.header }
155
156    fn new(header: RelationHeader) -> Self {
157        ColumnTrie {
158            layers: (0..header.arity())
159                .map(|_| ColumnTrieLayer {
160                    data: vec![],
161                    interval: vec![],
162                })
163                .collect::<Vec<_>>(),
164            header,
165        }
166    }
167
168    fn from_tuples(header: RelationHeader, mut tuples: Vec<Vec<usize>>) -> Self {
169        if tuples.is_empty() {
170            Self::new(header)
171        } else {
172            tuples.sort_unstable_by(|a, b| {
173                for i in 0..a.len() {
174                    match a[i].cmp(&b[i]) {
175                        | std::cmp::Ordering::Less => return std::cmp::Ordering::Less,
176                        | std::cmp::Ordering::Greater => return std::cmp::Ordering::Greater,
177                        | std::cmp::Ordering::Equal => continue,
178                    }
179                }
180                std::cmp::Ordering::Equal
181            });
182
183            let mut trie = Self::new(header);
184            for tuple in tuples {
185                trie.insert(tuple);
186            }
187            trie
188        }
189    }
190
191    fn insert(&mut self, tuple: Vec<usize>) -> bool {
192        debug_assert!(
193            tuple.len() == self.header().arity(),
194            "Tuple length must match the arity of the trie."
195        );
196        self.internal_insert(&tuple)
197    }
198
199    fn insert_all(&mut self, tuples: Vec<Vec<usize>>) -> bool {
200        for tuple in tuples {
201            if !self.insert(tuple) {
202                return false;
203            }
204        }
205        true
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use {
212        super::ColumnTrie,
213        crate::relation::{Projectable, Relation as _},
214        kermit_iters::TrieIterable,
215    };
216
217    #[test]
218    fn test_insert() {
219        let mut trie = ColumnTrie::new(2.into());
220        trie.insert(vec![2, 3]);
221        println!("{trie}");
222        trie.insert(vec![3, 1]);
223        println!("{trie}");
224        trie.insert(vec![1, 2]);
225        println!("{trie}");
226        println!("potato")
227    }
228
229    #[test]
230    fn test_project() {
231        let mut trie = ColumnTrie::new(3.into());
232        trie.insert(vec![1, 2, 3]);
233        trie.insert(vec![4, 5, 6]);
234        trie.insert(vec![7, 8, 9]);
235
236        // Project to columns 0 and 2 (first and third columns)
237        let projected = trie.project(vec![0, 2]);
238        assert_eq!(projected.header().arity(), 2);
239
240        // Collect all tuples from the projected relation using iterator
241        let mut all_tuples: Vec<Vec<usize>> = projected.trie_iter().into_iter().collect();
242
243        // Sort for comparison
244        all_tuples.sort();
245        assert_eq!(all_tuples, vec![vec![1, 3], vec![4, 6], vec![7, 9]]);
246    }
247
248    #[test]
249    fn test_project_with_named_attributes() {
250        // Create a relation with named attributes
251        let header = crate::relation::RelationHeader::new_nameless(vec![
252            "a".to_string(),
253            "b".to_string(),
254            "c".to_string(),
255        ]);
256        let mut trie = ColumnTrie::new(header);
257        trie.insert(vec![1, 2, 3]);
258        trie.insert(vec![4, 5, 6]);
259
260        // Project to columns 0 and 2 (first and third columns)
261        let projected = trie.project(vec![0, 2]);
262        assert_eq!(projected.header().arity(), 2);
263        assert_eq!(projected.header().attrs(), &[
264            "a".to_string(),
265            "c".to_string()
266        ]);
267
268        // Collect all tuples from the projected relation using iterator
269        let mut all_tuples: Vec<Vec<usize>> = projected.trie_iter().into_iter().collect();
270
271        // Sort for comparison
272        all_tuples.sort();
273        assert_eq!(all_tuples, vec![vec![1, 3], vec![4, 6]]);
274    }
275}