alopex_sql/storage/
index.rs

1use std::convert::TryInto;
2use std::marker::PhantomData;
3
4use alopex_core::kv::KVTransaction;
5use alopex_core::types::{Key, Value};
6
7use super::error::{Result, StorageError};
8use super::{KeyEncoder, SqlValue};
9
10/// IndexStorage manages secondary index entries and lookups.
11///
12/// The two lifetime parameters serve distinct purposes:
13/// - `'a`: The borrow duration of the transaction reference
14/// - `'txn`: The lifetime parameter of the KVTransaction type itself
15///
16/// This separation is necessary to allow SqlTransaction to return IndexStorage
17/// instances while maintaining proper lifetime relationships with GATs.
18pub struct IndexStorage<'a, 'txn, T: KVTransaction<'txn>> {
19    txn: &'a mut T,
20    index_id: u32,
21    unique: bool,
22    column_indices: Vec<usize>,
23    _txn_lifetime: PhantomData<&'txn ()>,
24}
25
26impl<'a, 'txn, T: KVTransaction<'txn>> IndexStorage<'a, 'txn, T> {
27    /// Create a new IndexStorage for the given index definition.
28    pub fn new(txn: &'a mut T, index_id: u32, unique: bool, column_indices: Vec<usize>) -> Self {
29        Self {
30            txn,
31            index_id,
32            unique,
33            column_indices,
34            _txn_lifetime: PhantomData,
35        }
36    }
37
38    /// Insert an index entry for the provided row values and RowID.
39    pub fn insert(&mut self, row: &[SqlValue], row_id: u64) -> Result<()> {
40        let values = self.extract_values(row)?;
41        let key = self.build_key(&values, row_id)?;
42        if self.unique {
43            let prefix = self.value_prefix(&values)?;
44            self.ensure_unique(&prefix, row_id)?;
45        }
46        self.txn.put(key, Vec::new())?;
47        Ok(())
48    }
49
50    /// Delete an index entry associated with the provided row values and RowID.
51    pub fn delete(&mut self, row: &[SqlValue], row_id: u64) -> Result<()> {
52        let values = self.extract_values(row)?;
53        let key = self.build_key(&values, row_id)?;
54        self.txn.delete(key)?;
55        Ok(())
56    }
57
58    /// Equality lookup for single-column index.
59    pub fn lookup(&mut self, value: &SqlValue) -> Result<Vec<u64>> {
60        if self.column_indices.len() != 1 {
61            return Err(StorageError::TypeMismatch {
62                expected: "single-column index".into(),
63                actual: format!("{} columns", self.column_indices.len()),
64            });
65        }
66        self.lookup_internal(std::slice::from_ref(value))
67    }
68
69    /// Equality lookup for composite index.
70    pub fn lookup_composite(&mut self, values: &[SqlValue]) -> Result<Vec<u64>> {
71        if values.len() != self.column_indices.len() {
72            return Err(StorageError::TypeMismatch {
73                expected: format!("{} values", self.column_indices.len()),
74                actual: format!("{} values", values.len()),
75            });
76        }
77        self.lookup_internal(values)
78    }
79
80    /// Range lookup for single-column index.
81    pub fn range_scan(
82        &mut self,
83        start: Option<&SqlValue>,
84        end: Option<&SqlValue>,
85        start_inclusive: bool,
86        end_inclusive: bool,
87    ) -> Result<IndexScanIterator<'_>> {
88        // Caller contract: SQL BETWEEN a AND b should pass start_inclusive=true, end_inclusive=true
89        // so the range aligns with SQL semantics. Exclusive bounds are available for internal use
90        // (e.g., >, <) but should not be used for BETWEEN.
91        if self.column_indices.len() != 1 {
92            return Err(StorageError::TypeMismatch {
93                expected: "single-column index".into(),
94                actual: format!("{} columns", self.column_indices.len()),
95            });
96        }
97        let (start_key, end_key) = self.range_bounds(start, end, start_inclusive, end_inclusive)?;
98        let index_id = self.index_id;
99        let inner = self.txn.scan_range(&start_key, &end_key)?;
100        Ok(IndexScanIterator::new(inner, index_id))
101    }
102
103    fn lookup_internal(&mut self, values: &[SqlValue]) -> Result<Vec<u64>> {
104        let prefix = self.value_prefix(values)?;
105        let iter = self.txn.scan_prefix(&prefix)?;
106        iter.map(|(key, _)| extract_row_id(&key, self.index_id))
107            .collect()
108    }
109
110    fn ensure_unique(&mut self, prefix: &[u8], row_id: u64) -> Result<()> {
111        let iter = self.txn.scan_prefix(prefix)?;
112        for (key, _) in iter {
113            let existing = extract_row_id(&key, self.index_id)?;
114            if existing != row_id {
115                return Err(StorageError::UniqueViolation {
116                    index_id: self.index_id,
117                });
118            }
119        }
120        Ok(())
121    }
122
123    fn extract_values(&self, row: &[SqlValue]) -> Result<Vec<SqlValue>> {
124        let max_index = self.column_indices.iter().copied().max().unwrap_or(0);
125        if row.len() <= max_index {
126            return Err(StorageError::TypeMismatch {
127                expected: format!("row with >= {} columns", max_index + 1),
128                actual: format!("{} columns", row.len()),
129            });
130        }
131        Ok(self
132            .column_indices
133            .iter()
134            .map(|idx| row[*idx].clone())
135            .collect())
136    }
137
138    fn build_key(&self, values: &[SqlValue], row_id: u64) -> Result<Key> {
139        if self.column_indices.len() == 1 {
140            KeyEncoder::index_key(self.index_id, &values[0], row_id)
141        } else {
142            KeyEncoder::composite_index_key(self.index_id, values, row_id)
143        }
144    }
145
146    fn value_prefix(&self, values: &[SqlValue]) -> Result<Vec<u8>> {
147        if self.column_indices.len() == 1 {
148            KeyEncoder::index_value_prefix(self.index_id, &values[0])
149        } else {
150            KeyEncoder::composite_index_prefix(self.index_id, values)
151        }
152    }
153
154    fn range_bounds(
155        &self,
156        start: Option<&SqlValue>,
157        end: Option<&SqlValue>,
158        start_inclusive: bool,
159        end_inclusive: bool,
160    ) -> Result<(Key, Key)> {
161        let start_key = match start {
162            Some(value) if start_inclusive => KeyEncoder::index_key(self.index_id, value, 0)?,
163            Some(value) => KeyEncoder::index_key(self.index_id, value, u64::MAX)?,
164            None => KeyEncoder::index_prefix(self.index_id),
165        };
166
167        let end_key = match end {
168            Some(value) if end_inclusive => {
169                let mut prefix = KeyEncoder::index_value_prefix(self.index_id, value)?;
170                // Ensure the exclusive upper bound sits after any RowID for the value.
171                prefix.extend_from_slice(&[0xFF; 9]);
172                prefix
173            }
174            Some(value) => KeyEncoder::index_key(self.index_id, value, 0)?,
175            None if self.index_id == u32::MAX => vec![0x03],
176            None => KeyEncoder::index_prefix(self.index_id.saturating_add(1)),
177        };
178
179        Ok((start_key, end_key))
180    }
181}
182
183fn extract_row_id(key: &[u8], expected_index: u32) -> Result<u64> {
184    if key.len() < 1 + 4 + 8 || key[0] != 0x02 {
185        return Err(StorageError::InvalidKeyFormat);
186    }
187    let index_id = u32::from_be_bytes(key[1..5].try_into().unwrap());
188    if index_id != expected_index {
189        return Err(StorageError::InvalidKeyFormat);
190    }
191    let row_id_pos = key.len().saturating_sub(8);
192    let mut buf = [0u8; 8];
193    buf.copy_from_slice(&key[row_id_pos..]);
194    Ok(u64::from_be_bytes(buf))
195}
196
197/// Iterator over index entries that yields RowIDs.
198pub struct IndexScanIterator<'a> {
199    inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>,
200    index_id: u32,
201}
202
203impl<'a> IndexScanIterator<'a> {
204    fn new(inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>, index_id: u32) -> Self {
205        Self { inner, index_id }
206    }
207}
208
209impl<'a> Iterator for IndexScanIterator<'a> {
210    type Item = Result<u64>;
211
212    fn next(&mut self) -> Option<Self::Item> {
213        self.inner
214            .next()
215            .map(|(key, _)| extract_row_id(&key, self.index_id))
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use alopex_core::kv::KVStore;
223    use alopex_core::kv::memory::MemoryKV;
224    use alopex_core::types::TxnMode;
225
226    fn with_index<F>(unique: bool, column_indices: Vec<usize>, f: F)
227    where
228        F: FnOnce(&mut IndexStorage<'static, 'static, <MemoryKV as KVStore>::Transaction<'static>>),
229    {
230        let store = MemoryKV::new();
231        let store_static: &'static MemoryKV = Box::leak(Box::new(store));
232        let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
233        let txn_static: &'static mut _ = Box::leak(Box::new(txn));
234        let mut index = IndexStorage::new(txn_static, 10, unique, column_indices);
235        f(&mut index);
236    }
237
238    #[test]
239    fn insert_lookup_and_delete_single_column() {
240        with_index(false, vec![0], |index| {
241            let row1 = vec![SqlValue::Integer(1), SqlValue::Text("a".into())];
242            let row2 = vec![SqlValue::Integer(2), SqlValue::Text("b".into())];
243            index.insert(&row1, 100).unwrap();
244            index.insert(&row2, 200).unwrap();
245
246            let mut results = index.lookup(&SqlValue::Integer(1)).unwrap();
247            results.sort();
248            assert_eq!(results, vec![100]);
249
250            let mut range_ids = {
251                let iter = index
252                    .range_scan(
253                        Some(&SqlValue::Integer(1)),
254                        Some(&SqlValue::Integer(3)),
255                        true,
256                        false,
257                    )
258                    .unwrap();
259                iter.collect::<Result<Vec<_>>>().unwrap()
260            };
261            range_ids.sort();
262            assert_eq!(range_ids, vec![100, 200]);
263
264            index.delete(&row1, 100).unwrap();
265            assert!(index.lookup(&SqlValue::Integer(1)).unwrap().is_empty());
266        });
267    }
268
269    #[test]
270    fn unique_constraint_blocks_duplicates() {
271        with_index(true, vec![0], |index| {
272            let row = vec![SqlValue::Text("alice".into())];
273            index.insert(&row, 1).unwrap();
274            let err = index.insert(&row, 2).unwrap_err();
275            matches!(err, StorageError::UniqueViolation { .. });
276        });
277    }
278
279    #[test]
280    fn composite_lookup_returns_matching_row_ids() {
281        with_index(false, vec![0, 1], |index| {
282            let row1 = vec![SqlValue::Text("tokyo".into()), SqlValue::Integer(1)];
283            let row2 = vec![SqlValue::Text("tokyo".into()), SqlValue::Integer(2)];
284            let row3 = vec![SqlValue::Text("osaka".into()), SqlValue::Integer(1)];
285            index.insert(&row1, 10).unwrap();
286            index.insert(&row2, 20).unwrap();
287            index.insert(&row3, 30).unwrap();
288
289            let ids = index
290                .lookup_composite(&[SqlValue::Text("tokyo".into()), SqlValue::Integer(2)])
291                .unwrap();
292            assert_eq!(ids, vec![20]);
293        });
294    }
295
296    #[test]
297    fn range_scan_respects_inclusive_and_exclusive_bounds() {
298        with_index(false, vec![0], |index| {
299            for (i, val) in [1, 2, 3, 4, 5].iter().enumerate() {
300                let row = vec![SqlValue::Integer(*val)];
301                index.insert(&row, (i + 1) as u64).unwrap();
302            }
303
304            let ids = {
305                let iter = index
306                    .range_scan(
307                        Some(&SqlValue::Integer(2)),
308                        Some(&SqlValue::Integer(4)),
309                        false,
310                        true,
311                    )
312                    .unwrap();
313                iter.collect::<Result<Vec<_>>>().unwrap()
314            };
315            // start_exclusive => 3, end_inclusive => include 4
316            assert_eq!(ids, vec![3, 4]);
317        });
318    }
319
320    #[test]
321    fn between_semantics_are_inclusive_on_both_ends() {
322        with_index(false, vec![0], |index| {
323            for (i, val) in [10, 20, 30, 40].iter().enumerate() {
324                let row = vec![SqlValue::Integer(*val)];
325                index.insert(&row, (i + 1) as u64).unwrap();
326            }
327
328            let ids = {
329                // Simulate SQL: WHERE col BETWEEN 20 AND 40
330                let iter = index
331                    .range_scan(
332                        Some(&SqlValue::Integer(20)),
333                        Some(&SqlValue::Integer(40)),
334                        true,
335                        true,
336                    )
337                    .unwrap();
338                iter.collect::<Result<Vec<_>>>().unwrap()
339            };
340            assert_eq!(ids, vec![2, 3, 4]);
341        });
342    }
343}