kstone_core/
memory_lsm.rs

1/// In-memory LSM Engine for testing and temporary databases
2///
3/// Provides the same API as the disk-based LSM engine but stores all data in memory.
4/// All data is lost when the MemoryLsmEngine is dropped.
5
6use crate::{
7    Result, Key, Item, Record, Error,
8    memory_wal::MemoryWal,
9    memory_sst::{MemorySstWriter, MemorySstReader},
10    index::TableSchema,
11    iterator::{QueryParams, QueryResult, ScanParams, ScanResult},
12    expression::{UpdateAction, UpdateExecutor, ExpressionContext, ExpressionEvaluator, Expr},
13    lsm::TransactWriteOperation,
14};
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::sync::{Arc, RwLock};
17
18const NUM_STRIPES: usize = 256;
19const MEMTABLE_THRESHOLD: usize = 1000;
20
21/// Calculate stripe ID from partition key
22fn stripe_id(pk: &[u8]) -> usize {
23    crc32fast::hash(pk) as usize % NUM_STRIPES
24}
25
26/// In-memory stripe
27struct MemoryStripe {
28    /// In-memory memtable
29    memtable: BTreeMap<Vec<u8>, Record>,
30    /// In-memory SSTs
31    ssts: Vec<MemorySstReader>,
32}
33
34impl MemoryStripe {
35    fn new() -> Self {
36        Self {
37            memtable: BTreeMap::new(),
38            ssts: Vec::new(),
39        }
40    }
41}
42
43/// Inner mutable state
44struct MemoryLsmInner {
45    /// In-memory WAL
46    wal: MemoryWal,
47    /// Stripes
48    stripes: Vec<MemoryStripe>,
49    /// Next sequence number
50    next_seq: u64,
51    /// Next SST ID
52    next_sst_id: u64,
53    /// Table schema (for indexes, TTL, streams)
54    schema: TableSchema,
55}
56
57/// In-memory LSM Engine
58#[derive(Clone)]
59pub struct MemoryLsmEngine {
60    inner: Arc<RwLock<MemoryLsmInner>>,
61}
62
63impl MemoryLsmEngine {
64    /// Create a new in-memory database
65    pub fn create() -> Result<Self> {
66        Self::create_with_schema(TableSchema::new())
67    }
68
69    /// Create a new in-memory database with a table schema
70    pub fn create_with_schema(schema: TableSchema) -> Result<Self> {
71        let wal = MemoryWal::create()?;
72        let stripes = (0..NUM_STRIPES).map(|_| MemoryStripe::new()).collect();
73
74        Ok(Self {
75            inner: Arc::new(RwLock::new(MemoryLsmInner {
76                wal,
77                stripes,
78                next_seq: 1,
79                next_sst_id: 1,
80                schema,
81            })),
82        })
83    }
84
85    /// Put an item
86    pub fn put(&self, key: Key, item: Item) -> Result<()> {
87        let mut inner = self.inner.write().unwrap();
88
89        let seq = inner.next_seq;
90        inner.next_seq += 1;
91
92        let record = Record::put(key.clone(), item, seq);
93
94        // Append to WAL
95        inner.wal.append(record.clone())?;
96
97        // Add to memtable
98        let stripe_idx = stripe_id(&key.pk);
99        inner.stripes[stripe_idx].memtable.insert(key.encode().to_vec(), record);
100
101        // Check if memtable needs flushing
102        if inner.stripes[stripe_idx].memtable.len() >= MEMTABLE_THRESHOLD {
103            Self::flush_stripe(&mut inner, stripe_idx)?;
104        }
105
106        Ok(())
107    }
108
109    /// Get an item
110    pub fn get(&self, key: &Key) -> Result<Option<Item>> {
111        let inner = self.inner.read().unwrap();
112        let stripe_idx = stripe_id(&key.pk);
113        let stripe = &inner.stripes[stripe_idx];
114        let key_bytes = key.encode();
115
116        // Check memtable first
117        if let Some(record) = stripe.memtable.get(key_bytes.as_ref()) {
118            return Ok(record.value.clone());
119        }
120
121        // Check SSTs (newest to oldest)
122        for sst in stripe.ssts.iter().rev() {
123            if let Some(record) = sst.get(key) {
124                return Ok(record.value.clone());
125            }
126        }
127
128        Ok(None)
129    }
130
131    /// Delete an item
132    pub fn delete(&self, key: Key) -> Result<()> {
133        let mut inner = self.inner.write().unwrap();
134
135        let seq = inner.next_seq;
136        inner.next_seq += 1;
137
138        let record = Record::delete(key.clone(), seq);
139
140        // Append to WAL
141        inner.wal.append(record.clone())?;
142
143        // Add tombstone to memtable
144        let stripe_idx = stripe_id(&key.pk);
145        inner.stripes[stripe_idx].memtable.insert(key.encode().to_vec(), record);
146
147        // Check if memtable needs flushing
148        if inner.stripes[stripe_idx].memtable.len() >= MEMTABLE_THRESHOLD {
149            Self::flush_stripe(&mut inner, stripe_idx)?;
150        }
151
152        Ok(())
153    }
154
155    /// Flush memtable to SST
156    fn flush_stripe(inner: &mut MemoryLsmInner, stripe_idx: usize) -> Result<()> {
157        let stripe = &mut inner.stripes[stripe_idx];
158
159        if stripe.memtable.is_empty() {
160            return Ok(());
161        }
162
163        // Create SST from memtable
164        let mut writer = MemorySstWriter::new();
165        for record in stripe.memtable.values() {
166            writer.add(record.clone());
167        }
168
169        let sst_id = inner.next_sst_id;
170        inner.next_sst_id += 1;
171
172        let sst_name = format!("mem-{:03}-{}.sst", stripe_idx, sst_id);
173        let reader = writer.finish(&sst_name)?;
174
175        stripe.ssts.push(reader);
176        stripe.memtable.clear();
177
178        // Clear WAL (in-memory, so just clear it)
179        inner.wal.clear();
180
181        Ok(())
182    }
183
184    /// Flush all stripes
185    pub fn flush(&self) -> Result<()> {
186        let mut inner = self.inner.write().unwrap();
187
188        for stripe_idx in 0..NUM_STRIPES {
189            if !inner.stripes[stripe_idx].memtable.is_empty() {
190                Self::flush_stripe(&mut inner, stripe_idx)?;
191            }
192        }
193
194        inner.wal.flush()?;
195        Ok(())
196    }
197
198    /// Clear all data (for testing)
199    pub fn clear(&self) -> Result<()> {
200        let mut inner = self.inner.write().unwrap();
201
202        for stripe in &mut inner.stripes {
203            stripe.memtable.clear();
204            stripe.ssts.clear();
205        }
206
207        inner.wal.clear();
208        inner.next_seq = 1;
209        inner.next_sst_id = 1;
210
211        Ok(())
212    }
213
214    /// Get the number of items in memory (approximate)
215    pub fn len(&self) -> usize {
216        let inner = self.inner.read().unwrap();
217        let mut count = 0;
218
219        for stripe in &inner.stripes {
220            count += stripe.memtable.len();
221            for sst in &stripe.ssts {
222                count += sst.len();
223            }
224        }
225
226        count
227    }
228
229    /// Check if empty
230    pub fn is_empty(&self) -> bool {
231        self.len() == 0
232    }
233
234    /// Query items within a partition
235    pub fn query(&self, params: QueryParams) -> Result<QueryResult> {
236        let inner = self.inner.read().unwrap();
237
238        // Route to correct stripe
239        let stripe_id = stripe_id(&params.pk);
240        let stripe = &inner.stripes[stripe_id];
241
242        let mut all_records: BTreeMap<Vec<u8>, Record> = BTreeMap::new();
243        let mut scanned_count = 0;
244
245        // Collect from memtable
246        for (key_enc, record) in &stripe.memtable {
247            // Check if PK matches
248            if record.key.pk != params.pk {
249                continue;
250            }
251
252            // Check sort key condition
253            if !params.matches_sk(&record.key.sk) {
254                continue;
255            }
256
257            all_records.insert(key_enc.clone(), record.clone());
258        }
259
260        // Collect from SSTs
261        for sst in &stripe.ssts {
262            for record in sst.iter() {
263                // Check if PK matches
264                if record.key.pk != params.pk {
265                    continue;
266                }
267
268                // Check sort key condition
269                if !params.matches_sk(&record.key.sk) {
270                    continue;
271                }
272
273                let key_enc = record.key.encode().to_vec();
274                // Only add if we don't already have this key (memtable is newer)
275                all_records.entry(key_enc).or_insert(record.clone());
276            }
277        }
278
279        // Convert to sorted vec
280        let mut sorted_records: Vec<(Vec<u8>, Record)> = all_records.into_iter().collect();
281
282        if !params.forward {
283            sorted_records.reverse();
284        }
285
286        // Apply pagination and limit
287        let mut items = Vec::new();
288        let mut last_key = None;
289        let mut seen_keys: HashSet<Vec<u8>> = HashSet::new();
290
291        for (key_enc, record) in sorted_records {
292            // Skip based on pagination
293            if params.should_skip(&record.key) {
294                continue;
295            }
296
297            scanned_count += 1;
298
299            // Skip if already seen
300            if seen_keys.contains(&key_enc) {
301                continue;
302            }
303            seen_keys.insert(key_enc);
304
305            // Skip tombstones
306            if record.value.is_none() {
307                continue;
308            }
309
310            last_key = Some(record.key.clone());
311
312            if let Some(item) = record.value {
313                items.push(item);
314
315                // Check limit
316                if let Some(limit) = params.limit {
317                    if items.len() >= limit {
318                        break;
319                    }
320                }
321            }
322        }
323
324        Ok(QueryResult::new(items, last_key, scanned_count))
325    }
326
327    /// Scan all items across all stripes
328    pub fn scan(&self, params: ScanParams) -> Result<ScanResult> {
329        let inner = self.inner.read().unwrap();
330
331        // Collect all records from all stripes
332        let mut all_records: BTreeMap<Vec<u8>, Record> = BTreeMap::new();
333
334        for stripe_id in 0..NUM_STRIPES {
335            // Skip stripes not assigned to this segment
336            if !params.should_scan_stripe(stripe_id) {
337                continue;
338            }
339
340            let stripe = &inner.stripes[stripe_id];
341
342            // Collect from memtable
343            for (key_enc, record) in &stripe.memtable {
344                // Skip tombstones
345                if record.value.is_none() {
346                    continue;
347                }
348
349                all_records.insert(key_enc.clone(), record.clone());
350            }
351
352            // Collect from SSTs
353            for sst in &stripe.ssts {
354                for record in sst.iter() {
355                    // Skip tombstones
356                    if record.value.is_none() {
357                        continue;
358                    }
359
360                    let key_enc = record.key.encode().to_vec();
361                    // Only add if we don't already have this key (memtable is newer)
362                    all_records.entry(key_enc).or_insert(record.clone());
363                }
364            }
365        }
366
367        // Apply pagination and limit
368        let mut items = Vec::new();
369        let mut scanned_count = 0;
370        let mut last_key = None;
371
372        for (_, record) in all_records {
373            // Skip based on pagination
374            if params.should_skip(&record.key) {
375                continue;
376            }
377
378            scanned_count += 1;
379
380            last_key = Some(record.key.clone());
381
382            if let Some(item) = record.value {
383                items.push(item);
384
385                // Check limit
386                if let Some(limit) = params.limit {
387                    if items.len() >= limit {
388                        return Ok(ScanResult::new(items, last_key, scanned_count));
389                    }
390                }
391            }
392        }
393
394        Ok(ScanResult::new(items, last_key, scanned_count))
395    }
396
397    /// Update an item using update expression
398    pub fn update(&self, key: &Key, actions: &[UpdateAction], context: &ExpressionContext) -> Result<Item> {
399        // Get current item (or create empty if doesn't exist)
400        let current_item = self.get(key)?.unwrap_or_else(|| HashMap::new());
401
402        // Execute update actions
403        let executor = UpdateExecutor::new(context);
404        let updated_item = executor.execute(&current_item, actions)?;
405
406        // Put the updated item
407        self.put(key.clone(), updated_item.clone())?;
408
409        Ok(updated_item)
410    }
411
412    /// Update an item with a condition expression
413    pub fn update_conditional(
414        &self,
415        key: &Key,
416        actions: &[UpdateAction],
417        condition: &Expr,
418        context: &ExpressionContext,
419    ) -> Result<Item> {
420        // Get current item (or create empty if doesn't exist)
421        let current_item = self.get(key)?.unwrap_or_else(|| HashMap::new());
422
423        // Evaluate condition
424        let evaluator = ExpressionEvaluator::new(&current_item, context);
425        let condition_passed = evaluator.evaluate(condition)?;
426
427        if !condition_passed {
428            return Err(Error::ConditionalCheckFailed("Update condition failed".into()));
429        }
430
431        // Condition passed, execute update
432        let executor = UpdateExecutor::new(context);
433        let updated_item = executor.execute(&current_item, actions)?;
434
435        // Put the updated item
436        self.put(key.clone(), updated_item.clone())?;
437
438        Ok(updated_item)
439    }
440
441    /// Put an item with a condition expression
442    pub fn put_conditional(&self, key: Key, item: Item, condition: &Expr, context: &ExpressionContext) -> Result<()> {
443        // Get current item (or empty if doesn't exist)
444        let current_item = self.get(&key)?.unwrap_or_else(|| HashMap::new());
445
446        // Evaluate condition
447        let evaluator = ExpressionEvaluator::new(&current_item, context);
448        let condition_passed = evaluator.evaluate(condition)?;
449
450        if !condition_passed {
451            return Err(Error::ConditionalCheckFailed("Put condition failed".into()));
452        }
453
454        // Condition passed, perform put
455        self.put(key, item)
456    }
457
458    /// Delete an item with a condition expression
459    pub fn delete_conditional(&self, key: Key, condition: &Expr, context: &ExpressionContext) -> Result<()> {
460        // Get current item (or empty if doesn't exist)
461        let current_item = self.get(&key)?.unwrap_or_else(|| HashMap::new());
462
463        // Evaluate condition
464        let evaluator = ExpressionEvaluator::new(&current_item, context);
465        let condition_passed = evaluator.evaluate(condition)?;
466
467        if !condition_passed {
468            return Err(Error::ConditionalCheckFailed("Delete condition failed".into()));
469        }
470
471        // Condition passed, perform delete
472        self.delete(key)
473    }
474
475    /// Batch get multiple items
476    pub fn batch_get(&self, keys: &[Key]) -> Result<HashMap<Key, Option<Item>>> {
477        let mut results = HashMap::new();
478
479        for key in keys {
480            let item = self.get(key)?;
481            results.insert(key.clone(), item);
482        }
483
484        Ok(results)
485    }
486
487    /// Batch write multiple items
488    pub fn batch_write(&self, operations: &[(Key, Option<Item>)]) -> Result<usize> {
489        let mut processed = 0;
490
491        for (key, item_opt) in operations {
492            match item_opt {
493                Some(item) => {
494                    self.put(key.clone(), item.clone())?;
495                    processed += 1;
496                }
497                None => {
498                    self.delete(key.clone())?;
499                    processed += 1;
500                }
501            }
502        }
503
504        Ok(processed)
505    }
506
507    /// Transaction get - read multiple items atomically
508    pub fn transact_get(&self, keys: &[Key]) -> Result<Vec<Option<Item>>> {
509        // Hold read lock for consistent snapshot
510        let _inner = self.inner.read().unwrap();
511
512        let mut items = Vec::new();
513        for key in keys {
514            let item = self.get(key)?;
515            items.push(item);
516        }
517
518        Ok(items)
519    }
520
521    /// Transaction write - write multiple items atomically with conditions
522    pub fn transact_write(
523        &self,
524        operations: &[(Key, TransactWriteOperation)],
525        context: &ExpressionContext,
526    ) -> Result<usize> {
527        // Acquire write lock for atomicity
528        let mut inner = self.inner.write().unwrap();
529
530        // Phase 1: Read all items and check all conditions
531        let mut current_items: Vec<Option<Item>> = Vec::new();
532        for (key, op) in operations {
533            let item = {
534                let stripe_id = stripe_id(&key.pk);
535                let stripe = &inner.stripes[stripe_id];
536                let key_enc = key.encode().to_vec();
537
538                // Check memtable
539                if let Some(record) = stripe.memtable.get(&key_enc) {
540                    record.value.clone()
541                } else {
542                    // Check SSTs
543                    let mut found = None;
544                    for sst in &stripe.ssts {
545                        if let Some(record) = sst.get(key) {
546                            found = record.value.clone();
547                            break;
548                        }
549                    }
550                    found
551                }
552            };
553
554            current_items.push(item.clone());
555
556            // Check condition if present
557            if let Some(condition_expr) = op.condition() {
558                let current_item = item.unwrap_or_else(|| HashMap::new());
559                let evaluator = ExpressionEvaluator::new(&current_item, context);
560                let condition_passed = evaluator.evaluate(condition_expr)?;
561
562                if !condition_passed {
563                    return Err(Error::TransactionCanceled(format!(
564                        "Condition failed for key {:?}",
565                        key
566                    )));
567                }
568            }
569        }
570
571        // Phase 2: All conditions passed, perform all writes
572        let mut committed = 0;
573        for (i, (key, op)) in operations.iter().enumerate() {
574            match op {
575                TransactWriteOperation::Put { item, .. } => {
576                    // Perform put
577                    let seq = inner.next_seq;
578                    inner.next_seq += 1;
579                    let record = Record::put(key.clone(), item.clone(), seq);
580                    inner.wal.append(record.clone())?;
581
582                    let stripe_id = stripe_id(&key.pk);
583                    let key_enc = key.encode().to_vec();
584                    inner.stripes[stripe_id].memtable.insert(key_enc, record);
585
586                    if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
587                        Self::flush_stripe(&mut inner, stripe_id)?;
588                    }
589
590                    committed += 1;
591                }
592                TransactWriteOperation::Delete { .. } => {
593                    // Perform delete
594                    let seq = inner.next_seq;
595                    inner.next_seq += 1;
596                    let record = Record::delete(key.clone(), seq);
597                    inner.wal.append(record.clone())?;
598
599                    let stripe_id = stripe_id(&key.pk);
600                    let key_enc = key.encode().to_vec();
601                    inner.stripes[stripe_id].memtable.insert(key_enc, record);
602
603                    if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
604                        Self::flush_stripe(&mut inner, stripe_id)?;
605                    }
606
607                    committed += 1;
608                }
609                TransactWriteOperation::Update { actions, .. } => {
610                    // Perform update
611                    let current_item = current_items[i].clone().unwrap_or_else(|| HashMap::new());
612                    let executor = UpdateExecutor::new(context);
613                    let updated_item = executor.execute(&current_item, actions)?;
614
615                    let seq = inner.next_seq;
616                    inner.next_seq += 1;
617                    let record = Record::put(key.clone(), updated_item, seq);
618                    inner.wal.append(record.clone())?;
619
620                    let stripe_id = stripe_id(&key.pk);
621                    let key_enc = key.encode().to_vec();
622                    inner.stripes[stripe_id].memtable.insert(key_enc, record);
623
624                    if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
625                        Self::flush_stripe(&mut inner, stripe_id)?;
626                    }
627
628                    committed += 1;
629                }
630                TransactWriteOperation::ConditionCheck { .. } => {
631                    // Condition already checked in phase 1, no write needed
632                    committed += 1;
633                }
634            }
635        }
636
637        Ok(committed)
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use crate::Value;
645    use std::collections::HashMap;
646
647    fn create_test_item(value: &str) -> Item {
648        let mut item = HashMap::new();
649        item.insert("test".to_string(), Value::string(value));
650        item
651    }
652
653    #[test]
654    fn test_memory_lsm_create() {
655        let engine = MemoryLsmEngine::create().unwrap();
656        assert!(engine.is_empty());
657    }
658
659    #[test]
660    fn test_memory_lsm_put_get() {
661        let engine = MemoryLsmEngine::create().unwrap();
662
663        let key = Key::new(b"key1".to_vec());
664        let item = create_test_item("value1");
665
666        engine.put(key.clone(), item.clone()).unwrap();
667
668        let result = engine.get(&key).unwrap();
669        assert_eq!(result, Some(item));
670    }
671
672    #[test]
673    fn test_memory_lsm_delete() {
674        let engine = MemoryLsmEngine::create().unwrap();
675
676        let key = Key::new(b"key1".to_vec());
677        let item = create_test_item("value1");
678
679        engine.put(key.clone(), item).unwrap();
680        assert!(engine.get(&key).unwrap().is_some());
681
682        engine.delete(key.clone()).unwrap();
683        assert!(engine.get(&key).unwrap().is_none());
684    }
685
686    #[test]
687    fn test_memory_lsm_overwrite() {
688        let engine = MemoryLsmEngine::create().unwrap();
689
690        let key = Key::new(b"key1".to_vec());
691
692        engine.put(key.clone(), create_test_item("value1")).unwrap();
693        engine.put(key.clone(), create_test_item("value2")).unwrap();
694
695        let result = engine.get(&key).unwrap().unwrap();
696        assert_eq!(result.get("test").unwrap().as_string(), Some("value2"));
697    }
698
699    #[test]
700    fn test_memory_lsm_flush() {
701        let engine = MemoryLsmEngine::create().unwrap();
702
703        // Add items to trigger flush
704        for i in 0..1500 {
705            let key = Key::new(format!("key{}", i).into_bytes());
706            engine.put(key, create_test_item(&format!("value{}", i))).unwrap();
707        }
708
709        // Verify we can still read items after flush
710        for i in 0..1500 {
711            let key = Key::new(format!("key{}", i).into_bytes());
712            let result = engine.get(&key).unwrap();
713            assert!(result.is_some());
714        }
715    }
716
717    #[test]
718    fn test_memory_lsm_clear() {
719        let engine = MemoryLsmEngine::create().unwrap();
720
721        let key = Key::new(b"key1".to_vec());
722        engine.put(key.clone(), create_test_item("value1")).unwrap();
723        assert!(!engine.is_empty());
724
725        engine.clear().unwrap();
726        assert!(engine.is_empty());
727        assert!(engine.get(&key).unwrap().is_none());
728    }
729
730    #[test]
731    fn test_memory_lsm_multiple_stripes() {
732        let engine = MemoryLsmEngine::create().unwrap();
733
734        // Put items that will go to different stripes
735        for i in 0..100 {
736            let key = Key::new(format!("user{}", i).into_bytes());
737            engine.put(key, create_test_item(&format!("value{}", i))).unwrap();
738        }
739
740        // Verify all items are retrievable
741        for i in 0..100 {
742            let key = Key::new(format!("user{}", i).into_bytes());
743            let result = engine.get(&key).unwrap();
744            assert!(result.is_some());
745        }
746    }
747
748    #[test]
749    fn test_memory_lsm_with_sort_key() {
750        let engine = MemoryLsmEngine::create().unwrap();
751
752        let key = Key::with_sk(b"pk1".to_vec(), b"sk1".to_vec());
753        let item = create_test_item("value1");
754
755        engine.put(key.clone(), item.clone()).unwrap();
756
757        let result = engine.get(&key).unwrap();
758        assert_eq!(result, Some(item));
759
760        // Different sort key should not match
761        let key2 = Key::with_sk(b"pk1".to_vec(), b"sk2".to_vec());
762        assert!(engine.get(&key2).unwrap().is_none());
763    }
764}