use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use rocksdb::ColumnFamilyDescriptor;
use serde::{Deserialize, Serialize};
type MultiThreadedRocksDb = rocksdb::OptimisticTransactionDB<rocksdb::MultiThreaded>;
#[derive(Debug, Clone)]
pub struct RocksDb {
db: Arc<MultiThreadedRocksDb>,
}
impl std::ops::Deref for RocksDb {
type Target = Arc<MultiThreadedRocksDb>;
fn deref(&self) -> &Self::Target {
&self.db
}
}
impl RocksDb {
pub fn open<P>(path: P, config: &RocksDbConfig) -> Result<Self, crate::Error>
where
P: AsRef<Path>,
{
let mut db_opts = rocksdb::Options::default();
db_opts.create_if_missing(config.create_if_missing);
db_opts.create_missing_column_families(config.create_missing_column_families);
db_opts.increase_parallelism(config.parallelism);
db_opts.set_write_buffer_size(config.write_buffer_size);
db_opts.set_max_open_files(config.max_open_files);
db_opts.set_allow_mmap_reads(true);
db_opts.set_allow_mmap_writes(true);
if let Some(max_background_jobs) = config.max_background_jobs {
db_opts.set_max_background_jobs(max_background_jobs);
}
if let Some(compaction_style) = &config.compaction_style {
db_opts.set_compaction_style(compaction_style_from_str(compaction_style)?);
}
if let Some(compression_type) = &config.compression_type {
db_opts.set_compression_type(compression_type_from_str(compression_type)?);
}
if config.enable_statistics {
db_opts.enable_statistics();
}
let mut seq_cf_opts = db_opts.clone();
seq_cf_opts.set_merge_operator_associative("add", adder_merge_operator);
let db = MultiThreadedRocksDb::open_cf_descriptors(
&db_opts,
path,
[
ColumnFamilyDescriptor::new(cf::SEQ_CF, seq_cf_opts),
ColumnFamilyDescriptor::new(cf::USER_TOKENS_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::TOKENS_OPTS_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::SERVICES_USER_KEYS_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::API_KEYS_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::API_KEYS_BY_ID_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::SERVICES_OAUTH_POLICY_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::OAUTH_JTI_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::OAUTH_RL_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::TLS_ASSETS_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::TLS_CERT_METADATA_CF, db_opts.clone()),
ColumnFamilyDescriptor::new(cf::TLS_ISSUANCE_LOG_CF, db_opts.clone()),
],
)?;
Ok(Self { db: Arc::new(db) })
}
}
pub mod cf {
pub const SEQ_CF: &str = "seq";
pub const TOKENS_OPTS_CF: &str = "tkns_opts";
pub const USER_TOKENS_CF: &str = "usr_tkns";
pub const SERVICES_USER_KEYS_CF: &str = "svs_usr_keys";
pub const API_KEYS_CF: &str = "api_keys";
pub const API_KEYS_BY_ID_CF: &str = "api_keys_by_id";
pub const SERVICES_OAUTH_POLICY_CF: &str = "services_oauth_policy";
pub const OAUTH_JTI_CF: &str = "oauth_jti";
pub const OAUTH_RL_CF: &str = "oauth_rl";
pub const TLS_ASSETS_CF: &str = "tls_assets";
pub const TLS_CERT_METADATA_CF: &str = "tls_cert_metadata";
pub const TLS_ISSUANCE_LOG_CF: &str = "tls_issuance_log";
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct RocksDbConfig {
pub create_if_missing: bool,
pub create_missing_column_families: bool,
pub parallelism: i32,
pub write_buffer_size: usize,
pub max_open_files: i32,
pub max_background_jobs: Option<i32>,
pub compression_type: Option<String>,
pub compaction_style: Option<String>,
pub enable_statistics: bool,
}
impl Default for RocksDbConfig {
fn default() -> Self {
Self {
create_if_missing: true,
create_missing_column_families: true,
parallelism: std::thread::available_parallelism()
.unwrap_or(NonZeroUsize::new(1).unwrap())
.saturating_mul(NonZeroUsize::new(2).unwrap())
.get() as i32,
write_buffer_size: 256 * 1024 * 1024,
max_open_files: 1024,
max_background_jobs: None,
compaction_style: None,
compression_type: Some("none".into()),
enable_statistics: false,
}
}
}
pub(crate) fn compaction_style_from_str(
s: &str,
) -> Result<rocksdb::DBCompactionStyle, crate::Error> {
match s.to_lowercase().as_str() {
"level" => Ok(rocksdb::DBCompactionStyle::Level),
"universal" => Ok(rocksdb::DBCompactionStyle::Universal),
"fifo" => Ok(rocksdb::DBCompactionStyle::Fifo),
_ => Err(crate::Error::InvalidDBCompactionStyle(s.into())),
}
}
pub(crate) fn compression_type_from_str(
s: &str,
) -> Result<rocksdb::DBCompressionType, crate::Error> {
match s.to_lowercase().as_str() {
"bz2" => Ok(rocksdb::DBCompressionType::Bz2),
"lz4" => Ok(rocksdb::DBCompressionType::Lz4),
"lz4hc" => Ok(rocksdb::DBCompressionType::Lz4hc),
"snappy" => Ok(rocksdb::DBCompressionType::Snappy),
"zlib" => Ok(rocksdb::DBCompressionType::Zlib),
"zstd" => Ok(rocksdb::DBCompressionType::Zstd),
"none" => Ok(rocksdb::DBCompressionType::None),
_ => Err(crate::Error::InvalidDBCompressionType(s.into())),
}
}
#[allow(unused)]
fn concat_merge_operator(
_key: &[u8],
existing_value: Option<&[u8]>,
operands: &rocksdb::merge_operator::MergeOperands,
) -> Option<Vec<u8>> {
let mut result = existing_value.unwrap_or(&[]).to_vec();
for operand in operands {
result.extend_from_slice(operand);
}
Some(result)
}
#[allow(unused)]
fn adder_merge_operator(
_key: &[u8],
existing_value: Option<&[u8]>,
operands: &rocksdb::merge_operator::MergeOperands,
) -> Option<Vec<u8>> {
let current_value = existing_value
.and_then(|v| v.try_into().ok())
.map(u64::from_be_bytes)
.unwrap_or(0);
let mut sum = current_value;
for operand in operands {
let v = operand.try_into().ok().map(u64::from_be_bytes).unwrap_or(0);
sum = sum.wrapping_add(v);
}
let result = sum.to_be_bytes().to_vec();
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rocksdb_config_default() {
let config = RocksDbConfig::default();
assert!(config.create_if_missing);
assert!(config.create_missing_column_families);
assert_eq!(config.compression_type, Some("none".to_string()));
assert!(!config.enable_statistics);
assert!(config.parallelism >= 1);
}
#[test]
fn test_rocksdb_open() {
let tmp_dir = tempfile::tempdir().unwrap();
let config = RocksDbConfig::default();
let db_result = RocksDb::open(tmp_dir.path(), &config);
assert!(db_result.is_ok());
let db = db_result.unwrap();
let cf_names = vec![
cf::SEQ_CF.to_string(),
cf::USER_TOKENS_CF.to_string(),
cf::TOKENS_OPTS_CF.to_string(),
cf::SERVICES_USER_KEYS_CF.to_string(),
cf::TLS_ASSETS_CF.to_string(),
cf::TLS_CERT_METADATA_CF.to_string(),
cf::TLS_ISSUANCE_LOG_CF.to_string(),
];
for name in cf_names {
let cf_handle = db.cf_handle(&name);
assert!(cf_handle.is_some());
}
}
#[test]
fn test_rocksdb_compression_types() {
assert!(compression_type_from_str("none").is_ok());
assert!(compression_type_from_str("lz4").is_ok());
assert!(compression_type_from_str("snappy").is_ok());
assert!(compression_type_from_str("zlib").is_ok());
assert!(compression_type_from_str("zstd").is_ok());
assert!(compression_type_from_str("bz2").is_ok());
assert!(compression_type_from_str("lz4hc").is_ok());
assert!(compression_type_from_str("LZ4").is_ok());
assert!(compression_type_from_str("Snappy").is_ok());
let invalid = compression_type_from_str("invalid_compression");
assert!(invalid.is_err());
assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
}
#[test]
fn test_rocksdb_compaction_styles() {
assert!(compaction_style_from_str("level").is_ok());
assert!(compaction_style_from_str("universal").is_ok());
assert!(compaction_style_from_str("fifo").is_ok());
assert!(compaction_style_from_str("LEVEL").is_ok());
assert!(compaction_style_from_str("Universal").is_ok());
let invalid = compaction_style_from_str("invalid_compaction");
assert!(invalid.is_err());
assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
}
#[test]
fn test_adder_merge_operator() {
let tmp_dir = tempfile::tempdir().unwrap();
let config = RocksDbConfig::default();
let db = RocksDb::open(tmp_dir.path(), &config).unwrap();
let seq_cf = db.cf_handle(cf::SEQ_CF).unwrap();
let result = db.merge_cf(&seq_cf, b"test_counter", 1u64.to_be_bytes());
assert!(result.is_ok());
let result = db.merge_cf(&seq_cf, b"test_counter", 2u64.to_be_bytes());
assert!(result.is_ok());
let value = db.get_cf(&seq_cf, b"test_counter").unwrap();
assert!(value.is_some());
let bytes = value.unwrap();
let mut value_bytes = [0u8; 8];
value_bytes.copy_from_slice(&bytes);
let value = u64::from_be_bytes(value_bytes);
assert_eq!(value, 3);
}
}