use crate::{Error, Result, db::Database};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
path: PathBuf,
create_if_missing: bool,
error_if_exists: bool,
max_open_files: Option<i32>,
parallelism: Option<i32>,
write_buffer_size: Option<usize>,
max_write_buffer_number: Option<i32>,
enable_statistics: bool,
optimize_for_point_lookup: Option<u64>,
column_families: HashSet<String>,
}
impl DatabaseConfig {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
let mut column_families = HashSet::new();
column_families.insert("default".to_string());
Self {
path: path.as_ref().to_path_buf(),
create_if_missing: false,
error_if_exists: false,
max_open_files: None,
parallelism: None,
write_buffer_size: None,
max_write_buffer_number: None,
enable_statistics: false,
optimize_for_point_lookup: None,
column_families,
}
}
pub fn create_if_missing(mut self, create: bool) -> Self {
self.create_if_missing = create;
self
}
pub fn error_if_exists(mut self, error: bool) -> Self {
self.error_if_exists = error;
self
}
pub fn set_max_open_files(mut self, max: i32) -> Self {
self.max_open_files = Some(max);
self
}
pub fn increase_parallelism(mut self, parallelism: i32) -> Self {
self.parallelism = Some(parallelism);
self
}
pub fn set_write_buffer_size(mut self, size: usize) -> Self {
self.write_buffer_size = Some(size);
self
}
pub fn set_max_write_buffer_number(mut self, num: i32) -> Self {
self.max_write_buffer_number = Some(num);
self
}
pub fn enable_statistics(mut self, enable: bool) -> Self {
self.enable_statistics = enable;
self
}
pub fn optimize_for_point_lookup(mut self, block_cache_size_mb: u64) -> Self {
self.optimize_for_point_lookup = Some(block_cache_size_mb);
self
}
pub fn add_column_family(mut self, name: &str) -> Self {
self.column_families.insert(name.to_string());
self
}
pub fn add_column_families<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for name in names {
self.column_families.insert(name.into());
}
self
}
pub fn open(self) -> Result<Database> {
self.open_internal()
}
fn open_internal(self) -> Result<Database> {
if self.path.as_os_str().is_empty() {
return Err(Error::InvalidConfig(
"Database path cannot be empty".to_string(),
));
}
let mut opts = rocksdb::Options::default();
opts.create_if_missing(self.create_if_missing);
opts.set_error_if_exists(self.error_if_exists);
opts.create_missing_column_families(true);
if let Some(max) = self.max_open_files {
opts.set_max_open_files(max);
}
if let Some(parallelism) = self.parallelism {
opts.increase_parallelism(parallelism);
}
if let Some(size) = self.write_buffer_size {
opts.set_write_buffer_size(size);
}
if let Some(num) = self.max_write_buffer_number {
opts.set_max_write_buffer_number(num);
}
if self.enable_statistics {
opts.enable_statistics();
}
if let Some(block_cache_size_mb) = self.optimize_for_point_lookup {
opts.optimize_for_point_lookup(block_cache_size_mb);
}
opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
let db_exists = self.path.exists();
let all_cfs: HashSet<String> = if db_exists {
let existing_cfs = rocksdb::DB::list_cf(&opts, &self.path)
.unwrap_or_else(|_| vec!["default".to_string()]);
let mut cfs: HashSet<String> = existing_cfs.into_iter().collect();
cfs.extend(self.column_families.clone());
cfs
} else {
self.column_families.clone()
};
let cf_descriptors: Vec<_> = all_cfs
.iter()
.map(|name| {
let mut cf_opts = rocksdb::Options::default();
cf_opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
rocksdb::ColumnFamilyDescriptor::new(name, cf_opts)
})
.collect();
let db = rocksdb::DB::open_cf_descriptors(&opts, &self.path, cf_descriptors)
.map_err(|e| Error::Database(format!("Failed to open database: {}", e)))?;
Ok(Database::new(db))
}
}
#[derive(Debug, Clone, Default)]
pub struct OpenOptions {
pub peer_nodes: Vec<String>,
pub enable_replication: bool,
pub node_id: Option<String>,
}
impl OpenOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_peer_nodes(mut self, nodes: Vec<String>) -> Self {
self.peer_nodes = nodes;
self
}
pub fn enable_replication(mut self, enable: bool) -> Self {
self.enable_replication = enable;
self
}
pub fn with_node_id(mut self, id: String) -> Self {
self.node_id = Some(id);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = DatabaseConfig::new("/tmp/test")
.create_if_missing(true)
.set_max_open_files(500)
.increase_parallelism(2);
assert_eq!(config.path, PathBuf::from("/tmp/test"));
assert!(config.create_if_missing);
assert_eq!(config.max_open_files, Some(500));
assert_eq!(config.parallelism, Some(2));
}
#[test]
fn test_open_options() {
let opts = OpenOptions::new()
.with_peer_nodes(vec![
"http://node1:8080".to_string(),
"http://node2:8080".to_string(),
])
.enable_replication(true)
.with_node_id("node-1".to_string());
assert_eq!(opts.peer_nodes.len(), 2);
assert!(opts.enable_replication);
assert_eq!(opts.node_id, Some("node-1".to_string()));
}
}