1use std::marker::PhantomData;
2
3use alopex_core::kv::KVTransaction;
4use alopex_core::types::{Key, Value};
5
6use crate::catalog::TableMetadata;
7
8use super::error::{Result, StorageError};
9use super::{KeyEncoder, RowCodec, SqlValue};
10
11pub struct TableStorage<'a, 'txn, T: KVTransaction<'txn>> {
20 txn: &'a mut T,
21 table_meta: TableMetadata,
22 table_id: u32,
23 _txn_lifetime: PhantomData<&'txn ()>,
24}
25
26impl<'a, 'txn, T: KVTransaction<'txn>> TableStorage<'a, 'txn, T> {
27 pub fn new(txn: &'a mut T, table_meta: &TableMetadata) -> Self {
31 Self {
32 txn,
33 table_id: table_meta.table_id,
34 table_meta: table_meta.clone(),
35 _txn_lifetime: PhantomData,
36 }
37 }
38
39 pub fn insert(&mut self, row_id: u64, row: &[SqlValue]) -> Result<()> {
41 self.validate_row(row)?;
42 let key = self.row_key(row_id);
43
44 if self.txn.get(&key)?.is_some() {
45 return Err(StorageError::PrimaryKeyViolation {
46 table_id: self.table_id,
47 row_id,
48 });
49 }
50
51 let encoded = RowCodec::encode(row);
52 self.txn.put(key, encoded)?;
53 Ok(())
54 }
55
56 pub fn get(&mut self, row_id: u64) -> Result<Option<Vec<SqlValue>>> {
58 let key = self.row_key(row_id);
59 match self.txn.get(&key)? {
60 Some(value) => {
61 let row = RowCodec::decode(&value)?;
62 Ok(Some(row))
63 }
64 None => Ok(None),
65 }
66 }
67
68 pub fn update(&mut self, row_id: u64, row: &[SqlValue]) -> Result<()> {
70 self.validate_row(row)?;
71 let key = self.row_key(row_id);
72 if self.txn.get(&key)?.is_none() {
73 return Err(StorageError::RowNotFound {
74 table_id: self.table_id,
75 row_id,
76 });
77 }
78 let encoded = RowCodec::encode(row);
79 self.txn.put(key, encoded)?;
80 Ok(())
81 }
82
83 pub fn delete(&mut self, row_id: u64) -> Result<()> {
85 let key = self.row_key(row_id);
86 self.txn.delete(key)?;
87 Ok(())
88 }
89
90 pub fn scan(&mut self) -> Result<TableScanIterator<'_>> {
92 let prefix = KeyEncoder::table_prefix(self.table_id);
93 let table_id = self.table_id;
94 let inner = self.txn.scan_prefix(&prefix)?;
95 Ok(TableScanIterator::new(inner, table_id))
96 }
97
98 pub fn range_scan(
100 &mut self,
101 start_row_id: u64,
102 end_row_id: u64,
103 ) -> Result<TableScanIterator<'_>> {
104 let start = KeyEncoder::row_key(self.table_id, start_row_id);
105 let end = if end_row_id == u64::MAX {
106 if self.table_id == u32::MAX {
107 vec![0x02]
109 } else {
110 KeyEncoder::table_prefix(self.table_id.saturating_add(1))
111 }
112 } else {
113 KeyEncoder::row_key(self.table_id, end_row_id.saturating_add(1))
114 };
115 let table_id = self.table_id;
116 let inner = self.txn.scan_range(&start, &end)?;
117 Ok(TableScanIterator::new(inner, table_id))
118 }
119
120 pub fn next_row_id(&mut self) -> Result<u64> {
122 let seq_key = KeyEncoder::sequence_key(self.table_id);
123 let current = self
124 .txn
125 .get(&seq_key)?
126 .map(|bytes| {
127 let mut arr = [0u8; 8];
128 arr.copy_from_slice(&bytes);
129 u64::from_be_bytes(arr)
130 })
131 .unwrap_or(0);
132 let next = current.saturating_add(1);
133 self.txn.put(seq_key, next.to_be_bytes().to_vec())?;
134 Ok(next)
135 }
136
137 fn validate_row(&self, row: &[SqlValue]) -> Result<()> {
138 let expected = self.table_meta.column_count();
139 if row.len() != expected {
140 return Err(StorageError::TypeMismatch {
141 expected: format!("{} columns", expected),
142 actual: format!("{} columns", row.len()),
143 });
144 }
145
146 for (idx, col) in self.table_meta.columns.iter().enumerate() {
147 if (col.not_null || col.primary_key) && row[idx].is_null() {
148 return Err(StorageError::NullConstraintViolation {
149 column: col.name.clone(),
150 });
151 }
152 }
153 Ok(())
154 }
155
156 fn row_key(&self, row_id: u64) -> Key {
157 KeyEncoder::row_key(self.table_id, row_id)
158 }
159}
160
161pub struct TableScanIterator<'a> {
163 inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>,
164 table_id: u32,
165}
166
167impl<'a> TableScanIterator<'a> {
168 pub fn new(inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>, table_id: u32) -> Self {
170 Self { inner, table_id }
171 }
172}
173
174impl<'a> Iterator for TableScanIterator<'a> {
175 type Item = Result<(u64, Vec<SqlValue>)>;
176
177 fn next(&mut self) -> Option<Self::Item> {
178 self.inner.next().map(|(key, value)| {
179 let (table_id, row_id) = KeyEncoder::decode_row_key(&key)?;
180 if table_id != self.table_id {
181 return Err(StorageError::InvalidKeyFormat);
182 }
183 let row = RowCodec::decode(&value)?;
184 Ok((row_id, row))
185 })
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::planner::types::ResolvedType;
193 use alopex_core::kv::KVStore;
194 use alopex_core::kv::memory::MemoryKV;
195 use alopex_core::types::TxnMode;
196
197 fn sample_table_meta(table_id: u32) -> TableMetadata {
198 TableMetadata::new(
199 "users",
200 vec![
201 crate::catalog::ColumnMetadata::new("id", ResolvedType::Integer)
202 .with_primary_key(true)
203 .with_not_null(true),
204 crate::catalog::ColumnMetadata::new("name", ResolvedType::Text).with_not_null(true),
205 crate::catalog::ColumnMetadata::new("age", ResolvedType::Integer),
206 ],
207 )
208 .with_table_id(table_id)
209 }
210
211 fn with_table<F>(store: &MemoryKV, meta: &TableMetadata, f: F)
212 where
213 F: FnOnce(
214 &mut TableStorage<
215 'static,
216 'static,
217 <MemoryKV as alopex_core::kv::KVStore>::Transaction<'static>,
218 >,
219 ),
220 {
221 let store_static: &'static MemoryKV = Box::leak(Box::new(store.clone()));
222 let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
223 let txn_static: &'static mut _ = Box::leak(Box::new(txn));
224 let mut table = TableStorage::new(txn_static, meta);
225 f(&mut table);
226 }
227
228 #[test]
229 fn insert_and_get_roundtrip() {
230 let store = MemoryKV::new();
231 let meta = sample_table_meta(1);
232 with_table(&store, &meta, |table| {
233 let row = vec![
234 SqlValue::Integer(1),
235 SqlValue::Text("alice".into()),
236 SqlValue::Integer(20),
237 ];
238 table.insert(1, &row).unwrap();
239 let fetched = table.get(1).unwrap().unwrap();
240 assert_eq!(fetched, row);
241 });
242 }
243
244 #[test]
245 fn duplicate_primary_key_is_rejected() {
246 let store = MemoryKV::new();
247 let meta = sample_table_meta(1);
248 with_table(&store, &meta, |table| {
249 let row = vec![
250 SqlValue::Integer(1),
251 SqlValue::Text("alice".into()),
252 SqlValue::Integer(20),
253 ];
254 table.insert(1, &row).unwrap();
255 let err = table.insert(1, &row).unwrap_err();
256 matches!(err, StorageError::PrimaryKeyViolation { .. });
257 });
258 }
259
260 #[test]
261 fn not_null_constraint_is_enforced() {
262 let store = MemoryKV::new();
263 let meta = sample_table_meta(1);
264 with_table(&store, &meta, |table| {
265 let row = vec![
266 SqlValue::Null,
267 SqlValue::Text("bob".into()),
268 SqlValue::Integer(30),
269 ];
270 let err = table.insert(2, &row).unwrap_err();
271 matches!(err, StorageError::NullConstraintViolation { .. });
272 });
273 }
274
275 #[test]
276 fn update_overwrites_existing_row() {
277 let store = MemoryKV::new();
278 let meta = sample_table_meta(1);
279 with_table(&store, &meta, |table| {
280 let row1 = vec![
281 SqlValue::Integer(1),
282 SqlValue::Text("alice".into()),
283 SqlValue::Integer(20),
284 ];
285 table.insert(1, &row1).unwrap();
286
287 let row2 = vec![
288 SqlValue::Integer(1),
289 SqlValue::Text("alice-updated".into()),
290 SqlValue::Integer(25),
291 ];
292 table.update(1, &row2).unwrap();
293 let fetched = table.get(1).unwrap().unwrap();
294 assert_eq!(fetched, row2);
295 });
296 }
297
298 #[test]
299 fn update_nonexistent_returns_not_found() {
300 let store = MemoryKV::new();
301 let meta = sample_table_meta(1);
302 with_table(&store, &meta, |table| {
303 let row = vec![
304 SqlValue::Integer(99),
305 SqlValue::Text("ghost".into()),
306 SqlValue::Integer(0),
307 ];
308 let err = table.update(99, &row).unwrap_err();
309 matches!(err, StorageError::RowNotFound { .. });
310 });
311 }
312
313 #[test]
314 fn delete_removes_row() {
315 let store = MemoryKV::new();
316 let meta = sample_table_meta(1);
317 with_table(&store, &meta, |table| {
318 let row = vec![
319 SqlValue::Integer(1),
320 SqlValue::Text("alice".into()),
321 SqlValue::Integer(20),
322 ];
323 table.insert(1, &row).unwrap();
324 table.delete(1).unwrap();
325 assert!(table.get(1).unwrap().is_none());
326 });
327 }
328
329 #[test]
330 fn scan_returns_all_rows_in_order() {
331 let store = MemoryKV::new();
332 let meta = sample_table_meta(1);
333 with_table(&store, &meta, |table| {
334 for i in 1..=3 {
335 let row = vec![
336 SqlValue::Integer(i as i32),
337 SqlValue::Text(format!("user{i}")),
338 SqlValue::Integer(10 + i as i32),
339 ];
340 table.insert(i, &row).unwrap();
341 }
342
343 let rows: Vec<_> = table.scan().unwrap().map(|res| res.unwrap().0).collect();
344 assert_eq!(rows, vec![1, 2, 3]);
345 });
346 }
347
348 #[test]
349 fn range_scan_respects_bounds() {
350 let store = MemoryKV::new();
351 let meta = sample_table_meta(1);
352 with_table(&store, &meta, |table| {
353 for i in 1..=5 {
354 let row = vec![
355 SqlValue::Integer(i as i32),
356 SqlValue::Text(format!("user{i}")),
357 SqlValue::Integer(10 + i as i32),
358 ];
359 table.insert(i, &row).unwrap();
360 }
361
362 let rows: Vec<_> = table
363 .range_scan(2, 4)
364 .unwrap()
365 .map(|res| res.unwrap().0)
366 .collect();
367 assert_eq!(rows, vec![2, 3, 4]);
368 });
369 }
370
371 #[test]
372 fn range_scan_handles_max_table_id_end_bound() {
373 let store = MemoryKV::new();
374 let meta = sample_table_meta(u32::MAX);
375 let store_static: &'static MemoryKV = Box::leak(Box::new(store.clone()));
376 let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
377 let txn_static: &'static mut _ = Box::leak(Box::new(txn));
378 let mut table = TableStorage::new(txn_static, &meta);
379
380 let row = vec![
381 SqlValue::Integer(1),
382 SqlValue::Text("max".into()),
383 SqlValue::Integer(1),
384 ];
385 table.insert(1, &row).unwrap();
386 let rows: Vec<_> = table
387 .range_scan(1, u64::MAX)
388 .unwrap()
389 .map(|res| res.unwrap().0)
390 .collect();
391 assert_eq!(rows, vec![1]);
392 }
393
394 #[test]
395 fn next_row_id_increments_sequence() {
396 let store = MemoryKV::new();
397 let meta = sample_table_meta(1);
398 with_table(&store, &meta, |table| {
399 let id1 = table.next_row_id().unwrap();
400 let id2 = table.next_row_id().unwrap();
401 assert_eq!((id1, id2), (1, 2));
402 });
403 }
404}