1use std::{convert::Infallible, future::Future};
2
3use async_txn::{error::WtmError, PwmComparableRange};
4use skipdb_core::rev_range::WriteTransactionRevRange;
5
6use super::*;
7
8pub struct OptimisticTransaction<K, V, SP: AsyncSpawner, S = RandomState> {
10 db: OptimisticDb<K, V, SP, S>,
11 pub(super) wtm: AsyncWtm<K, V, HashCm<K, S>, BTreePwm<K, V>, SP>,
12}
13
14impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
15where
16 K: Ord + Hash + Eq,
17 S: BuildHasher + Clone,
18 SP: AsyncSpawner,
19{
20 #[inline]
21 pub(super) async fn new(db: OptimisticDb<K, V, SP, S>, cap: Option<usize>) -> Self {
22 let wtm = db
23 .inner
24 .tm
25 .write_with_blocking_cm_and_pwm(
26 (),
27 HashCmOptions::with_capacity(db.inner.hasher.clone(), cap.unwrap_or(8)),
28 )
29 .await
30 .unwrap();
31 Self { db, wtm }
32 }
33}
34
35impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
36where
37 K: Ord + Hash + Eq + Send + Sync + 'static,
38 V: Send + Sync + 'static,
39 S: BuildHasher + Send + Sync + 'static,
40 SP: AsyncSpawner,
41{
42 #[inline]
58 pub async fn commit(
59 &mut self,
60 ) -> Result<(), WtmError<Infallible, Infallible, core::convert::Infallible>> {
61 let db = self.db.clone();
62 self
63 .wtm
64 .commit(|ents| async move {
65 db.inner.map.apply(ents);
66 Ok(())
67 })
68 .await
69 }
70}
71
72impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
73where
74 K: Ord + Hash + Eq + Send + Sync + 'static,
75 V: Send + Sync + 'static,
76 S: BuildHasher + Send + Sync + 'static,
77 SP: AsyncSpawner,
78{
79 #[inline]
95 pub async fn commit_with_task<Fut, E, R>(
96 &mut self,
97 callback: impl FnOnce(Result<(), E>) -> Fut + Send + 'static,
98 ) -> Result<SP::JoinHandle<R>, WtmError<Infallible, Infallible, E>>
99 where
100 Fut: Future<Output = R> + Send + 'static,
101 E: std::error::Error + Send,
102 R: Send + 'static,
103 {
104 let db = self.db.clone();
105
106 self
107 .wtm
108 .commit_with_task(
109 move |ents| async move {
110 db.inner.map.apply(ents);
111 Ok(())
112 },
113 callback,
114 )
115 .await
116 }
117}
118
119impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
120where
121 K: Ord + Hash + Eq,
122 V: 'static,
123 S: BuildHasher,
124 SP: AsyncSpawner,
125{
126 #[inline]
128 pub fn version(&self) -> u64 {
129 self.wtm.version()
130 }
131
132 #[inline]
134 pub fn rollback(&mut self) -> Result<(), TransactionError<Infallible, Infallible>> {
135 self.wtm.rollback_blocking()
136 }
137
138 #[inline]
140 pub fn contains_key<Q>(
141 &mut self,
142 key: &Q,
143 ) -> Result<bool, TransactionError<Infallible, Infallible>>
144 where
145 K: Borrow<Q>,
146 Q: Hash + Eq + Ord + ?Sized,
147 {
148 let version = self.wtm.version();
149 match self
150 .wtm
151 .contains_key_equivalent_cm_comparable_pm_blocking(key)?
152 {
153 Some(true) => Ok(true),
154 Some(false) => Ok(false),
155 None => Ok(self.db.inner.map.contains_key(key, version)),
156 }
157 }
158
159 #[inline]
161 pub fn get<'a, 'b: 'a, Q>(
162 &'a mut self,
163 key: &'b Q,
164 ) -> Result<Option<Ref<'a, K, V>>, TransactionError<Infallible, Infallible>>
165 where
166 K: Borrow<Q>,
167 Q: Hash + Eq + Ord + ?Sized,
168 {
169 let version = self.wtm.version();
170 match self.wtm.get_equivalent_cm_comparable_pm_blocking(key)? {
171 Some(v) => {
172 if v.value().is_some() {
173 Ok(Some(v.into()))
174 } else {
175 Ok(None)
176 }
177 }
178 None => Ok(self.db.inner.map.get(key, version).map(Into::into)),
179 }
180 }
181
182 #[inline]
184 pub fn insert(
185 &mut self,
186 key: K,
187 value: V,
188 ) -> Result<(), TransactionError<Infallible, Infallible>> {
189 self.wtm.insert_blocking(key, value)
190 }
191
192 #[inline]
194 pub fn remove(&mut self, key: K) -> Result<(), TransactionError<Infallible, Infallible>> {
195 self.wtm.remove_blocking(key)
196 }
197
198 #[inline]
200 pub fn iter(
201 &mut self,
202 ) -> Result<TransactionIter<'_, K, V, HashCm<K, S>>, TransactionError<Infallible, Infallible>> {
203 let version = self.wtm.version();
204 let (marker, pm) = self
205 .wtm
206 .blocking_marker_with_pm()
207 .ok_or(TransactionError::Discard)?;
208
209 let committed = self.db.inner.map.iter(version);
210 let pendings = pm.iter();
211
212 Ok(TransactionIter::new(pendings, committed, Some(marker)))
213 }
214
215 #[inline]
217 pub fn iter_rev(
218 &mut self,
219 ) -> Result<
220 WriteTransactionRevIter<'_, K, V, HashCm<K, S>>,
221 TransactionError<Infallible, Infallible>,
222 > {
223 let version = self.wtm.version();
224 let (marker, pm) = self
225 .wtm
226 .blocking_marker_with_pm()
227 .ok_or(TransactionError::Discard)?;
228
229 let committed = self.db.inner.map.iter_rev(version);
230 let pendings = pm.iter().rev();
231
232 Ok(WriteTransactionRevIter::new(
233 pendings,
234 committed,
235 Some(marker),
236 ))
237 }
238
239 #[inline]
241 pub fn range<'a, Q, R>(
242 &'a mut self,
243 range: R,
244 ) -> Result<
245 TransactionRange<'a, Q, R, K, V, HashCm<K, S>>,
246 TransactionError<Infallible, Infallible>,
247 >
248 where
249 K: Borrow<Q>,
250 R: RangeBounds<Q> + 'a,
251 Q: Ord + ?Sized,
252 {
253 let version = self.wtm.version();
254 let (marker, pm) = self
255 .wtm
256 .blocking_marker_with_pm()
257 .ok_or(TransactionError::Discard)?;
258 let start = range.start_bound();
259 let end = range.end_bound();
260 let pendings = pm.range_comparable((start, end));
261 let committed = self.db.inner.map.range(range, version);
262
263 Ok(TransactionRange::new(pendings, committed, Some(marker)))
264 }
265
266 #[inline]
268 pub fn range_rev<'a, Q, R>(
269 &'a mut self,
270 range: R,
271 ) -> Result<
272 WriteTransactionRevRange<'a, Q, R, K, V, HashCm<K, S>>,
273 TransactionError<Infallible, Infallible>,
274 >
275 where
276 K: Borrow<Q>,
277 R: RangeBounds<Q> + 'a,
278 Q: Ord + ?Sized,
279 {
280 let version = self.wtm.version();
281 let (marker, pm) = self
282 .wtm
283 .blocking_marker_with_pm()
284 .ok_or(TransactionError::Discard)?;
285 let start = range.start_bound();
286 let end = range.end_bound();
287 let pendings = pm.range_comparable((start, end));
288 let committed = self.db.inner.map.range_rev(range, version);
289
290 Ok(WriteTransactionRevRange::new(
291 pendings.rev(),
292 committed,
293 Some(marker),
294 ))
295 }
296}