1use std::convert::TryInto;
2use std::marker::PhantomData;
3
4use alopex_core::kv::KVTransaction;
5use alopex_core::types::{Key, Value};
6
7use super::error::{Result, StorageError};
8use super::{KeyEncoder, SqlValue};
9
10pub struct IndexStorage<'a, 'txn, T: KVTransaction<'txn>> {
19 txn: &'a mut T,
20 index_id: u32,
21 unique: bool,
22 column_indices: Vec<usize>,
23 _txn_lifetime: PhantomData<&'txn ()>,
24}
25
26impl<'a, 'txn, T: KVTransaction<'txn>> IndexStorage<'a, 'txn, T> {
27 pub fn new(txn: &'a mut T, index_id: u32, unique: bool, column_indices: Vec<usize>) -> Self {
29 Self {
30 txn,
31 index_id,
32 unique,
33 column_indices,
34 _txn_lifetime: PhantomData,
35 }
36 }
37
38 pub fn insert(&mut self, row: &[SqlValue], row_id: u64) -> Result<()> {
40 let values = self.extract_values(row)?;
41 let key = self.build_key(&values, row_id)?;
42 if self.unique {
43 let prefix = self.value_prefix(&values)?;
44 self.ensure_unique(&prefix, row_id)?;
45 }
46 self.txn.put(key, Vec::new())?;
47 Ok(())
48 }
49
50 pub fn delete(&mut self, row: &[SqlValue], row_id: u64) -> Result<()> {
52 let values = self.extract_values(row)?;
53 let key = self.build_key(&values, row_id)?;
54 self.txn.delete(key)?;
55 Ok(())
56 }
57
58 pub fn lookup(&mut self, value: &SqlValue) -> Result<Vec<u64>> {
60 if self.column_indices.len() != 1 {
61 return Err(StorageError::TypeMismatch {
62 expected: "single-column index".into(),
63 actual: format!("{} columns", self.column_indices.len()),
64 });
65 }
66 self.lookup_internal(std::slice::from_ref(value))
67 }
68
69 pub fn lookup_composite(&mut self, values: &[SqlValue]) -> Result<Vec<u64>> {
71 if values.len() != self.column_indices.len() {
72 return Err(StorageError::TypeMismatch {
73 expected: format!("{} values", self.column_indices.len()),
74 actual: format!("{} values", values.len()),
75 });
76 }
77 self.lookup_internal(values)
78 }
79
80 pub fn range_scan(
82 &mut self,
83 start: Option<&SqlValue>,
84 end: Option<&SqlValue>,
85 start_inclusive: bool,
86 end_inclusive: bool,
87 ) -> Result<IndexScanIterator<'_>> {
88 if self.column_indices.len() != 1 {
92 return Err(StorageError::TypeMismatch {
93 expected: "single-column index".into(),
94 actual: format!("{} columns", self.column_indices.len()),
95 });
96 }
97 let (start_key, end_key) = self.range_bounds(start, end, start_inclusive, end_inclusive)?;
98 let index_id = self.index_id;
99 let inner = self.txn.scan_range(&start_key, &end_key)?;
100 Ok(IndexScanIterator::new(inner, index_id))
101 }
102
103 fn lookup_internal(&mut self, values: &[SqlValue]) -> Result<Vec<u64>> {
104 let prefix = self.value_prefix(values)?;
105 let iter = self.txn.scan_prefix(&prefix)?;
106 iter.map(|(key, _)| extract_row_id(&key, self.index_id))
107 .collect()
108 }
109
110 fn ensure_unique(&mut self, prefix: &[u8], row_id: u64) -> Result<()> {
111 let iter = self.txn.scan_prefix(prefix)?;
112 for (key, _) in iter {
113 let existing = extract_row_id(&key, self.index_id)?;
114 if existing != row_id {
115 return Err(StorageError::UniqueViolation {
116 index_id: self.index_id,
117 });
118 }
119 }
120 Ok(())
121 }
122
123 fn extract_values(&self, row: &[SqlValue]) -> Result<Vec<SqlValue>> {
124 let max_index = self.column_indices.iter().copied().max().unwrap_or(0);
125 if row.len() <= max_index {
126 return Err(StorageError::TypeMismatch {
127 expected: format!("row with >= {} columns", max_index + 1),
128 actual: format!("{} columns", row.len()),
129 });
130 }
131 Ok(self
132 .column_indices
133 .iter()
134 .map(|idx| row[*idx].clone())
135 .collect())
136 }
137
138 fn build_key(&self, values: &[SqlValue], row_id: u64) -> Result<Key> {
139 if self.column_indices.len() == 1 {
140 KeyEncoder::index_key(self.index_id, &values[0], row_id)
141 } else {
142 KeyEncoder::composite_index_key(self.index_id, values, row_id)
143 }
144 }
145
146 fn value_prefix(&self, values: &[SqlValue]) -> Result<Vec<u8>> {
147 if self.column_indices.len() == 1 {
148 KeyEncoder::index_value_prefix(self.index_id, &values[0])
149 } else {
150 KeyEncoder::composite_index_prefix(self.index_id, values)
151 }
152 }
153
154 fn range_bounds(
155 &self,
156 start: Option<&SqlValue>,
157 end: Option<&SqlValue>,
158 start_inclusive: bool,
159 end_inclusive: bool,
160 ) -> Result<(Key, Key)> {
161 let start_key = match start {
162 Some(value) if start_inclusive => KeyEncoder::index_key(self.index_id, value, 0)?,
163 Some(value) => KeyEncoder::index_key(self.index_id, value, u64::MAX)?,
164 None => KeyEncoder::index_prefix(self.index_id),
165 };
166
167 let end_key = match end {
168 Some(value) if end_inclusive => {
169 let mut prefix = KeyEncoder::index_value_prefix(self.index_id, value)?;
170 prefix.extend_from_slice(&[0xFF; 9]);
172 prefix
173 }
174 Some(value) => KeyEncoder::index_key(self.index_id, value, 0)?,
175 None if self.index_id == u32::MAX => vec![0x03],
176 None => KeyEncoder::index_prefix(self.index_id.saturating_add(1)),
177 };
178
179 Ok((start_key, end_key))
180 }
181}
182
183fn extract_row_id(key: &[u8], expected_index: u32) -> Result<u64> {
184 if key.len() < 1 + 4 + 8 || key[0] != 0x02 {
185 return Err(StorageError::InvalidKeyFormat);
186 }
187 let index_id = u32::from_be_bytes(key[1..5].try_into().unwrap());
188 if index_id != expected_index {
189 return Err(StorageError::InvalidKeyFormat);
190 }
191 let row_id_pos = key.len().saturating_sub(8);
192 let mut buf = [0u8; 8];
193 buf.copy_from_slice(&key[row_id_pos..]);
194 Ok(u64::from_be_bytes(buf))
195}
196
197pub struct IndexScanIterator<'a> {
199 inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>,
200 index_id: u32,
201}
202
203impl<'a> IndexScanIterator<'a> {
204 fn new(inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>, index_id: u32) -> Self {
205 Self { inner, index_id }
206 }
207}
208
209impl<'a> Iterator for IndexScanIterator<'a> {
210 type Item = Result<u64>;
211
212 fn next(&mut self) -> Option<Self::Item> {
213 self.inner
214 .next()
215 .map(|(key, _)| extract_row_id(&key, self.index_id))
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use alopex_core::kv::KVStore;
223 use alopex_core::kv::memory::MemoryKV;
224 use alopex_core::types::TxnMode;
225
226 fn with_index<F>(unique: bool, column_indices: Vec<usize>, f: F)
227 where
228 F: FnOnce(&mut IndexStorage<'static, 'static, <MemoryKV as KVStore>::Transaction<'static>>),
229 {
230 let store = MemoryKV::new();
231 let store_static: &'static MemoryKV = Box::leak(Box::new(store));
232 let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
233 let txn_static: &'static mut _ = Box::leak(Box::new(txn));
234 let mut index = IndexStorage::new(txn_static, 10, unique, column_indices);
235 f(&mut index);
236 }
237
238 #[test]
239 fn insert_lookup_and_delete_single_column() {
240 with_index(false, vec![0], |index| {
241 let row1 = vec![SqlValue::Integer(1), SqlValue::Text("a".into())];
242 let row2 = vec![SqlValue::Integer(2), SqlValue::Text("b".into())];
243 index.insert(&row1, 100).unwrap();
244 index.insert(&row2, 200).unwrap();
245
246 let mut results = index.lookup(&SqlValue::Integer(1)).unwrap();
247 results.sort();
248 assert_eq!(results, vec![100]);
249
250 let mut range_ids = {
251 let iter = index
252 .range_scan(
253 Some(&SqlValue::Integer(1)),
254 Some(&SqlValue::Integer(3)),
255 true,
256 false,
257 )
258 .unwrap();
259 iter.collect::<Result<Vec<_>>>().unwrap()
260 };
261 range_ids.sort();
262 assert_eq!(range_ids, vec![100, 200]);
263
264 index.delete(&row1, 100).unwrap();
265 assert!(index.lookup(&SqlValue::Integer(1)).unwrap().is_empty());
266 });
267 }
268
269 #[test]
270 fn unique_constraint_blocks_duplicates() {
271 with_index(true, vec![0], |index| {
272 let row = vec![SqlValue::Text("alice".into())];
273 index.insert(&row, 1).unwrap();
274 let err = index.insert(&row, 2).unwrap_err();
275 matches!(err, StorageError::UniqueViolation { .. });
276 });
277 }
278
279 #[test]
280 fn composite_lookup_returns_matching_row_ids() {
281 with_index(false, vec![0, 1], |index| {
282 let row1 = vec![SqlValue::Text("tokyo".into()), SqlValue::Integer(1)];
283 let row2 = vec![SqlValue::Text("tokyo".into()), SqlValue::Integer(2)];
284 let row3 = vec![SqlValue::Text("osaka".into()), SqlValue::Integer(1)];
285 index.insert(&row1, 10).unwrap();
286 index.insert(&row2, 20).unwrap();
287 index.insert(&row3, 30).unwrap();
288
289 let ids = index
290 .lookup_composite(&[SqlValue::Text("tokyo".into()), SqlValue::Integer(2)])
291 .unwrap();
292 assert_eq!(ids, vec![20]);
293 });
294 }
295
296 #[test]
297 fn range_scan_respects_inclusive_and_exclusive_bounds() {
298 with_index(false, vec![0], |index| {
299 for (i, val) in [1, 2, 3, 4, 5].iter().enumerate() {
300 let row = vec![SqlValue::Integer(*val)];
301 index.insert(&row, (i + 1) as u64).unwrap();
302 }
303
304 let ids = {
305 let iter = index
306 .range_scan(
307 Some(&SqlValue::Integer(2)),
308 Some(&SqlValue::Integer(4)),
309 false,
310 true,
311 )
312 .unwrap();
313 iter.collect::<Result<Vec<_>>>().unwrap()
314 };
315 assert_eq!(ids, vec![3, 4]);
317 });
318 }
319
320 #[test]
321 fn between_semantics_are_inclusive_on_both_ends() {
322 with_index(false, vec![0], |index| {
323 for (i, val) in [10, 20, 30, 40].iter().enumerate() {
324 let row = vec![SqlValue::Integer(*val)];
325 index.insert(&row, (i + 1) as u64).unwrap();
326 }
327
328 let ids = {
329 let iter = index
331 .range_scan(
332 Some(&SqlValue::Integer(20)),
333 Some(&SqlValue::Integer(40)),
334 true,
335 true,
336 )
337 .unwrap();
338 iter.collect::<Result<Vec<_>>>().unwrap()
339 };
340 assert_eq!(ids, vec![2, 3, 4]);
341 });
342 }
343}