1use std::marker::PhantomData;
2
3use alopex_core::kv::KVTransaction;
4use alopex_core::types::{Key, Value};
5
6use crate::catalog::TableMetadata;
7
8use super::error::{Result, StorageError};
9use super::{KeyEncoder, RowCodec, SqlValue};
10
11pub struct TableStorage<'a, 'txn, T: KVTransaction<'txn>> {
20 txn: &'a mut T,
21 table_meta: TableMetadata,
22 table_id: u32,
23 _txn_lifetime: PhantomData<&'txn ()>,
24}
25
26impl<'a, 'txn, T: KVTransaction<'txn>> TableStorage<'a, 'txn, T> {
27 pub fn new(txn: &'a mut T, table_meta: &TableMetadata) -> Self {
31 Self {
32 txn,
33 table_id: table_meta.table_id,
34 table_meta: table_meta.clone(),
35 _txn_lifetime: PhantomData,
36 }
37 }
38
39 pub fn insert(&mut self, row_id: u64, row: &[SqlValue]) -> Result<()> {
41 self.validate_row(row)?;
42 let key = self.row_key(row_id);
43
44 if self.txn.get(&key)?.is_some() {
45 return Err(StorageError::PrimaryKeyViolation {
46 table_id: self.table_id,
47 row_id,
48 });
49 }
50
51 let encoded = RowCodec::encode(row);
52 self.txn.put(key, encoded)?;
53 Ok(())
54 }
55
56 pub fn get(&mut self, row_id: u64) -> Result<Option<Vec<SqlValue>>> {
58 let key = self.row_key(row_id);
59 match self.txn.get(&key)? {
60 Some(value) => {
61 let row = RowCodec::decode(&value)?;
62 Ok(Some(row))
63 }
64 None => Ok(None),
65 }
66 }
67
68 pub fn update(&mut self, row_id: u64, row: &[SqlValue]) -> Result<()> {
70 self.validate_row(row)?;
71 let key = self.row_key(row_id);
72 if self.txn.get(&key)?.is_none() {
73 return Err(StorageError::RowNotFound {
74 table_id: self.table_id,
75 row_id,
76 });
77 }
78 let encoded = RowCodec::encode(row);
79 self.txn.put(key, encoded)?;
80 Ok(())
81 }
82
83 pub fn delete(&mut self, row_id: u64) -> Result<()> {
85 let key = self.row_key(row_id);
86 self.txn.delete(key)?;
87 Ok(())
88 }
89
90 pub fn scan(&mut self) -> Result<TableScanIterator<'_>> {
92 let prefix = KeyEncoder::table_prefix(self.table_id);
93 let table_id = self.table_id;
94 let inner = self.txn.scan_prefix(&prefix)?;
95 Ok(TableScanIterator::new(inner, table_id))
96 }
97
98 pub fn range_scan(
100 &mut self,
101 start_row_id: u64,
102 end_row_id: u64,
103 ) -> Result<TableScanIterator<'_>> {
104 let start = KeyEncoder::row_key(self.table_id, start_row_id);
105 let end = if end_row_id == u64::MAX {
106 if self.table_id == u32::MAX {
107 vec![0x02]
109 } else {
110 KeyEncoder::table_prefix(self.table_id.saturating_add(1))
111 }
112 } else {
113 KeyEncoder::row_key(self.table_id, end_row_id.saturating_add(1))
114 };
115 let table_id = self.table_id;
116 let inner = self.txn.scan_range(&start, &end)?;
117 Ok(TableScanIterator::new(inner, table_id))
118 }
119
120 pub fn next_row_id(&mut self) -> Result<u64> {
122 let seq_key = KeyEncoder::sequence_key(self.table_id);
123 let current = self
124 .txn
125 .get(&seq_key)?
126 .map(|bytes| {
127 let mut arr = [0u8; 8];
128 arr.copy_from_slice(&bytes);
129 u64::from_be_bytes(arr)
130 })
131 .unwrap_or(0);
132 let next = current.saturating_add(1);
133 self.txn.put(seq_key, next.to_be_bytes().to_vec())?;
134 Ok(next)
135 }
136
137 fn validate_row(&self, row: &[SqlValue]) -> Result<()> {
138 let expected = self.table_meta.column_count();
139 if row.len() != expected {
140 return Err(StorageError::TypeMismatch {
141 expected: format!("{} columns", expected),
142 actual: format!("{} columns", row.len()),
143 });
144 }
145
146 for (idx, col) in self.table_meta.columns.iter().enumerate() {
147 if (col.not_null || col.primary_key) && row[idx].is_null() {
148 return Err(StorageError::NullConstraintViolation {
149 column: col.name.clone(),
150 });
151 }
152 }
153 Ok(())
154 }
155
156 fn row_key(&self, row_id: u64) -> Key {
157 KeyEncoder::row_key(self.table_id, row_id)
158 }
159}
160
161pub struct TableScanIterator<'a> {
163 inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>,
164 table_id: u32,
165}
166
167impl<'a> TableScanIterator<'a> {
168 fn new(inner: Box<dyn Iterator<Item = (Key, Value)> + 'a>, table_id: u32) -> Self {
169 Self { inner, table_id }
170 }
171}
172
173impl<'a> Iterator for TableScanIterator<'a> {
174 type Item = Result<(u64, Vec<SqlValue>)>;
175
176 fn next(&mut self) -> Option<Self::Item> {
177 self.inner.next().map(|(key, value)| {
178 let (table_id, row_id) = KeyEncoder::decode_row_key(&key)?;
179 if table_id != self.table_id {
180 return Err(StorageError::InvalidKeyFormat);
181 }
182 let row = RowCodec::decode(&value)?;
183 Ok((row_id, row))
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::planner::types::ResolvedType;
192 use alopex_core::kv::KVStore;
193 use alopex_core::kv::memory::MemoryKV;
194 use alopex_core::types::TxnMode;
195
196 fn sample_table_meta(table_id: u32) -> TableMetadata {
197 TableMetadata::new(
198 "users",
199 vec![
200 crate::catalog::ColumnMetadata::new("id", ResolvedType::Integer)
201 .with_primary_key(true)
202 .with_not_null(true),
203 crate::catalog::ColumnMetadata::new("name", ResolvedType::Text).with_not_null(true),
204 crate::catalog::ColumnMetadata::new("age", ResolvedType::Integer),
205 ],
206 )
207 .with_table_id(table_id)
208 }
209
210 fn with_table<F>(store: &MemoryKV, meta: &TableMetadata, f: F)
211 where
212 F: FnOnce(
213 &mut TableStorage<
214 'static,
215 'static,
216 <MemoryKV as alopex_core::kv::KVStore>::Transaction<'static>,
217 >,
218 ),
219 {
220 let store_static: &'static MemoryKV = Box::leak(Box::new(store.clone()));
221 let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
222 let txn_static: &'static mut _ = Box::leak(Box::new(txn));
223 let mut table = TableStorage::new(txn_static, meta);
224 f(&mut table);
225 }
226
227 #[test]
228 fn insert_and_get_roundtrip() {
229 let store = MemoryKV::new();
230 let meta = sample_table_meta(1);
231 with_table(&store, &meta, |table| {
232 let row = vec![
233 SqlValue::Integer(1),
234 SqlValue::Text("alice".into()),
235 SqlValue::Integer(20),
236 ];
237 table.insert(1, &row).unwrap();
238 let fetched = table.get(1).unwrap().unwrap();
239 assert_eq!(fetched, row);
240 });
241 }
242
243 #[test]
244 fn duplicate_primary_key_is_rejected() {
245 let store = MemoryKV::new();
246 let meta = sample_table_meta(1);
247 with_table(&store, &meta, |table| {
248 let row = vec![
249 SqlValue::Integer(1),
250 SqlValue::Text("alice".into()),
251 SqlValue::Integer(20),
252 ];
253 table.insert(1, &row).unwrap();
254 let err = table.insert(1, &row).unwrap_err();
255 matches!(err, StorageError::PrimaryKeyViolation { .. });
256 });
257 }
258
259 #[test]
260 fn not_null_constraint_is_enforced() {
261 let store = MemoryKV::new();
262 let meta = sample_table_meta(1);
263 with_table(&store, &meta, |table| {
264 let row = vec![
265 SqlValue::Null,
266 SqlValue::Text("bob".into()),
267 SqlValue::Integer(30),
268 ];
269 let err = table.insert(2, &row).unwrap_err();
270 matches!(err, StorageError::NullConstraintViolation { .. });
271 });
272 }
273
274 #[test]
275 fn update_overwrites_existing_row() {
276 let store = MemoryKV::new();
277 let meta = sample_table_meta(1);
278 with_table(&store, &meta, |table| {
279 let row1 = vec![
280 SqlValue::Integer(1),
281 SqlValue::Text("alice".into()),
282 SqlValue::Integer(20),
283 ];
284 table.insert(1, &row1).unwrap();
285
286 let row2 = vec![
287 SqlValue::Integer(1),
288 SqlValue::Text("alice-updated".into()),
289 SqlValue::Integer(25),
290 ];
291 table.update(1, &row2).unwrap();
292 let fetched = table.get(1).unwrap().unwrap();
293 assert_eq!(fetched, row2);
294 });
295 }
296
297 #[test]
298 fn update_nonexistent_returns_not_found() {
299 let store = MemoryKV::new();
300 let meta = sample_table_meta(1);
301 with_table(&store, &meta, |table| {
302 let row = vec![
303 SqlValue::Integer(99),
304 SqlValue::Text("ghost".into()),
305 SqlValue::Integer(0),
306 ];
307 let err = table.update(99, &row).unwrap_err();
308 matches!(err, StorageError::RowNotFound { .. });
309 });
310 }
311
312 #[test]
313 fn delete_removes_row() {
314 let store = MemoryKV::new();
315 let meta = sample_table_meta(1);
316 with_table(&store, &meta, |table| {
317 let row = vec![
318 SqlValue::Integer(1),
319 SqlValue::Text("alice".into()),
320 SqlValue::Integer(20),
321 ];
322 table.insert(1, &row).unwrap();
323 table.delete(1).unwrap();
324 assert!(table.get(1).unwrap().is_none());
325 });
326 }
327
328 #[test]
329 fn scan_returns_all_rows_in_order() {
330 let store = MemoryKV::new();
331 let meta = sample_table_meta(1);
332 with_table(&store, &meta, |table| {
333 for i in 1..=3 {
334 let row = vec![
335 SqlValue::Integer(i as i32),
336 SqlValue::Text(format!("user{i}")),
337 SqlValue::Integer(10 + i as i32),
338 ];
339 table.insert(i, &row).unwrap();
340 }
341
342 let rows: Vec<_> = table.scan().unwrap().map(|res| res.unwrap().0).collect();
343 assert_eq!(rows, vec![1, 2, 3]);
344 });
345 }
346
347 #[test]
348 fn range_scan_respects_bounds() {
349 let store = MemoryKV::new();
350 let meta = sample_table_meta(1);
351 with_table(&store, &meta, |table| {
352 for i in 1..=5 {
353 let row = vec![
354 SqlValue::Integer(i as i32),
355 SqlValue::Text(format!("user{i}")),
356 SqlValue::Integer(10 + i as i32),
357 ];
358 table.insert(i, &row).unwrap();
359 }
360
361 let rows: Vec<_> = table
362 .range_scan(2, 4)
363 .unwrap()
364 .map(|res| res.unwrap().0)
365 .collect();
366 assert_eq!(rows, vec![2, 3, 4]);
367 });
368 }
369
370 #[test]
371 fn range_scan_handles_max_table_id_end_bound() {
372 let store = MemoryKV::new();
373 let meta = sample_table_meta(u32::MAX);
374 let store_static: &'static MemoryKV = Box::leak(Box::new(store.clone()));
375 let txn = store_static.begin(TxnMode::ReadWrite).unwrap();
376 let txn_static: &'static mut _ = Box::leak(Box::new(txn));
377 let mut table = TableStorage::new(txn_static, &meta);
378
379 let row = vec![
380 SqlValue::Integer(1),
381 SqlValue::Text("max".into()),
382 SqlValue::Integer(1),
383 ];
384 table.insert(1, &row).unwrap();
385 let rows: Vec<_> = table
386 .range_scan(1, u64::MAX)
387 .unwrap()
388 .map(|res| res.unwrap().0)
389 .collect();
390 assert_eq!(rows, vec![1]);
391 }
392
393 #[test]
394 fn next_row_id_increments_sequence() {
395 let store = MemoryKV::new();
396 let meta = sample_table_meta(1);
397 with_table(&store, &meta, |table| {
398 let id1 = table.next_row_id().unwrap();
399 let id2 = table.next_row_id().unwrap();
400 assert_eq!((id1, id2), (1, 2));
401 });
402 }
403}