1use std::path::Path;
4use sled::{Db, Tree};
5use tokio::sync::RwLock;
6use tracing::{debug, error};
7
8use crate::{
9 KVError, KVResult, Key, Entry, DatabaseId, Storage, StorageStats,
10};
11
12pub struct DiskStorage {
14 db: Db,
16 trees: RwLock<Vec<Option<Tree>>>,
18}
19
20impl DiskStorage {
21 pub fn new<P: AsRef<Path>>(path: P) -> KVResult<Self> {
26 let db_path = path.as_ref();
27 std::fs::create_dir_all(db_path)
28 .map_err(|e| KVError::Storage(format!("Failed to create data directory: {e}")))?;
29
30 let db = sled::open(db_path)
31 .map_err(|e| KVError::Storage(format!("Failed to open sled database: {e}")))?;
32
33 let mut trees = Vec::with_capacity(16);
35 for i in 0..16 {
36 let tree_name = format!("db_{i}");
37 let tree = db.open_tree(&tree_name)
38 .map_err(|e| KVError::Storage(format!("Failed to open tree {tree_name}: {e}")))?;
39 trees.push(Some(tree));
40 }
41
42 Ok(Self {
43 db,
44 trees: RwLock::new(trees),
45 })
46 }
47
48 async fn get_tree(&self, database_id: DatabaseId) -> KVResult<Tree> {
50 if database_id >= 16 {
51 return Err(KVError::InvalidKey(format!("Database ID {database_id} out of range (0-15)")));
52 }
53
54 let trees = self.trees.read().await;
55 trees[database_id as usize]
56 .as_ref()
57 .ok_or_else(|| KVError::Storage("Database tree not found".to_string())).cloned()
58 }
59
60 fn serialize_entry(entry: &Entry) -> KVResult<Vec<u8>> {
62 serde_json::to_vec(entry)
63 .map_err(|e| KVError::Internal(format!("Serialization failed: {e}")))
64 }
65
66 fn deserialize_entry(data: &[u8]) -> KVResult<Entry> {
68 serde_json::from_slice(data)
69 .map_err(|e| KVError::Internal(format!("Deserialization failed: {e}")))
70 }
71}
72
73#[async_trait::async_trait]
74impl Storage for DiskStorage {
75 async fn get(&self, database_id: DatabaseId, key: &Key) -> KVResult<Option<Entry>> {
76 let tree = self.get_tree(database_id).await?;
77
78 match tree.get(key.as_bytes()) {
79 Ok(Some(data)) => {
80 let entry = Self::deserialize_entry(&data)?;
81 Ok(Some(entry))
82 }
83 Ok(None) => Ok(None),
84 Err(e) => Err(KVError::Storage(format!("Failed to get key: {e}"))),
85 }
86 }
87
88 async fn set(&self, database_id: DatabaseId, key: Key, entry: Entry) -> KVResult<()> {
89 let tree = self.get_tree(database_id).await?;
90 let data = Self::serialize_entry(&entry)?;
91
92 tree.insert(key.as_bytes(), data)
93 .map_err(|e| KVError::Storage(format!("Failed to set key: {e}")))?;
94
95 Ok(())
96 }
97
98 async fn delete(&self, database_id: DatabaseId, key: &Key) -> KVResult<bool> {
99 let tree = self.get_tree(database_id).await?;
100
101 match tree.remove(key.as_bytes()) {
102 Ok(Some(_)) => Ok(true),
103 Ok(None) => Ok(false),
104 Err(e) => Err(KVError::Storage(format!("Failed to delete key: {e}"))),
105 }
106 }
107
108 async fn exists(&self, database_id: DatabaseId, key: &Key) -> KVResult<bool> {
109 let tree = self.get_tree(database_id).await?;
110
111 match tree.contains_key(key.as_bytes()) {
112 Ok(exists) => Ok(exists),
113 Err(e) => Err(KVError::Storage(format!("Failed to check key existence: {e}"))),
114 }
115 }
116
117 async fn keys(&self, database_id: DatabaseId) -> KVResult<Vec<Key>> {
118 let tree = self.get_tree(database_id).await?;
119 let mut keys = Vec::new();
120
121 for result in &tree {
122 match result {
123 Ok((key_bytes, _)) => {
124 let key = String::from_utf8(key_bytes.to_vec())
125 .map_err(|e| KVError::Storage(format!("Invalid key encoding: {e}")))?;
126 keys.push(key);
127 }
128 Err(e) => {
129 error!("Error iterating keys: {}", e);
130 return Err(KVError::Storage(format!("Failed to iterate keys: {e}")));
131 }
132 }
133 }
134
135 Ok(keys)
136 }
137
138 async fn keys_pattern(&self, database_id: DatabaseId, pattern: &str) -> KVResult<Vec<Key>> {
139 let all_keys = self.keys(database_id).await?;
140 let matching_keys: Vec<Key> = all_keys
141 .into_iter()
142 .filter(|key| matches_pattern(key, pattern))
143 .collect();
144
145 Ok(matching_keys)
146 }
147
148 async fn clear_database(&self, database_id: DatabaseId) -> KVResult<()> {
149 let tree = self.get_tree(database_id).await?;
150
151 tree.clear()
152 .map_err(|e| KVError::Storage(format!("Failed to clear database: {e}")))?;
153
154 Ok(())
155 }
156
157 async fn get_stats(&self, database_id: DatabaseId) -> KVResult<StorageStats> {
158 let tree = self.get_tree(database_id).await?;
159
160 let total_keys = tree.len() as u64;
161 let memory_usage = 0; Ok(StorageStats {
164 total_keys,
165 memory_usage,
166 disk_usage: Some(memory_usage),
167 last_flush: None, })
169 }
170
171 async fn flush(&self) -> KVResult<()> {
172 self.db.flush()
173 .map_err(|e| KVError::Storage(format!("Failed to flush database: {e}")))?;
174 debug!("Disk storage flushed");
175 Ok(())
176 }
177
178 async fn close(&self) -> KVResult<()> {
179 self.db.flush()
180 .map_err(|e| KVError::Storage(format!("Failed to flush on close: {e}")))?;
181 debug!("Disk storage closed");
182 Ok(())
183 }
184}
185
186fn matches_pattern(key: &str, pattern: &str) -> bool {
188 if pattern == "*" {
189 return true;
190 }
191
192 if !pattern.contains('*') {
193 return key == pattern;
194 }
195
196 let pattern_parts: Vec<&str> = pattern.split('*').collect();
198 if pattern_parts.len() == 2 {
199 let prefix = pattern_parts[0];
200 let suffix = pattern_parts[1];
201
202 if prefix.is_empty() {
203 key.ends_with(suffix)
204 } else if suffix.is_empty() {
205 key.starts_with(prefix)
206 } else {
207 key.starts_with(prefix) && key.ends_with(suffix)
208 }
209 } else {
210 key.contains(pattern.trim_matches('*'))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::{Value, Entry};
219 use tempfile::TempDir;
220
221 #[tokio::test]
222 async fn test_disk_storage_basic_operations() {
223 let temp_dir = TempDir::new().unwrap();
224 let storage = DiskStorage::new(temp_dir.path()).unwrap();
225 let database_id = 0;
226
227 let entry = Entry::new(Value::String("test_value".to_string()), None);
229 storage.set(database_id, "test_key".to_string(), entry.clone()).await.unwrap();
230
231 let retrieved = storage.get(database_id, &"test_key".to_string()).await.unwrap();
232 assert!(retrieved.is_some());
233 assert_eq!(retrieved.unwrap().value.as_string().unwrap(), "test_value");
234
235 let exists = storage.exists(database_id, &"test_key".to_string()).await.unwrap();
237 assert!(exists);
238
239 let deleted = storage.delete(database_id, &"test_key".to_string()).await.unwrap();
241 assert!(deleted);
242
243 let exists_after = storage.exists(database_id, &"test_key".to_string()).await.unwrap();
244 assert!(!exists_after);
245 }
246
247 #[tokio::test]
248 async fn test_disk_storage_persistence() {
249 let temp_dir = TempDir::new().unwrap();
250 let storage_path = temp_dir.path();
251
252 {
254 let storage = DiskStorage::new(storage_path).unwrap();
255 let entry = Entry::new(Value::String("persistent_value".to_string()), None);
256 storage.set(0, "persistent_key".to_string(), entry).await.unwrap();
257 storage.flush().await.unwrap();
258 }
259
260 {
262 let storage = DiskStorage::new(storage_path).unwrap();
263 let retrieved = storage.get(0, &"persistent_key".to_string()).await.unwrap();
264 assert!(retrieved.is_some());
265 assert_eq!(retrieved.unwrap().value.as_string().unwrap(), "persistent_value");
266 }
267 }
268
269 #[tokio::test]
270 async fn test_disk_storage_multiple_databases() {
271 let temp_dir = TempDir::new().unwrap();
272 let storage = DiskStorage::new(temp_dir.path()).unwrap();
273
274 let entry1 = Entry::new(Value::String("db0_value".to_string()), None);
276 let entry2 = Entry::new(Value::String("db1_value".to_string()), None);
277
278 storage.set(0, "key".to_string(), entry1).await.unwrap();
279 storage.set(1, "key".to_string(), entry2).await.unwrap();
280
281 let db0_value = storage.get(0, &"key".to_string()).await.unwrap().unwrap();
283 let db1_value = storage.get(1, &"key".to_string()).await.unwrap().unwrap();
284
285 assert_eq!(db0_value.value.as_string().unwrap(), "db0_value");
286 assert_eq!(db1_value.value.as_string().unwrap(), "db1_value");
287 }
288}