amadeus_utils/
rocksdb.rs

1//! Deterministic wrapper API over RocksDB v10.
2
3// Re-export commonly used types for downstream crates
4pub use rust_librocksdb_sys;
5pub use rust_rocksdb::{
6    AsColumnFamilyRef, BlockBasedIndexType, BlockBasedOptions, BottommostLevelCompaction, BoundColumnFamily, Cache,
7    ColumnFamilyDescriptor, CompactOptions, DBCompressionType, DBRawIteratorWithThreadMode, DBRecoveryMode, Direction,
8    Error as RocksDbError, IteratorMode, LruCacheOptions, MultiThreaded, Options, ReadOptions, SliceTransform,
9    Transaction, TransactionDB, TransactionDBOptions, TransactionOptions, WriteOptions, statistics,
10};
11use tokio::fs::create_dir_all;
12
13#[cfg(test)]
14thread_local! {
15    static TEST_DB: std::cell::RefCell<Option<TransactionDB<MultiThreaded>>> = std::cell::RefCell::new(None);
16}
17
18#[cfg(test)]
19pub struct TestDbGuard {
20    base: String,
21}
22
23#[cfg(test)]
24impl Drop for TestDbGuard {
25    fn drop(&mut self) {
26        // drop the thread-local DB so RocksDB files can be removed
27        TEST_DB.with(|cell| {
28            *cell.borrow_mut() = None;
29        });
30        // best-effort cleanup of the base directory
31        let _ = std::fs::remove_dir_all(&self.base);
32    }
33}
34
35#[cfg(test)]
36impl TestDbGuard {
37    pub fn base(&self) -> &str {
38        &self.base
39    }
40}
41
42#[derive(Debug, thiserror::Error)]
43pub enum Error {
44    #[error(transparent)]
45    RocksDb(#[from] rust_rocksdb::Error),
46    #[error(transparent)]
47    TokioIo(#[from] tokio::io::Error),
48    #[error("Column family not found: {0}")]
49    ColumnFamilyNotFound(String),
50}
51
52/// Instance-oriented wrapper to be used from Context
53#[derive(Clone)]
54pub struct RocksDb {
55    pub inner: std::sync::Arc<TransactionDB<MultiThreaded>>,
56}
57
58impl std::fmt::Debug for RocksDb {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("RocksDb").finish_non_exhaustive()
61    }
62}
63
64fn cf_names() -> &'static [&'static str] {
65    &[
66        "default",
67        "sysconf",
68        "entry",
69        "entry_meta",
70        "attestation",
71        "tx",
72        "tx_account_nonce",
73        "tx_receiver_nonce",
74        "contractstate",
75    ]
76}
77
78#[cfg(test)]
79pub fn init_for_test(base: &str) -> Result<TestDbGuard, Error> {
80    let path = format!("{}/db", base);
81    std::fs::create_dir_all(&path)?;
82
83    let block_cache = Cache::new_lru_cache(4 * 1024 * 1024 * 1024);
84
85    let mut db_opts = Options::default();
86    db_opts.create_if_missing(true);
87    db_opts.create_missing_column_families(true);
88    db_opts.set_max_open_files(30000);
89    db_opts.increase_parallelism(4);
90    db_opts.set_max_background_jobs(2);
91    db_opts.set_max_total_wal_size(2 * 1024 * 1024 * 1024);
92    db_opts.set_target_file_size_base(8 * 1024 * 1024 * 1024);
93    db_opts.set_max_compaction_bytes(20 * 1024 * 1024 * 1024);
94    db_opts.enable_statistics();
95    db_opts.set_statistics_level(statistics::StatsLevel::All);
96    db_opts.set_skip_stats_update_on_db_open(true);
97    db_opts.set_write_buffer_size(512 * 1024 * 1024);
98    db_opts.set_max_write_buffer_number(6);
99    db_opts.set_min_write_buffer_number_to_merge(2);
100    db_opts.set_level_zero_file_num_compaction_trigger(8);
101    db_opts.set_level_zero_slowdown_writes_trigger(30);
102    db_opts.set_level_zero_stop_writes_trigger(100);
103    db_opts.set_max_subcompactions(2);
104
105    let cf_descs: Vec<_> = cf_names()
106        .iter()
107        .map(|&name| {
108            let mut cf_opts = Options::default();
109            let mut block_based_options = BlockBasedOptions::default();
110            block_based_options.set_block_cache(&block_cache);
111            block_based_options.set_index_type(BlockBasedIndexType::TwoLevelIndexSearch);
112            block_based_options.set_partition_filters(true);
113            block_based_options.set_cache_index_and_filter_blocks(true);
114            block_based_options.set_cache_index_and_filter_blocks_with_high_priority(true);
115            block_based_options.set_pin_top_level_index_and_filter(true);
116            block_based_options.set_pin_l0_filter_and_index_blocks_in_cache(false);
117            cf_opts.set_block_based_table_factory(&block_based_options);
118            let dict_bytes = 32 * 1024;
119            cf_opts.set_compression_per_level(&[
120                DBCompressionType::None,
121                DBCompressionType::None,
122                DBCompressionType::Zstd,
123                DBCompressionType::Zstd,
124                DBCompressionType::Zstd,
125                DBCompressionType::Zstd,
126                DBCompressionType::Zstd,
127            ]);
128            cf_opts.set_compression_type(DBCompressionType::Zstd);
129            cf_opts.set_compression_options(-14, 2, 0, dict_bytes);
130            cf_opts.set_zstd_max_train_bytes(100 * dict_bytes);
131            cf_opts.set_max_total_wal_size(2 * 1024 * 1024 * 1024);
132            cf_opts.set_target_file_size_base(8 * 1024 * 1024 * 1024);
133            cf_opts.set_max_compaction_bytes(20 * 1024 * 1024 * 1024);
134            cf_opts.set_write_buffer_size(512 * 1024 * 1024);
135            cf_opts.set_max_write_buffer_number(6);
136            cf_opts.set_min_write_buffer_number_to_merge(2);
137            cf_opts.set_level_zero_file_num_compaction_trigger(20);
138            cf_opts.set_level_zero_slowdown_writes_trigger(40);
139            cf_opts.set_level_zero_stop_writes_trigger(100);
140            cf_opts.set_max_subcompactions(2);
141            ColumnFamilyDescriptor::new(name, cf_opts)
142        })
143        .collect();
144
145    let mut txn_db_opts = TransactionDBOptions::default();
146    txn_db_opts.set_default_lock_timeout(3000);
147    txn_db_opts.set_txn_lock_timeout(3000);
148    txn_db_opts.set_num_stripes(32);
149
150    let db = TransactionDB::open_cf_descriptors(&db_opts, &txn_db_opts, path, cf_descs)?;
151
152    TEST_DB.with(|cell| {
153        *cell.borrow_mut() = Some(db);
154    });
155
156    Ok(TestDbGuard { base: base.to_string() })
157}
158
159/// Lightweight transaction wrapper for instance API
160pub struct RocksDbTxn<'a> {
161    inner: SimpleTransaction<'a>,
162}
163
164impl<'a> RocksDbTxn<'a> {
165    /// Get access to the inner transaction for advanced operations
166    pub fn inner(&self) -> &SimpleTransaction<'a> {
167        &self.inner
168    }
169}
170
171impl RocksDb {
172    pub async fn open(path: String) -> Result<Self, Error> {
173        create_dir_all(&path).await?;
174
175        let block_cache = Cache::new_lru_cache(4 * 1024 * 1024 * 1024);
176
177        let mut db_opts = Options::default();
178        db_opts.create_if_missing(true);
179        db_opts.create_missing_column_families(true);
180        db_opts.set_max_open_files(30000);
181        db_opts.increase_parallelism(4);
182        db_opts.set_max_background_jobs(2);
183
184        db_opts.set_max_total_wal_size(2 * 1024 * 1024 * 1024); // 2GB
185        db_opts.set_target_file_size_base(8 * 1024 * 1024 * 1024);
186        db_opts.set_max_compaction_bytes(20 * 1024 * 1024 * 1024);
187
188        db_opts.enable_statistics();
189        db_opts.set_statistics_level(statistics::StatsLevel::All);
190        db_opts.set_skip_stats_update_on_db_open(true);
191
192        // Bigger L0 flushes
193        db_opts.set_write_buffer_size(512 * 1024 * 1024);
194        db_opts.set_max_write_buffer_number(6);
195        db_opts.set_min_write_buffer_number_to_merge(2);
196        // L0 thresholds
197        db_opts.set_level_zero_file_num_compaction_trigger(8);
198        db_opts.set_level_zero_slowdown_writes_trigger(30);
199        db_opts.set_level_zero_stop_writes_trigger(100);
200        db_opts.set_max_subcompactions(2);
201
202        let cf_descs: Vec<_> = cf_names()
203            .iter()
204            .map(|&name| {
205                let mut cf_opts = Options::default();
206
207                let mut block_based_options = BlockBasedOptions::default();
208                block_based_options.set_block_cache(&block_cache);
209                block_based_options.set_index_type(BlockBasedIndexType::TwoLevelIndexSearch);
210                block_based_options.set_partition_filters(true);
211                block_based_options.set_cache_index_and_filter_blocks(true);
212                // Note: set_cache_index_and_filter_blocks_with_high_priority not available in crates.io version
213                // block_based_options.set_cache_index_and_filter_blocks_with_high_priority(true);
214                block_based_options.set_pin_top_level_index_and_filter(true);
215                block_based_options.set_pin_l0_filter_and_index_blocks_in_cache(false);
216                cf_opts.set_block_based_table_factory(&block_based_options);
217
218                let dict_bytes = 32 * 1024;
219                cf_opts.set_compression_per_level(&[
220                    DBCompressionType::None, // L0
221                    DBCompressionType::None, // L1
222                    DBCompressionType::Zstd, // L2
223                    DBCompressionType::Zstd, // L3
224                    DBCompressionType::Zstd, // L4
225                    DBCompressionType::Zstd, // L5
226                    DBCompressionType::Zstd, // L6
227                ]);
228
229                cf_opts.set_compression_type(DBCompressionType::Zstd);
230                cf_opts.set_compression_options(-14, 2, 0, dict_bytes);
231                cf_opts.set_zstd_max_train_bytes(100 * dict_bytes);
232
233                cf_opts.set_max_total_wal_size(2 * 1024 * 1024 * 1024); // 2GB
234                cf_opts.set_target_file_size_base(8 * 1024 * 1024 * 1024);
235                cf_opts.set_max_compaction_bytes(20 * 1024 * 1024 * 1024);
236
237                // Bigger L0 flushes
238                cf_opts.set_write_buffer_size(512 * 1024 * 1024);
239                cf_opts.set_max_write_buffer_number(6);
240                cf_opts.set_min_write_buffer_number_to_merge(2);
241                // L0 thresholds
242                cf_opts.set_level_zero_file_num_compaction_trigger(20);
243                cf_opts.set_level_zero_slowdown_writes_trigger(40);
244                cf_opts.set_level_zero_stop_writes_trigger(100);
245                cf_opts.set_max_subcompactions(2);
246
247                ColumnFamilyDescriptor::new(name, cf_opts)
248            })
249            .collect();
250
251        let mut txn_db_opts = TransactionDBOptions::default();
252        txn_db_opts.set_default_lock_timeout(3000);
253        txn_db_opts.set_txn_lock_timeout(3000);
254        txn_db_opts.set_num_stripes(32);
255
256        let db: TransactionDB<MultiThreaded> =
257            TransactionDB::open_cf_descriptors(&db_opts, &txn_db_opts, path.clone(), cf_descs)?;
258        // Note: flush methods not available on TransactionDB in crates.io version
259        // db.flush()?;
260        // db.flush_wal(true)?;
261
262        Ok(RocksDb { inner: std::sync::Arc::new(db) })
263    }
264
265    pub fn get(&self, cf: &str, key: &[u8]) -> Result<Option<Vec<u8>>, RocksDbError> {
266        let cf_h = self.inner.cf_handle(cf).unwrap();
267        Ok(self.inner.get_cf(&cf_h, key)?)
268    }
269    pub fn put(&self, cf: &str, key: &[u8], value: &[u8]) -> Result<(), RocksDbError> {
270        let cf_h = self.inner.cf_handle(cf).unwrap();
271        Ok(self.inner.put_cf(&cf_h, key, value)?)
272    }
273    pub fn delete(&self, cf: &str, key: &[u8]) -> Result<(), RocksDbError> {
274        let cf_h = self.inner.cf_handle(cf).unwrap();
275        Ok(self.inner.delete_cf(&cf_h, key)?)
276    }
277    pub fn iter_prefix(&self, cf: &str, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>, RocksDbError> {
278        let cf_h = self.inner.cf_handle(cf).unwrap();
279        let opts = ReadOptions::default();
280        let it_mode = IteratorMode::From(prefix, Direction::Forward);
281        let iter = self.inner.iterator_cf_opt(&cf_h, opts, it_mode);
282        let mut out = Vec::new();
283        for item in iter {
284            let (k, v) = item?;
285            if !k.starts_with(prefix) {
286                break;
287            }
288            out.push((k.to_vec(), v.to_vec()));
289        }
290        Ok(out)
291    }
292    pub fn get_prev_or_first(
293        &self,
294        cf: &str,
295        prefix: &str,
296        key_suffix: &str,
297    ) -> Result<Option<(Vec<u8>, Vec<u8>)>, RocksDbError> {
298        let Some(cf_h) = self.inner.cf_handle(cf) else {
299            return Ok(None);
300        };
301        let opts = ReadOptions::default();
302        let key = format!("{}{}", prefix, key_suffix);
303        let it_mode = IteratorMode::From(key.as_bytes(), Direction::Reverse);
304        let mut iter = self.inner.iterator_cf_opt(&cf_h, opts, it_mode);
305        if let Some(item) = iter.next() {
306            let (k, v) = item?;
307            if !k.starts_with(prefix.as_bytes()) {
308                return Ok(None);
309            }
310            return Ok(Some((k.to_vec(), v.to_vec())));
311        }
312        // fallback: first forward
313        let it_mode_f = IteratorMode::From(prefix.as_bytes(), Direction::Forward);
314        let mut iter_f = self.inner.iterator_cf_opt(&cf_h, ReadOptions::default(), it_mode_f);
315        if let Some(item) = iter_f.next() {
316            let (k, v) = item?;
317            if k.starts_with(prefix.as_bytes()) {
318                return Ok(Some((k.to_vec(), v.to_vec())));
319            }
320        }
321        Ok(None)
322    }
323    pub fn begin_transaction(&self) -> Transaction<'_, TransactionDB<MultiThreaded>> {
324        let txn_opts = TransactionOptions::default();
325        let write_opts = WriteOptions::default();
326        self.inner.transaction_opt(&write_opts, &txn_opts)
327    }
328
329    /// Flush write-ahead log to disk
330    /// Note: Not available in crates.io version of rust-rocksdb
331    pub fn flush_wal(&self, _sync: bool) -> Result<(), Error> {
332        // self.inner.flush_wal(sync).map_err(Into::into)
333        Ok(()) // No-op for crates.io compatibility
334    }
335
336    /// Flush all memtables to disk
337    /// Note: Not available in crates.io version of rust-rocksdb
338    pub fn flush(&self) -> Result<(), Error> {
339        // self.inner.flush().map_err(Into::into)
340        Ok(()) // No-op for crates.io compatibility
341    }
342
343    /// Flush a specific column family's memtable to disk
344    /// Note: Not available in crates.io version of rust-rocksdb
345    pub fn flush_cf(&self, _cf: &str) -> Result<(), Error> {
346        // let cf_h = self.inner.cf_handle(cf).ok_or_else(|| Error::ColumnFamilyNotFound(cf.to_string()))?;
347        // self.inner.flush_cf(&cf_h).map_err(Into::into)
348        Ok(()) // No-op for crates.io compatibility
349    }
350
351    /// Close the database gracefully by flushing pending writes
352    /// Note: RocksDB will be properly closed when this struct is dropped
353    pub fn close(&self) -> Result<(), Error> {
354        // Flush WAL before closing
355        self.flush_wal(true)?;
356        // Flush all memtables
357        self.flush()?;
358        // Database will be closed when Arc is dropped
359        Ok(())
360    }
361
362    /// Create a checkpoint (snapshot) of the database at the given path
363    /// This is a native RocksDB checkpoint operation
364    /// Note: Not available in crates.io version of rust-rocksdb
365    pub fn checkpoint(&self, _path: &str) -> Result<(), Error> {
366        // self.inner.create_checkpoint(path).map_err(Into::into)
367        Err(Error::ColumnFamilyNotFound("checkpoint not available in crates.io version".to_string()))
368    }
369}
370
371impl<'a> RocksDbTxn<'a> {
372    pub fn put(&self, cf: &str, key: &[u8], value: &[u8]) -> Result<(), Error> {
373        self.inner.put(cf, key, value)
374    }
375    pub fn delete(&self, cf: &str, key: &[u8]) -> Result<(), Error> {
376        self.inner.delete(cf, key)
377    }
378    pub fn get(&self, cf: &str, key: &[u8]) -> Result<Option<Vec<u8>>, Error> {
379        self.inner.get(cf, key)
380    }
381    pub fn raw_iterator_cf(
382        &self,
383        cf: &str,
384    ) -> Result<DBRawIteratorWithThreadMode<'_, Transaction<'_, TransactionDB<MultiThreaded>>>, Error> {
385        self.inner.raw_iterator_cf(cf)
386    }
387    pub fn commit(self) -> Result<(), Error> {
388        self.inner.commit()
389    }
390    pub fn rollback(self) -> Result<(), Error> {
391        self.inner.rollback()
392    }
393}
394
395/// RocksDB transaction trait
396pub trait RocksDbTransaction {
397    /// Put a key-value pair in the transaction
398    fn put(&self, cf: &str, key: &[u8], value: &[u8]) -> Result<(), Error>;
399
400    /// Delete a key in the transaction
401    fn delete(&self, cf: &str, key: &[u8]) -> Result<(), Error>;
402
403    /// Get a value from the transaction
404    fn get(&self, cf: &str, key: &[u8]) -> Result<Option<Vec<u8>>, Error>;
405
406    /// Commit the transaction
407    fn commit(self) -> Result<(), Error>;
408
409    /// Rollback the transaction
410    fn rollback(self) -> Result<(), Error>;
411}
412
413/// RocksDB trait for database operations with transaction support
414pub trait RocksDbTrait {
415    type Transaction<'a>: RocksDbTransaction
416    where
417        Self: 'a;
418
419    /// Create a new transaction
420    fn txn(&self) -> Self::Transaction<'_>;
421
422    /// Direct get operation without transaction (for read-only operations)
423    fn get(&self, cf: &str, key: &[u8]) -> Option<Vec<u8>>;
424}
425
426#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
427pub enum Cf {
428    Default,
429    SysConf,
430    Entry,
431    EntryMeta,
432    Attestation,
433    Tx,
434    TxAccountNonce,
435    TxReceiverNonce,
436    ContractState,
437}
438
439impl Cf {
440    pub fn as_str(&self) -> &'static str {
441        match self {
442            Cf::Default => "default",
443            Cf::SysConf => "sysconf",
444            Cf::Entry => "entry",
445            Cf::EntryMeta => "entry_meta",
446            Cf::Attestation => "attestation",
447            Cf::Tx => "tx",
448            Cf::TxAccountNonce => "tx_account_nonce",
449            Cf::TxReceiverNonce => "tx_receiver_nonce",
450            Cf::ContractState => "contractstate",
451        }
452    }
453}
454
455/// Simple transaction for TransactionDB
456pub struct SimpleTransaction<'a> {
457    pub txn: Transaction<'a, TransactionDB<MultiThreaded>>,
458    pub db: &'a TransactionDB<MultiThreaded>,
459}
460
461impl<'a> SimpleTransaction<'a> {
462    pub fn raw_iterator_cf(
463        &self,
464        cf: &str,
465    ) -> Result<DBRawIteratorWithThreadMode<'_, Transaction<'_, TransactionDB<MultiThreaded>>>, Error> {
466        let cf_handle = self.db.cf_handle(cf).ok_or_else(|| Error::ColumnFamilyNotFound(cf.to_string()))?;
467        Ok(self.txn.raw_iterator_cf(&cf_handle))
468    }
469}
470
471impl<'a> RocksDbTransaction for SimpleTransaction<'a> {
472    fn put(&self, cf: &str, key: &[u8], value: &[u8]) -> Result<(), Error> {
473        let cf_handle = self.db.cf_handle(cf).ok_or_else(|| Error::ColumnFamilyNotFound(cf.to_string()))?;
474        self.txn.put_cf(&cf_handle, key, value).map_err(Into::into)
475    }
476
477    fn delete(&self, cf: &str, key: &[u8]) -> Result<(), Error> {
478        let cf_handle = self.db.cf_handle(cf).ok_or_else(|| Error::ColumnFamilyNotFound(cf.to_string()))?;
479        self.txn.delete_cf(&cf_handle, key).map_err(Into::into)
480    }
481
482    fn get(&self, cf: &str, key: &[u8]) -> Result<Option<Vec<u8>>, Error> {
483        let cf_handle = self.db.cf_handle(cf).ok_or_else(|| Error::ColumnFamilyNotFound(cf.to_string()))?;
484        self.txn.get_cf(&cf_handle, key).map_err(Into::into)
485    }
486
487    fn commit(self) -> Result<(), Error> {
488        self.txn.commit().map_err(Into::into)
489    }
490
491    fn rollback(self) -> Result<(), Error> {
492        self.txn.rollback().map_err(Into::into)
493    }
494}
495
496/// Snapshot module for deterministic export/import of column families
497pub mod snapshot {
498    use super::*;
499    use blake3::Hasher;
500    use serde::{Deserialize, Serialize};
501    use std::path::Path;
502    use tokio::fs::File;
503    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
504
505    const MAGIC: &[u8] = b"SPK1";
506    const DOMAIN_SEP: &str = "statepack-v1";
507
508    #[derive(Debug, Serialize, Deserialize)]
509    pub struct Manifest {
510        pub version: u32,
511        pub algo: String,
512        pub cf: String,
513        pub items_total: u64,
514        pub root_hex: String,
515        pub snapshot_seq: Option<u64>,
516        pub domain_sep: String,
517    }
518
519    /// Write a varint (unsigned LEB128) to async writer
520    async fn write_varint(mut value: u64, writer: &mut (impl AsyncWrite + Unpin)) -> Result<(), Error> {
521        loop {
522            let mut byte = (value & 0x7f) as u8;
523            value >>= 7;
524            if value != 0 {
525                byte |= 0x80;
526            }
527            writer.write_u8(byte).await.map_err(Error::TokioIo)?;
528            if value == 0 {
529                break;
530            }
531        }
532        Ok(())
533    }
534
535    /// Read a varint (unsigned LEB128) from async reader
536    async fn read_varint(reader: &mut (impl AsyncRead + Unpin)) -> Result<u64, Error> {
537        let mut result = 0u64;
538        let mut shift = 0;
539        loop {
540            let byte = reader.read_u8().await.map_err(Error::TokioIo)?;
541            result |= ((byte & 0x7f) as u64) << shift;
542            if (byte & 0x80) == 0 {
543                break;
544            }
545            shift += 7;
546            if shift >= 64 {
547                return Err(Error::TokioIo(
548                    std::io::Error::new(std::io::ErrorKind::InvalidData, "varint too large").into(),
549                ));
550            }
551        }
552        Ok(result)
553    }
554
555    /// Encode a varint as bytes (for hashing)
556    fn encode_varint_bytes(mut value: u64) -> Vec<u8> {
557        let mut bytes = Vec::new();
558        loop {
559            let mut byte = (value & 0x7f) as u8;
560            value >>= 7;
561            if value != 0 {
562                byte |= 0x80;
563            }
564            bytes.push(byte);
565            if value == 0 {
566                break;
567            }
568        }
569        bytes
570    }
571
572    /// Export a column family to a deterministic snapshot file (.spk)
573    pub async fn export_spk(db: &super::RocksDb, cf_name: &str, output_path: &Path) -> Result<Manifest, Error> {
574        let cf_handle = db.inner.cf_handle(cf_name).ok_or_else(|| {
575            Error::TokioIo(
576                std::io::Error::new(std::io::ErrorKind::NotFound, format!("column family '{}' not found", cf_name))
577                    .into(),
578            )
579        })?;
580
581        let snapshot = db.inner.snapshot();
582        let mut read_opts = ReadOptions::default();
583        read_opts.set_total_order_seek(true);
584        read_opts.set_snapshot(&snapshot);
585
586        let iterator = db.inner.iterator_cf_opt(&cf_handle, read_opts, IteratorMode::From(&[], Direction::Forward));
587
588        let mut records = Vec::new();
589        let mut count = 0u64;
590
591        for item in iterator {
592            let (key, value) = item?;
593            records.push((key.to_vec(), value.to_vec()));
594            count += 1;
595        }
596
597        // Sort records by key for deterministic export
598        records.sort_by(|a, b| a.0.cmp(&b.0));
599
600        let file = File::create(output_path).await.map_err(Error::TokioIo)?;
601        let mut writer = BufWriter::new(file);
602        let mut hasher = Hasher::new();
603
604        // Write header
605        writer.write_all(MAGIC).await.map_err(Error::TokioIo)?;
606        // Note: MAGIC is not included in the hash for consistency with hash_cf
607
608        // Hash domain separator
609        hasher.update(DOMAIN_SEP.as_bytes());
610
611        // Write and hash records
612        for (key, value) in records {
613            // Hash key length, key, value length, value
614            let key_len_bytes = encode_varint_bytes(key.len() as u64);
615            let value_len_bytes = encode_varint_bytes(value.len() as u64);
616
617            hasher.update(&key_len_bytes);
618            hasher.update(&key);
619            hasher.update(&value_len_bytes);
620            hasher.update(&value);
621
622            // Write to file
623            write_varint(key.len() as u64, &mut writer).await?;
624            writer.write_all(&key).await.map_err(Error::TokioIo)?;
625            write_varint(value.len() as u64, &mut writer).await?;
626            writer.write_all(&value).await.map_err(Error::TokioIo)?;
627        }
628
629        writer.flush().await.map_err(Error::TokioIo)?;
630        let hash = hasher.finalize();
631        let hash_hex = hex::encode(hash.as_bytes());
632
633        Ok(Manifest {
634            version: 1,
635            algo: "blake3".to_string(),
636            cf: cf_name.to_string(),
637            items_total: count,
638            root_hex: hash_hex,
639            snapshot_seq: None,
640            domain_sep: DOMAIN_SEP.to_string(),
641        })
642    }
643
644    /// Import a snapshot file (.spk) into a column family using streaming with batching
645    pub async fn import_spk(
646        db: &super::RocksDb,
647        cf_name: &str,
648        spk_in: &Path,
649        manifest: &Manifest,
650        batch_bytes: usize,
651    ) -> Result<(), Error> {
652        use tokio::sync::mpsc;
653
654        // Verify manifest matches request
655        if manifest.cf != cf_name {
656            return Err(Error::TokioIo(
657                std::io::Error::new(
658                    std::io::ErrorKind::InvalidInput,
659                    format!("manifest cf '{}' != requested cf '{}'", manifest.cf, cf_name),
660                )
661                .into(),
662            ));
663        }
664
665        let file = File::open(spk_in).await.map_err(Error::TokioIo)?;
666        let mut reader = BufReader::new(file);
667
668        // Verify magic header
669        let mut magic_buf = [0u8; 4];
670        reader.read_exact(&mut magic_buf).await.map_err(Error::TokioIo)?;
671        if &magic_buf != MAGIC {
672            return Err(Error::TokioIo(
673                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid magic header").into(),
674            ));
675        }
676
677        // Create a channel for batching writes
678        let (tx, mut rx) = mpsc::channel::<(Vec<u8>, Vec<u8>)>(100);
679
680        // Spawn task to handle batch writes
681        let cf_name_owned = cf_name.to_string();
682        let db_clone = db.clone();
683        let write_task = tokio::spawn(async move {
684            let db = db_clone;
685            let mut current_batch = Vec::new();
686            let mut current_size = 0;
687
688            while let Some((key, value)) = rx.recv().await {
689                let item_size = key.len() + value.len();
690                if current_size + item_size > batch_bytes && !current_batch.is_empty() {
691                    // Write current batch
692                    write_batch(&db, &cf_name_owned, &current_batch)?;
693                    current_batch.clear();
694                    current_size = 0;
695                }
696
697                current_batch.push((key, value));
698                current_size += item_size;
699            }
700
701            // Write final batch
702            if !current_batch.is_empty() {
703                write_batch(&db, &cf_name_owned, &current_batch)?;
704            }
705
706            Ok::<(), Error>(())
707        });
708
709        // Read and send records to write task
710        let mut records_read = 0u64;
711        while records_read < manifest.items_total {
712            let key_len = read_varint(&mut reader).await?;
713            let mut key = vec![0u8; key_len as usize];
714            reader.read_exact(&mut key).await.map_err(Error::TokioIo)?;
715
716            let value_len = read_varint(&mut reader).await?;
717            let mut value = vec![0u8; value_len as usize];
718            reader.read_exact(&mut value).await.map_err(Error::TokioIo)?;
719
720            tx.send((key, value)).await.map_err(|_| {
721                Error::TokioIo(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed").into())
722            })?;
723
724            records_read += 1;
725        }
726
727        drop(tx); // Close channel
728        write_task.await.map_err(|e| Error::TokioIo(std::io::Error::new(std::io::ErrorKind::Other, e).into()))??;
729
730        Ok(())
731    }
732
733    /// Write a batch of key-value pairs to the database
734    fn write_batch(db: &super::RocksDb, cf_name: &str, batch: &[(Vec<u8>, Vec<u8>)]) -> Result<(), Error> {
735        let cf_handle = db.inner.cf_handle(cf_name).ok_or_else(|| Error::ColumnFamilyNotFound(cf_name.to_string()))?;
736
737        let mut write_opts = WriteOptions::default();
738        write_opts.set_sync(false); // Use async writes for better performance
739
740        for (key, value) in batch {
741            db.inner.put_cf_opt(&cf_handle, key, value, &write_opts)?;
742        }
743
744        Ok(())
745    }
746
747    /// Hash a column family in the database (for verification)
748    pub async fn hash_cf(db: &super::RocksDb, cf_name: &str) -> Result<[u8; 32], Error> {
749        let snapshot = db.inner.snapshot();
750        let mut read_opts = ReadOptions::default();
751        read_opts.set_total_order_seek(true);
752        read_opts.set_snapshot(&snapshot);
753
754        let cf_handle = db.inner.cf_handle(cf_name).ok_or_else(|| {
755            Error::TokioIo(
756                std::io::Error::new(std::io::ErrorKind::NotFound, format!("cf '{}' missing", cf_name)).into(),
757            )
758        })?;
759
760        let iterator = db.inner.iterator_cf_opt(&cf_handle, read_opts, IteratorMode::Start);
761
762        let mut hasher = Hasher::new();
763        hasher.update(DOMAIN_SEP.as_bytes());
764
765        let mut records = Vec::new();
766        for item in iterator {
767            let (key, value) = item?;
768            records.push((key.to_vec(), value.to_vec()));
769        }
770
771        // Sort for deterministic hashing
772        records.sort_by(|a, b| a.0.cmp(&b.0));
773
774        for (key, value) in records {
775            let key_len_bytes = encode_varint_bytes(key.len() as u64);
776            let value_len_bytes = encode_varint_bytes(value.len() as u64);
777
778            hasher.update(&key_len_bytes);
779            hasher.update(&key);
780            hasher.update(&value_len_bytes);
781            hasher.update(&value);
782        }
783
784        let hash_result = hasher.finalize();
785        let mut hash_array = [0u8; 32];
786        hash_array.copy_from_slice(hash_result.as_bytes());
787        Ok(hash_array)
788    }
789
790    #[cfg(test)]
791    mod tests {
792        use super::*;
793        use std::any::type_name_of_val;
794
795        fn tmp_base_for_test<F: ?Sized>(f: &F) -> String {
796            let secs = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
797            let fq = type_name_of_val(f);
798            format!("/tmp/{}{}", fq, secs)
799        }
800
801        #[tokio::test]
802        async fn test_snapshot_export_import() {
803            let base = tmp_base_for_test(&test_snapshot_export_import);
804            let db = super::RocksDb::open(base.clone()).await.expect("open test db");
805
806            // Put some test data
807            db.put(crate::constants::CF_DEFAULT, b"key1", b"value1").unwrap();
808            db.put(crate::constants::CF_DEFAULT, b"key2", b"value2").unwrap();
809            db.put(crate::constants::CF_DEFAULT, b"key3", b"value3").unwrap();
810
811            let spk_path = std::path::PathBuf::from(format!("{}/test.spk", base));
812
813            // Export snapshot
814            let manifest = export_spk(&db, crate::constants::CF_DEFAULT, &spk_path).await.unwrap();
815            assert_eq!(manifest.items_total, 3);
816            assert_eq!(manifest.cf, crate::constants::CF_DEFAULT);
817            assert_eq!(manifest.version, 1);
818
819            // Verify hash matches
820            let cf_hash = hash_cf(&db, crate::constants::CF_DEFAULT).await.unwrap();
821            assert_eq!(hex::encode(cf_hash), manifest.root_hex);
822
823            // Test import on a fresh database instance
824            let base2 = tmp_base_for_test(&"test_import_fresh");
825            let db2 = super::RocksDb::open(base2.clone()).await.expect("open test db 2");
826
827            // Import the snapshot to the fresh database
828            import_spk(&db2, crate::constants::CF_DEFAULT, &spk_path, &manifest, 1024).await.unwrap();
829
830            // Verify data was imported correctly
831            assert_eq!(db2.get(crate::constants::CF_DEFAULT, b"key1").unwrap(), Some(b"value1".to_vec()));
832            assert_eq!(db2.get(crate::constants::CF_DEFAULT, b"key2").unwrap(), Some(b"value2".to_vec()));
833            assert_eq!(db2.get(crate::constants::CF_DEFAULT, b"key3").unwrap(), Some(b"value3".to_vec()));
834
835            // Verify hash matches on imported data
836            let cf_hash_after = hash_cf(&db2, crate::constants::CF_DEFAULT).await.unwrap();
837            assert_eq!(hex::encode(cf_hash_after), manifest.root_hex);
838        }
839    }
840}
841
842// Implement Database trait for RocksDb
843impl crate::database::Database for RocksDb {
844    fn get(&self, column_family: &str, key: &[u8]) -> Result<Option<Vec<u8>>, crate::database::DatabaseError> {
845        self.get(column_family, key).map_err(|e| crate::database::DatabaseError::Generic(e.to_string()))
846    }
847
848    fn put(&self, column_family: &str, key: &[u8], value: &[u8]) -> Result<(), crate::database::DatabaseError> {
849        self.put(column_family, key, value).map_err(|e| crate::database::DatabaseError::Generic(e.to_string()))
850    }
851
852    fn delete(&self, column_family: &str, key: &[u8]) -> Result<(), crate::database::DatabaseError> {
853        self.delete(column_family, key).map_err(|e| crate::database::DatabaseError::Generic(e.to_string()))
854    }
855
856    fn iter_prefix(
857        &self,
858        column_family: &str,
859        prefix: &[u8],
860    ) -> Result<Vec<(Vec<u8>, Vec<u8>)>, crate::database::DatabaseError> {
861        self.iter_prefix(column_family, prefix).map_err(|e| crate::database::DatabaseError::Generic(e.to_string()))
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868
869    #[global_allocator]
870    static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
871
872    #[tokio::test]
873    #[ignore]
874    async fn spam_random_writes() {
875        use rand::Rng;
876        let db = RocksDb::open("/tmp/rocksdb_spam".to_string()).await.unwrap();
877        let db_ref = &db.inner;
878        std::thread::scope(|s| {
879            for _ in 0..16 {
880                s.spawn(|| {
881                    let mut rng = rand::rng();
882                    loop {
883                        let cf = cf_names()[rng.random_range(0..cf_names().len())];
884                        let key: Vec<u8> = (0..rng.random_range(8..64)).map(|_| rng.random()).collect();
885                        let val: Vec<u8> = (0..rng.random_range(8000..12000)).map(|_| rng.random()).collect();
886                        let cf_h = db_ref.cf_handle(cf).unwrap();
887                        db_ref.put_cf(&cf_h, &key, &val).unwrap();
888                    }
889                });
890            }
891        });
892    }
893
894    #[test]
895    #[ignore]
896    fn append_to_all_keys() {
897        use rand::Rng;
898        let _guard = init_for_test("/tmp/rocksdb_spam").unwrap();
899        TEST_DB.with(|cell| {
900            let h = cell.borrow();
901            let db = h.as_ref().unwrap();
902            let mut rng = rand::rng();
903            loop {
904                for cf in cf_names() {
905                    let cf_h = db.cf_handle(cf).unwrap();
906                    let mut opts = ReadOptions::default();
907                    opts.set_total_order_seek(true);
908                    let iter = db.iterator_cf_opt(&cf_h, opts, IteratorMode::Start);
909                    for item in iter {
910                        let (key, val) = item.unwrap();
911                        let mut new_val = val.to_vec();
912                        let append: Vec<u8> = (0..rng.random_range(16..100)).map(|_| rng.random()).collect();
913                        new_val.extend_from_slice(&append);
914                        db.put_cf(&cf_h, &key, &new_val).unwrap();
915                    }
916                }
917            }
918        });
919    }
920
921    #[test]
922    #[ignore]
923    fn delete_all_keys() {
924        let _guard = init_for_test("/tmp/rocksdb_spam").unwrap();
925        TEST_DB.with(|cell| {
926            let h = cell.borrow();
927            let db = h.as_ref().unwrap();
928            loop {
929                for cf in cf_names() {
930                    let cf_h = db.cf_handle(cf).unwrap();
931                    let mut opts = ReadOptions::default();
932                    opts.set_total_order_seek(true);
933                    let iter = db.iterator_cf_opt(&cf_h, opts, IteratorMode::Start);
934                    for item in iter {
935                        let (key, _val) = item.unwrap();
936                        db.delete_cf(&cf_h, &key).unwrap();
937                    }
938                }
939            }
940        });
941    }
942}