1use crate::{cache::CachePolicy, db::DB, errors::StoreError};
2
3use super::prelude::{Cache, DbKey, DbWriter};
4use parking_lot::{RwLock, RwLockReadGuard};
5use rocksdb::{IterateBounds, IteratorMode, ReadOptions};
6use serde::{de::DeserializeOwned, Serialize};
7use std::{
8 collections::{hash_map::RandomState, HashSet},
9 fmt::Debug,
10 hash::BuildHasher,
11 marker::PhantomData,
12 sync::Arc,
13};
14
15#[derive(Default, Debug)]
17pub struct ReadLock<T>(Arc<RwLock<T>>);
18
19impl<T> ReadLock<T> {
20 pub fn new(rwlock: Arc<RwLock<T>>) -> Self {
21 Self(rwlock)
22 }
23
24 pub fn read(&self) -> RwLockReadGuard<T> {
25 self.0.read()
26 }
27}
28
29impl<T> From<T> for ReadLock<T> {
30 fn from(value: T) -> Self {
31 Self::new(Arc::new(RwLock::new(value)))
32 }
33}
34
35#[derive(Clone)]
37pub struct CachedDbSetAccess<TKey, TData, S = RandomState, W = RandomState>
38where
39 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
40 TData: Clone + Send + Sync,
41 W: Send + Sync,
42{
43 inner: DbSetAccess<TKey, TData>,
45
46 cache: Cache<TKey, Arc<RwLock<HashSet<TData, W>>>, S>,
48}
49
50impl<TKey, TData, S, W> CachedDbSetAccess<TKey, TData, S, W>
51where
52 TKey: Clone + std::hash::Hash + Eq + Send + Sync + AsRef<[u8]>,
53 TData: Clone + std::hash::Hash + Eq + Send + Sync + DeserializeOwned + Serialize,
54 S: BuildHasher + Default,
55 W: BuildHasher + Default + Send + Sync,
56{
57 pub fn new(db: Arc<DB>, cache_policy: CachePolicy, prefix: Vec<u8>) -> Self {
58 Self { inner: DbSetAccess::new(db, prefix), cache: Cache::new(cache_policy) }
59 }
60
61 pub fn read_from_cache(&self, key: TKey) -> Option<ReadLock<HashSet<TData, W>>> {
62 self.cache.get(&key).map(ReadLock::new)
63 }
64
65 fn read_locked_entry(&self, key: TKey) -> Result<Arc<RwLock<HashSet<TData, W>>>, StoreError> {
67 if let Some(data) = self.cache.get(&key) {
68 Ok(data)
69 } else {
70 let data: HashSet<TData, _> = self.inner.bucket_iterator(key.clone()).collect::<Result<_, _>>()?;
71 let data = Arc::new(RwLock::new(data));
72 self.cache.insert(key, data.clone());
73 Ok(data)
74 }
75 }
76
77 pub fn read(&self, key: TKey) -> Result<ReadLock<HashSet<TData, W>>, StoreError> {
78 Ok(ReadLock::new(self.read_locked_entry(key)?))
79 }
80
81 pub fn write(&self, writer: impl DbWriter, key: TKey, data: TData) -> Result<(), StoreError> {
82 self.cache.update_if_entry_exists(key.clone(), |locked_entry| {
84 locked_entry.write().insert(data.clone());
85 });
86 self.inner.write(writer, key, data)
87 }
88
89 pub fn delete_bucket(&self, writer: impl DbWriter, key: TKey) -> Result<(), StoreError> {
90 self.cache.remove(&key);
91 self.inner.delete_bucket(writer, key)
92 }
93
94 pub fn delete(&self, writer: impl DbWriter, key: TKey, data: TData) -> Result<(), StoreError> {
95 self.cache.update_if_entry_exists(key.clone(), |locked_entry| {
97 locked_entry.write().remove(&data);
98 });
99 self.inner.delete(writer, key, data)?;
100 Ok(())
101 }
102
103 pub fn prefix(&self) -> &[u8] {
104 self.inner.prefix()
105 }
106}
107
108#[derive(Clone)]
110pub struct DbSetAccess<TKey, TData>
111where
112 TKey: Clone + std::hash::Hash + Eq + Send + Sync,
113 TData: Clone + Send + Sync,
114{
115 db: Arc<DB>,
116
117 prefix: Vec<u8>,
119
120 _phantom: PhantomData<(TKey, TData)>,
121}
122
123impl<TKey, TData> DbSetAccess<TKey, TData>
124where
125 TKey: Clone + std::hash::Hash + Eq + Send + Sync + AsRef<[u8]>,
126 TData: Clone + std::hash::Hash + Eq + Send + Sync + DeserializeOwned + Serialize,
127{
128 pub fn new(db: Arc<DB>, prefix: Vec<u8>) -> Self {
129 Self { db, prefix, _phantom: Default::default() }
130 }
131
132 pub fn write(&self, mut writer: impl DbWriter, key: TKey, data: TData) -> Result<(), StoreError> {
133 writer.put(self.get_db_key(&key, &data)?, [])?;
134 Ok(())
135 }
136
137 fn get_db_key(&self, key: &TKey, data: &TData) -> Result<DbKey, StoreError> {
138 let bin_data = bincode::serialize(&data)?;
139 Ok(DbKey::new_with_bucket(&self.prefix, key, bin_data))
140 }
141
142 pub fn delete_bucket(&self, mut writer: impl DbWriter, key: TKey) -> Result<(), StoreError> {
143 let db_key = DbKey::new_with_bucket(&self.prefix, &key, []);
144 let (from, to) = rocksdb::PrefixRange(db_key.as_ref()).into_bounds();
145 writer.delete_range(from.unwrap(), to.unwrap())?;
146 Ok(())
147 }
148
149 pub fn delete(&self, mut writer: impl DbWriter, key: TKey, data: TData) -> Result<(), StoreError> {
150 writer.delete(self.get_db_key(&key, &data)?)?;
151 Ok(())
152 }
153
154 fn seek_iterator(
155 &self,
156 key: TKey,
157 limit: usize, skip_first: bool, ) -> impl Iterator<Item = Result<Box<[u8]>, StoreError>> + '_
160 where
161 TKey: Clone + AsRef<[u8]>,
162 TData: DeserializeOwned,
163 {
164 let db_key = DbKey::new_with_bucket(&self.prefix, &key, []);
165 let mut read_opts = ReadOptions::default();
166 read_opts.set_iterate_range(rocksdb::PrefixRange(db_key.as_ref()));
167
168 let mut db_iterator = self.db.iterator_opt(IteratorMode::Start, read_opts);
169
170 if skip_first {
171 db_iterator.next();
172 }
173
174 db_iterator.take(limit).map(move |item| match item {
175 Ok((key_bytes, _)) => Ok(key_bytes[db_key.prefix_len()..].into()),
176 Err(err) => Err(err.into()),
177 })
178 }
179
180 pub fn prefix(&self) -> &[u8] {
181 &self.prefix
182 }
183
184 pub fn bucket_iterator(&self, key: TKey) -> impl Iterator<Item = Result<TData, StoreError>> + '_
185 where
186 TKey: Clone + AsRef<[u8]>,
187 TData: DeserializeOwned,
188 {
189 self.seek_iterator(key, usize::MAX, false).map(|res| match res {
190 Ok(data) => Ok(bincode::deserialize(&data)?),
191 Err(err) => Err(err),
192 })
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::{
200 create_temp_db,
201 prelude::{BatchDbWriter, ConnBuilder, DirectDbWriter},
202 };
203 use kaspa_hashes::Hash;
204 use rocksdb::WriteBatch;
205
206 #[test]
207 fn test_delete_bucket() {
208 let (_lifetime, db) = create_temp_db!(ConnBuilder::default().with_files_limit(10));
209 let access = DbSetAccess::<Hash, u64>::new(db.clone(), vec![1, 2]);
210
211 for i in 0..16 {
212 for j in 0..2 {
213 access.write(DirectDbWriter::new(&db), i.into(), i + j).unwrap();
214 }
215 }
216 for i in 0..16 {
217 assert_eq!(2, access.bucket_iterator(i.into()).count());
218 }
219 access.delete_bucket(DirectDbWriter::new(&db), 3.into()).unwrap();
220 assert_eq!(0, access.bucket_iterator(3.into()).count());
221
222 let mut batch = WriteBatch::default();
223 access.delete_bucket(BatchDbWriter::new(&mut batch), 6.into()).unwrap();
224 db.write(batch).unwrap();
225 assert_eq!(0, access.bucket_iterator(6.into()).count());
226 }
227}