Skip to main content

aurora_db/storage/
cold.rs

1use crate::error::{AqlError, ErrorCode, Result};
2use crate::types::{AuroraConfig, ColdStoreMode};
3use sled::Db;
4use std::sync::Arc;
5
6const ZSTD_MAGIC: u8 = 0x5A; // 'Z' in hex
7
8pub struct ColdStore {
9    db: Db,
10    #[allow(dead_code)]
11    db_path: String,
12}
13
14impl ColdStore {
15    pub fn new(path: &str) -> Result<Self> {
16        let config = AuroraConfig::default();
17        Self::with_config(
18            path,
19            config.cold_cache_capacity_mb,
20            config.cold_flush_interval_ms,
21            config.cold_mode,
22        )
23    }
24
25    pub fn with_config(
26        path: &str,
27        cache_capacity_mb: usize,
28        flush_interval_ms: Option<u64>,
29        mode: ColdStoreMode,
30    ) -> Result<Self> {
31        let db_path = if !path.ends_with(".db") {
32            format!("{}.db", path)
33        } else {
34            path.to_string()
35        };
36
37        let mut sled_config = sled::Config::new()
38            .path(&db_path)
39            .cache_capacity((cache_capacity_mb * 1024 * 1024) as u64)
40            .flush_every_ms(flush_interval_ms);
41
42        sled_config = match mode {
43            ColdStoreMode::HighThroughput => sled_config.mode(sled::Mode::HighThroughput),
44            ColdStoreMode::LowSpace => sled_config.mode(sled::Mode::LowSpace),
45        };
46
47        let db = sled_config.open().map_err(|e| {
48            let error_msg = e.to_string();
49
50            if error_msg.contains("Access is denied") || error_msg.contains("os error 5") {
51                AqlError::new(
52                    ErrorCode::IoError,
53                    format!(
54                        "Cannot open database at '{}': file is locked or access denied.",
55                        db_path,
56                    ),
57                )
58            } else {
59                AqlError::from(e)
60            }
61        })?;
62
63        Ok(Self { db, db_path })
64    }
65
66    pub fn open_tree(&self, name: &str) -> Result<sled::Tree> {
67        Ok(self.db.open_tree(name)?)
68    }
69
70    pub fn try_remove_stale_lock(db_path: &str) -> Result<bool> {
71        use std::path::Path;
72
73        let path = if !db_path.ends_with(".db") {
74            format!("{}.db", db_path)
75        } else {
76            db_path.to_string()
77        };
78
79        let db_dir = Path::new(&path);
80        if !db_dir.exists() {
81            return Ok(false);
82        }
83
84        let lock_file = db_dir.join(".lock");
85        if lock_file.exists() {
86            std::fs::remove_file(&lock_file)?;
87            Ok(true)
88        } else {
89            Ok(false)
90        }
91    }
92
93    pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
94        let val = self.db.get(key.as_bytes())?;
95        match val {
96            Some(ivec) => {
97                let bytes = ivec.as_ref();
98                if bytes.starts_with(&[ZSTD_MAGIC]) {
99                    let decompressed = zstd::decode_all(&bytes[1..]).map_err(|e| {
100                        AqlError::new(
101                            ErrorCode::SerializationError,
102                            format!("Decompression failed: {}", e),
103                        )
104                    })?;
105                    Ok(Some(decompressed))
106                } else {
107                    Ok(Some(bytes.to_vec()))
108                }
109            }
110            None => Ok(None),
111        }
112    }
113
114    pub fn set(&self, key: String, value: Vec<u8>) -> Result<()> {
115        let mut compressed = vec![ZSTD_MAGIC];
116        zstd::stream::copy_encode(&value[..], &mut compressed, 3).map_err(|e| {
117            AqlError::new(
118                ErrorCode::SerializationError,
119                format!("Compression failed: {}", e),
120            )
121        })?;
122
123        self.db.insert(key.as_bytes(), compressed)?;
124        Ok(())
125    }
126
127    pub fn delete(&self, key: &str) -> Result<()> {
128        self.db.remove(key.as_bytes())?;
129        Ok(())
130    }
131
132    pub fn scan(&self) -> impl Iterator<Item = Result<(String, Vec<u8>)>> + '_ {
133        self.db.iter().map(|result| {
134            result.map_err(AqlError::from).and_then(|(key, value)| {
135                let bytes = value.as_ref();
136                let data = if bytes.starts_with(&[ZSTD_MAGIC]) {
137                    zstd::decode_all(&bytes[1..]).map_err(|e| {
138                        AqlError::new(
139                            ErrorCode::SerializationError,
140                            format!("Decompression failed: {}", e),
141                        )
142                    })?
143                } else {
144                    bytes.to_vec()
145                };
146
147                Ok((
148                    String::from_utf8(key.to_vec()).map_err(|_| {
149                        AqlError::new(ErrorCode::ProtocolError, "Invalid UTF-8 in key".to_string())
150                    })?,
151                    data,
152                ))
153            })
154        })
155    }
156
157    pub fn scan_prefix(
158        &self,
159        prefix: &str,
160    ) -> impl Iterator<Item = Result<(String, Vec<u8>)>> + '_ {
161        self.db.scan_prefix(prefix.as_bytes()).map(|result| {
162            result.map_err(AqlError::from).and_then(|(key, value)| {
163                let bytes = value.as_ref();
164                let data = if bytes.starts_with(&[ZSTD_MAGIC]) {
165                    zstd::decode_all(&bytes[1..]).map_err(|e| {
166                        AqlError::new(
167                            ErrorCode::SerializationError,
168                            format!("Decompression failed: {}", e),
169                        )
170                    })?
171                } else {
172                    bytes.to_vec()
173                };
174
175                Ok((
176                    String::from_utf8(key.to_vec()).map_err(|_| {
177                        AqlError::new(ErrorCode::ProtocolError, "Invalid UTF-8 in key".to_string())
178                    })?,
179                    data,
180                ))
181            })
182        })
183    }
184
185    pub fn batch_set(&self, pairs: Vec<(String, Vec<u8>)>) -> Result<()> {
186        let batch = pairs
187            .into_iter()
188            .fold(sled::Batch::default(), |mut batch, (key, value)| {
189                let mut compressed = vec![ZSTD_MAGIC];
190                let _ = zstd::stream::copy_encode(&value[..], &mut compressed, 3);
191                batch.insert(key.as_bytes(), compressed);
192                batch
193            });
194
195        self.db.apply_batch(batch)?;
196        Ok(())
197    }
198
199    pub fn batch_set_arc(&self, pairs: Vec<(Arc<String>, Arc<Vec<u8>>)>) -> Result<()> {
200        let batch = pairs
201            .into_iter()
202            .fold(sled::Batch::default(), |mut batch, (key, value)| {
203                let mut compressed = vec![ZSTD_MAGIC];
204                let _ = zstd::stream::copy_encode(value.as_slice(), &mut compressed, 3);
205                batch.insert(key.as_bytes(), compressed);
206                batch
207            });
208
209        self.db.apply_batch(batch)?;
210        Ok(())
211    }
212
213    pub fn flush(&self) -> Result<()> {
214        self.db.flush()?;
215        Ok(())
216    }
217
218    pub fn compact(&self) -> Result<()> {
219        // Sled compaction is handled by its internal GC, but we can call checksum or other things
220        // or just let it flush.
221        self.db.flush()?;
222        Ok(())
223    }
224
225    pub fn get_stats(&self) -> Result<ColdStoreStats> {
226        Ok(ColdStoreStats {
227            size_on_disk: self.estimated_size(),
228            tree_count: self.db.tree_names().len() as u64,
229        })
230    }
231
232    pub fn estimated_size(&self) -> u64 {
233        self.db.size_on_disk().unwrap_or(0)
234    }
235}
236
237impl Drop for ColdStore {
238    fn drop(&mut self) {
239        if let Err(e) = self.db.flush() {
240            eprintln!("Error flushing database: {}", e);
241        }
242    }
243}
244
245#[derive(Debug)]
246pub struct ColdStoreStats {
247    pub size_on_disk: u64,
248    pub tree_count: u64,
249}