1use crate::{Error, Result, Storable, serialization::helpers, traits::KeyType};
7use parking_lot::{Mutex, RwLock};
8use rocksdb::{BoundColumnFamily, WriteBatch as RocksWriteBatch};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11use std::path::Path;
12use std::sync::Arc;
13use tracing::{debug, error, info, instrument, warn};
14
15#[derive(Clone)]
57pub struct Database {
58 pub(crate) inner: Arc<DatabaseInner>,
59}
60
61pub(crate) struct DatabaseInner {
62 pub(crate) db: Arc<rocksdb::DB>,
63 shutdown: Arc<RwLock<bool>>,
65}
66
67impl Database {
68 pub(crate) fn new(db: rocksdb::DB) -> Self {
70 Self {
71 inner: Arc::new(DatabaseInner {
72 db: Arc::new(db),
73 shutdown: Arc::new(RwLock::new(false)),
74 }),
75 }
76 }
77
78 #[instrument(skip(self))]
108 pub fn collection<T: Storable>(&self, name: &str) -> Result<Collection<T>> {
109 let shutdown_guard = self.inner.shutdown.read();
112
113 if *shutdown_guard {
114 return Err(Error::Database("Database has been shut down".to_string()));
115 }
116
117 self.inner.db.cf_handle(name).ok_or_else(|| {
119 error!("Column family '{}' not found", name);
120 Error::Database(format!(
121 "Column family '{}' does not exist. Ensure it was declared in DatabaseConfig::add_column_family() before opening the database.",
122 name
123 ))
124 })?;
125
126 debug!("Created collection for column family '{}'", name);
127
128 drop(shutdown_guard);
131
132 Ok(Collection::new(
133 Arc::clone(&self.inner.db),
134 name,
135 Arc::clone(&self.inner.shutdown),
136 ))
137 }
138
139 pub fn list_collections(&self) -> Result<Vec<String>> {
147 let _guard = self.check_shutdown()?;
148
149 rocksdb::DB::list_cf(&rocksdb::Options::default(), self.inner.db.path())
150 .map_err(|e| Error::Database(format!("Failed to list collections: {}", e)))
151 }
152
153 #[instrument(skip(self))]
157 pub fn flush(&self) -> Result<()> {
158 info!("Flushing database");
159 self.inner.db.flush().map_err(|e| {
160 error!("Flush failed: {}", e);
161 Error::Database(format!("Flush failed: {}", e))
162 })
163 }
164
165 #[instrument(skip(self))]
169 pub fn compact_all(&self) -> Result<()> {
170 info!("Compacting entire database");
171 self.inner.db.compact_range::<&[u8], &[u8]>(None, None);
172 Ok(())
173 }
174
175 #[instrument(skip(self, backup_path))]
181 pub fn backup<P: AsRef<Path>>(&self, backup_path: P) -> Result<()> {
182 use rocksdb::backup::{BackupEngine, BackupEngineOptions};
183
184 let _guard = self.check_shutdown()?;
186
187 let path = backup_path.as_ref();
188 info!("Creating backup at {:?}", path);
189
190 let backup_opts = BackupEngineOptions::new(path).map_err(|e| {
191 error!("Failed to create backup options: {}", e);
192 Error::Database(format!("Failed to create backup options: {}", e))
193 })?;
194
195 let mut backup_engine =
196 BackupEngine::open(&backup_opts, &rocksdb::Env::new()?).map_err(|e| {
197 error!("Failed to open backup engine: {}", e);
198 Error::Database(format!("Failed to open backup engine: {}", e))
199 })?;
200
201 backup_engine
202 .create_new_backup(&self.inner.db)
203 .map_err(|e| {
204 error!("Failed to create backup: {}", e);
205 Error::Database(format!("Failed to create backup: {}", e))
206 })?;
207
208 info!("Backup created successfully");
209 Ok(())
210 }
211
212 #[inline]
213 fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
214 let guard = self.inner.shutdown.read();
215
216 if *guard {
217 return Err(Error::Database("Database has been shut down".to_string()));
218 }
219
220 Ok(guard)
221 }
222
223 pub fn restore_from_backup<P: AsRef<Path>>(backup_path: P, restore_path: P) -> Result<()> {
230 use rocksdb::backup::{BackupEngine, BackupEngineOptions, RestoreOptions};
231
232 let backup_path = backup_path.as_ref();
233 let restore_path = restore_path.as_ref();
234
235 info!(
236 "Restoring from backup {:?} to {:?}",
237 backup_path, restore_path
238 );
239
240 let backup_opts = BackupEngineOptions::new(backup_path).map_err(|e| {
241 error!("Failed to create backup options: {}", e);
242 Error::Database(format!("Failed to create backup options: {}", e))
243 })?;
244
245 let mut backup_engine =
246 BackupEngine::open(&backup_opts, &rocksdb::Env::new()?).map_err(|e| {
247 error!("Failed to open backup engine: {}", e);
248 Error::Database(format!("Failed to open backup engine: {}", e))
249 })?;
250
251 let restore_opts = RestoreOptions::default();
252 backup_engine
253 .restore_from_latest_backup(restore_path, restore_path, &restore_opts)
254 .map_err(|e| {
255 error!("Failed to restore backup: {}", e);
256 Error::Database(format!("Failed to restore backup: {}", e))
257 })?;
258
259 info!("Backup restored successfully");
260 Ok(())
261 }
262
263 pub fn list_backups<P: AsRef<Path>>(backup_path: P) -> Result<Vec<BackupInfo>> {
265 use rocksdb::backup::{BackupEngine, BackupEngineOptions};
266
267 let path = backup_path.as_ref();
268 let backup_opts = BackupEngineOptions::new(path)
269 .map_err(|e| Error::Database(format!("Failed to create backup options: {}", e)))?;
270
271 let backup_engine = BackupEngine::open(&backup_opts, &rocksdb::Env::new()?)
272 .map_err(|e| Error::Database(format!("Failed to open backup engine: {}", e)))?;
273
274 let infos = backup_engine.get_backup_info();
275 Ok(infos
276 .iter()
277 .map(|info| BackupInfo {
278 backup_id: info.backup_id,
279 timestamp: info.timestamp,
280 size: info.size,
281 })
282 .collect())
283 }
284
285 #[instrument(skip(self))]
312 pub fn transaction(&self) -> Result<Transaction> {
313 let shutdown = self.inner.shutdown.read();
315
316 if *shutdown {
317 return Err(Error::Database("Database has been shut down".to_string()));
318 }
319
320 Ok(Transaction::new(
321 Arc::clone(&self.inner.db),
322 Arc::clone(&self.inner.shutdown),
323 ))
324 }
325
326 #[instrument(skip(self))]
337 pub fn shutdown(&self) -> Result<()> {
338 info!("Shutting down database");
339
340 let mut shutdown_guard = self.inner.shutdown.write();
342
343 let flush_result = self.flush();
346
347 if flush_result.is_ok() {
349 *shutdown_guard = true;
350 info!("Database shutdown complete");
351 } else {
352 error!("Shutdown failed: flush error, database remains operational");
353 }
354
355 flush_result
357 }
358}
359
360unsafe impl Send for Database {}
367unsafe impl Sync for Database {}
368
369#[derive(Debug, Clone)]
371pub struct BackupInfo {
372 pub backup_id: u32,
374 pub timestamp: i64,
376 pub size: u64,
378}
379
380#[derive(Debug)]
385pub struct Collection<T: Storable> {
386 db: Arc<rocksdb::DB>,
387 cf_name: String,
388 shutdown: Arc<RwLock<bool>>,
389 _phantom: PhantomData<T>,
390}
391
392impl<T: Storable> Collection<T> {
393 fn new(db: Arc<rocksdb::DB>, name: &str, shutdown: Arc<RwLock<bool>>) -> Self {
394 Self {
395 db,
396 cf_name: name.to_string(),
397 shutdown,
398 _phantom: PhantomData,
399 }
400 }
401
402 fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
403 self.db
405 .cf_handle(&self.cf_name)
406 .ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
407 }
408
409 #[inline]
410 fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
411 let guard = self.shutdown.read();
414
415 if *guard {
416 return Err(Error::Database("Database has been shut down".to_string()));
417 }
418
419 Ok(guard)
420 }
421
422 #[instrument(skip(self, value))]
451 pub fn put(&self, value: &T) -> Result<()> {
452 let _guard = self.check_shutdown()?;
453
454 value.validate()?;
456
457 let key = value.key();
458 let key_bytes = key.to_bytes()?;
459 let value_bytes = helpers::serialize(value)?;
460
461 debug!("Putting value in collection '{}'", self.cf_name);
462
463 let cf = self.cf()?;
464 self.db.put_cf(&cf, key_bytes, value_bytes).map_err(|e| {
465 error!("Failed to put value: {}", e);
466 Error::Database(format!("Failed to put value: {}", e))
467 })?;
468
469 value.on_stored();
470 Ok(())
471 }
472
473 #[instrument(skip(self))]
488 pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
489 let _guard = self.check_shutdown()?;
490
491 let key_bytes = key.to_bytes()?;
492 let cf = self.cf()?;
493
494 match self.db.get_cf(&cf, key_bytes)? {
495 Some(value_bytes) => {
496 let value: T = helpers::deserialize(&value_bytes)?;
497 Ok(Some(value))
498 }
499 None => Ok(None),
500 }
501 }
502
503 #[instrument(skip(self, db))]
553 pub fn get_with_refs(&self, key: &T::Key, db: &crate::Database) -> Result<Option<T>>
554 where
555 T: crate::Referable,
556 {
557 let _guard = self.check_shutdown()?;
558
559 let key_bytes = key.to_bytes()?;
560 let cf = self.cf()?;
561
562 match self.db.get_cf(&cf, key_bytes)? {
563 Some(value_bytes) => {
564 let value: T = helpers::deserialize(&value_bytes)?;
565 value.resolve_all(db)?;
566 Ok(Some(value))
567 }
568 None => Ok(None),
569 }
570 }
571
572 #[instrument(skip(self, keys))]
585 pub fn get_many(&self, keys: &[T::Key]) -> Result<Vec<Option<T>>> {
586 let _guard = self.check_shutdown()?;
587
588 if keys.is_empty() {
589 return Ok(Vec::new());
590 }
591
592 let key_bytes: Result<Vec<Vec<u8>>> = keys.iter().map(|k| k.to_bytes()).collect();
594 let key_bytes = key_bytes?;
595
596 let cf = self.cf()?;
598 let cf_refs: Vec<_> = key_bytes.iter().map(|k| (&cf, k.as_slice())).collect();
599
600 let results = self.db.multi_get_cf(cf_refs);
602
603 let mut output = Vec::with_capacity(keys.len());
605 for result in results {
606 match result {
607 Ok(Some(value_bytes)) => {
608 let value: T = helpers::deserialize(&value_bytes)?;
609 output.push(Some(value));
610 }
611 Ok(None) => output.push(None),
612 Err(e) => {
613 return Err(Error::Database(format!("Multi-get failed: {}", e)));
614 }
615 }
616 }
617
618 Ok(output)
619 }
620
621 #[instrument(skip(self, keys, db))]
636 pub fn get_many_with_refs(
637 &self,
638 keys: &[T::Key],
639 db: &crate::Database,
640 ) -> Result<Vec<Option<T>>>
641 where
642 T: crate::Referable,
643 {
644 let _guard = self.check_shutdown()?;
645
646 if keys.is_empty() {
647 return Ok(Vec::new());
648 }
649
650 let key_bytes: Result<Vec<Vec<u8>>> = keys.iter().map(|k| k.to_bytes()).collect();
652 let key_bytes = key_bytes?;
653
654 let cf = self.cf()?;
656 let cf_refs: Vec<_> = key_bytes.iter().map(|k| (&cf, k.as_slice())).collect();
657
658 let results = self.db.multi_get_cf(cf_refs);
660
661 let mut output = Vec::with_capacity(keys.len());
663 for result in results {
664 match result {
665 Ok(Some(value_bytes)) => {
666 let value: T = helpers::deserialize(&value_bytes)?;
667 value.resolve_all(db)?;
668 output.push(Some(value));
669 }
670 Ok(None) => output.push(None),
671 Err(e) => {
672 return Err(Error::Database(format!("Multi-get failed: {}", e)));
673 }
674 }
675 }
676
677 Ok(output)
678 }
679
680 #[instrument(skip(self))]
690 pub fn delete(&self, key: &T::Key) -> Result<()> {
691 let _guard = self.check_shutdown()?;
692
693 let key_bytes = key.to_bytes()?;
694
695 debug!("Deleting key from collection '{}'", self.cf_name);
696
697 let cf = self.cf()?;
698 self.db.delete_cf(&cf, key_bytes).map_err(|e| {
699 error!("Failed to delete: {}", e);
700 Error::Database(format!("Failed to delete: {}", e))
701 })
702 }
703
704 #[instrument(skip(self))]
706 pub fn exists(&self, key: &T::Key) -> Result<bool> {
707 let _guard = self.check_shutdown()?;
708 Ok(self.get(key)?.is_some())
709 }
710
711 pub fn batch(&self) -> Batch<T> {
716 Batch::new(Arc::clone(&self.db), self.cf_name.clone())
717 }
718
719 pub fn snapshot(&self) -> Snapshot<T> {
723 Snapshot::new(Arc::clone(&self.db), self.cf_name.clone())
724 }
725
726 pub fn iter(&self) -> Result<Iterator<T>> {
731 let _guard = self.check_shutdown()?;
732 Ok(Iterator::new(
733 Arc::clone(&self.db),
734 self.cf_name.clone(),
735 IteratorMode::Start,
736 Arc::clone(&self.shutdown),
737 ))
738 }
739
740 pub fn iter_from(&self, key: &T::Key) -> Result<Iterator<T>> {
742 let _guard = self.check_shutdown()?;
743 let key_bytes = key.to_bytes()?;
744 Ok(Iterator::new(
745 Arc::clone(&self.db),
746 self.cf_name.clone(),
747 IteratorMode::From(key_bytes),
748 Arc::clone(&self.shutdown),
749 ))
750 }
751
752 pub fn estimate_num_keys(&self) -> Result<u64> {
756 let cf = self.cf()?;
757 self.db
758 .property_int_value_cf(&cf, "rocksdb.estimate-num-keys")
759 .map(|v| v.unwrap_or(0))
760 .map_err(|e| Error::Database(format!("Failed to get estimate: {}", e)))
761 }
762
763 #[instrument(skip(self))]
765 pub fn flush(&self) -> Result<()> {
766 info!("Flushing collection '{}'", self.cf_name);
767 let cf = self.cf()?;
768 self.db.flush_cf(&cf).map_err(|e| {
769 error!("Flush failed: {}", e);
770 Error::Database(format!("Flush failed: {}", e))
771 })
772 }
773
774 #[instrument(skip(self, start, end))]
776 pub fn compact_range(&self, start: Option<&T::Key>, end: Option<&T::Key>) -> Result<()> {
777 let start_bytes = start.map(|k| k.to_bytes()).transpose()?;
778 let end_bytes = end.map(|k| k.to_bytes()).transpose()?;
779
780 info!("Compacting range in collection '{}'", self.cf_name);
781
782 let cf = self.cf()?;
783 self.db
784 .compact_range_cf(&cf, start_bytes.as_deref(), end_bytes.as_deref());
785 Ok(())
786 }
787
788 pub fn name(&self) -> &str {
790 &self.cf_name
791 }
792}
793
794unsafe impl<T: Storable> Send for Collection<T> {}
799unsafe impl<T: Storable> Sync for Collection<T> {}
800
801pub struct Batch<T: Storable> {
805 db: Arc<rocksdb::DB>,
806 cf_name: String,
807 batch: RocksWriteBatch,
808 _phantom: PhantomData<T>,
809}
810
811impl<T: Storable> Batch<T> {
812 fn new(db: Arc<rocksdb::DB>, cf_name: String) -> Self {
813 Self {
814 db,
815 cf_name,
816 batch: RocksWriteBatch::default(),
817 _phantom: PhantomData,
818 }
819 }
820
821 pub fn put(&mut self, value: &T) -> Result<()> {
823 value.validate()?;
824
825 let key = value.key();
826 let key_bytes = key.to_bytes()?;
827 let value_bytes = helpers::serialize(value)?;
828
829 let cf = self.db.cf_handle(&self.cf_name).ok_or_else(|| {
830 Error::Database(format!("Column family '{}' not found", self.cf_name))
831 })?;
832 self.batch.put_cf(&cf, &key_bytes, &value_bytes);
833 Ok(())
834 }
835
836 pub fn delete(&mut self, key: &T::Key) -> Result<()> {
838 let key_bytes = key.to_bytes()?;
839
840 let cf = self.db.cf_handle(&self.cf_name).ok_or_else(|| {
841 Error::Database(format!("Column family '{}' not found", self.cf_name))
842 })?;
843 self.batch.delete_cf(&cf, &key_bytes);
844 Ok(())
845 }
846
847 pub fn clear(&mut self) {
849 self.batch.clear();
850 }
851
852 pub fn len(&self) -> usize {
854 self.batch.len()
855 }
856
857 pub fn is_empty(&self) -> bool {
859 self.batch.is_empty()
860 }
861
862 #[instrument(skip(self))]
864 pub fn commit(self) -> Result<()> {
865 let op_count = self.batch.len();
866 debug!(
867 "Committing batch with {} operations to '{}'",
868 op_count, self.cf_name
869 );
870
871 self.db.write(self.batch).map_err(|e| {
872 error!("Batch commit failed: {}", e);
873 Error::Database(format!("Batch commit failed: {}", e))
874 })
875 }
876}
877
878pub struct Snapshot<T: Storable> {
888 db: Arc<rocksdb::DB>,
889 snapshot_ptr: *const rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB>,
891 cf_name: String,
892 _phantom: PhantomData<T>,
893}
894
895impl<T: Storable> Snapshot<T> {
896 fn new(db: Arc<rocksdb::DB>, cf_name: String) -> Self {
897 let snapshot_ptr = unsafe {
904 let snapshot = db.snapshot();
905 let static_snapshot: rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB> =
907 std::mem::transmute(snapshot);
908 let boxed = Box::new(static_snapshot);
909 Box::into_raw(boxed) as *const _
910 };
911
912 Self {
913 db,
914 snapshot_ptr,
915 cf_name,
916 _phantom: PhantomData,
917 }
918 }
919
920 fn snapshot(&self) -> &rocksdb::SnapshotWithThreadMode<'_, rocksdb::DB> {
921 unsafe { &*self.snapshot_ptr }
926 }
927
928 fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
929 self.db
930 .cf_handle(&self.cf_name)
931 .ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
932 }
933
934 pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
938 let key_bytes = key.to_bytes()?;
939 let cf = self.cf()?;
940
941 match self.snapshot().get_cf(&cf, key_bytes)? {
942 Some(value_bytes) => {
943 let value: T = helpers::deserialize(&value_bytes)?;
944 Ok(Some(value))
945 }
946 None => Ok(None),
947 }
948 }
949
950 pub fn exists(&self, key: &T::Key) -> Result<bool> {
952 Ok(self.get(key)?.is_some())
953 }
954}
955
956impl<T: Storable> Drop for Snapshot<T> {
957 fn drop(&mut self) {
958 unsafe {
961 let _ = Box::from_raw(
962 self.snapshot_ptr as *mut rocksdb::SnapshotWithThreadMode<'static, rocksdb::DB>,
963 );
964 }
965 }
966}
967
968enum IteratorMode {
970 Start,
971 From(Vec<u8>),
972}
973
974#[derive(Debug, Clone, Copy, PartialEq, Eq)]
976pub enum IterationStatus {
977 Completed,
979 StoppedEarly,
981}
982
983pub struct Iterator<T: Storable> {
987 db: Arc<rocksdb::DB>,
988 cf_name: String,
989 mode: IteratorMode,
990 shutdown: Arc<RwLock<bool>>,
991 _phantom: PhantomData<T>,
992}
993
994impl<T: Storable> Iterator<T> {
995 fn new(
996 db: Arc<rocksdb::DB>,
997 cf_name: String,
998 mode: IteratorMode,
999 shutdown: Arc<RwLock<bool>>,
1000 ) -> Self {
1001 Self {
1002 db,
1003 cf_name,
1004 mode,
1005 shutdown,
1006 _phantom: PhantomData,
1007 }
1008 }
1009
1010 fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
1011 self.db
1012 .cf_handle(&self.cf_name)
1013 .ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
1014 }
1015
1016 #[inline]
1017 fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
1018 let guard = self.shutdown.read();
1019
1020 if *guard {
1021 return Err(Error::Database("Database has been shut down".to_string()));
1022 }
1023
1024 Ok(guard)
1025 }
1026
1027 pub fn collect_all(&self) -> Result<Vec<T>> {
1033 let _guard = self.check_shutdown()?;
1034
1035 let mut results = Vec::new();
1036 let cf = self.cf()?;
1037 let iter = match &self.mode {
1038 IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
1039 IteratorMode::From(key) => self.db.iterator_cf(
1040 &cf,
1041 rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
1042 ),
1043 };
1044
1045 for item in iter {
1046 let (_key, value_bytes) =
1047 item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
1048
1049 let value: T = helpers::deserialize(&value_bytes)?;
1050 results.push(value);
1051 }
1052
1053 Ok(results)
1054 }
1055
1056 pub fn for_each<F>(&self, mut f: F) -> Result<IterationStatus>
1069 where
1070 F: FnMut(T) -> bool,
1071 {
1072 let _guard = self.check_shutdown()?;
1073
1074 let cf = self.cf()?;
1075 let iter = match &self.mode {
1076 IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
1077 IteratorMode::From(key) => self.db.iterator_cf(
1078 &cf,
1079 rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
1080 ),
1081 };
1082
1083 for item in iter {
1084 let (_key, value_bytes) =
1085 item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
1086
1087 let value: T = helpers::deserialize(&value_bytes)?;
1088
1089 if !f(value) {
1090 return Ok(IterationStatus::StoppedEarly);
1091 }
1092 }
1093
1094 Ok(IterationStatus::Completed)
1095 }
1096
1097 pub fn count(&self) -> Result<usize> {
1099 let _guard = self.check_shutdown()?;
1100
1101 let cf = self.cf()?;
1102 let iter = match &self.mode {
1103 IteratorMode::Start => self.db.iterator_cf(&cf, rocksdb::IteratorMode::Start),
1104 IteratorMode::From(key) => self.db.iterator_cf(
1105 &cf,
1106 rocksdb::IteratorMode::From(key, rocksdb::Direction::Forward),
1107 ),
1108 };
1109
1110 let mut count = 0;
1111 for item in iter {
1112 item.map_err(|e| Error::IteratorError(format!("Iterator error: {}", e)))?;
1113 count += 1;
1114 }
1115
1116 Ok(count)
1117 }
1118}
1119
1120pub struct Transaction {
1167 db: Arc<rocksdb::DB>,
1168 batch: Mutex<RocksWriteBatch>,
1169 cache: Mutex<TransactionCache>,
1172 shutdown: Arc<RwLock<bool>>,
1173}
1174
1175struct TransactionCache {
1176 data: HashMap<(String, Vec<u8>), Option<Vec<u8>>>,
1177 operation_count: usize,
1178 total_bytes: usize,
1179}
1180
1181impl TransactionCache {
1182 fn new() -> Self {
1183 Self {
1184 data: HashMap::new(),
1185 operation_count: 0,
1186 total_bytes: 0,
1187 }
1188 }
1189
1190 fn insert(&mut self, key: (String, Vec<u8>), value: Option<Vec<u8>>) -> Result<()> {
1191 const MAX_OPERATIONS: usize = 100_000;
1192 const MAX_BYTES: usize = 100 * 1024 * 1024; const HASHMAP_OVERHEAD: usize = 32; if self.operation_count >= MAX_OPERATIONS {
1197 return Err(Error::Database(format!(
1198 "Transaction limit exceeded: maximum {} operations allowed",
1199 MAX_OPERATIONS
1200 )));
1201 }
1202
1203 let entry_size = key.0.len()
1205 + key.1.len()
1206 + value.as_ref().map(|v| v.len()).unwrap_or(0)
1207 + HASHMAP_OVERHEAD;
1208
1209 let size_delta = if let Some(old_value) = self.data.get(&key) {
1211 let old_size = key.0.len()
1212 + key.1.len()
1213 + old_value.as_ref().map(|v| v.len()).unwrap_or(0)
1214 + HASHMAP_OVERHEAD;
1215 entry_size as i64 - old_size as i64
1216 } else {
1217 entry_size as i64
1218 };
1219
1220 let new_total = (self.total_bytes as i64 + size_delta) as usize;
1222 if new_total > MAX_BYTES {
1223 return Err(Error::Database(format!(
1224 "Transaction memory limit exceeded: maximum {}MB allowed",
1225 MAX_BYTES / (1024 * 1024)
1226 )));
1227 }
1228
1229 let is_new_entry = !self.data.contains_key(&key);
1231 if is_new_entry {
1232 self.operation_count += 1;
1233 }
1234 self.total_bytes = new_total;
1235
1236 self.data.insert(key, value);
1237 Ok(())
1238 }
1239
1240 fn get(&self, key: &(String, Vec<u8>)) -> Option<&Option<Vec<u8>>> {
1241 self.data.get(key)
1242 }
1243
1244 fn clear(&mut self) {
1245 self.data.clear();
1246 self.operation_count = 0;
1247 self.total_bytes = 0;
1248 }
1249}
1250
1251impl Transaction {
1252 fn new(db: Arc<rocksdb::DB>, shutdown: Arc<RwLock<bool>>) -> Self {
1253 Self {
1254 db,
1255 batch: Mutex::new(RocksWriteBatch::default()),
1256 cache: Mutex::new(TransactionCache::new()),
1257 shutdown,
1258 }
1259 }
1260
1261 #[inline]
1262 fn check_shutdown(&self) -> Result<parking_lot::RwLockReadGuard<'_, bool>> {
1263 let guard = self.shutdown.read();
1264
1265 if *guard {
1266 return Err(Error::Database("Database has been shut down".to_string()));
1267 }
1268
1269 Ok(guard)
1270 }
1271
1272 #[instrument(skip(self))]
1274 pub fn collection<'txn, T: Storable>(
1275 &'txn self,
1276 name: &str,
1277 ) -> Result<TransactionCollection<'txn, T>> {
1278 let _guard = self.check_shutdown()?;
1279
1280 self.db.cf_handle(name).ok_or_else(|| {
1282 error!("Column family '{}' not found", name);
1283 Error::Database(format!("Column family '{}' not found", name))
1284 })?;
1285
1286 debug!("Created transaction collection for '{}'", name);
1287 Ok(TransactionCollection::new(
1288 Arc::clone(&self.db),
1289 name.to_string(),
1290 &self.batch,
1291 &self.cache,
1292 ))
1293 }
1294
1295 #[instrument(skip(self))]
1299 pub fn commit(self) -> Result<()> {
1300 let guard = self.check_shutdown()?;
1301 drop(guard); let db = self.db;
1304 let batch = self.batch.into_inner();
1305 let op_count = batch.len();
1306
1307 info!("Committing transaction with {} operations", op_count);
1308
1309 db.write(batch).map_err(|e| {
1310 error!("Failed to commit transaction: {}", e);
1311 Error::Database(format!("Failed to commit transaction: {}", e))
1312 })
1313 }
1314
1315 #[instrument(skip(self))]
1319 pub fn rollback(self) -> Result<()> {
1320 let op_count = self.batch.lock().len();
1321 warn!("Rolling back transaction with {} operations", op_count);
1322 Ok(())
1323 }
1324
1325 pub fn clear(&self) -> Result<()> {
1327 self.batch.lock().clear();
1328 self.cache.lock().clear();
1329 Ok(())
1330 }
1331
1332 pub fn len(&self) -> Result<usize> {
1334 Ok(self.batch.lock().len())
1335 }
1336
1337 pub fn is_empty(&self) -> Result<bool> {
1339 Ok(self.batch.lock().is_empty())
1340 }
1341}
1342
1343unsafe impl Send for Transaction {}
1349unsafe impl Sync for Transaction {}
1350
1351pub struct TransactionCollection<'txn, T: Storable> {
1356 db: Arc<rocksdb::DB>,
1357 cf_name: String,
1358 batch: &'txn Mutex<RocksWriteBatch>,
1359 cache: &'txn Mutex<TransactionCache>,
1360 _phantom: PhantomData<T>,
1361}
1362
1363impl<'txn, T: Storable> TransactionCollection<'txn, T> {
1364 fn new(
1365 db: Arc<rocksdb::DB>,
1366 cf_name: String,
1367 batch: &'txn Mutex<RocksWriteBatch>,
1368 cache: &'txn Mutex<TransactionCache>,
1369 ) -> Self {
1370 Self {
1371 db,
1372 cf_name,
1373 batch,
1374 cache,
1375 _phantom: PhantomData,
1376 }
1377 }
1378
1379 fn cf<'a>(&'a self) -> Result<Arc<BoundColumnFamily<'a>>> {
1380 self.db
1381 .cf_handle(&self.cf_name)
1382 .ok_or_else(|| Error::Database(format!("Column family '{}' not found", self.cf_name)))
1383 }
1384
1385 #[instrument(skip(self, value))]
1390 pub fn put(&self, value: &T) -> Result<()> {
1391 value.validate()?;
1393
1394 let key = value.key();
1395 let key_bytes = key.to_bytes()?;
1396 let value_bytes = helpers::serialize(value)?;
1397
1398 debug!("Transaction put in collection '{}'", self.cf_name);
1399
1400 let mut batch = self.batch.lock();
1402 let mut cache = self.cache.lock();
1403
1404 let cf = self.cf()?;
1406 batch.put_cf(&cf, &key_bytes, &value_bytes);
1407
1408 cache.insert((self.cf_name.clone(), key_bytes), Some(value_bytes))?;
1410
1411 value.on_stored();
1412 Ok(())
1413 }
1414
1415 #[instrument(skip(self))]
1420 pub fn get(&self, key: &T::Key) -> Result<Option<T>> {
1421 let key_bytes = key.to_bytes()?;
1422
1423 let cache_key = (self.cf_name.clone(), key_bytes.clone());
1425 let cached_value = self.cache.lock().get(&cache_key).cloned();
1426
1427 if let Some(cached) = cached_value {
1428 debug!("Transaction cache hit for key in '{}'", self.cf_name);
1429 return match cached {
1430 Some(value_bytes) => {
1431 let value: T = helpers::deserialize(&value_bytes)?;
1432 Ok(Some(value))
1433 }
1434 None => Ok(None), };
1436 }
1437
1438 let cf = self.cf()?;
1440 match self.db.get_cf(&cf, key_bytes)? {
1441 Some(value_bytes) => {
1442 let value: T = helpers::deserialize(&value_bytes)?;
1443 Ok(Some(value))
1444 }
1445 None => Ok(None),
1446 }
1447 }
1448
1449 #[instrument(skip(self))]
1453 pub fn delete(&self, key: &T::Key) -> Result<()> {
1454 let key_bytes = key.to_bytes()?;
1455
1456 debug!("Transaction delete in collection '{}'", self.cf_name);
1457
1458 let mut batch = self.batch.lock();
1460 let mut cache = self.cache.lock();
1461
1462 let cf = self.cf()?;
1464 batch.delete_cf(&cf, &key_bytes);
1465
1466 cache.insert((self.cf_name.clone(), key_bytes), None)?;
1468
1469 Ok(())
1470 }
1471
1472 pub fn exists(&self, key: &T::Key) -> Result<bool> {
1474 Ok(self.get(key)?.is_some())
1475 }
1476
1477 #[instrument(skip(self, keys))]
1491 pub fn get_many(&self, keys: &[T::Key]) -> Result<Vec<Option<T>>> {
1492 if keys.is_empty() {
1493 return Ok(Vec::new());
1494 }
1495
1496 let key_bytes: Vec<Vec<u8>> = keys
1498 .iter()
1499 .map(|k| k.to_bytes())
1500 .collect::<Result<Vec<Vec<u8>>>>()?;
1501
1502 let mut results: Vec<Option<T>> = (0..keys.len()).map(|_| None).collect();
1504 let mut uncached_indices = Vec::new();
1505 let mut uncached_keys = Vec::new();
1506
1507 {
1510 let cache = self.cache.lock();
1511
1512 for (i, kb) in key_bytes.iter().enumerate() {
1513 let cache_key = (self.cf_name.clone(), kb.clone());
1514
1515 if let Some(cached) = cache.get(&cache_key) {
1516 results[i] = match cached {
1518 Some(value_bytes) => Some(helpers::deserialize(value_bytes)?),
1519 None => None, };
1521 } else {
1522 uncached_indices.push(i);
1524 uncached_keys.push(kb.clone());
1525 }
1526 }
1527 } if !uncached_keys.is_empty() {
1531 let cf = self.cf()?;
1532 let cf_refs: Vec<_> = uncached_keys.iter().map(|k| (&cf, k.as_slice())).collect();
1533 let db_results = self.db.multi_get_cf(cf_refs);
1534
1535 debug_assert_eq!(
1537 db_results.len(),
1538 uncached_keys.len(),
1539 "RocksDB multi_get violated contract: got {} results but expected {}",
1540 db_results.len(),
1541 uncached_keys.len()
1542 );
1543
1544 for (result_idx, db_result) in db_results.into_iter().enumerate() {
1545 let original_idx = uncached_indices[result_idx];
1546 results[original_idx] = match db_result {
1547 Ok(Some(value_bytes)) => Some(helpers::deserialize(&value_bytes)?),
1548 Ok(None) => None,
1549 Err(e) => return Err(Error::Database(format!("Multi-get failed: {}", e))),
1550 };
1551 }
1552 }
1553
1554 Ok(results)
1555 }
1556}
1557
1558unsafe impl<'txn, T: Storable> Send for TransactionCollection<'txn, T> {}
1564unsafe impl<'txn, T: Storable> Sync for TransactionCollection<'txn, T> {}
1565
1566#[cfg(test)]
1567mod tests {
1568 use borsh::{BorshDeserialize, BorshSerialize};
1569
1570 use super::*;
1571 use crate::DatabaseConfig;
1572
1573 #[derive(Debug, Clone, PartialEq, BorshSerialize, BorshDeserialize)]
1574 struct TestItem {
1575 id: u64,
1576 data: String,
1577 }
1578
1579 impl Storable for TestItem {
1580 type Key = u64;
1581 fn key(&self) -> Self::Key {
1582 self.id
1583 }
1584 }
1585
1586 fn create_test_db() -> Database {
1587 use std::sync::atomic::{AtomicU64, Ordering};
1588 static COUNTER: AtomicU64 = AtomicU64::new(0);
1589
1590 let id = COUNTER.fetch_add(1, Ordering::SeqCst);
1591 let path = std::env::temp_dir().join(format!("ngdb_test_{}", id));
1592 let _ = std::fs::remove_dir_all(&path);
1593
1594 DatabaseConfig::new(&path)
1595 .create_if_missing(true)
1596 .add_column_family("test")
1597 .open()
1598 .expect("Failed to create test database")
1599 }
1600
1601 #[test]
1602 fn test_collection_put_and_get() {
1603 let db = create_test_db();
1604 let collection = db.collection::<TestItem>("test").unwrap();
1605
1606 let item = TestItem {
1607 id: 1,
1608 data: "test".to_string(),
1609 };
1610
1611 collection.put(&item).unwrap();
1612 let retrieved = collection.get(&1).unwrap();
1613
1614 assert_eq!(Some(item), retrieved);
1615 }
1616
1617 #[test]
1618 fn test_collection_delete() {
1619 let db = create_test_db();
1620 let collection = db.collection::<TestItem>("test").unwrap();
1621
1622 let item = TestItem {
1623 id: 1,
1624 data: "test".to_string(),
1625 };
1626
1627 collection.put(&item).unwrap();
1628 collection.delete(&1).unwrap();
1629
1630 assert_eq!(None, collection.get(&1).unwrap());
1631 }
1632
1633 #[test]
1634 fn test_batch() {
1635 let db = create_test_db();
1636 let collection = db.collection::<TestItem>("test").unwrap();
1637
1638 let mut batch = collection.batch();
1639 for i in 0..10 {
1640 batch
1641 .put(&TestItem {
1642 id: i,
1643 data: format!("item_{}", i),
1644 })
1645 .unwrap();
1646 }
1647 batch.commit().unwrap();
1648
1649 for i in 0..10 {
1650 let item = collection.get(&i).unwrap().unwrap();
1651 assert_eq!(i, item.id);
1652 }
1653 }
1654
1655 #[test]
1656 fn test_iterator() {
1657 let db = create_test_db();
1658 let collection = db.collection::<TestItem>("test").unwrap();
1659
1660 for i in 0..5 {
1661 collection
1662 .put(&TestItem {
1663 id: i,
1664 data: format!("item_{}", i),
1665 })
1666 .unwrap();
1667 }
1668
1669 let items = collection.iter().unwrap().collect_all().unwrap();
1670 assert_eq!(5, items.len());
1671 }
1672
1673 #[test]
1674 fn test_get_many() {
1675 let db = create_test_db();
1676 let collection = db.collection::<TestItem>("test").unwrap();
1677
1678 for i in 0..10 {
1680 collection
1681 .put(&TestItem {
1682 id: i,
1683 data: format!("item_{}", i),
1684 })
1685 .unwrap();
1686 }
1687
1688 let keys = vec![1, 3, 5, 99]; let results = collection.get_many(&keys).unwrap();
1691
1692 assert_eq!(4, results.len());
1693 assert!(results[0].is_some());
1694 assert_eq!(1, results[0].as_ref().unwrap().id);
1695 assert!(results[1].is_some());
1696 assert_eq!(3, results[1].as_ref().unwrap().id);
1697 assert!(results[2].is_some());
1698 assert_eq!(5, results[2].as_ref().unwrap().id);
1699 assert!(results[3].is_none());
1700 }
1701
1702 #[test]
1703 fn test_transaction() {
1704 let db = create_test_db();
1705 let txn = db.transaction().unwrap();
1706 let collection = txn.collection::<TestItem>("test").unwrap();
1707
1708 collection
1709 .put(&TestItem {
1710 id: 1,
1711 data: "test".to_string(),
1712 })
1713 .unwrap();
1714
1715 assert!(collection.get(&1).unwrap().is_some());
1717
1718 txn.commit().unwrap();
1719
1720 let regular_collection = db.collection::<TestItem>("test").unwrap();
1722 assert!(regular_collection.get(&1).unwrap().is_some());
1723 }
1724
1725 #[test]
1726 fn test_transaction_get_many() {
1727 let db = create_test_db();
1728 let collection = db.collection::<TestItem>("test").unwrap();
1729
1730 collection
1732 .put(&TestItem {
1733 id: 1,
1734 data: "one".to_string(),
1735 })
1736 .unwrap();
1737 collection
1738 .put(&TestItem {
1739 id: 2,
1740 data: "two".to_string(),
1741 })
1742 .unwrap();
1743 collection
1744 .put(&TestItem {
1745 id: 5,
1746 data: "five".to_string(),
1747 })
1748 .unwrap();
1749
1750 let txn = db.transaction().unwrap();
1751 let txn_collection = txn.collection::<TestItem>("test").unwrap();
1752
1753 txn_collection
1755 .put(&TestItem {
1756 id: 3,
1757 data: "three".to_string(),
1758 })
1759 .unwrap();
1760 txn_collection
1761 .put(&TestItem {
1762 id: 4,
1763 data: "four".to_string(),
1764 })
1765 .unwrap();
1766
1767 txn_collection.delete(&5).unwrap();
1769
1770 let keys = vec![1, 2, 3, 4, 5, 6];
1772 let results = txn_collection.get_many(&keys).unwrap();
1773
1774 assert!(results[0].is_some()); assert_eq!(results[0].as_ref().unwrap().data, "one");
1777
1778 assert!(results[1].is_some()); assert_eq!(results[1].as_ref().unwrap().data, "two");
1780
1781 assert!(results[2].is_some()); assert_eq!(results[2].as_ref().unwrap().data, "three");
1783
1784 assert!(results[3].is_some()); assert_eq!(results[3].as_ref().unwrap().data, "four");
1786
1787 assert!(results[4].is_none()); assert!(results[5].is_none()); let committed_results = collection.get_many(&keys).unwrap();
1792 assert!(committed_results[2].is_none()); assert!(committed_results[3].is_none()); assert!(committed_results[4].is_some()); txn.commit().unwrap();
1798
1799 let final_results = collection.get_many(&keys).unwrap();
1800 assert!(final_results[2].is_some()); assert!(final_results[3].is_some()); assert!(final_results[4].is_none()); }
1804
1805 #[test]
1806 fn test_transaction_limits() {
1807 let db = create_test_db();
1808 let txn = db.transaction().unwrap();
1809 let collection = txn.collection::<TestItem>("test").unwrap();
1810
1811 for i in 0..100_001 {
1813 let result = collection.put(&TestItem {
1814 id: i,
1815 data: format!("item_{}", i),
1816 });
1817
1818 if i < 100_000 {
1819 assert!(result.is_ok());
1820 } else {
1821 assert!(result.is_err());
1823 assert!(result.unwrap_err().to_string().contains("limit exceeded"));
1824 break;
1825 }
1826 }
1827 }
1828
1829 #[test]
1830 fn test_shutdown_prevents_operations() {
1831 let db = create_test_db();
1832 let collection = db.collection::<TestItem>("test").unwrap();
1833
1834 let item = TestItem {
1836 id: 1,
1837 data: "test".to_string(),
1838 };
1839 assert!(collection.put(&item).is_ok());
1840
1841 db.shutdown().unwrap();
1843
1844 let item2 = TestItem {
1846 id: 2,
1847 data: "test2".to_string(),
1848 };
1849 let result = collection.put(&item2);
1850 assert!(result.is_err());
1851 assert!(result.unwrap_err().to_string().contains("shut down"));
1852
1853 let result = db.collection::<TestItem>("test");
1855 assert!(result.is_err());
1856 assert!(result.unwrap_err().to_string().contains("shut down"));
1857 }
1858
1859 #[test]
1860 fn test_iterator_checks_shutdown() {
1861 let db = create_test_db();
1862 let collection = db.collection::<TestItem>("test").unwrap();
1863
1864 for i in 0..5 {
1866 collection
1867 .put(&TestItem {
1868 id: i,
1869 data: format!("item_{}", i),
1870 })
1871 .unwrap();
1872 }
1873
1874 let iter = collection.iter().unwrap();
1876
1877 db.shutdown().unwrap();
1879
1880 let result = iter.collect_all();
1882 assert!(result.is_err());
1883 assert!(result.unwrap_err().to_string().contains("shut down"));
1884 }
1885
1886 #[test]
1887 fn test_shutdown_lock_is_released() {
1888 use std::sync::Arc;
1889 use std::thread;
1890 use std::time::Duration;
1891
1892 let db = Arc::new(create_test_db());
1893 let collection = db.collection::<TestItem>("test").unwrap();
1894
1895 collection
1897 .put(&TestItem {
1898 id: 1,
1899 data: "test".to_string(),
1900 })
1901 .unwrap();
1902
1903 let db_clone = Arc::clone(&db);
1905 let shutdown_handle = thread::spawn(move || {
1906 thread::sleep(Duration::from_millis(50));
1907 db_clone.shutdown()
1909 });
1910
1911 thread::sleep(Duration::from_millis(100));
1913
1914 let shutdown_result = shutdown_handle.join().unwrap();
1916
1917 assert!(shutdown_result.is_ok());
1919
1920 let result = collection.put(&TestItem {
1922 id: 2,
1923 data: "test2".to_string(),
1924 });
1925 assert!(result.is_err());
1926 assert!(result.unwrap_err().to_string().contains("shut down"));
1927
1928 let result = db.collection::<TestItem>("another");
1930 assert!(result.is_err());
1931 assert!(result.unwrap_err().to_string().contains("shut down"));
1932 }
1933
1934 #[test]
1935 fn test_list_collections_checks_shutdown() {
1936 let db = create_test_db();
1937
1938 let collections = db.list_collections();
1940 assert!(collections.is_ok());
1941
1942 db.shutdown().unwrap();
1944
1945 let result = db.list_collections();
1947 assert!(result.is_err());
1948 assert!(result.unwrap_err().to_string().contains("shut down"));
1949 }
1950}