use std::path::Path;
use std::sync::Arc;
use crate::core::types::Compression;
use crate::core::{DbConfig, WriteBatch};
use super::engine::{Engine, StorageConfig, StorageError};
use super::sstable::CompressionType;
pub type Result<T> = std::result::Result<T, DbError>;
#[derive(Debug, thiserror::Error)]
pub enum DbError {
#[error("Storage error: {0}")]
Storage(#[from] StorageError),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Database error: {0}")]
Other(String),
}
#[derive(Debug, Clone)]
pub struct DbOptions {
pub wal_enabled: bool,
pub sync_writes: bool,
pub memtable_size: usize,
pub block_cache_size: usize,
pub compression: Compression,
}
impl Default for DbOptions {
fn default() -> Self {
Self {
wal_enabled: true,
sync_writes: false, memtable_size: 64 * 1024 * 1024,
block_cache_size: 64 * 1024 * 1024,
compression: Compression::Lz4, }
}
}
impl DbOptions {
pub fn fast() -> Self {
Self {
wal_enabled: false,
sync_writes: false,
memtable_size: 64 * 1024 * 1024,
block_cache_size: 64 * 1024 * 1024,
compression: Compression::Snappy,
}
}
pub fn durable() -> Self {
Self {
wal_enabled: true,
sync_writes: false,
memtable_size: 64 * 1024 * 1024,
block_cache_size: 64 * 1024 * 1024,
compression: Compression::Snappy,
}
}
pub fn paranoid() -> Self {
Self {
wal_enabled: true,
sync_writes: true,
memtable_size: 64 * 1024 * 1024,
block_cache_size: 64 * 1024 * 1024,
compression: Compression::Snappy,
}
}
pub fn with_compression(mut self, compression: Compression) -> Self {
self.compression = compression;
self
}
}
pub struct Db {
engine: Arc<Engine>,
}
impl Db {
pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::open_with_options(path, DbOptions::default()).await
}
pub async fn open_with_options<P: AsRef<Path>>(path: P, options: DbOptions) -> Result<Self> {
let db_config = DbConfig {
wal_enabled: options.wal_enabled,
sync_writes: options.sync_writes,
memtable_size: options.memtable_size,
block_cache_size: options.block_cache_size,
..Default::default()
};
let mut storage_config =
StorageConfig::from_db_config(&db_config, path.as_ref().to_path_buf());
storage_config.sstable_config.compression = match options.compression {
Compression::None => CompressionType::None,
Compression::Snappy => CompressionType::Snappy,
Compression::Zstd => CompressionType::Zstd,
Compression::Lz4 => CompressionType::Lz4,
};
let engine = Engine::open(storage_config).await?;
Ok(Self {
engine: Arc::new(engine),
})
}
pub async fn insert<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<()> {
self.engine.insert(key.as_ref(), value.as_ref()).await?;
Ok(())
}
pub async fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>> {
Ok(self.engine.get(key.as_ref()).await?)
}
pub async fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<()> {
self.engine.delete(key.as_ref()).await?;
Ok(())
}
pub async fn contains_key<K: AsRef<[u8]>>(&self, key: K) -> Result<bool> {
Ok(self.engine.get(key.as_ref()).await?.is_some())
}
pub async fn range<K: AsRef<[u8]>>(&self, start: K, end: K) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
Ok(self.engine.range(start.as_ref(), end.as_ref()).await?)
}
pub async fn scan_prefix<K: AsRef<[u8]>>(&self, prefix: K) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
Ok(self.engine.scan_prefix(prefix.as_ref()).await?)
}
pub async fn range_iter<K: AsRef<[u8]>>(
&self,
start: K,
end: K,
) -> Result<super::iter::RangeIter> {
Ok(self.engine.range_iter(start.as_ref(), end.as_ref()).await?)
}
pub async fn scan_prefix_iter<K: AsRef<[u8]>>(
&self,
prefix: K,
) -> Result<super::iter::PrefixIter> {
Ok(self.engine.scan_prefix_iter(prefix.as_ref()).await?)
}
pub async fn write_batch(&self, batch: &WriteBatch) -> Result<()> {
self.engine.write_batch(batch).await?;
Ok(())
}
pub async fn flush(&self) -> Result<()> {
self.engine.flush_write_buffers()?;
self.engine.flush().await?;
Ok(())
}
pub async fn compact(&self) -> Result<()> {
self.engine.compact().await?;
Ok(())
}
pub fn stats(&self) -> DbStats {
let stats = self.engine.stats();
DbStats {
total_keys: stats.total_keys,
total_bytes: stats.total_bytes,
wal_size: stats.wal_size,
sstable_count: stats.sstable_count as u64,
memtable_size: stats.memtable_size,
}
}
}
#[derive(Debug, Clone)]
pub struct DbStats {
pub total_keys: u64,
pub total_bytes: u64,
pub wal_size: u64,
pub sstable_count: u64,
pub memtable_size: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_basic_operations() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"key1", b"value1").await.unwrap();
db.insert(b"key2", b"value2").await.unwrap();
assert_eq!(db.get(b"key1").await.unwrap(), Some(b"value1".to_vec()));
assert_eq!(db.get(b"key2").await.unwrap(), Some(b"value2".to_vec()));
assert_eq!(db.get(b"key3").await.unwrap(), None);
assert!(db.contains_key(b"key1").await.unwrap());
assert!(!db.contains_key(b"key3").await.unwrap());
db.remove(b"key1").await.unwrap();
assert_eq!(db.get(b"key1").await.unwrap(), None);
}
#[tokio::test]
async fn test_range_scan() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"a", b"1").await.unwrap();
db.insert(b"b", b"2").await.unwrap();
db.insert(b"c", b"3").await.unwrap();
db.insert(b"d", b"4").await.unwrap();
let results = db.range(b"b", b"d").await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0], (b"b".to_vec(), b"2".to_vec()));
assert_eq!(results[1], (b"c".to_vec(), b"3".to_vec()));
}
#[tokio::test]
async fn test_prefix_scan() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"user:1", b"alice").await.unwrap();
db.insert(b"user:2", b"bob").await.unwrap();
db.insert(b"post:1", b"hello").await.unwrap();
let users = db.scan_prefix(b"user:").await.unwrap();
assert_eq!(users.len(), 2);
}
#[tokio::test]
async fn test_fast_mode_optimized() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"key1", b"value1").await.unwrap();
db.insert(b"key2", b"value2").await.unwrap();
db.flush().await.unwrap();
assert_eq!(db.get(b"key1").await.unwrap(), Some(b"value1".to_vec()));
assert_eq!(db.get(b"key2").await.unwrap(), Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_fast_mode_many_inserts() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
for i in 0..1000 {
let key = format!("key{:04}", i);
let value = format!("value{:04}", i);
db.insert(key.as_bytes(), value.as_bytes()).await.unwrap();
}
db.flush().await.unwrap();
for i in 0..1000 {
let key = format!("key{:04}", i);
let expected = format!("value{:04}", i);
assert_eq!(
db.get(key.as_bytes()).await.unwrap(),
Some(expected.into_bytes())
);
}
}
#[tokio::test]
async fn test_range_iter_count() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"a", b"1").await.unwrap();
db.insert(b"b", b"2").await.unwrap();
db.insert(b"c", b"3").await.unwrap();
db.insert(b"d", b"4").await.unwrap();
let count = db.range_iter(b"a", b"d").await.unwrap().count();
assert_eq!(count, 3); }
#[tokio::test]
async fn test_range_iter_keys_only() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"user:1", b"alice").await.unwrap();
db.insert(b"user:2", b"bob").await.unwrap();
let keys = db.scan_prefix_iter(b"user:").await.unwrap().keys();
assert_eq!(keys.len(), 2);
assert!(keys.contains(&b"user:1".to_vec()));
assert!(keys.contains(&b"user:2".to_vec()));
}
#[tokio::test]
async fn test_range_iter_filter_by_key() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
db.insert(b"user:1:name", b"alice").await.unwrap();
db.insert(b"user:1:email", b"alice@example.com")
.await
.unwrap();
db.insert(b"user:2:name", b"bob").await.unwrap();
db.insert(b"user:2:email", b"bob@example.com")
.await
.unwrap();
let names: Vec<_> = db
.scan_prefix_iter(b"user:")
.await
.unwrap()
.filter(|g| g.key().ends_with(b":name"))
.map(|g| g.into_value())
.collect();
assert_eq!(names.len(), 2);
assert!(names.contains(&b"alice".to_vec()));
assert!(names.contains(&b"bob".to_vec()));
}
#[tokio::test]
async fn test_range_iter_paginate() {
let temp = TempDir::new().unwrap();
let db = Db::open_with_options(temp.path(), DbOptions::fast())
.await
.unwrap();
for i in 0..10 {
let key = format!("key:{:02}", i);
let value = format!("value:{:02}", i);
db.insert(key.as_bytes(), value.as_bytes()).await.unwrap();
}
let page: Vec<_> = db
.scan_prefix_iter(b"key:")
.await
.unwrap()
.paginate(3, 4)
.map(|g| String::from_utf8_lossy(g.key()).to_string())
.collect();
assert_eq!(page.len(), 4);
assert_eq!(page[0], "key:03");
assert_eq!(page[3], "key:06");
}
}