1use std::sync::Arc;
2
3use async_trait::async_trait;
4use rocksdb::{BoundColumnFamily, DBAccess, DBIteratorWithThreadMode, IteratorMode};
5
6use super::ty::{DBType, TxType};
7use crate::{
8 err::Error,
9 interface::{
10 kv::{Key, Val},
11 KeyValuePair,
12 },
13 model::{DBTransaction, SimpleTransaction},
14 TagBucket, CF,
15};
16
17fn take_with_prefix<T: DBAccess>(
18 iterator: DBIteratorWithThreadMode<T>,
19 prefix: Vec<u8>,
20) -> impl Iterator<Item = Result<(Box<[u8]>, Box<[u8]>), rocksdb::Error>> + '_ {
21 iterator.take_while(move |item| -> bool {
22 if let Ok((ref k, _)) = *item {
23 k.starts_with(&prefix)
24 } else {
25 true
26 }
27 })
28}
29
30fn take_with_suffix<T: DBAccess>(
31 iterator: DBIteratorWithThreadMode<T>,
32 suffix: Vec<u8>,
33) -> impl Iterator<Item = Result<(Box<[u8]>, Box<[u8]>), rocksdb::Error>> + '_ {
34 iterator.take_while(move |item| -> bool {
35 if let Ok((ref k, _)) = *item {
36 k.ends_with(&suffix)
37 } else {
38 true
39 }
40 })
41}
42
43impl DBTransaction<DBType, TxType> {
44 fn get_column_family(&self, cf: CF) -> Result<Arc<BoundColumnFamily>, Error> {
45 if cf.is_none() {
46 return Err(Error::DsColumnFamilyIsNotValid);
47 }
48 let cf_name = String::from_utf8(cf.unwrap()).unwrap();
49 let bounded_cf = self._db.cf_handle(&cf_name);
50
51 match bounded_cf {
52 Some(cf) => Ok(cf),
53 _ => Err(Error::DsNoColumnFamilyFound),
54 }
55 }
56}
57
58#[async_trait(?Send)]
59impl SimpleTransaction for DBTransaction<DBType, TxType> {
60 fn closed(&self) -> bool {
61 self.ok
62 }
63
64 async fn count(&mut self, tags: TagBucket) -> Result<usize, Error> {
65 if self.closed() {
66 return Err(Error::TxFinished);
67 }
68
69 let guarded_tx = self.tx.lock().await;
70 let tx = guarded_tx.as_ref().unwrap();
71 let cf = tags.get_bytes("column_family");
72 let cf = &self.get_column_family(cf).unwrap();
73 Ok(tx.iterator_cf(cf, IteratorMode::Start).count())
74 }
75
76 async fn cancel(&mut self) -> Result<(), Error> {
77 if self.ok {
78 return Err(Error::TxFinished);
79 }
80
81 self.ok = true;
83
84 let mut tx = self.tx.lock().await;
85 match tx.take() {
86 Some(tx) => tx.rollback()?,
87 None => unreachable!(),
88 }
89
90 Ok(())
91 }
92
93 async fn commit(&mut self) -> Result<(), Error> {
94 if self.closed() {
95 return Err(Error::TxFinished);
96 }
97
98 if !self.writable {
100 return Err(Error::TxReadonly);
101 }
102
103 self.ok = true;
105
106 let mut tx = self.tx.lock().await;
107 match tx.take() {
108 Some(tx) => tx.commit()?,
109 None => unreachable!(),
110 }
111
112 Ok(())
113 }
114
115 async fn exi<K>(&self, key: K, tags: TagBucket) -> Result<bool, Error>
116 where
117 K: Into<Key> + Send,
118 {
119 if self.closed() {
120 return Err(Error::TxFinished);
121 }
122
123 let tx = self.tx.lock().await;
124 let cf = tags.get_bytes("column_family");
125 match cf {
126 Some(_) => {
127 let cf = &self.get_column_family(cf).unwrap();
128 let result = tx.as_ref().unwrap().get_cf(cf, key.into()).unwrap().is_some();
129 Ok(result)
130 }
131 None => {
132 let result = tx.as_ref().unwrap().get(key.into()).unwrap().is_some();
133 Ok(result)
134 }
135 }
136 }
137 async fn get<K>(&self, key: K, tags: TagBucket) -> Result<Option<Val>, Error>
139 where
140 K: Into<Key> + Send,
141 {
142 if self.closed() {
143 return Err(Error::TxFinished);
144 }
145
146 let guarded_tx = self.tx.lock().await;
147 let tx = guarded_tx.as_ref().unwrap();
148 let cf = tags.get_bytes("column_family");
149 Ok(match cf {
150 Some(_) => {
151 let cf = &self.get_column_family(cf).unwrap();
152 tx.get_cf(cf, key.into()).unwrap()
153 }
154 None => tx.get(key.into()).unwrap(),
155 })
156 }
157
158 async fn set<K, V>(&mut self, key: K, val: V, tags: TagBucket) -> Result<(), Error>
160 where
161 K: Into<Key> + Send,
162 V: Into<Key> + Send,
163 {
164 if self.closed() {
165 return Err(Error::TxFinished);
166 }
167
168 if !self.writable {
170 return Err(Error::TxReadonly);
171 }
172
173 let guarded_tx = self.tx.lock().await;
175 let tx = guarded_tx.as_ref().unwrap();
176 let cf = tags.get_bytes("column_family");
177 match cf {
178 Some(_) => {
179 let cf = &self.get_column_family(cf).unwrap();
180 tx.put_cf(cf, key.into(), val.into())?;
181 }
182 None => tx.put(key.into(), val.into())?,
183 };
184 Ok(())
185 }
186
187 async fn put<K, V>(&mut self, key: K, val: V, tags: TagBucket) -> Result<(), Error>
189 where
190 K: Into<Key> + Send,
191 V: Into<Key> + Send,
192 {
193 if self.closed() {
194 return Err(Error::TxFinished);
195 }
196
197 if !self.writable {
199 return Err(Error::TxReadonly);
200 }
201
202 let guarded_tx = self.tx.lock().await;
204 let tx = guarded_tx.as_ref().unwrap();
205 let (key, val) = (key.into(), val.into());
206 let cf = tags.get_bytes("column_family");
207 match cf {
208 Some(_) => {
209 let cf = &self.get_column_family(cf).unwrap();
210 match tx.get_cf(cf, &key)? {
211 None => tx.put_cf(cf, key, val)?,
212 _ => return Err(Error::TxConditionNotMet),
213 };
214 }
215 None => {
216 match tx.get(&key)? {
217 None => tx.put(key, val)?,
218 _ => return Err(Error::TxConditionNotMet),
219 };
220 }
221 };
222
223 Ok(())
224 }
225
226 async fn del<K>(&mut self, key: K, tags: TagBucket) -> Result<(), Error>
228 where
229 K: Into<Key> + Send,
230 {
231 if self.closed() {
232 return Err(Error::TxFinished);
233 }
234
235 if !self.writable {
237 return Err(Error::TxReadonly);
238 }
239
240 let key = key.into();
241 let guarded_tx = self.tx.lock().await;
242 let tx = guarded_tx.as_ref().unwrap();
243 let cf = tags.get_bytes("column_family");
244 let cf = &self.get_column_family(cf).unwrap();
245 match tx.get_cf(cf, &key)? {
246 Some(_v) => tx.delete_cf(cf, key)?,
247 None => return Err(Error::TxnKeyNotFound),
248 };
249
250 Ok(())
251 }
252
253 async fn iterate(&self, tags: TagBucket) -> Result<Vec<Result<KeyValuePair, Error>>, Error> {
255 if self.closed() {
256 return Err(Error::TxFinished);
257 }
258
259 let guarded_tx = self.tx.lock().await;
260 let tx = guarded_tx.as_ref().unwrap();
261
262 let cf = tags.get_bytes("column_family");
263 let get_iterator = match cf {
264 Some(_) => {
265 let get_cf = self.get_column_family(cf);
266 match get_cf {
267 Ok(cf) => Ok(tx.iterator_cf(&cf, IteratorMode::Start)),
268 Err(err) => Err(err),
269 }
270 }
271 None => Ok(tx.iterator(IteratorMode::Start)),
272 };
273
274 match get_iterator {
275 Ok(iterator) => Ok(iterator
276 .map(|pair| {
277 let (k, v) = pair.unwrap();
278 Ok((k.to_vec(), v.to_vec()))
279 })
280 .collect()),
281 Err(err) => Err(err),
282 }
283 }
284
285 async fn suffix_iterate<S>(
286 &self,
287 suffix: S,
288 tags: TagBucket,
289 ) -> Result<Vec<Result<KeyValuePair, Error>>, Error>
290 where
291 S: Into<Key> + Send,
292 {
293 if self.closed() {
294 return Err(Error::TxFinished);
295 }
296
297 let guarded_tx = self.tx.lock().await;
298 let tx = guarded_tx.as_ref().unwrap();
299 let suffix: Key = suffix.into();
300 let cf = tags.get_bytes("column_family");
301 let iterator = match cf {
302 Some(_) => {
303 let cf = &self.get_column_family(cf).unwrap();
304 tx.iterator_cf(cf, IteratorMode::Start)
305 }
306 None => tx.iterator(IteratorMode::Start),
307 };
308 let taken_iterator = take_with_suffix(iterator, suffix);
309
310 Ok(taken_iterator
311 .map(|pair| {
312 let (k, v) = pair.unwrap();
313 Ok((k.to_vec(), v.to_vec()))
314 })
315 .collect())
316 }
317
318 async fn prefix_iterate<P>(
320 &self,
321 prefix: P,
322 tags: TagBucket,
323 ) -> Result<Vec<Result<KeyValuePair, Error>>, Error>
324 where
325 P: Into<Key> + Send,
326 {
327 if self.closed() {
328 return Err(Error::TxFinished);
329 }
330
331 let guarded_tx = self.tx.lock().await;
332 let tx = guarded_tx.as_ref().unwrap();
333 let prefix: Key = prefix.into();
334 let cf = tags.get_bytes("column_family");
335 let iterator = match cf {
336 Some(_) => {
337 let cf = &self.get_column_family(cf).unwrap();
338 tx.iterator_cf(cf, IteratorMode::Start)
339 }
340 None => tx.iterator(IteratorMode::Start),
341 };
342 let taken_iterator = take_with_prefix(iterator, prefix);
343
344 Ok(taken_iterator
345 .map(|v| {
346 let (k, v) = v.unwrap();
347 Ok((k.to_vec(), v.to_vec()))
348 })
349 .collect())
350 }
351}