1use std::sync::{Arc, Mutex, MutexGuard, RwLock};
2
3use crate::{common::errors::KeyNotFound, memtable::table::Memtable};
4use anyhow::Result;
5use bytes::Bytes;
6
7mod common;
8mod memtable;
9
10pub fn new(config: Config) -> Storage {
11 Storage {
12 config,
13 state_lock: Mutex::new(()),
14 state: RwLock::new(StorageState {
15 memtable: Arc::new(Memtable::new()),
16 frozen_memtables: Vec::new(),
17 }),
18 }
19}
20
21impl Storage {
22 pub fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
23 let size;
24
25 {
26 let guard = self.state.read().unwrap();
27 let memtable = guard.memtable.clone();
28 memtable.put(key, value)?;
29 size = memtable.get_size();
30 }
31
32 self.try_freeze(size);
33 Ok(())
34 }
35
36 pub fn get(&self, key: &[u8]) -> Result<Bytes, KeyNotFound> {
37 let guard = self.state.read().unwrap();
38 let memtable = guard.memtable.clone();
39
40 let Some(value) = memtable.get(key) else {
41 for frozen_table in guard.frozen_memtables.clone() {
42 if let Some(value) = frozen_table.get(key) {
43 return Ok(value);
44 }
45 }
46
47 return Err(KeyNotFound);
48 };
49
50 Ok(value)
51 }
52
53 fn try_freeze(&self, size: usize) {
54 if size >= self.config.sst_size {
55 let lock = self.state_lock.lock().unwrap();
56 self.freeze(&lock);
57 }
58 }
59
60 fn freeze(&self, _state_lock: &MutexGuard<()>) {
61 let mut guard = self.state.write().unwrap();
62 let memtable = guard.memtable.clone();
63
64 if memtable.get_size() >= self.config.sst_size {
66 guard.frozen_memtables.insert(0, memtable);
67 guard.memtable = Arc::new(Memtable::new());
68 }
69 }
70}
71
72#[derive(Debug)]
73pub struct Storage {
74 state: RwLock<StorageState>,
75 state_lock: Mutex<()>,
76 config: Config,
77}
78
79#[derive(Debug)]
80struct StorageState {
81 memtable: Arc<Memtable>,
82 frozen_memtables: Vec<Arc<Memtable>>,
83}
84
85#[derive(Debug)]
86pub struct Config {
87 pub sst_size: usize,
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn filled_up_memtables_are_frozen() {
96 let config = Config { sst_size: 4 };
97 let storage = new(config);
98
99 let input = vec![b"1", b"2", b"3", b"4", b"5"];
100 for entry in input {
101 storage.put(entry, entry).unwrap();
102 }
103
104 assert_eq!(2, storage.state.read().unwrap().frozen_memtables.len());
105 assert_eq!(2, storage.state.read().unwrap().memtable.get_size());
106 }
107}