1use crate::{
7 Result, Key, Item, Record, Error,
8 memory_wal::MemoryWal,
9 memory_sst::{MemorySstWriter, MemorySstReader},
10 index::TableSchema,
11 iterator::{QueryParams, QueryResult, ScanParams, ScanResult},
12 expression::{UpdateAction, UpdateExecutor, ExpressionContext, ExpressionEvaluator, Expr},
13 lsm::TransactWriteOperation,
14};
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::sync::{Arc, RwLock};
17
18const NUM_STRIPES: usize = 256;
19const MEMTABLE_THRESHOLD: usize = 1000;
20
21fn stripe_id(pk: &[u8]) -> usize {
23 crc32fast::hash(pk) as usize % NUM_STRIPES
24}
25
26struct MemoryStripe {
28 memtable: BTreeMap<Vec<u8>, Record>,
30 ssts: Vec<MemorySstReader>,
32}
33
34impl MemoryStripe {
35 fn new() -> Self {
36 Self {
37 memtable: BTreeMap::new(),
38 ssts: Vec::new(),
39 }
40 }
41}
42
43struct MemoryLsmInner {
45 wal: MemoryWal,
47 stripes: Vec<MemoryStripe>,
49 next_seq: u64,
51 next_sst_id: u64,
53 schema: TableSchema,
55}
56
57#[derive(Clone)]
59pub struct MemoryLsmEngine {
60 inner: Arc<RwLock<MemoryLsmInner>>,
61}
62
63impl MemoryLsmEngine {
64 pub fn create() -> Result<Self> {
66 Self::create_with_schema(TableSchema::new())
67 }
68
69 pub fn create_with_schema(schema: TableSchema) -> Result<Self> {
71 let wal = MemoryWal::create()?;
72 let stripes = (0..NUM_STRIPES).map(|_| MemoryStripe::new()).collect();
73
74 Ok(Self {
75 inner: Arc::new(RwLock::new(MemoryLsmInner {
76 wal,
77 stripes,
78 next_seq: 1,
79 next_sst_id: 1,
80 schema,
81 })),
82 })
83 }
84
85 pub fn put(&self, key: Key, item: Item) -> Result<()> {
87 let mut inner = self.inner.write().unwrap();
88
89 let seq = inner.next_seq;
90 inner.next_seq += 1;
91
92 let record = Record::put(key.clone(), item, seq);
93
94 inner.wal.append(record.clone())?;
96
97 let stripe_idx = stripe_id(&key.pk);
99 inner.stripes[stripe_idx].memtable.insert(key.encode().to_vec(), record);
100
101 if inner.stripes[stripe_idx].memtable.len() >= MEMTABLE_THRESHOLD {
103 Self::flush_stripe(&mut inner, stripe_idx)?;
104 }
105
106 Ok(())
107 }
108
109 pub fn get(&self, key: &Key) -> Result<Option<Item>> {
111 let inner = self.inner.read().unwrap();
112 let stripe_idx = stripe_id(&key.pk);
113 let stripe = &inner.stripes[stripe_idx];
114 let key_bytes = key.encode();
115
116 if let Some(record) = stripe.memtable.get(key_bytes.as_ref()) {
118 return Ok(record.value.clone());
119 }
120
121 for sst in stripe.ssts.iter().rev() {
123 if let Some(record) = sst.get(key) {
124 return Ok(record.value.clone());
125 }
126 }
127
128 Ok(None)
129 }
130
131 pub fn delete(&self, key: Key) -> Result<()> {
133 let mut inner = self.inner.write().unwrap();
134
135 let seq = inner.next_seq;
136 inner.next_seq += 1;
137
138 let record = Record::delete(key.clone(), seq);
139
140 inner.wal.append(record.clone())?;
142
143 let stripe_idx = stripe_id(&key.pk);
145 inner.stripes[stripe_idx].memtable.insert(key.encode().to_vec(), record);
146
147 if inner.stripes[stripe_idx].memtable.len() >= MEMTABLE_THRESHOLD {
149 Self::flush_stripe(&mut inner, stripe_idx)?;
150 }
151
152 Ok(())
153 }
154
155 fn flush_stripe(inner: &mut MemoryLsmInner, stripe_idx: usize) -> Result<()> {
157 let stripe = &mut inner.stripes[stripe_idx];
158
159 if stripe.memtable.is_empty() {
160 return Ok(());
161 }
162
163 let mut writer = MemorySstWriter::new();
165 for record in stripe.memtable.values() {
166 writer.add(record.clone());
167 }
168
169 let sst_id = inner.next_sst_id;
170 inner.next_sst_id += 1;
171
172 let sst_name = format!("mem-{:03}-{}.sst", stripe_idx, sst_id);
173 let reader = writer.finish(&sst_name)?;
174
175 stripe.ssts.push(reader);
176 stripe.memtable.clear();
177
178 inner.wal.clear();
180
181 Ok(())
182 }
183
184 pub fn flush(&self) -> Result<()> {
186 let mut inner = self.inner.write().unwrap();
187
188 for stripe_idx in 0..NUM_STRIPES {
189 if !inner.stripes[stripe_idx].memtable.is_empty() {
190 Self::flush_stripe(&mut inner, stripe_idx)?;
191 }
192 }
193
194 inner.wal.flush()?;
195 Ok(())
196 }
197
198 pub fn clear(&self) -> Result<()> {
200 let mut inner = self.inner.write().unwrap();
201
202 for stripe in &mut inner.stripes {
203 stripe.memtable.clear();
204 stripe.ssts.clear();
205 }
206
207 inner.wal.clear();
208 inner.next_seq = 1;
209 inner.next_sst_id = 1;
210
211 Ok(())
212 }
213
214 pub fn len(&self) -> usize {
216 let inner = self.inner.read().unwrap();
217 let mut count = 0;
218
219 for stripe in &inner.stripes {
220 count += stripe.memtable.len();
221 for sst in &stripe.ssts {
222 count += sst.len();
223 }
224 }
225
226 count
227 }
228
229 pub fn is_empty(&self) -> bool {
231 self.len() == 0
232 }
233
234 pub fn query(&self, params: QueryParams) -> Result<QueryResult> {
236 let inner = self.inner.read().unwrap();
237
238 let stripe_id = stripe_id(¶ms.pk);
240 let stripe = &inner.stripes[stripe_id];
241
242 let mut all_records: BTreeMap<Vec<u8>, Record> = BTreeMap::new();
243 let mut scanned_count = 0;
244
245 for (key_enc, record) in &stripe.memtable {
247 if record.key.pk != params.pk {
249 continue;
250 }
251
252 if !params.matches_sk(&record.key.sk) {
254 continue;
255 }
256
257 all_records.insert(key_enc.clone(), record.clone());
258 }
259
260 for sst in &stripe.ssts {
262 for record in sst.iter() {
263 if record.key.pk != params.pk {
265 continue;
266 }
267
268 if !params.matches_sk(&record.key.sk) {
270 continue;
271 }
272
273 let key_enc = record.key.encode().to_vec();
274 all_records.entry(key_enc).or_insert(record.clone());
276 }
277 }
278
279 let mut sorted_records: Vec<(Vec<u8>, Record)> = all_records.into_iter().collect();
281
282 if !params.forward {
283 sorted_records.reverse();
284 }
285
286 let mut items = Vec::new();
288 let mut last_key = None;
289 let mut seen_keys: HashSet<Vec<u8>> = HashSet::new();
290
291 for (key_enc, record) in sorted_records {
292 if params.should_skip(&record.key) {
294 continue;
295 }
296
297 scanned_count += 1;
298
299 if seen_keys.contains(&key_enc) {
301 continue;
302 }
303 seen_keys.insert(key_enc);
304
305 if record.value.is_none() {
307 continue;
308 }
309
310 last_key = Some(record.key.clone());
311
312 if let Some(item) = record.value {
313 items.push(item);
314
315 if let Some(limit) = params.limit {
317 if items.len() >= limit {
318 break;
319 }
320 }
321 }
322 }
323
324 Ok(QueryResult::new(items, last_key, scanned_count))
325 }
326
327 pub fn scan(&self, params: ScanParams) -> Result<ScanResult> {
329 let inner = self.inner.read().unwrap();
330
331 let mut all_records: BTreeMap<Vec<u8>, Record> = BTreeMap::new();
333
334 for stripe_id in 0..NUM_STRIPES {
335 if !params.should_scan_stripe(stripe_id) {
337 continue;
338 }
339
340 let stripe = &inner.stripes[stripe_id];
341
342 for (key_enc, record) in &stripe.memtable {
344 if record.value.is_none() {
346 continue;
347 }
348
349 all_records.insert(key_enc.clone(), record.clone());
350 }
351
352 for sst in &stripe.ssts {
354 for record in sst.iter() {
355 if record.value.is_none() {
357 continue;
358 }
359
360 let key_enc = record.key.encode().to_vec();
361 all_records.entry(key_enc).or_insert(record.clone());
363 }
364 }
365 }
366
367 let mut items = Vec::new();
369 let mut scanned_count = 0;
370 let mut last_key = None;
371
372 for (_, record) in all_records {
373 if params.should_skip(&record.key) {
375 continue;
376 }
377
378 scanned_count += 1;
379
380 last_key = Some(record.key.clone());
381
382 if let Some(item) = record.value {
383 items.push(item);
384
385 if let Some(limit) = params.limit {
387 if items.len() >= limit {
388 return Ok(ScanResult::new(items, last_key, scanned_count));
389 }
390 }
391 }
392 }
393
394 Ok(ScanResult::new(items, last_key, scanned_count))
395 }
396
397 pub fn update(&self, key: &Key, actions: &[UpdateAction], context: &ExpressionContext) -> Result<Item> {
399 let current_item = self.get(key)?.unwrap_or_else(|| HashMap::new());
401
402 let executor = UpdateExecutor::new(context);
404 let updated_item = executor.execute(¤t_item, actions)?;
405
406 self.put(key.clone(), updated_item.clone())?;
408
409 Ok(updated_item)
410 }
411
412 pub fn update_conditional(
414 &self,
415 key: &Key,
416 actions: &[UpdateAction],
417 condition: &Expr,
418 context: &ExpressionContext,
419 ) -> Result<Item> {
420 let current_item = self.get(key)?.unwrap_or_else(|| HashMap::new());
422
423 let evaluator = ExpressionEvaluator::new(¤t_item, context);
425 let condition_passed = evaluator.evaluate(condition)?;
426
427 if !condition_passed {
428 return Err(Error::ConditionalCheckFailed("Update condition failed".into()));
429 }
430
431 let executor = UpdateExecutor::new(context);
433 let updated_item = executor.execute(¤t_item, actions)?;
434
435 self.put(key.clone(), updated_item.clone())?;
437
438 Ok(updated_item)
439 }
440
441 pub fn put_conditional(&self, key: Key, item: Item, condition: &Expr, context: &ExpressionContext) -> Result<()> {
443 let current_item = self.get(&key)?.unwrap_or_else(|| HashMap::new());
445
446 let evaluator = ExpressionEvaluator::new(¤t_item, context);
448 let condition_passed = evaluator.evaluate(condition)?;
449
450 if !condition_passed {
451 return Err(Error::ConditionalCheckFailed("Put condition failed".into()));
452 }
453
454 self.put(key, item)
456 }
457
458 pub fn delete_conditional(&self, key: Key, condition: &Expr, context: &ExpressionContext) -> Result<()> {
460 let current_item = self.get(&key)?.unwrap_or_else(|| HashMap::new());
462
463 let evaluator = ExpressionEvaluator::new(¤t_item, context);
465 let condition_passed = evaluator.evaluate(condition)?;
466
467 if !condition_passed {
468 return Err(Error::ConditionalCheckFailed("Delete condition failed".into()));
469 }
470
471 self.delete(key)
473 }
474
475 pub fn batch_get(&self, keys: &[Key]) -> Result<HashMap<Key, Option<Item>>> {
477 let mut results = HashMap::new();
478
479 for key in keys {
480 let item = self.get(key)?;
481 results.insert(key.clone(), item);
482 }
483
484 Ok(results)
485 }
486
487 pub fn batch_write(&self, operations: &[(Key, Option<Item>)]) -> Result<usize> {
489 let mut processed = 0;
490
491 for (key, item_opt) in operations {
492 match item_opt {
493 Some(item) => {
494 self.put(key.clone(), item.clone())?;
495 processed += 1;
496 }
497 None => {
498 self.delete(key.clone())?;
499 processed += 1;
500 }
501 }
502 }
503
504 Ok(processed)
505 }
506
507 pub fn transact_get(&self, keys: &[Key]) -> Result<Vec<Option<Item>>> {
509 let _inner = self.inner.read().unwrap();
511
512 let mut items = Vec::new();
513 for key in keys {
514 let item = self.get(key)?;
515 items.push(item);
516 }
517
518 Ok(items)
519 }
520
521 pub fn transact_write(
523 &self,
524 operations: &[(Key, TransactWriteOperation)],
525 context: &ExpressionContext,
526 ) -> Result<usize> {
527 let mut inner = self.inner.write().unwrap();
529
530 let mut current_items: Vec<Option<Item>> = Vec::new();
532 for (key, op) in operations {
533 let item = {
534 let stripe_id = stripe_id(&key.pk);
535 let stripe = &inner.stripes[stripe_id];
536 let key_enc = key.encode().to_vec();
537
538 if let Some(record) = stripe.memtable.get(&key_enc) {
540 record.value.clone()
541 } else {
542 let mut found = None;
544 for sst in &stripe.ssts {
545 if let Some(record) = sst.get(key) {
546 found = record.value.clone();
547 break;
548 }
549 }
550 found
551 }
552 };
553
554 current_items.push(item.clone());
555
556 if let Some(condition_expr) = op.condition() {
558 let current_item = item.unwrap_or_else(|| HashMap::new());
559 let evaluator = ExpressionEvaluator::new(¤t_item, context);
560 let condition_passed = evaluator.evaluate(condition_expr)?;
561
562 if !condition_passed {
563 return Err(Error::TransactionCanceled(format!(
564 "Condition failed for key {:?}",
565 key
566 )));
567 }
568 }
569 }
570
571 let mut committed = 0;
573 for (i, (key, op)) in operations.iter().enumerate() {
574 match op {
575 TransactWriteOperation::Put { item, .. } => {
576 let seq = inner.next_seq;
578 inner.next_seq += 1;
579 let record = Record::put(key.clone(), item.clone(), seq);
580 inner.wal.append(record.clone())?;
581
582 let stripe_id = stripe_id(&key.pk);
583 let key_enc = key.encode().to_vec();
584 inner.stripes[stripe_id].memtable.insert(key_enc, record);
585
586 if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
587 Self::flush_stripe(&mut inner, stripe_id)?;
588 }
589
590 committed += 1;
591 }
592 TransactWriteOperation::Delete { .. } => {
593 let seq = inner.next_seq;
595 inner.next_seq += 1;
596 let record = Record::delete(key.clone(), seq);
597 inner.wal.append(record.clone())?;
598
599 let stripe_id = stripe_id(&key.pk);
600 let key_enc = key.encode().to_vec();
601 inner.stripes[stripe_id].memtable.insert(key_enc, record);
602
603 if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
604 Self::flush_stripe(&mut inner, stripe_id)?;
605 }
606
607 committed += 1;
608 }
609 TransactWriteOperation::Update { actions, .. } => {
610 let current_item = current_items[i].clone().unwrap_or_else(|| HashMap::new());
612 let executor = UpdateExecutor::new(context);
613 let updated_item = executor.execute(¤t_item, actions)?;
614
615 let seq = inner.next_seq;
616 inner.next_seq += 1;
617 let record = Record::put(key.clone(), updated_item, seq);
618 inner.wal.append(record.clone())?;
619
620 let stripe_id = stripe_id(&key.pk);
621 let key_enc = key.encode().to_vec();
622 inner.stripes[stripe_id].memtable.insert(key_enc, record);
623
624 if inner.stripes[stripe_id].memtable.len() >= MEMTABLE_THRESHOLD {
625 Self::flush_stripe(&mut inner, stripe_id)?;
626 }
627
628 committed += 1;
629 }
630 TransactWriteOperation::ConditionCheck { .. } => {
631 committed += 1;
633 }
634 }
635 }
636
637 Ok(committed)
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use crate::Value;
645 use std::collections::HashMap;
646
647 fn create_test_item(value: &str) -> Item {
648 let mut item = HashMap::new();
649 item.insert("test".to_string(), Value::string(value));
650 item
651 }
652
653 #[test]
654 fn test_memory_lsm_create() {
655 let engine = MemoryLsmEngine::create().unwrap();
656 assert!(engine.is_empty());
657 }
658
659 #[test]
660 fn test_memory_lsm_put_get() {
661 let engine = MemoryLsmEngine::create().unwrap();
662
663 let key = Key::new(b"key1".to_vec());
664 let item = create_test_item("value1");
665
666 engine.put(key.clone(), item.clone()).unwrap();
667
668 let result = engine.get(&key).unwrap();
669 assert_eq!(result, Some(item));
670 }
671
672 #[test]
673 fn test_memory_lsm_delete() {
674 let engine = MemoryLsmEngine::create().unwrap();
675
676 let key = Key::new(b"key1".to_vec());
677 let item = create_test_item("value1");
678
679 engine.put(key.clone(), item).unwrap();
680 assert!(engine.get(&key).unwrap().is_some());
681
682 engine.delete(key.clone()).unwrap();
683 assert!(engine.get(&key).unwrap().is_none());
684 }
685
686 #[test]
687 fn test_memory_lsm_overwrite() {
688 let engine = MemoryLsmEngine::create().unwrap();
689
690 let key = Key::new(b"key1".to_vec());
691
692 engine.put(key.clone(), create_test_item("value1")).unwrap();
693 engine.put(key.clone(), create_test_item("value2")).unwrap();
694
695 let result = engine.get(&key).unwrap().unwrap();
696 assert_eq!(result.get("test").unwrap().as_string(), Some("value2"));
697 }
698
699 #[test]
700 fn test_memory_lsm_flush() {
701 let engine = MemoryLsmEngine::create().unwrap();
702
703 for i in 0..1500 {
705 let key = Key::new(format!("key{}", i).into_bytes());
706 engine.put(key, create_test_item(&format!("value{}", i))).unwrap();
707 }
708
709 for i in 0..1500 {
711 let key = Key::new(format!("key{}", i).into_bytes());
712 let result = engine.get(&key).unwrap();
713 assert!(result.is_some());
714 }
715 }
716
717 #[test]
718 fn test_memory_lsm_clear() {
719 let engine = MemoryLsmEngine::create().unwrap();
720
721 let key = Key::new(b"key1".to_vec());
722 engine.put(key.clone(), create_test_item("value1")).unwrap();
723 assert!(!engine.is_empty());
724
725 engine.clear().unwrap();
726 assert!(engine.is_empty());
727 assert!(engine.get(&key).unwrap().is_none());
728 }
729
730 #[test]
731 fn test_memory_lsm_multiple_stripes() {
732 let engine = MemoryLsmEngine::create().unwrap();
733
734 for i in 0..100 {
736 let key = Key::new(format!("user{}", i).into_bytes());
737 engine.put(key, create_test_item(&format!("value{}", i))).unwrap();
738 }
739
740 for i in 0..100 {
742 let key = Key::new(format!("user{}", i).into_bytes());
743 let result = engine.get(&key).unwrap();
744 assert!(result.is_some());
745 }
746 }
747
748 #[test]
749 fn test_memory_lsm_with_sort_key() {
750 let engine = MemoryLsmEngine::create().unwrap();
751
752 let key = Key::with_sk(b"pk1".to_vec(), b"sk1".to_vec());
753 let item = create_test_item("value1");
754
755 engine.put(key.clone(), item.clone()).unwrap();
756
757 let result = engine.get(&key).unwrap();
758 assert_eq!(result, Some(item));
759
760 let key2 = Key::with_sk(b"pk1".to_vec(), b"sk2".to_vec());
762 assert!(engine.get(&key2).unwrap().is_none());
763 }
764}