alopex_core/kv/
memory.rs

1//! An in-memory key-value store implementation with Write-Ahead Logging
2//! and Optimistic Concurrency Control for Snapshot Isolation.
3
4use crate::error::{Error, Result};
5use crate::kv::{KVStore, KVTransaction};
6use crate::log::wal::{WalReader, WalRecord, WalWriter};
7use crate::storage::flush::write_empty_vector_segment;
8use crate::storage::sstable::{SstableReader, SstableWriter};
9use crate::txn::TxnManager;
10use crate::types::{Key, TxnId, TxnMode, TxnState, Value};
11use std::collections::{BTreeMap, HashMap};
12use std::ops::Bound::{Excluded, Included};
13use std::path::{Path, PathBuf};
14use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
15use std::sync::{Arc, RwLock, RwLockReadGuard};
16use tracing::warn;
17
18/// メモリ使用量の統計(バイト単位)。
19#[derive(Debug, Clone, Default)]
20pub struct MemoryStats {
21    /// 全体のメモリ使用量。
22    pub total_bytes: usize,
23    /// KV データのメモリ使用量。
24    pub kv_bytes: usize,
25    /// 補助インデックスのメモリ使用量。
26    pub index_bytes: usize,
27}
28
29/// An in-memory key-value store.
30#[derive(Clone)]
31pub struct MemoryKV {
32    manager: Arc<MemoryTxnManager>,
33}
34
35impl MemoryKV {
36    /// Creates a new, purely transient in-memory KV store.
37    pub fn new() -> Self {
38        Self {
39            manager: Arc::new(MemoryTxnManager::new(None, None, None)),
40        }
41    }
42
43    /// Returns current in-memory usage statistics.
44    pub fn memory_stats(&self) -> MemoryStats {
45        self.manager.memory_stats()
46    }
47
48    /// Creates a new in-memory KV store with an optional memory limit.
49    pub fn new_with_limit(limit: Option<usize>) -> Self {
50        Self {
51            manager: Arc::new(MemoryTxnManager::new_with_limit(limit)),
52        }
53    }
54
55    /// Opens a persistent in-memory KV store from a file path.
56    pub fn open(path: &Path) -> Result<Self> {
57        let wal_writer = WalWriter::new(path)?;
58        let sstable_path = path.with_extension("sst");
59        let manager = Arc::new(MemoryTxnManager::new(
60            Some(wal_writer),
61            Some(path.to_path_buf()),
62            Some(sstable_path),
63        ));
64        manager.recover()?;
65        Ok(Self { manager })
66    }
67
68    /// Flushes the in-memory data to an SSTable.
69    pub fn flush(&self) -> Result<()> {
70        self.manager.flush()
71    }
72}
73
74impl Default for MemoryKV {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl KVStore for MemoryKV {
81    type Transaction<'a> = MemoryTransaction<'a>;
82    type Manager<'a> = &'a MemoryTxnManager;
83
84    fn txn_manager(&self) -> Self::Manager<'_> {
85        &self.manager
86    }
87
88    fn begin(&self, mode: TxnMode) -> Result<Self::Transaction<'_>> {
89        self.manager.begin_internal(mode)
90    }
91}
92
93// The internal value stored in the BTreeMap, containing the data and its version.
94type VersionedValue = (Value, u64);
95
96/// The underlying shared state for the in-memory store.
97struct MemorySharedState {
98    /// The main data store, mapping keys to versioned values.
99    data: RwLock<BTreeMap<Key, VersionedValue>>,
100    /// The next transaction ID to be allocated.
101    next_txn_id: AtomicU64,
102    /// The current commit version of the database. Incremented on every successful commit.
103    commit_version: AtomicU64,
104    /// The WAL writer. If None, the store is transient.
105    wal_writer: Option<RwLock<WalWriter>>,
106    /// Optional WAL path for replay on reopen.
107    wal_path: Option<PathBuf>,
108    /// Optional SSTable reader for read-through.
109    sstable: RwLock<Option<SstableReader>>,
110    /// Optional SSTable path for flush/reopen.
111    sstable_path: Option<PathBuf>,
112    /// Optional memory upper limit (bytes) for in-memory mode。
113    memory_limit: RwLock<Option<usize>>,
114    /// Current memory consumption (bytes) tracked across operations。
115    current_memory: AtomicUsize,
116}
117
118impl MemorySharedState {
119    /// Check whether adding `additional` bytes would exceed the memory limit.
120    fn check_memory_limit(&self, additional: usize) -> Result<()> {
121        if let Some(limit) = *self.memory_limit.read().unwrap() {
122            let current = self.current_memory.load(Ordering::Relaxed);
123            let requested = current.saturating_add(additional);
124            if requested > limit {
125                return Err(Error::MemoryLimitExceeded { limit, requested });
126            }
127        }
128        Ok(())
129    }
130
131    /// Return current memory usage statistics.
132    fn memory_stats(&self) -> MemoryStats {
133        let kv_bytes = self.current_memory.load(Ordering::Relaxed);
134        MemoryStats {
135            total_bytes: kv_bytes,
136            kv_bytes,
137            index_bytes: 0,
138        }
139    }
140
141    /// Recompute tracked memory usage from existing data (used after recovery).
142    fn recompute_current_memory(&self) {
143        let data = self.data.read().unwrap();
144        let mut total = 0usize;
145        for (k, (v, _)) in data.iter() {
146            total = total.saturating_add(k.len() + v.len());
147        }
148        self.current_memory.store(total, Ordering::Relaxed);
149    }
150}
151
152/// A transaction manager backed by an in-memory map and optional WAL.
153pub struct MemoryTxnManager {
154    state: Arc<MemorySharedState>,
155}
156
157impl MemoryTxnManager {
158    fn new_with_params(
159        wal_writer: Option<WalWriter>,
160        wal_path: Option<PathBuf>,
161        sstable_path: Option<PathBuf>,
162        memory_limit: Option<usize>,
163    ) -> Self {
164        Self {
165            state: Arc::new(MemorySharedState {
166                data: RwLock::new(BTreeMap::new()),
167                next_txn_id: AtomicU64::new(1),
168                commit_version: AtomicU64::new(0),
169                wal_writer: wal_writer.map(RwLock::new),
170                wal_path,
171                sstable: RwLock::new(None),
172                sstable_path,
173                memory_limit: RwLock::new(memory_limit),
174                current_memory: AtomicUsize::new(0),
175            }),
176        }
177    }
178
179    fn new(
180        wal_writer: Option<WalWriter>,
181        wal_path: Option<PathBuf>,
182        sstable_path: Option<PathBuf>,
183    ) -> Self {
184        Self::new_with_params(wal_writer, wal_path, sstable_path, None)
185    }
186
187    /// Creates an in-memory manager with an optional memory limit.
188    pub fn new_with_limit(limit: Option<usize>) -> Self {
189        Self::new_with_params(None, None, None, limit)
190    }
191
192    /// Returns current memory usage statistics.
193    pub fn memory_stats(&self) -> MemoryStats {
194        self.state.memory_stats()
195    }
196
197    /// Update the configured memory limit at runtime.
198    pub fn set_memory_limit(&self, limit: Option<usize>) {
199        let mut guard = self.state.memory_limit.write().unwrap();
200        *guard = limit;
201    }
202
203    /// Returns a snapshot clone of all key/value pairs.
204    pub fn snapshot(&self) -> Vec<(Key, Value)> {
205        let data = self.state.data.read().unwrap();
206        data.iter()
207            .map(|(k, (v, _))| (k.clone(), v.clone()))
208            .collect()
209    }
210
211    /// Clears all data and resets memory accounting.
212    pub fn clear_all(&self) {
213        let mut data = self.state.data.write().unwrap();
214        data.clear();
215        drop(data);
216        self.state.current_memory.store(0, Ordering::Relaxed);
217        self.state.commit_version.store(0, Ordering::Relaxed);
218    }
219
220    /// Runs compaction if it can fit within the configured memory limit.
221    /// Returns Ok(true) when compaction executed, Ok(false) when skipped.
222    pub fn compact_with_limit<F>(
223        &self,
224        input_bytes: usize,
225        output_bytes: usize,
226        run: F,
227    ) -> Result<bool>
228    where
229        F: FnOnce() -> Result<()>,
230    {
231        if let Some(limit) = *self.state.memory_limit.read().unwrap() {
232            let current = self.state.current_memory.load(Ordering::Relaxed);
233            // predicted usage after compaction: current - input + output (clamped at 0)
234            let prospective = current
235                .saturating_sub(input_bytes)
236                .saturating_add(output_bytes);
237            if prospective > limit {
238                warn!(
239                    limit,
240                    requested = prospective,
241                    "compaction skipped due to memory limit"
242                );
243                return Ok(false);
244            }
245        }
246
247        run()?;
248
249        // Update tracked memory to reflect compaction result.
250        let current = self.state.current_memory.load(Ordering::Relaxed);
251        let new_usage = current
252            .saturating_sub(input_bytes)
253            .saturating_add(output_bytes);
254        self.state
255            .current_memory
256            .store(new_usage, Ordering::Relaxed);
257        Ok(true)
258    }
259
260    /// In-memory compaction entrypoint that rebuilds the map while honoring memory limits.
261    pub fn compact_in_memory(&self) -> Result<bool> {
262        let snapshot_bytes = {
263            let data = self.state.data.read().unwrap();
264            let mut bytes = 0usize;
265            for (k, (v, _)) in data.iter() {
266                bytes = bytes.saturating_add(k.len() + v.len());
267            }
268            bytes
269        };
270
271        self.compact_with_limit(snapshot_bytes, snapshot_bytes, || {
272            let data = self.state.data.read().unwrap();
273            let mut rebuilt = BTreeMap::new();
274            for (k, (v, version)) in data.iter() {
275                rebuilt.insert(k.clone(), (v.clone(), *version));
276            }
277            drop(data);
278
279            let mut write_guard = self.state.data.write().unwrap();
280            *write_guard = rebuilt;
281            Ok(())
282        })
283    }
284
285    /// Flushes the current in-memory data to an SSTable file.
286    pub fn flush(&self) -> Result<()> {
287        let Some(path) = self.state.sstable_path.as_ref() else {
288            return Ok(());
289        };
290
291        let data = self.state.data.read().unwrap();
292        let mut writer = SstableWriter::create(path)?;
293        for (key, (value, _version)) in data.iter() {
294            writer.append(key, value)?;
295        }
296        drop(data);
297
298        let _footer = writer.finish()?;
299        let reader = SstableReader::open(path)?;
300        // Also emit a placeholder vector segment alongside SSTable for future vector recovery.
301        let vec_path = path.with_extension("vec");
302        write_empty_vector_segment(&vec_path)?;
303
304        let mut slot = self.state.sstable.write().unwrap();
305        *slot = Some(reader);
306        Ok(())
307    }
308
309    /// Replays the WAL to restore the state of the in-memory map.
310    fn replay(&self) -> Result<()> {
311        let path = match &self.state.wal_path {
312            Some(p) => p,
313            None => return Ok(()),
314        };
315        if !path.exists() || std::fs::metadata(path)?.len() == 0 {
316            return Ok(());
317        }
318
319        let mut data = self.state.data.write().unwrap();
320        let mut max_txn_id = 0;
321        let mut max_version = self.state.commit_version.load(Ordering::Acquire);
322        let reader = WalReader::new(path)?;
323        let mut pending_txns: HashMap<TxnId, Vec<(Key, Option<Value>)>> = HashMap::new();
324
325        for record_result in reader {
326            match record_result? {
327                WalRecord::Begin(txn_id) => {
328                    max_txn_id = max_txn_id.max(txn_id.0);
329                    pending_txns.entry(txn_id).or_default();
330                }
331                WalRecord::Put(txn_id, key, value) => {
332                    max_txn_id = max_txn_id.max(txn_id.0);
333                    pending_txns
334                        .entry(txn_id)
335                        .or_default()
336                        .push((key, Some(value)));
337                }
338                WalRecord::Delete(txn_id, key) => {
339                    max_txn_id = max_txn_id.max(txn_id.0);
340                    pending_txns.entry(txn_id).or_default().push((key, None));
341                }
342                WalRecord::Commit(txn_id) => {
343                    if let Some(writes) = pending_txns.remove(&txn_id) {
344                        max_version += 1;
345                        for (key, value) in writes {
346                            if let Some(v) = value {
347                                data.insert(key, (v, max_version));
348                            } else {
349                                data.remove(&key);
350                            }
351                        }
352                    }
353                }
354            }
355        }
356
357        self.state
358            .next_txn_id
359            .store(max_txn_id + 1, Ordering::SeqCst);
360        self.state
361            .commit_version
362            .store(max_version, Ordering::SeqCst);
363        Ok(())
364    }
365
366    fn load_sstable(&self) -> Result<()> {
367        let path = match &self.state.sstable_path {
368            Some(p) => p,
369            None => return Ok(()),
370        };
371        if !path.exists() {
372            return Ok(());
373        }
374
375        let mut reader = SstableReader::open(path)?;
376        let mut data = self.state.data.write().unwrap();
377        let mut version = self.state.commit_version.load(Ordering::Acquire);
378
379        let keys: Vec<Key> = reader
380            .index()
381            .iter()
382            .map(|entry| entry.key.clone())
383            .collect();
384
385        for key in keys {
386            if let Some(value) = reader.get(&key)? {
387                version += 1;
388                data.insert(key, (value, version));
389            }
390        }
391
392        self.state.commit_version.store(version, Ordering::SeqCst);
393        let mut slot = self.state.sstable.write().unwrap();
394        *slot = Some(reader);
395        Ok(())
396    }
397
398    /// Loads SSTable then replays WAL to restore state.
399    fn recover(&self) -> Result<()> {
400        self.load_sstable()?;
401        self.replay()?;
402        self.state.recompute_current_memory();
403        Ok(())
404    }
405
406    fn sstable_get(&self, key: &Key) -> Result<Option<Value>> {
407        let mut guard = self.state.sstable.write().unwrap();
408        if let Some(reader) = guard.as_mut() {
409            return reader.get(key);
410        }
411        Ok(None)
412    }
413
414    fn begin_internal(&self, mode: TxnMode) -> Result<MemoryTransaction<'_>> {
415        let txn_id = self.state.next_txn_id.fetch_add(1, Ordering::SeqCst);
416        let start_version = self.state.commit_version.load(Ordering::Acquire);
417        Ok(MemoryTransaction::new(
418            self,
419            TxnId(txn_id),
420            mode,
421            start_version,
422        ))
423    }
424}
425
426impl<'a> TxnManager<'a, MemoryTransaction<'a>> for &'a MemoryTxnManager {
427    fn begin(&'a self, mode: TxnMode) -> Result<MemoryTransaction<'a>> {
428        self.begin_internal(mode)
429    }
430
431    fn commit(&'a self, mut txn: MemoryTransaction<'a>) -> Result<()> {
432        if txn.state != TxnState::Active {
433            return Err(Error::TxnClosed);
434        }
435        if txn.mode == TxnMode::ReadOnly || txn.writes.is_empty() {
436            txn.state = TxnState::Committed;
437            return Ok(());
438        }
439
440        let mut data = self.state.data.write().unwrap();
441
442        for key in txn.read_set.keys() {
443            let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
444            if current_version > txn.start_version {
445                return Err(Error::TxnConflict);
446            }
447        }
448
449        // Detect write-write conflicts even when the key was never read.
450        for key in txn.writes.keys() {
451            let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
452            if current_version > txn.start_version {
453                return Err(Error::TxnConflict);
454            }
455        }
456
457        // Compute prospective memory usage and enforce limits before mutating state.
458        let mut delta: isize = 0;
459        for (key, value) in &txn.writes {
460            let current_size = data.get(key).map(|(v, _)| key.len() + v.len()).unwrap_or(0);
461            let new_size = match value {
462                Some(v) => key.len() + v.len(),
463                None => 0,
464            };
465            delta += new_size as isize - current_size as isize;
466        }
467
468        let current_mem = self.state.current_memory.load(Ordering::Relaxed);
469        let prospective = if delta >= 0 {
470            current_mem.saturating_add(delta as usize)
471        } else {
472            current_mem.saturating_sub(delta.unsigned_abs())
473        };
474
475        if delta > 0 {
476            self.state.check_memory_limit(delta as usize)?;
477        }
478
479        let commit_version = self.state.commit_version.fetch_add(1, Ordering::AcqRel) + 1;
480
481        if let Some(wal_lock) = &self.state.wal_writer {
482            let mut wal = wal_lock.write().unwrap();
483            wal.append(&WalRecord::Begin(txn.id))?;
484            for (key, value) in &txn.writes {
485                let record = match value {
486                    Some(v) => WalRecord::Put(txn.id, key.clone(), v.clone()),
487                    None => WalRecord::Delete(txn.id, key.clone()),
488                };
489                wal.append(&record)?;
490            }
491            wal.append(&WalRecord::Commit(txn.id))?;
492        }
493
494        for (key, value) in std::mem::take(&mut txn.writes) {
495            if let Some(v) = value {
496                data.insert(key, (v, commit_version));
497            } else {
498                data.remove(&key);
499            }
500        }
501
502        self.state
503            .current_memory
504            .store(prospective, Ordering::Relaxed);
505
506        txn.state = TxnState::Committed;
507        Ok(())
508    }
509
510    fn rollback(&'a self, mut txn: MemoryTransaction<'a>) -> Result<()> {
511        if txn.state != TxnState::Active {
512            return Err(Error::TxnClosed);
513        }
514        txn.state = TxnState::RolledBack;
515        Ok(())
516    }
517}
518
519/// An in-memory transaction that enforces snapshot isolation.
520pub struct MemoryTransaction<'a> {
521    manager: &'a MemoryTxnManager,
522    id: TxnId,
523    mode: TxnMode,
524    state: TxnState,
525    start_version: u64,
526    writes: BTreeMap<Key, Option<Value>>,
527    read_set: HashMap<Key, u64>,
528}
529
530impl<'a> MemoryTransaction<'a> {
531    fn new(manager: &'a MemoryTxnManager, id: TxnId, mode: TxnMode, start_version: u64) -> Self {
532        Self {
533            manager,
534            id,
535            mode,
536            state: TxnState::Active,
537            start_version,
538            writes: BTreeMap::new(),
539            read_set: HashMap::new(),
540        }
541    }
542
543    fn ensure_active(&self) -> Result<()> {
544        if self.state != TxnState::Active {
545            return Err(Error::TxnClosed);
546        }
547        Ok(())
548    }
549
550    /// トランザクションを消費せずにロールバックする。
551    pub(crate) fn rollback_in_place(&mut self) -> Result<()> {
552        if self.state != TxnState::Active {
553            return Err(Error::TxnClosed);
554        }
555        self.state = TxnState::RolledBack;
556        Ok(())
557    }
558
559    fn scan_range_internal(&mut self, start: &[u8], end: &[u8]) -> MergedScanIter<'_> {
560        let start_vec = start.to_vec();
561        let end_vec = end.to_vec();
562        let data_guard = self.manager.state.data.read().unwrap();
563        let data_ptr: *const BTreeMap<Key, VersionedValue> = &*data_guard;
564        let data_iter = unsafe {
565            // Safety: data_guard keeps the map alive for the lifetime of the iterator.
566            (&*data_ptr).range((Included(start_vec.clone()), Excluded(end_vec.clone())))
567        };
568        let write_iter = self
569            .writes
570            .range((Included(start_vec.clone()), Excluded(end_vec.clone())));
571
572        MergedScanIter::new(
573            data_guard,
574            data_iter,
575            write_iter,
576            None,
577            Some(end_vec),
578            self.start_version,
579            &mut self.read_set,
580        )
581    }
582
583    fn scan_prefix_internal(&mut self, prefix: &[u8]) -> MergedScanIter<'_> {
584        let prefix_vec = prefix.to_vec();
585        let data_guard = self.manager.state.data.read().unwrap();
586        let data_ptr: *const BTreeMap<Key, VersionedValue> = &*data_guard;
587        let data_iter = unsafe {
588            // Safety: data_guard keeps the map alive for the lifetime of the iterator.
589            (&*data_ptr).range(prefix_vec.clone()..)
590        };
591        let write_iter = self.writes.range(prefix_vec.clone()..);
592        MergedScanIter::new(
593            data_guard,
594            data_iter,
595            write_iter,
596            Some(prefix_vec),
597            None,
598            self.start_version,
599            &mut self.read_set,
600        )
601    }
602}
603
604impl<'a> KVTransaction<'a> for MemoryTransaction<'a> {
605    fn id(&self) -> TxnId {
606        self.id
607    }
608
609    fn mode(&self) -> TxnMode {
610        self.mode
611    }
612
613    fn get(&mut self, key: &Key) -> Result<Option<Value>> {
614        if self.state != TxnState::Active {
615            return Err(Error::TxnClosed);
616        }
617
618        if let Some(value) = self.writes.get(key) {
619            return Ok(value.clone());
620        }
621
622        let result = {
623            let data = self.manager.state.data.read().unwrap();
624            data.get(key).cloned()
625        };
626
627        if let Some((v, version)) = result {
628            self.read_set.insert(key.clone(), version);
629            return Ok(Some(v));
630        }
631
632        // Read-through to SSTable if not found in memory.
633        if let Some(value) = self.manager.sstable_get(key)? {
634            let version = self.manager.state.commit_version.load(Ordering::Acquire);
635            self.read_set.insert(key.clone(), version);
636            return Ok(Some(value));
637        }
638
639        Ok(None)
640    }
641
642    fn put(&mut self, key: Key, value: Value) -> Result<()> {
643        if self.state != TxnState::Active {
644            return Err(Error::TxnClosed);
645        }
646        if self.mode == TxnMode::ReadOnly {
647            return Err(Error::TxnReadOnly);
648        }
649        self.writes.insert(key, Some(value));
650        Ok(())
651    }
652
653    fn delete(&mut self, key: Key) -> Result<()> {
654        if self.state != TxnState::Active {
655            return Err(Error::TxnClosed);
656        }
657        if self.mode == TxnMode::ReadOnly {
658            return Err(Error::TxnReadOnly);
659        }
660        self.writes.insert(key, None);
661        Ok(())
662    }
663
664    fn scan_prefix(
665        &mut self,
666        prefix: &[u8],
667    ) -> Result<Box<dyn Iterator<Item = (Key, Value)> + '_>> {
668        self.ensure_active()?;
669        let iter = self
670            .scan_prefix_internal(prefix)
671            .filter_map(|(k, v)| v.map(|val| (k, val)));
672        Ok(Box::new(iter))
673    }
674
675    fn scan_range(
676        &mut self,
677        start: &[u8],
678        end: &[u8],
679    ) -> Result<Box<dyn Iterator<Item = (Key, Value)> + '_>> {
680        self.ensure_active()?;
681        let iter = self
682            .scan_range_internal(start, end)
683            .filter_map(|(k, v)| v.map(|val| (k, val)));
684        Ok(Box::new(iter))
685    }
686
687    fn commit_self(mut self) -> Result<()> {
688        if self.state != TxnState::Active {
689            return Err(Error::TxnClosed);
690        }
691        if self.mode == TxnMode::ReadOnly || self.writes.is_empty() {
692            self.state = TxnState::Committed;
693            return Ok(());
694        }
695
696        let mut data = self.manager.state.data.write().unwrap();
697
698        // Check read-set for conflicts
699        for key in self.read_set.keys() {
700            let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
701            if current_version > self.start_version {
702                return Err(Error::TxnConflict);
703            }
704        }
705
706        // Check write-write conflicts
707        for key in self.writes.keys() {
708            let current_version = data.get(key).map(|(_, v)| *v).unwrap_or(0);
709            if current_version > self.start_version {
710                return Err(Error::TxnConflict);
711            }
712        }
713
714        // Compute prospective memory usage
715        let mut delta: isize = 0;
716        for (key, value) in &self.writes {
717            let current_size = data.get(key).map(|(v, _)| key.len() + v.len()).unwrap_or(0);
718            let new_size = match value {
719                Some(v) => key.len() + v.len(),
720                None => 0,
721            };
722            delta += new_size as isize - current_size as isize;
723        }
724
725        let current_mem = self.manager.state.current_memory.load(Ordering::Relaxed);
726        let prospective = if delta >= 0 {
727            current_mem.saturating_add(delta as usize)
728        } else {
729            current_mem.saturating_sub(delta.unsigned_abs())
730        };
731
732        if delta > 0 {
733            self.manager.state.check_memory_limit(delta as usize)?;
734        }
735
736        let commit_version = self
737            .manager
738            .state
739            .commit_version
740            .fetch_add(1, Ordering::AcqRel)
741            + 1;
742
743        // WAL write
744        if let Some(wal_lock) = &self.manager.state.wal_writer {
745            let mut wal = wal_lock.write().unwrap();
746            wal.append(&WalRecord::Begin(self.id))?;
747            for (key, value) in &self.writes {
748                let record = match value {
749                    Some(v) => WalRecord::Put(self.id, key.clone(), v.clone()),
750                    None => WalRecord::Delete(self.id, key.clone()),
751                };
752                wal.append(&record)?;
753            }
754            wal.append(&WalRecord::Commit(self.id))?;
755        }
756
757        // Apply writes
758        for (key, value) in std::mem::take(&mut self.writes) {
759            if let Some(v) = value {
760                data.insert(key, (v, commit_version));
761            } else {
762                data.remove(&key);
763            }
764        }
765
766        self.manager
767            .state
768            .current_memory
769            .store(prospective, Ordering::Relaxed);
770
771        self.state = TxnState::Committed;
772        Ok(())
773    }
774
775    fn rollback_self(mut self) -> Result<()> {
776        if self.state != TxnState::Active {
777            return Err(Error::TxnClosed);
778        }
779        self.state = TxnState::RolledBack;
780        Ok(())
781    }
782}
783
784/// Lazy merge iterator that overlays in-flight writes onto a snapshot guard.
785struct MergedScanIter<'a> {
786    _data_guard: RwLockReadGuard<'a, BTreeMap<Key, VersionedValue>>,
787    data_iter: std::collections::btree_map::Range<'a, Key, VersionedValue>,
788    write_iter: std::collections::btree_map::Range<'a, Key, Option<Value>>,
789    data_peek: Option<(Key, (Value, u64))>,
790    write_peek: Option<(Key, Option<Value>)>,
791    prefix: Option<Vec<u8>>,
792    end: Option<Key>,
793    start_version: u64,
794    read_set: &'a mut HashMap<Key, u64>,
795}
796
797impl<'a> MergedScanIter<'a> {
798    #[allow(clippy::too_many_arguments)]
799    fn new(
800        data_guard: std::sync::RwLockReadGuard<'a, BTreeMap<Key, VersionedValue>>,
801        data_iter: std::collections::btree_map::Range<'a, Key, VersionedValue>,
802        write_iter: std::collections::btree_map::Range<'a, Key, Option<Value>>,
803        prefix: Option<Vec<u8>>,
804        end: Option<Key>,
805        start_version: u64,
806        read_set: &'a mut HashMap<Key, u64>,
807    ) -> Self {
808        let mut iter = Self {
809            _data_guard: data_guard,
810            data_iter,
811            write_iter,
812            data_peek: None,
813            write_peek: None,
814            prefix,
815            end,
816            start_version,
817            read_set,
818        };
819        iter.advance_data();
820        iter.advance_write();
821        iter
822    }
823
824    fn advance_data(&mut self) {
825        self.data_peek = None;
826        while let Some((k, (v, ver))) = self.data_iter.next().map(|(k, v)| (k.clone(), v.clone())) {
827            if let Some(end) = &self.end {
828                if k >= *end {
829                    return;
830                }
831            }
832            if let Some(prefix) = &self.prefix {
833                if !k.starts_with(prefix) {
834                    return;
835                }
836            }
837            if ver > self.start_version {
838                continue;
839            }
840            self.data_peek = Some((k, (v, ver)));
841            return;
842        }
843    }
844
845    fn advance_write(&mut self) {
846        self.write_peek = None;
847        if let Some((k, v)) = self.write_iter.next().map(|(k, v)| (k.clone(), v.clone())) {
848            if let Some(end) = &self.end {
849                if k >= *end {
850                    return;
851                }
852            }
853            if let Some(prefix) = &self.prefix {
854                if !k.starts_with(prefix) {
855                    return;
856                }
857            }
858            self.write_peek = Some((k, v));
859        }
860    }
861}
862
863impl<'a> Iterator for MergedScanIter<'a> {
864    type Item = (Key, Option<Value>);
865
866    fn next(&mut self) -> Option<Self::Item> {
867        let data_key = self.data_peek.as_ref().map(|(k, _)| k.clone());
868        let write_key = self.write_peek.as_ref().map(|(k, _)| k.clone());
869
870        match (data_key, write_key) {
871            (Some(dk), Some(wk)) => {
872                if dk == wk {
873                    let (_, (_, ver)) = self.data_peek.take().unwrap();
874                    let (_, write_val) = self.write_peek.take().unwrap();
875                    self.read_set.insert(dk.clone(), ver);
876                    self.advance_data();
877                    self.advance_write();
878                    Some((dk, write_val))
879                } else if dk < wk {
880                    let (k, (v, ver)) = self.data_peek.take().unwrap();
881                    self.read_set.insert(k.clone(), ver);
882                    self.advance_data();
883                    Some((k, Some(v)))
884                } else {
885                    let (k, write_val) = self.write_peek.take().unwrap();
886                    self.advance_write();
887                    Some((k, write_val))
888                }
889            }
890            (Some(_), None) => {
891                let (k, (v, ver)) = self.data_peek.take().unwrap();
892                self.read_set.insert(k.clone(), ver);
893                self.advance_data();
894                Some((k, Some(v)))
895            }
896            (None, Some(_)) => {
897                let (k, write_val) = self.write_peek.take().unwrap();
898                self.advance_write();
899                Some((k, write_val))
900            }
901            (None, None) => None,
902        }
903    }
904}
905
906impl<'a> Drop for MemoryTransaction<'a> {
907    fn drop(&mut self) {
908        if self.state == TxnState::Active {
909            self.state = TxnState::RolledBack;
910        }
911    }
912}
913
914#[cfg(all(test, not(target_arch = "wasm32")))]
915mod tests {
916    use super::*;
917    use crate::{KVTransaction, TxnManager};
918    use tempfile::tempdir;
919    use tracing::Level;
920
921    fn key(s: &str) -> Key {
922        s.as_bytes().to_vec()
923    }
924
925    fn value(s: &str) -> Value {
926        s.as_bytes().to_vec()
927    }
928
929    #[test]
930    fn test_put_and_get_transient() {
931        let store = MemoryKV::new();
932        let manager = store.txn_manager();
933        let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
934        txn.put(key("hello"), value("world")).unwrap();
935        let val = txn.get(&key("hello")).unwrap();
936        assert_eq!(val, Some(value("world")));
937        manager.commit(txn).unwrap();
938
939        let mut txn2 = manager.begin(TxnMode::ReadOnly).unwrap();
940        let val2 = txn2.get(&key("hello")).unwrap();
941        assert_eq!(val2, Some(value("world")));
942    }
943
944    #[test]
945    fn test_occ_conflict() {
946        let store = MemoryKV::new();
947        let manager = store.txn_manager();
948
949        let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
950        t1.get(&key("k1")).unwrap();
951
952        let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
953        t2.put(key("k1"), value("v2")).unwrap();
954        assert!(manager.commit(t2).is_ok());
955
956        t1.put(key("k1"), value("v1")).unwrap();
957        let result = manager.commit(t1);
958        assert!(matches!(result, Err(Error::TxnConflict)));
959    }
960
961    #[test]
962    fn test_blind_write_conflict() {
963        let store = MemoryKV::new();
964        let manager = store.txn_manager();
965
966        let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
967        t1.put(key("k1"), value("v1")).unwrap();
968
969        let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
970        t2.put(key("k1"), value("v2")).unwrap();
971        assert!(manager.commit(t2).is_ok());
972
973        let result = manager.commit(t1);
974        assert!(matches!(result, Err(Error::TxnConflict)));
975    }
976
977    #[test]
978    fn test_read_only_write_fails() {
979        let store = MemoryKV::new();
980        let manager = store.txn_manager();
981        let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
982        assert!(matches!(
983            txn.put(key("k1"), value("v1")),
984            Err(Error::TxnReadOnly)
985        ));
986        assert!(matches!(txn.delete(key("k1")), Err(Error::TxnReadOnly)));
987    }
988
989    #[test]
990    fn test_txn_closed_error() {
991        let store = MemoryKV::new();
992        let manager = store.txn_manager();
993        let txn = manager.begin(TxnMode::ReadWrite).unwrap();
994        manager.commit(txn).unwrap();
995
996        // This is tricky to test because commit takes ownership.
997        // We can test by creating a new txn and manually setting its state.
998        let mut closed_txn = manager.begin(TxnMode::ReadWrite).unwrap();
999        closed_txn.state = TxnState::Committed;
1000        assert!(matches!(closed_txn.get(&key("k1")), Err(Error::TxnClosed)));
1001        assert!(matches!(
1002            closed_txn.put(key("k1"), value("v1")),
1003            Err(Error::TxnClosed)
1004        ));
1005    }
1006
1007    #[test]
1008    fn test_get_not_found() {
1009        let store = MemoryKV::new();
1010        let manager = store.txn_manager();
1011        let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1012        let res = txn.get(&key("non-existent"));
1013        assert!(res.is_ok());
1014        assert!(res.unwrap().is_none());
1015    }
1016
1017    #[test]
1018    fn flush_and_reopen_reads_from_sstable() {
1019        let dir = tempdir().unwrap();
1020        let wal_path = dir.path().join("wal.log");
1021        {
1022            let store = MemoryKV::open(&wal_path).unwrap();
1023            let manager = store.txn_manager();
1024            let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1025            txn.put(key("k1"), value("v1")).unwrap();
1026            manager.commit(txn).unwrap();
1027            store.flush().unwrap();
1028        }
1029
1030        let reopened = MemoryKV::open(&wal_path).unwrap();
1031        let manager = reopened.txn_manager();
1032        let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1033        assert_eq!(txn.get(&key("k1")).unwrap(), Some(value("v1")));
1034    }
1035
1036    #[test]
1037    fn wal_overlays_sstable_on_reopen() {
1038        let dir = tempdir().unwrap();
1039        let wal_path = dir.path().join("wal.log");
1040        {
1041            let store = MemoryKV::open(&wal_path).unwrap();
1042            let manager = store.txn_manager();
1043            let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1044            txn.put(key("k1"), value("v1")).unwrap();
1045            manager.commit(txn).unwrap();
1046            store.flush().unwrap();
1047
1048            let mut txn2 = manager.begin(TxnMode::ReadWrite).unwrap();
1049            txn2.put(key("k1"), value("v2")).unwrap();
1050            manager.commit(txn2).unwrap();
1051        }
1052
1053        let reopened = MemoryKV::open(&wal_path).unwrap();
1054        let manager = reopened.txn_manager();
1055        let mut txn = manager.begin(TxnMode::ReadOnly).unwrap();
1056        assert_eq!(txn.get(&key("k1")).unwrap(), Some(value("v2")));
1057    }
1058
1059    #[test]
1060    fn scan_prefix_merges_snapshot_and_writes() {
1061        let store = MemoryKV::new();
1062        let manager = store.txn_manager();
1063
1064        let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1065        seed.put(key("p:1"), value("old1")).unwrap();
1066        seed.put(key("p:2"), value("old2")).unwrap();
1067        seed.put(key("q:1"), value("other")).unwrap();
1068        manager.commit(seed).unwrap();
1069
1070        let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1071        txn.put(key("p:1"), value("new1")).unwrap();
1072        txn.delete(key("p:2")).unwrap();
1073        txn.put(key("p:3"), value("new3")).unwrap();
1074
1075        let results: Vec<_> = txn.scan_prefix(b"p:").unwrap().collect();
1076        assert_eq!(
1077            results,
1078            vec![(key("p:1"), value("new1")), (key("p:3"), value("new3"))]
1079        );
1080    }
1081
1082    #[test]
1083    fn scan_range_skips_newer_versions() {
1084        let store = MemoryKV::new();
1085        let manager = store.txn_manager();
1086
1087        let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1088        seed.put(key("b"), value("v1")).unwrap();
1089        manager.commit(seed).unwrap();
1090
1091        let mut txn1 = manager.begin(TxnMode::ReadWrite).unwrap();
1092
1093        let mut txn2 = manager.begin(TxnMode::ReadWrite).unwrap();
1094        txn2.put(key("ba"), value("v2")).unwrap();
1095        manager.commit(txn2).unwrap();
1096
1097        let results: Vec<_> = txn1.scan_range(b"b", b"c").unwrap().collect();
1098        assert_eq!(results, vec![(key("b"), value("v1"))]);
1099    }
1100
1101    #[test]
1102    fn scan_range_records_reads_for_conflict_detection() {
1103        let store = MemoryKV::new();
1104        let manager = store.txn_manager();
1105
1106        let mut seed = manager.begin(TxnMode::ReadWrite).unwrap();
1107        seed.put(key("k1"), value("v1")).unwrap();
1108        manager.commit(seed).unwrap();
1109
1110        let mut t1 = manager.begin(TxnMode::ReadWrite).unwrap();
1111        let results: Vec<_> = t1.scan_range(b"k0", b"kz").unwrap().collect();
1112        assert_eq!(results, vec![(key("k1"), value("v1"))]);
1113        t1.put(key("k_new"), value("v_new")).unwrap();
1114
1115        let mut t2 = manager.begin(TxnMode::ReadWrite).unwrap();
1116        t2.put(key("k1"), value("v2")).unwrap();
1117        manager.commit(t2).unwrap();
1118
1119        let result = manager.commit(t1);
1120        assert!(matches!(result, Err(Error::TxnConflict)));
1121    }
1122
1123    #[test]
1124    fn memory_stats_tracks_put_and_delete() {
1125        let store = MemoryKV::new();
1126        let manager = store.txn_manager();
1127
1128        let stats = manager.memory_stats();
1129        assert_eq!(stats.total_bytes, 0);
1130        assert_eq!(stats.kv_bytes, 0);
1131        assert_eq!(stats.index_bytes, 0);
1132
1133        // Insert a value and commit.
1134        let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1135        txn.put(key("a"), value("1234")).unwrap(); // key=1, value=4 => 5 bytes
1136        manager.commit(txn).unwrap();
1137
1138        let stats = manager.memory_stats();
1139        assert_eq!(stats.total_bytes, 5);
1140        assert_eq!(stats.kv_bytes, 5);
1141        assert_eq!(stats.index_bytes, 0);
1142
1143        // Delete and ensure usage returns to zero.
1144        let mut txn = manager.begin(TxnMode::ReadWrite).unwrap();
1145        txn.delete(key("a")).unwrap();
1146        manager.commit(txn).unwrap();
1147
1148        let stats = manager.memory_stats();
1149        assert_eq!(stats.total_bytes, 0);
1150        assert_eq!(stats.kv_bytes, 0);
1151    }
1152
1153    #[test]
1154    fn memory_limit_error_does_not_break_reads() {
1155        let store = MemoryKV::new_with_limit(Some(10));
1156        let manager = store.txn_manager();
1157
1158        // First insert within limit: key(2) + value(4) = 6.
1159        let mut txn = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1160        txn.put(key("k1"), value("vvvv")).unwrap();
1161        manager.commit(txn).unwrap();
1162
1163        // Next insert would exceed limit: key(2) + value(6) + existing(6) -> 14 > 10.
1164        let mut txn2 = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1165        txn2.put(key("k2"), value("vvvvvv")).unwrap();
1166        let result = manager.commit(txn2);
1167        assert!(matches!(result, Err(Error::MemoryLimitExceeded { .. })));
1168
1169        // Read still works and existing data intact.
1170        let mut read_txn = manager.begin_internal(TxnMode::ReadOnly).unwrap();
1171        let got = read_txn.get(&key("k1")).unwrap();
1172        assert_eq!(got, Some(value("vvvv")));
1173
1174        // Memory usage stays at the previous successful commit.
1175        let stats = manager.memory_stats();
1176        assert_eq!(stats.total_bytes, 6);
1177    }
1178
1179    struct VecWriter(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
1180
1181    impl std::io::Write for VecWriter {
1182        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1183            let mut guard = self.0.lock().unwrap();
1184            guard.extend_from_slice(buf);
1185            Ok(buf.len())
1186        }
1187
1188        fn flush(&mut self) -> std::io::Result<()> {
1189            Ok(())
1190        }
1191    }
1192
1193    #[test]
1194    fn compaction_skips_when_over_limit_and_logs_warning() {
1195        let store = MemoryKV::new_with_limit(Some(12));
1196        let manager = store.txn_manager();
1197
1198        // Populate data to track current memory: key(2)+val(6)=8 bytes.
1199        let mut txn = manager.begin_internal(TxnMode::ReadWrite).unwrap();
1200        txn.put(key("k1"), value("123456")).unwrap();
1201        manager.commit(txn).unwrap();
1202
1203        // Prepare log capture.
1204        let buffer = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
1205        let make_writer = {
1206            let buf = buffer.clone();
1207            move || VecWriter(buf.clone())
1208        };
1209        let subscriber = tracing_subscriber::fmt()
1210            .with_max_level(Level::WARN)
1211            .with_writer(make_writer)
1212            .without_time()
1213            .finish();
1214        let _guard = tracing::subscriber::set_default(subscriber);
1215
1216        // input=2 (assume one entry), output=10 => projected 8-2+10=16 > 12 -> skip.
1217        let ran = manager.compact_with_limit(2, 10, || Ok(())).unwrap();
1218        assert!(!ran);
1219
1220        // Memory usage unchanged.
1221        assert_eq!(manager.memory_stats().total_bytes, 8);
1222
1223        // Verify warning was logged.
1224        let log = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
1225        assert!(
1226            log.contains("compaction skipped due to memory limit"),
1227            "expected warning log, got: {}",
1228            log
1229        );
1230    }
1231}