blueprint_auth/
db.rs

1use std::num::NonZeroUsize;
2use std::path::Path;
3use std::sync::Arc;
4
5use rocksdb::ColumnFamilyDescriptor;
6use serde::{Deserialize, Serialize};
7
8type MultiThreadedRocksDb = rocksdb::OptimisticTransactionDB<rocksdb::MultiThreaded>;
9
10/// RocksDB instance
11///
12/// This is cheap to clone, as it uses an [`Arc`] internally.
13#[derive(Debug, Clone)]
14pub struct RocksDb {
15    db: Arc<MultiThreadedRocksDb>,
16}
17
18impl std::ops::Deref for RocksDb {
19    type Target = Arc<MultiThreadedRocksDb>;
20
21    fn deref(&self) -> &Self::Target {
22        &self.db
23    }
24}
25
26impl RocksDb {
27    pub fn open<P>(path: P, config: &RocksDbConfig) -> Result<Self, crate::Error>
28    where
29        P: AsRef<Path>,
30    {
31        let mut db_opts = rocksdb::Options::default();
32        db_opts.create_if_missing(config.create_if_missing);
33        db_opts.create_missing_column_families(config.create_missing_column_families);
34        db_opts.increase_parallelism(config.parallelism);
35        db_opts.set_write_buffer_size(config.write_buffer_size);
36        db_opts.set_max_open_files(config.max_open_files);
37        db_opts.set_allow_mmap_reads(true);
38        db_opts.set_allow_mmap_writes(true);
39
40        if let Some(max_background_jobs) = config.max_background_jobs {
41            db_opts.set_max_background_jobs(max_background_jobs);
42        }
43        if let Some(compaction_style) = &config.compaction_style {
44            db_opts.set_compaction_style(compaction_style_from_str(compaction_style)?);
45        }
46        if let Some(compression_type) = &config.compression_type {
47            db_opts.set_compression_type(compression_type_from_str(compression_type)?);
48        }
49        if config.enable_statistics {
50            db_opts.enable_statistics();
51        };
52
53        // Set the merge operator for the sequence column family
54        let mut seq_cf_opts = db_opts.clone();
55        seq_cf_opts.set_merge_operator_associative("add", adder_merge_operator);
56
57        let db = MultiThreadedRocksDb::open_cf_descriptors(
58            &db_opts,
59            path,
60            [
61                ColumnFamilyDescriptor::new(cf::SEQ_CF, seq_cf_opts),
62                ColumnFamilyDescriptor::new(cf::USER_TOKENS_CF, db_opts.clone()),
63                ColumnFamilyDescriptor::new(cf::TOKENS_OPTS_CF, db_opts.clone()),
64                ColumnFamilyDescriptor::new(cf::SERVICES_USER_KEYS_CF, db_opts.clone()),
65                ColumnFamilyDescriptor::new(cf::API_KEYS_CF, db_opts.clone()),
66                ColumnFamilyDescriptor::new(cf::API_KEYS_BY_ID_CF, db_opts.clone()),
67                ColumnFamilyDescriptor::new(cf::SERVICES_OAUTH_POLICY_CF, db_opts.clone()),
68                ColumnFamilyDescriptor::new(cf::OAUTH_JTI_CF, db_opts.clone()),
69                ColumnFamilyDescriptor::new(cf::OAUTH_RL_CF, db_opts.clone()),
70                ColumnFamilyDescriptor::new(cf::TLS_ASSETS_CF, db_opts.clone()),
71                ColumnFamilyDescriptor::new(cf::TLS_CERT_METADATA_CF, db_opts.clone()),
72                ColumnFamilyDescriptor::new(cf::TLS_ISSUANCE_LOG_CF, db_opts.clone()),
73            ],
74        )?;
75        Ok(Self { db: Arc::new(db) })
76    }
77}
78
79/// Column family names
80pub mod cf {
81    /// Sequence column family (used to store sequence numbers)
82    pub const SEQ_CF: &str = "seq";
83    /// Tokens options column family (used to store the tokens options, like the token expiration time)
84    pub const TOKENS_OPTS_CF: &str = "tkns_opts";
85    /// Users' tokens column family (used to store the tokens of the users)
86    pub const USER_TOKENS_CF: &str = "usr_tkns";
87    /// Services column family (used to store the services with their user keys)
88    pub const SERVICES_USER_KEYS_CF: &str = "svs_usr_keys";
89    /// API keys column family (used to store long-lived API keys by key_id)
90    pub const API_KEYS_CF: &str = "api_keys";
91    /// API keys by ID column family (used to lookup API keys by database ID)
92    pub const API_KEYS_BY_ID_CF: &str = "api_keys_by_id";
93    /// OAuth per-service policy configuration
94    pub const SERVICES_OAUTH_POLICY_CF: &str = "services_oauth_policy";
95    /// OAuth assertion replay cache (jti -> exp)
96    pub const OAUTH_JTI_CF: &str = "oauth_jti";
97    /// OAuth token endpoint rate limit buckets
98    pub const OAUTH_RL_CF: &str = "oauth_rl";
99    /// TLS assets (encrypted certificates, keys, CA bundles)
100    pub const TLS_ASSETS_CF: &str = "tls_assets";
101    /// TLS certificate metadata (service_id, cert_id) -> metadata
102    pub const TLS_CERT_METADATA_CF: &str = "tls_cert_metadata";
103    /// TLS certificate issuance log (append-only for auditing)
104    pub const TLS_ISSUANCE_LOG_CF: &str = "tls_issuance_log";
105}
106
107/// RocksDbConfig is used to configure RocksDb.
108#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
109#[serde(default)]
110pub struct RocksDbConfig {
111    pub create_if_missing: bool,
112    pub create_missing_column_families: bool,
113    pub parallelism: i32,
114    pub write_buffer_size: usize,
115    pub max_open_files: i32,
116    pub max_background_jobs: Option<i32>,
117    pub compression_type: Option<String>,
118    pub compaction_style: Option<String>,
119    pub enable_statistics: bool,
120}
121
122impl Default for RocksDbConfig {
123    fn default() -> Self {
124        Self {
125            create_if_missing: true,
126            create_missing_column_families: true,
127            parallelism: std::thread::available_parallelism()
128                .unwrap_or(NonZeroUsize::new(1).unwrap())
129                .saturating_mul(NonZeroUsize::new(2).unwrap())
130                .get() as i32,
131            write_buffer_size: 256 * 1024 * 1024,
132            max_open_files: 1024,
133            max_background_jobs: None,
134            compaction_style: None,
135            compression_type: Some("none".into()),
136            enable_statistics: false,
137        }
138    }
139}
140
141/// Converts string to a compaction style RocksDB variant.
142pub(crate) fn compaction_style_from_str(
143    s: &str,
144) -> Result<rocksdb::DBCompactionStyle, crate::Error> {
145    match s.to_lowercase().as_str() {
146        "level" => Ok(rocksdb::DBCompactionStyle::Level),
147        "universal" => Ok(rocksdb::DBCompactionStyle::Universal),
148        "fifo" => Ok(rocksdb::DBCompactionStyle::Fifo),
149        _ => Err(crate::Error::InvalidDBCompactionStyle(s.into())),
150    }
151}
152
153/// Converts string to a compression type RocksDB variant.
154pub(crate) fn compression_type_from_str(
155    s: &str,
156) -> Result<rocksdb::DBCompressionType, crate::Error> {
157    match s.to_lowercase().as_str() {
158        "bz2" => Ok(rocksdb::DBCompressionType::Bz2),
159        "lz4" => Ok(rocksdb::DBCompressionType::Lz4),
160        "lz4hc" => Ok(rocksdb::DBCompressionType::Lz4hc),
161        "snappy" => Ok(rocksdb::DBCompressionType::Snappy),
162        "zlib" => Ok(rocksdb::DBCompressionType::Zlib),
163        "zstd" => Ok(rocksdb::DBCompressionType::Zstd),
164        "none" => Ok(rocksdb::DBCompressionType::None),
165        _ => Err(crate::Error::InvalidDBCompressionType(s.into())),
166    }
167}
168
169#[allow(unused)]
170fn concat_merge_operator(
171    _key: &[u8],
172    existing_value: Option<&[u8]>,
173    operands: &rocksdb::merge_operator::MergeOperands,
174) -> Option<Vec<u8>> {
175    let mut result = existing_value.unwrap_or(&[]).to_vec();
176    for operand in operands {
177        result.extend_from_slice(operand);
178    }
179    Some(result)
180}
181
182/// merge operator that will add all values together
183///
184/// Note that it treats the values as u64 big endian encoded numbers.
185///
186/// This will wrap around if the value will overflow.
187#[allow(unused)]
188fn adder_merge_operator(
189    _key: &[u8],
190    existing_value: Option<&[u8]>,
191    operands: &rocksdb::merge_operator::MergeOperands,
192) -> Option<Vec<u8>> {
193    let current_value = existing_value
194        .and_then(|v| v.try_into().ok())
195        .map(u64::from_be_bytes)
196        .unwrap_or(0);
197    let mut sum = current_value;
198    for operand in operands {
199        let v = operand.try_into().ok().map(u64::from_be_bytes).unwrap_or(0);
200        // No overflow needed, we will wrap around back to 0
201        sum = sum.wrapping_add(v);
202    }
203    let result = sum.to_be_bytes().to_vec();
204    Some(result)
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_rocksdb_config_default() {
213        let config = RocksDbConfig::default();
214
215        assert!(config.create_if_missing);
216        assert!(config.create_missing_column_families);
217        assert_eq!(config.compression_type, Some("none".to_string()));
218        assert!(!config.enable_statistics);
219
220        // Check that parallelism is reasonable (should be at least 1)
221        assert!(config.parallelism >= 1);
222    }
223
224    #[test]
225    fn test_rocksdb_open() {
226        let tmp_dir = tempfile::tempdir().unwrap();
227        let config = RocksDbConfig::default();
228
229        let db_result = RocksDb::open(tmp_dir.path(), &config);
230        assert!(db_result.is_ok());
231
232        let db = db_result.unwrap();
233
234        // Verify that DB opened successfully by testing a simple put/get
235        let cf_names = vec![
236            cf::SEQ_CF.to_string(),
237            cf::USER_TOKENS_CF.to_string(),
238            cf::TOKENS_OPTS_CF.to_string(),
239            cf::SERVICES_USER_KEYS_CF.to_string(),
240            cf::TLS_ASSETS_CF.to_string(),
241            cf::TLS_CERT_METADATA_CF.to_string(),
242            cf::TLS_ISSUANCE_LOG_CF.to_string(),
243        ];
244
245        // Check that we can get each column family
246        for name in cf_names {
247            let cf_handle = db.cf_handle(&name);
248            assert!(cf_handle.is_some());
249        }
250    }
251
252    #[test]
253    fn test_rocksdb_compression_types() {
254        // Test valid compression types
255        assert!(compression_type_from_str("none").is_ok());
256        assert!(compression_type_from_str("lz4").is_ok());
257        assert!(compression_type_from_str("snappy").is_ok());
258        assert!(compression_type_from_str("zlib").is_ok());
259        assert!(compression_type_from_str("zstd").is_ok());
260        assert!(compression_type_from_str("bz2").is_ok());
261        assert!(compression_type_from_str("lz4hc").is_ok());
262
263        // Case insensitivity
264        assert!(compression_type_from_str("LZ4").is_ok());
265        assert!(compression_type_from_str("Snappy").is_ok());
266
267        // Invalid compression type
268        let invalid = compression_type_from_str("invalid_compression");
269        assert!(invalid.is_err());
270        assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
271    }
272
273    #[test]
274    fn test_rocksdb_compaction_styles() {
275        // Test valid compaction styles
276        assert!(compaction_style_from_str("level").is_ok());
277        assert!(compaction_style_from_str("universal").is_ok());
278        assert!(compaction_style_from_str("fifo").is_ok());
279
280        // Case insensitivity
281        assert!(compaction_style_from_str("LEVEL").is_ok());
282        assert!(compaction_style_from_str("Universal").is_ok());
283
284        // Invalid compaction style
285        let invalid = compaction_style_from_str("invalid_compaction");
286        assert!(invalid.is_err());
287        assert!(format!("{}", invalid.unwrap_err()).contains("Invalid"));
288    }
289
290    #[test]
291    fn test_adder_merge_operator() {
292        let tmp_dir = tempfile::tempdir().unwrap();
293        let config = RocksDbConfig::default();
294        let db = RocksDb::open(tmp_dir.path(), &config).unwrap();
295
296        let seq_cf = db.cf_handle(cf::SEQ_CF).unwrap();
297
298        // Test merge operation with sequence counter
299        let result = db.merge_cf(&seq_cf, b"test_counter", 1u64.to_be_bytes());
300        assert!(result.is_ok());
301
302        let result = db.merge_cf(&seq_cf, b"test_counter", 2u64.to_be_bytes());
303        assert!(result.is_ok());
304
305        let value = db.get_cf(&seq_cf, b"test_counter").unwrap();
306        assert!(value.is_some());
307
308        // Convert the value to u64, should be 3 (1 + 2)
309        let bytes = value.unwrap();
310        let mut value_bytes = [0u8; 8];
311        value_bytes.copy_from_slice(&bytes);
312        let value = u64::from_be_bytes(value_bytes);
313        assert_eq!(value, 3);
314    }
315}