akv/persistence/
engine.rs

1use std::collections::HashMap;
2
3use tokio::{sync::mpsc, time::interval};
4
5use crate::{
6    persistence::{
7        config::PersistenceConfig, snapshot::SnapshotManager, wal::WalEntry, WalManager,
8    },
9    value::Entry,
10};
11
12#[derive(Clone)]
13pub struct PersistenceEngine {
14    #[allow(dead_code)]
15    config: PersistenceConfig,
16    wal_manager: WalManager,
17    snapshot_manager: SnapshotManager,
18    #[allow(dead_code)]
19    snapshot_trigger: mpsc::Sender<()>,
20}
21
22impl PersistenceEngine {
23    pub async fn new(config: PersistenceConfig) -> Result<Self, Box<dyn std::error::Error>> {
24        let wal_manager = WalManager::new(config.clone()).await?;
25        let snapshot_manager = SnapshotManager::new(config.clone()).await?;
26
27        let (snapshot_trigger, mut snapshot_receiver) = mpsc::channel::<()>(1);
28
29        let engine = Self {
30            config: config.clone(),
31            wal_manager,
32            snapshot_manager,
33            snapshot_trigger: snapshot_trigger.clone(),
34        };
35
36        let _snapshot_manager = engine.snapshot_manager.clone();
37        let wal_manager = engine.wal_manager.clone();
38        let config = config.clone();
39
40        tokio::spawn(async move {
41            let mut snapshot_timer = interval(config.snapshot_interval);
42            snapshot_timer.tick().await;
43
44            loop {
45                tokio::select! {
46                    _ = snapshot_timer.tick() => {
47                        let count = wal_manager.get_entry_count().await;
48                        if count >= config.snapshot_wal_threshold {
49                            let _ = snapshot_trigger.send(()).await;
50                        }
51                    }
52                    Some(_) = snapshot_receiver.recv() => {
53                    }
54                }
55            }
56        });
57
58        Ok(engine)
59    }
60
61    pub async fn set(
62        &self,
63        db: String,
64        key: String,
65        entry: Entry,
66    ) -> Result<(), Box<dyn std::error::Error>> {
67        self.wal_manager
68            .append(WalEntry::Set { db, key, entry })
69            .await
70    }
71
72    pub async fn delete(&self, db: String, key: String) -> Result<(), Box<dyn std::error::Error>> {
73        self.wal_manager.append(WalEntry::Delete { db, key }).await
74    }
75
76    pub async fn expire(
77        &self,
78        db: String,
79        key: String,
80        expires_at: u64,
81    ) -> Result<(), Box<dyn std::error::Error>> {
82        self.wal_manager
83            .append(WalEntry::Expire { db, key, expires_at })
84            .await
85    }
86
87    pub async fn create_snapshot(
88        &self,
89        databases: &HashMap<String, HashMap<String, Entry>>,
90    ) -> Result<u64, Box<dyn std::error::Error>> {
91        let timestamp = self
92            .snapshot_manager
93            .create_snapshot(databases)
94            .await?;
95
96        self.wal_manager.cleanup_old_wals(timestamp).await?;
97
98        Ok(timestamp)
99    }
100
101    pub async fn recover(
102        &self,
103    ) -> Result<HashMap<String, HashMap<String, Entry>>, Box<dyn std::error::Error>> {
104        let mut databases = HashMap::new();
105
106        if let Some(snapshot) = self.snapshot_manager.load_latest_snapshot().await? {
107            databases = snapshot.databases;
108        }
109
110        let wal_entries = self.wal_manager.recover().await?;
111
112        for entry in wal_entries {
113            match entry {
114                WalEntry::Set { db, key, entry } => {
115                    let db_map = databases.entry(db).or_insert_with(HashMap::new);
116                    db_map.insert(key, entry);
117                }
118                WalEntry::Delete { db, key } => {
119                    if let Some(db_map) = databases.get_mut(&db) {
120                        db_map.remove(&key);
121                    }
122                }
123                WalEntry::Expire {
124                    db,
125                    key,
126                    expires_at,
127                } => {
128                    if let Some(db_map) = databases.get_mut(&db) {
129                        if let Some(entry) = db_map.get_mut(&key) {
130                            entry.expires_at = Some(expires_at);
131                        }
132                    }
133                }
134            }
135        }
136
137        Ok(databases)
138    }
139
140    pub async fn shutdown(&self) -> Result<(), Box<dyn std::error::Error>> {
141        Ok(())
142    }
143}