1use std::collections::HashMap;
2use std::sync::Arc;
3
4use alopex_core::kv::{KVStore, KVTransaction};
5use alopex_core::types::TxnMode;
6use alopex_core::vector::hnsw::{HnswIndex, HnswTransactionState};
7use alopex_core::{Error as CoreError, Result as CoreResult};
8
9use crate::catalog::CatalogOverlay;
10use crate::catalog::TableMetadata;
11
12use super::error::Result;
13use super::{IndexStorage, TableStorage};
14
15pub trait SqlTxn<'txn, S: KVStore + 'txn> {
16 fn mode(&self) -> TxnMode;
17
18 fn ensure_write_txn(&self) -> CoreResult<()>;
19
20 fn inner_mut(&mut self) -> &mut S::Transaction<'txn>;
21
22 fn hnsw_entry(&mut self, name: &str) -> CoreResult<&HnswIndex>;
23
24 fn hnsw_entry_mut(&mut self, name: &str) -> CoreResult<&mut HnswTxnEntry>;
25
26 fn flush_hnsw(&mut self) -> Result<()>;
27
28 fn abandon_hnsw(&mut self) -> Result<()>;
29
30 fn delete_prefix(&mut self, prefix: &[u8]) -> Result<()>;
31
32 fn table_storage<'a>(
33 &'a mut self,
34 table_meta: &TableMetadata,
35 ) -> TableStorage<'a, 'txn, S::Transaction<'txn>> {
36 TableStorage::new(self.inner_mut(), table_meta)
37 }
38
39 fn index_storage<'a>(
40 &'a mut self,
41 index_id: u32,
42 unique: bool,
43 column_indices: Vec<usize>,
44 ) -> IndexStorage<'a, 'txn, S::Transaction<'txn>> {
45 IndexStorage::new(self.inner_mut(), index_id, unique, column_indices)
46 }
47
48 fn with_table<R, F>(&mut self, table_meta: &TableMetadata, f: F) -> Result<R>
49 where
50 F: FnOnce(&mut TableStorage<'_, 'txn, S::Transaction<'txn>>) -> Result<R>,
51 {
52 let mut storage = self.table_storage(table_meta);
53 f(&mut storage)
54 }
55
56 fn with_index<R, F>(
57 &mut self,
58 index_id: u32,
59 unique: bool,
60 column_indices: Vec<usize>,
61 f: F,
62 ) -> Result<R>
63 where
64 F: FnOnce(&mut IndexStorage<'_, 'txn, S::Transaction<'txn>>) -> Result<R>,
65 {
66 let mut storage = self.index_storage(index_id, unique, column_indices);
67 f(&mut storage)
68 }
69}
70
71pub struct TxnBridge<S: KVStore> {
73 store: Arc<S>,
74}
75
76impl<S: KVStore> TxnBridge<S> {
77 pub fn new(store: Arc<S>) -> Self {
79 Self { store }
80 }
81
82 pub fn begin_read(&self) -> Result<SqlTransaction<'_, S>> {
84 let txn = self.store.begin(TxnMode::ReadOnly)?;
85 Ok(SqlTransaction {
86 inner: txn,
87 mode: TxnMode::ReadOnly,
88 hnsw_indices: HashMap::new(),
89 })
90 }
91
92 pub fn begin_write(&self) -> Result<SqlTransaction<'_, S>> {
94 let txn = self.store.begin(TxnMode::ReadWrite)?;
95 Ok(SqlTransaction {
96 inner: txn,
97 mode: TxnMode::ReadWrite,
98 hnsw_indices: HashMap::new(),
99 })
100 }
101
102 pub fn wrap_external<'a, 'b, 'c>(
103 txn: &'a mut S::Transaction<'b>,
104 mode: TxnMode,
105 overlay: &'c mut CatalogOverlay,
106 ) -> BorrowedSqlTransaction<'a, 'b, 'c, S> {
107 BorrowedSqlTransaction {
108 inner: txn,
109 mode,
110 overlay,
111 hnsw_indices: HashMap::new(),
112 }
113 }
114
115 pub fn with_read_txn<R, F>(&self, f: F) -> Result<R>
120 where
121 F: FnOnce(&mut SqlTransaction<'_, S>) -> Result<R>,
122 {
123 let mut txn = self.begin_read()?;
124 let result = f(&mut txn)?;
125 txn.commit()?;
126 Ok(result)
127 }
128
129 pub fn with_write_txn<R, F>(&self, f: F) -> Result<R>
134 where
135 F: FnOnce(&mut SqlTransaction<'_, S>) -> Result<R>,
136 {
137 let mut txn = self.begin_write()?;
138 let result = f(&mut txn)?;
139 txn.commit()?;
140 Ok(result)
141 }
142
143 pub fn with_write_txn_explicit<R, F>(&self, f: F) -> Result<R>
148 where
149 F: FnOnce(&mut SqlTransaction<'_, S>) -> Result<(R, bool)>,
150 {
151 let mut txn = self.begin_write()?;
152 let (result, should_commit) = f(&mut txn)?;
153 if should_commit {
154 txn.commit()?;
155 } else {
156 txn.rollback()?;
157 }
158 Ok(result)
159 }
160}
161
162pub struct SqlTransaction<'a, S: KVStore + 'a> {
167 inner: S::Transaction<'a>,
168 mode: TxnMode,
169 hnsw_indices: HashMap<String, HnswTxnEntry>,
170}
171
172pub struct HnswTxnEntry {
173 pub index: HnswIndex,
174 pub state: HnswTransactionState,
175 pub dirty: bool,
176}
177
178impl<'a, S: KVStore + 'a> SqlTransaction<'a, S> {
179 pub fn mode(&self) -> TxnMode {
181 self.mode
182 }
183
184 pub fn table_storage<'b>(
191 &'b mut self,
192 table_meta: &TableMetadata,
193 ) -> TableStorage<'b, 'a, S::Transaction<'a>> {
194 TableStorage::new(&mut self.inner, table_meta)
195 }
196
197 pub fn index_storage<'b>(
202 &'b mut self,
203 index_id: u32,
204 unique: bool,
205 column_indices: Vec<usize>,
206 ) -> IndexStorage<'b, 'a, S::Transaction<'a>> {
207 IndexStorage::new(&mut self.inner, index_id, unique, column_indices)
208 }
209
210 #[allow(dead_code)]
212 pub(crate) fn hnsw_entry(&mut self, name: &str) -> CoreResult<&HnswIndex> {
213 if !self.hnsw_indices.contains_key(name) {
214 let index = HnswIndex::load(name, &mut self.inner)?;
215 self.hnsw_indices.insert(
216 name.to_string(),
217 HnswTxnEntry {
218 index,
219 state: HnswTransactionState::default(),
220 dirty: false,
221 },
222 );
223 }
224 Ok(&self.hnsw_indices.get(name).expect("inserted above").index)
225 }
226
227 pub(crate) fn hnsw_entry_mut(&mut self, name: &str) -> CoreResult<&mut HnswTxnEntry> {
229 if !self.hnsw_indices.contains_key(name) {
230 let index = HnswIndex::load(name, &mut self.inner)?;
231 self.hnsw_indices.insert(
232 name.to_string(),
233 HnswTxnEntry {
234 index,
235 state: HnswTransactionState::default(),
236 dirty: false,
237 },
238 );
239 }
240 Ok(self.hnsw_indices.get_mut(name).expect("inserted above"))
241 }
242
243 pub(crate) fn ensure_write_txn(&self) -> CoreResult<()> {
244 if self.mode != TxnMode::ReadWrite {
245 return Err(CoreError::TxnConflict);
246 }
247 Ok(())
248 }
249
250 pub(crate) fn inner_mut(&mut self) -> &mut S::Transaction<'a> {
252 &mut self.inner
253 }
254
255 pub fn delete_prefix(&mut self, prefix: &[u8]) -> Result<()> {
257 const BATCH: usize = 512;
259 loop {
260 let mut keys = Vec::with_capacity(BATCH);
261 {
262 let iter = self.inner.scan_prefix(prefix)?;
263 for (key, _) in iter.take(BATCH) {
264 keys.push(key);
265 }
266 }
267
268 if keys.is_empty() {
269 break;
270 }
271
272 for key in keys {
273 self.inner.delete(key)?;
274 }
275 }
276
277 Ok(())
278 }
279
280 pub fn with_table<R, F>(&mut self, table_meta: &TableMetadata, f: F) -> Result<R>
287 where
288 F: FnOnce(&mut TableStorage<'_, 'a, S::Transaction<'a>>) -> Result<R>,
289 {
290 let mut storage = self.table_storage(table_meta);
291 f(&mut storage)
292 }
293
294 pub fn with_index<R, F>(
299 &mut self,
300 index_id: u32,
301 unique: bool,
302 column_indices: Vec<usize>,
303 f: F,
304 ) -> Result<R>
305 where
306 F: FnOnce(&mut IndexStorage<'_, 'a, S::Transaction<'a>>) -> Result<R>,
307 {
308 let mut storage = self.index_storage(index_id, unique, column_indices);
309 f(&mut storage)
310 }
311
312 pub fn commit(mut self) -> Result<()> {
316 self.commit_hnsw()?;
317 self.inner.commit_self()?;
318 Ok(())
319 }
320
321 pub fn rollback(mut self) -> Result<()> {
325 self.rollback_hnsw()?;
326 self.inner.rollback_self()?;
327 Ok(())
328 }
329
330 fn commit_hnsw(&mut self) -> Result<()> {
331 for entry in self.hnsw_indices.values_mut() {
332 if entry.dirty {
333 entry
334 .index
335 .commit_staged(&mut self.inner, &mut entry.state)?;
336 }
337 }
338 self.hnsw_indices.clear();
339 Ok(())
340 }
341
342 fn rollback_hnsw(&mut self) -> Result<()> {
343 for entry in self.hnsw_indices.values_mut() {
344 if entry.dirty {
345 entry.index.rollback(&mut entry.state)?;
346 }
347 }
348 self.hnsw_indices.clear();
349 Ok(())
350 }
351}
352
353impl<'a, S: KVStore + 'a> SqlTxn<'a, S> for SqlTransaction<'a, S> {
354 fn mode(&self) -> TxnMode {
355 self.mode()
356 }
357
358 fn ensure_write_txn(&self) -> CoreResult<()> {
359 self.ensure_write_txn()
360 }
361
362 fn inner_mut(&mut self) -> &mut S::Transaction<'a> {
363 self.inner_mut()
364 }
365
366 fn hnsw_entry(&mut self, name: &str) -> CoreResult<&HnswIndex> {
367 self.hnsw_entry(name)
368 }
369
370 fn hnsw_entry_mut(&mut self, name: &str) -> CoreResult<&mut HnswTxnEntry> {
371 self.hnsw_entry_mut(name)
372 }
373
374 fn flush_hnsw(&mut self) -> Result<()> {
375 self.commit_hnsw()
376 }
377
378 fn abandon_hnsw(&mut self) -> Result<()> {
379 self.rollback_hnsw()
380 }
381
382 fn delete_prefix(&mut self, prefix: &[u8]) -> Result<()> {
383 self.delete_prefix(prefix)
384 }
385}
386
387pub struct BorrowedSqlTransaction<'a, 'b, 'c, S: KVStore + 'b> {
388 inner: &'a mut S::Transaction<'b>,
389 mode: TxnMode,
390 overlay: &'c mut CatalogOverlay,
391 hnsw_indices: HashMap<String, HnswTxnEntry>,
392}
393
394impl<'a, 'b, 'c, S: KVStore + 'b> BorrowedSqlTransaction<'a, 'b, 'c, S> {
395 pub fn mode(&self) -> TxnMode {
396 self.mode
397 }
398
399 pub fn split_parts(&mut self) -> (BorrowedSqlTxn<'_, 'b, S>, &mut CatalogOverlay) {
400 (
401 BorrowedSqlTxn {
402 inner: self.inner,
403 mode: self.mode,
404 hnsw_indices: &mut self.hnsw_indices,
405 },
406 self.overlay,
407 )
408 }
409}
410
411impl<'a, 'b, 'c, S: KVStore + 'b> Drop for BorrowedSqlTransaction<'a, 'b, 'c, S> {
412 fn drop(&mut self) {
413 for entry in self.hnsw_indices.values_mut() {
414 if entry.dirty {
415 let _ = entry.index.rollback(&mut entry.state);
416 entry.dirty = false;
417 }
418 }
419 self.hnsw_indices.clear();
420 }
421}
422
423pub struct BorrowedSqlTxn<'a, 'b, S: KVStore + 'b> {
424 inner: &'a mut S::Transaction<'b>,
425 mode: TxnMode,
426 hnsw_indices: &'a mut HashMap<String, HnswTxnEntry>,
427}
428
429impl<'a, 'b, S: KVStore + 'b> SqlTxn<'b, S> for BorrowedSqlTxn<'a, 'b, S> {
430 fn mode(&self) -> TxnMode {
431 self.mode
432 }
433
434 fn ensure_write_txn(&self) -> CoreResult<()> {
435 if self.mode != TxnMode::ReadWrite {
436 return Err(CoreError::TxnReadOnly);
437 }
438 Ok(())
439 }
440
441 fn inner_mut(&mut self) -> &mut S::Transaction<'b> {
442 self.inner
443 }
444
445 fn hnsw_entry(&mut self, name: &str) -> CoreResult<&HnswIndex> {
446 if !self.hnsw_indices.contains_key(name) {
447 let index = HnswIndex::load(name, self.inner)?;
448 self.hnsw_indices.insert(
449 name.to_string(),
450 HnswTxnEntry {
451 index,
452 state: HnswTransactionState::default(),
453 dirty: false,
454 },
455 );
456 }
457 Ok(&self.hnsw_indices.get(name).expect("inserted above").index)
458 }
459
460 fn hnsw_entry_mut(&mut self, name: &str) -> CoreResult<&mut HnswTxnEntry> {
461 if !self.hnsw_indices.contains_key(name) {
462 let index = HnswIndex::load(name, self.inner)?;
463 self.hnsw_indices.insert(
464 name.to_string(),
465 HnswTxnEntry {
466 index,
467 state: HnswTransactionState::default(),
468 dirty: false,
469 },
470 );
471 }
472 Ok(self.hnsw_indices.get_mut(name).expect("inserted above"))
473 }
474
475 fn flush_hnsw(&mut self) -> Result<()> {
476 for entry in self.hnsw_indices.values_mut() {
477 if entry.dirty {
478 entry.index.commit_staged(self.inner, &mut entry.state)?;
479 entry.dirty = false;
480 }
481 }
482 self.hnsw_indices.clear();
483 Ok(())
484 }
485
486 fn abandon_hnsw(&mut self) -> Result<()> {
487 for entry in self.hnsw_indices.values_mut() {
488 if entry.dirty {
489 entry.index.rollback(&mut entry.state)?;
490 entry.dirty = false;
491 }
492 }
493 self.hnsw_indices.clear();
494 Ok(())
495 }
496
497 fn delete_prefix(&mut self, prefix: &[u8]) -> Result<()> {
498 const BATCH: usize = 512;
500 loop {
501 let mut keys = Vec::with_capacity(BATCH);
502 {
503 let iter = self.inner.scan_prefix(prefix)?;
504 for (key, _) in iter.take(BATCH) {
505 keys.push(key);
506 }
507 }
508
509 if keys.is_empty() {
510 break;
511 }
512
513 for key in keys {
514 self.inner.delete(key)?;
515 }
516 }
517
518 Ok(())
519 }
520}
521
522pub type TxnContext<'a, S> = SqlTransaction<'a, S>;
524
525#[cfg(test)]
526mod tests {
527 use super::super::SqlValue;
528 use super::*;
529 use crate::catalog::ColumnMetadata;
530 use crate::planner::types::ResolvedType;
531 use alopex_core::kv::memory::MemoryKV;
532 use alopex_core::types::TxnMode;
533 use std::sync::Arc;
534
535 fn sample_table_meta() -> TableMetadata {
536 TableMetadata::new(
537 "users",
538 vec![
539 ColumnMetadata::new("id", ResolvedType::Integer)
540 .with_primary_key(true)
541 .with_not_null(true),
542 ColumnMetadata::new("name", ResolvedType::Text).with_not_null(true),
543 ],
544 )
545 .with_table_id(1)
546 }
547
548 #[test]
549 fn read_txn_mode_is_readonly() {
550 let store = Arc::new(MemoryKV::new());
551 let bridge = TxnBridge::new(store);
552
553 bridge
554 .with_read_txn(|ctx| {
555 assert_eq!(ctx.mode(), TxnMode::ReadOnly);
556 Ok(())
557 })
558 .unwrap();
559 }
560
561 #[test]
562 fn write_txn_mode_is_readwrite() {
563 let store = Arc::new(MemoryKV::new());
564 let bridge = TxnBridge::new(store);
565
566 bridge
567 .with_write_txn(|ctx| {
568 assert_eq!(ctx.mode(), TxnMode::ReadWrite);
569 Ok(())
570 })
571 .unwrap();
572 }
573
574 #[test]
575 fn commit_persists_changes_and_read_sees_them() {
576 let store = Arc::new(MemoryKV::new());
577 let bridge = TxnBridge::new(store.clone());
578 let meta = sample_table_meta();
579
580 bridge
582 .with_write_txn(|ctx| {
583 ctx.with_table(&meta, |table| {
584 table.insert(1, &[SqlValue::Integer(1), SqlValue::Text("alice".into())])
585 })
586 })
587 .unwrap();
588
589 let row = bridge
591 .with_read_txn(|ctx| ctx.with_table(&meta, |table| table.get(1)))
592 .unwrap()
593 .unwrap();
594
595 assert_eq!(row[1], SqlValue::Text("alice".into()));
596 }
597
598 #[test]
599 fn rollback_discards_uncommitted_writes() {
600 let store = Arc::new(MemoryKV::new());
601 let bridge = TxnBridge::new(store.clone());
602 let meta = sample_table_meta();
603
604 bridge
606 .with_write_txn_explicit(|ctx| {
607 ctx.with_table(&meta, |table| {
608 table.insert(1, &[SqlValue::Integer(1), SqlValue::Text("bob".into())])
609 })?;
610 Ok(((), false)) })
612 .unwrap();
613
614 let row = bridge
616 .with_read_txn(|ctx| ctx.with_table(&meta, |table| table.get(1)))
617 .unwrap();
618
619 assert!(row.is_none());
620 }
621
622 #[test]
623 fn conflicting_commits_trigger_transaction_conflict() {
624 let store = Arc::new(MemoryKV::new());
625 let bridge = TxnBridge::new(store);
626 let meta = sample_table_meta();
627
628 let mut txn1 = bridge.begin_write().unwrap();
630 {
631 let mut table = txn1.table_storage(&meta);
632 table
633 .insert(1, &[SqlValue::Integer(1), SqlValue::Text("alice".into())])
634 .unwrap();
635 }
636
637 let mut txn2 = bridge.begin_write().unwrap();
639 {
640 let mut table = txn2.table_storage(&meta);
641 table
642 .insert(1, &[SqlValue::Integer(1), SqlValue::Text("bob".into())])
643 .unwrap();
644 }
645
646 txn1.commit().unwrap();
648 let err = txn2.commit().unwrap_err();
649 assert!(matches!(
650 err,
651 super::super::StorageError::TransactionConflict
652 ));
653 }
654
655 #[test]
656 fn scan_rows_in_transaction() {
657 let store = Arc::new(MemoryKV::new());
658 let bridge = TxnBridge::new(store.clone());
659 let meta = sample_table_meta();
660
661 bridge
663 .with_write_txn(|ctx| {
664 ctx.with_table(&meta, |table| {
665 for i in 1..=3 {
666 table.insert(
667 i,
668 &[
669 SqlValue::Integer(i as i32),
670 SqlValue::Text(format!("user{i}")),
671 ],
672 )?;
673 }
674 Ok(())
675 })
676 })
677 .unwrap();
678
679 let rows: Vec<u64> = bridge
681 .with_read_txn(|ctx| {
682 ctx.with_table(&meta, |table| {
683 let iter = table.scan()?;
684 let ids: Vec<u64> = iter.filter_map(|r| r.ok().map(|(id, _)| id)).collect();
685 Ok(ids)
686 })
687 })
688 .unwrap();
689
690 assert_eq!(rows, vec![1, 2, 3]);
691 }
692}