1use std::num::NonZeroUsize;
2use std::path::Path;
3use std::sync::Arc;
4
5use rocksdb::ColumnFamilyDescriptor;
6use serde::{Deserialize, Serialize};
7
8type MultiThreadedRocksDb = rocksdb::OptimisticTransactionDB<rocksdb::MultiThreaded>;
9
10#[derive(Debug, Clone)]
14pub struct RocksDb {
15 db: Arc<MultiThreadedRocksDb>,
16}
17
18impl std::ops::Deref for RocksDb {
19 type Target = Arc<MultiThreadedRocksDb>;
20
21 fn deref(&self) -> &Self::Target {
22 &self.db
23 }
24}
25
26impl RocksDb {
27 pub fn open<P>(path: P, config: &RocksDbConfig) -> Result<Self, crate::Error>
28 where
29 P: AsRef<Path>,
30 {
31 let mut db_opts = rocksdb::Options::default();
32 db_opts.create_if_missing(config.create_if_missing);
33 db_opts.create_missing_column_families(config.create_missing_column_families);
34 db_opts.increase_parallelism(config.parallelism);
35 db_opts.set_write_buffer_size(config.write_buffer_size);
36 db_opts.set_max_open_files(config.max_open_files);
37 db_opts.set_allow_mmap_reads(true);
38 db_opts.set_allow_mmap_writes(true);
39
40 if let Some(max_background_jobs) = config.max_background_jobs {
41 db_opts.set_max_background_jobs(max_background_jobs);
42 }
43 if let Some(compaction_style) = &config.compaction_style {
44 db_opts.set_compaction_style(compaction_style_from_str(compaction_style)?);
45 }
46 if let Some(compression_type) = &config.compression_type {
47 db_opts.set_compression_type(compression_type_from_str(compression_type)?);
48 }
49 if config.enable_statistics {
50 db_opts.enable_statistics();
51 };
52
53 let mut seq_cf_opts = db_opts.clone();
55 seq_cf_opts.set_merge_operator_associative("add", adder_merge_operator);
56
57 let db = MultiThreadedRocksDb::open_cf_descriptors(
58 &db_opts,
59 path,
60 [
61 ColumnFamilyDescriptor::new(cf::SEQ_CF, seq_cf_opts),
62 ColumnFamilyDescriptor::new(cf::USER_TOKENS_CF, db_opts.clone()),
63 ColumnFamilyDescriptor::new(cf::TOKENS_OPTS_CF, db_opts.clone()),
64 ColumnFamilyDescriptor::new(cf::SERVICES_USER_KEYS_CF, db_opts.clone()),
65 ColumnFamilyDescriptor::new(cf::API_KEYS_CF, db_opts.clone()),
66 ColumnFamilyDescriptor::new(cf::API_KEYS_BY_ID_CF, db_opts.clone()),
67 ColumnFamilyDescriptor::new(cf::SERVICES_OAUTH_POLICY_CF, db_opts.clone()),
68 ColumnFamilyDescriptor::new(cf::OAUTH_JTI_CF, db_opts.clone()),
69 ColumnFamilyDescriptor::new(cf::OAUTH_RL_CF, db_opts.clone()),
70 ColumnFamilyDescriptor::new(cf::TLS_ASSETS_CF, db_opts.clone()),
71 ColumnFamilyDescriptor::new(cf::TLS_CERT_METADATA_CF, db_opts.clone()),
72 ColumnFamilyDescriptor::new(cf::TLS_ISSUANCE_LOG_CF, db_opts.clone()),
73 ],
74 )?;
75 Ok(Self { db: Arc::new(db) })
76 }
77}
78
79pub mod cf {
81 pub const SEQ_CF: &str = "seq";
83 pub const TOKENS_OPTS_CF: &str = "tkns_opts";
85 pub const USER_TOKENS_CF: &str = "usr_tkns";
87 pub const SERVICES_USER_KEYS_CF: &str = "svs_usr_keys";
89 pub const API_KEYS_CF: &str = "api_keys";
91 pub const API_KEYS_BY_ID_CF: &str = "api_keys_by_id";
93 pub const SERVICES_OAUTH_POLICY_CF: &str = "services_oauth_policy";
95 pub const OAUTH_JTI_CF: &str = "oauth_jti";
97 pub const OAUTH_RL_CF: &str = "oauth_rl";
99 pub const TLS_ASSETS_CF: &str = "tls_assets";
101 pub const TLS_CERT_METADATA_CF: &str = "tls_cert_metadata";
103 pub const TLS_ISSUANCE_LOG_CF: &str = "tls_issuance_log";
105}
106
107#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
109#[serde(default)]
110pub struct RocksDbConfig {
111 pub create_if_missing: bool,
112 pub create_missing_column_families: bool,
113 pub parallelism: i32,
114 pub write_buffer_size: usize,
115 pub max_open_files: i32,
116 pub max_background_jobs: Option<i32>,
117 pub compression_type: Option<String>,
118 pub compaction_style: Option<String>,
119 pub enable_statistics: bool,
120}
121
122impl Default for RocksDbConfig {
123 fn default() -> Self {
124 Self {
125 create_if_missing: true,
126 create_missing_column_families: true,
127 parallelism: std::thread::available_parallelism()
128 .unwrap_or(NonZeroUsize::new(1).unwrap())
129 .saturating_mul(NonZeroUsize::new(2).unwrap())
130 .get() as i32,
131 write_buffer_size: 256 * 1024 * 1024,
132 max_open_files: 1024,
133 max_background_jobs: None,
134 compaction_style: None,
135 compression_type: Some("none".into()),
136 enable_statistics: false,
137 }
138 }
139}
140
141pub(crate) fn compaction_style_from_str(
143 s: &str,
144) -> Result<rocksdb::DBCompactionStyle, crate::Error> {
145 match s.to_lowercase().as_str() {
146 "level" => Ok(rocksdb::DBCompactionStyle::Level),
147 "universal" => Ok(rocksdb::DBCompactionStyle::Universal),
148 "fifo" => Ok(rocksdb::DBCompactionStyle::Fifo),
149 _ => Err(crate::Error::InvalidDBCompactionStyle(s.into())),
150 }
151}
152
153pub(crate) fn compression_type_from_str(
155 s: &str,
156) -> Result<rocksdb::DBCompressionType, crate::Error> {
157 match s.to_lowercase().as_str() {
158 "bz2" => Ok(rocksdb::DBCompressionType::Bz2),
159 "lz4" => Ok(rocksdb::DBCompressionType::Lz4),
160 "lz4hc" => Ok(rocksdb::DBCompressionType::Lz4hc),
161 "snappy" => Ok(rocksdb::DBCompressionType::Snappy),
162 "zlib" => Ok(rocksdb::DBCompressionType::Zlib),
163 "zstd" => Ok(rocksdb::DBCompressionType::Zstd),
164 "none" => Ok(rocksdb::DBCompressionType::None),
165 _ => Err(crate::Error::InvalidDBCompressionType(s.into())),
166 }
167}
168
169#[allow(unused)]
170fn concat_merge_operator(
171 _key: &[u8],
172 existing_value: Option<&[u8]>,
173 operands: &rocksdb::merge_operator::MergeOperands,
174) -> Option<Vec<u8>> {
175 let mut result = existing_value.unwrap_or(&[]).to_vec();
176 for operand in operands {
177 result.extend_from_slice(operand);
178 }
179 Some(result)
180}
181
182#[allow(unused)]
188fn adder_merge_operator(
189 _key: &[u8],
190 existing_value: Option<&[u8]>,
191 operands: &rocksdb::merge_operator::MergeOperands,
192) -> Option<Vec<u8>> {
193 let current_value = existing_value
194 .and_then(|v| v.try_into().ok())
195 .map(u64::from_be_bytes)
196 .unwrap_or(0);
197 let mut sum = current_value;
198 for operand in operands {
199 let v = operand.try_into().ok().map(u64::from_be_bytes).unwrap_or(0);
200 sum = sum.wrapping_add(v);
202 }
203 let result = sum.to_be_bytes().to_vec();
204 Some(result)
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_rocksdb_config_default() {
213 let config = RocksDbConfig::default();
214
215 assert!(config.create_if_missing);
216 assert!(config.create_missing_column_families);
217 assert_eq!(config.compression_type, Some("none".to_string()));
218 assert!(!config.enable_statistics);
219
220 assert!(config.parallelism >= 1);
222 }
223
224 #[test]
225 fn test_rocksdb_open() {
226 let tmp_dir = tempfile::tempdir().unwrap();
227 let config = RocksDbConfig::default();
228
229 let db_result = RocksDb::open(tmp_dir.path(), &config);
230 assert!(db_result.is_ok());
231
232 let db = db_result.unwrap();
233
234 let cf_names = vec![
236 cf::SEQ_CF.to_string(),
237 cf::USER_TOKENS_CF.to_string(),
238 cf::TOKENS_OPTS_CF.to_string(),
239 cf::SERVICES_USER_KEYS_CF.to_string(),
240 cf::TLS_ASSETS_CF.to_string(),
241 cf::TLS_CERT_METADATA_CF.to_string(),
242 cf::TLS_ISSUANCE_LOG_CF.to_string(),
243 ];
244
245 for name in cf_names {
247 let cf_handle = db.cf_handle(&name);
248 assert!(cf_handle.is_some());
249 }
250 }
251
252 #[test]
253 fn test_rocksdb_compression_types() {
254 assert!(compression_type_from_str("none").is_ok());
256 assert!(compression_type_from_str("lz4").is_ok());
257 assert!(compression_type_from_str("snappy").is_ok());
258 assert!(compression_type_from_str("zlib").is_ok());
259 assert!(compression_type_from_str("zstd").is_ok());
260 assert!(compression_type_from_str("bz2").is_ok());
261 assert!(compression_type_from_str("lz4hc").is_ok());
262
263 assert!(compression_type_from_str("LZ4").is_ok());
265 assert!(compression_type_from_str("Snappy").is_ok());
266
267 let invalid = compression_type_from_str("invalid_compression");
269 assert!(invalid.is_err());
270 assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
271 }
272
273 #[test]
274 fn test_rocksdb_compaction_styles() {
275 assert!(compaction_style_from_str("level").is_ok());
277 assert!(compaction_style_from_str("universal").is_ok());
278 assert!(compaction_style_from_str("fifo").is_ok());
279
280 assert!(compaction_style_from_str("LEVEL").is_ok());
282 assert!(compaction_style_from_str("Universal").is_ok());
283
284 let invalid = compaction_style_from_str("invalid_compaction");
286 assert!(invalid.is_err());
287 assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
288 }
289
290 #[test]
291 fn test_adder_merge_operator() {
292 let tmp_dir = tempfile::tempdir().unwrap();
293 let config = RocksDbConfig::default();
294 let db = RocksDb::open(tmp_dir.path(), &config).unwrap();
295
296 let seq_cf = db.cf_handle(cf::SEQ_CF).unwrap();
297
298 let result = db.merge_cf(&seq_cf, b"test_counter", 1u64.to_be_bytes());
300 assert!(result.is_ok());
301
302 let result = db.merge_cf(&seq_cf, b"test_counter", 2u64.to_be_bytes());
303 assert!(result.is_ok());
304
305 let value = db.get_cf(&seq_cf, b"test_counter").unwrap();
306 assert!(value.is_some());
307
308 let bytes = value.unwrap();
310 let mut value_bytes = [0u8; 8];
311 value_bytes.copy_from_slice(&bytes);
312 let value = u64::from_be_bytes(value_bytes);
313 assert_eq!(value, 3);
314 }
315}