1use crossbeam::sync::{ShardedLock, ShardedLockReadGuard, ShardedLockWriteGuard};
10use std::cmp::Ordering;
11use std::collections::btree_map::Range;
12use std::collections::BTreeMap;
13use std::default::Default;
14use std::iter::Fuse;
15use std::mem;
16use std::ops::Bound;
17use std::sync::Arc;
18
19use itertools::Itertools;
20use miette::{bail, Result};
21
22use crate::data::tuple::{check_key_for_validity, Tuple};
23use crate::data::value::ValidityTs;
24use crate::runtime::relation::{decode_tuple_from_kv, extend_tuple_from_v};
25use crate::storage::{Storage, StoreTx};
26use crate::utils::swap_option_result;
27
28pub fn new_cozo_mem() -> Result<crate::Db<MemStorage>> {
32 let ret = crate::Db::new(MemStorage::default())?;
33
34 ret.initialize()?;
35 Ok(ret)
36}
37
38#[derive(Default, Clone)]
40pub struct MemStorage {
41 store: Arc<ShardedLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
42}
43
44impl<'s> Storage<'s> for MemStorage {
45 type Tx = MemTx<'s>;
46
47 fn storage_kind(&self) -> &'static str {
48 "mem"
49 }
50
51 fn transact(&'s self, write: bool) -> Result<Self::Tx> {
52 Ok(if write {
53 let wtr = self.store.write().unwrap();
54 MemTx::Writer(wtr, Default::default())
55 } else {
56 let rdr = self.store.read().unwrap();
57 MemTx::Reader(rdr)
58 })
59 }
60
61 fn range_compact(&'s self, _lower: &[u8], _upper: &[u8]) -> Result<()> {
62 Ok(())
63 }
64
65 fn batch_put<'a>(
66 &'a self,
67 data: Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>,
68 ) -> Result<()> {
69 let mut store = self.store.write().unwrap();
70 for pair in data {
71 let (k, v) = pair?;
72 store.insert(k, v);
73 }
74 Ok(())
75 }
76}
77
78pub enum MemTx<'s> {
79 Reader(ShardedLockReadGuard<'s, BTreeMap<Vec<u8>, Vec<u8>>>),
80 Writer(
81 ShardedLockWriteGuard<'s, BTreeMap<Vec<u8>, Vec<u8>>>,
82 BTreeMap<Vec<u8>, Option<Vec<u8>>>,
83 ),
84}
85
86impl<'s> StoreTx<'s> for MemTx<'s> {
87 fn get(&self, key: &[u8], _for_update: bool) -> Result<Option<Vec<u8>>> {
88 Ok(match self {
89 MemTx::Reader(rdr) => rdr.get(key).cloned(),
90 MemTx::Writer(wtr, cache) => match cache.get(key) {
91 Some(r) => r.clone(),
92 None => wtr.get(key).cloned(),
93 },
94 })
95 }
96
97 fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
98 match self {
99 MemTx::Reader(_) => {
100 bail!("write in read transaction")
101 }
102 MemTx::Writer(_, cache) => {
103 cache.insert(key.to_vec(), Some(val.to_vec()));
104 Ok(())
105 }
106 }
107 }
108
109 fn supports_par_put(&self) -> bool {
110 false
111 }
112
113 fn par_put(&self, _key: &[u8], _val: &[u8]) -> Result<()> {
114 panic!()
115 }
116
117 fn del(&mut self, key: &[u8]) -> Result<()> {
118 match self {
119 MemTx::Reader(_) => {
120 bail!("write in read transaction")
121 }
122 MemTx::Writer(_, cache) => {
123 cache.insert(key.to_vec(), None);
124 Ok(())
125 }
126 }
127 }
128
129 fn del_range_from_persisted(&mut self, lower: &[u8], upper: &[u8]) -> Result<()> {
130 match self {
131 MemTx::Reader(_) => {
132 bail!("write in read transaction")
133 }
134 MemTx::Writer(ref mut wtr, _) => {
135 let keys = wtr
136 .range(lower.to_vec()..upper.to_vec())
137 .map(|kv| kv.0.clone())
138 .collect_vec();
139 for k in keys.iter() {
140 wtr.remove(k);
141 }
142 }
143 }
144
145 Ok(())
146 }
147
148 fn exists(&self, key: &[u8], _for_update: bool) -> Result<bool> {
149 Ok(match self {
150 MemTx::Reader(rdr) => rdr.contains_key(key),
151 MemTx::Writer(wtr, cache) => match cache.get(key) {
152 Some(r) => r.is_some(),
153 None => wtr.contains_key(key),
154 },
155 })
156 }
157
158 fn commit(&mut self) -> Result<()> {
159 match self {
160 MemTx::Reader(_) => Ok(()),
161 MemTx::Writer(wtr, cached) => {
162 let mut cache = BTreeMap::default();
163 mem::swap(&mut cache, cached);
164 for (k, mv) in cache {
165 match mv {
166 None => {
167 wtr.remove(&k);
168 }
169 Some(v) => {
170 wtr.insert(k, v);
171 }
172 }
173 }
174 Ok(())
175 }
176 }
177 }
178
179 fn range_scan_tuple<'a>(
180 &'a self,
181 lower: &[u8],
182 upper: &[u8],
183 ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a>
184 where
185 's: 'a,
186 {
187 match self {
188 MemTx::Reader(rdr) => Box::new(
189 rdr.range(lower.to_vec()..upper.to_vec())
190 .map(|(k, v)| Ok(decode_tuple_from_kv(k, v, None))),
191 ),
192 MemTx::Writer(wtr, cache) => Box::new(CacheIter {
193 change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
194 db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
195 change_cache: None,
196 db_cache: None,
197 }),
198 }
199 }
200
201 fn range_skip_scan_tuple<'a>(
202 &'a self,
203 lower: &[u8],
204 upper: &[u8],
205 valid_at: ValidityTs,
206 ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a> {
207 match self {
208 MemTx::Reader(stored) => Box::new(
209 SkipIterator {
210 inner: stored,
211 upper: upper.to_vec(),
212 valid_at,
213 next_bound: lower.to_vec(),
214 size_hint: None,
215 }
216 .map(Ok),
217 ),
218 MemTx::Writer(stored, delta) => Box::new(
219 SkipDualIterator {
220 stored,
221 delta,
222 upper: upper.to_vec(),
223 valid_at,
224 next_bound: lower.to_vec(),
225 }
226 .map(Ok),
227 ),
228 }
229 }
230
231 fn range_scan<'a>(
232 &'a self,
233 lower: &[u8],
234 upper: &[u8],
235 ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
236 where
237 's: 'a,
238 {
239 match self {
240 MemTx::Reader(rdr) => Box::new(
241 rdr.range(lower.to_vec()..upper.to_vec())
242 .map(|(k, v)| Ok((k.clone(), v.clone()))),
243 ),
244 MemTx::Writer(wtr, cache) => Box::new(CacheIterRaw {
245 change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
246 db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
247 change_cache: None,
248 db_cache: None,
249 }),
250 }
251 }
252
253 fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
254 where
255 's: 'a,
256 {
257 Ok(match self {
258 MemTx::Reader(rdr) => rdr.range(lower.to_vec()..upper.to_vec()).count(),
259 MemTx::Writer(wtr, cache) => (CacheIterRaw {
260 change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
261 db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
262 change_cache: None,
263 db_cache: None,
264 })
265 .count(),
266 })
267 }
268
269 fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
270 where
271 's: 'a,
272 {
273 match self {
274 MemTx::Reader(rdr) => Box::new(rdr.iter().map(|(k, v)| Ok((k.clone(), v.clone())))),
275 MemTx::Writer(wtr, cache) => Box::new(CacheIterRaw {
276 change_iter: cache.iter().fuse(),
277 db_iter: wtr.iter().fuse(),
278 change_cache: None,
279 db_cache: None,
280 }),
281 }
282 }
283}
284
285struct CacheIterRaw<'a, C, T>
286where
287 C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
288 T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
289{
290 change_iter: C,
291 db_iter: T,
292 change_cache: Option<(&'a Vec<u8>, &'a Option<Vec<u8>>)>,
293 db_cache: Option<(&'a Vec<u8>, &'a Vec<u8>)>,
294}
295
296impl<'a, C, T> CacheIterRaw<'a, C, T>
297where
298 C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
299 T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
300{
301 #[inline]
302 fn fill_cache(&mut self) -> Result<()> {
303 if self.change_cache.is_none() {
304 if let Some(kmv) = self.change_iter.next() {
305 self.change_cache = Some(kmv)
306 }
307 }
308
309 if self.db_cache.is_none() {
310 if let Some(kv) = self.db_iter.next() {
311 self.db_cache = Some(kv);
312 }
313 }
314
315 Ok(())
316 }
317
318 #[inline]
319 fn next_inner(&mut self) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
320 loop {
321 self.fill_cache()?;
322 match (&self.change_cache, &self.db_cache) {
323 (None, None) => return Ok(None),
324 (Some(_), None) => {
325 let (k, cv) = self.change_cache.take().unwrap();
326 match cv {
327 None => continue,
328 Some(v) => return Ok(Some((k.clone(), v.clone()))),
329 }
330 }
331 (None, Some(_)) => {
332 let (k, v) = self.db_cache.take().unwrap();
333 return Ok(Some((k.clone(), v.clone())));
334 }
335 (Some((ck, _)), Some((dk, _))) => match ck.cmp(dk) {
336 Ordering::Less => {
337 let (k, sv) = self.change_cache.take().unwrap();
338 match sv {
339 None => continue,
340 Some(v) => return Ok(Some((k.clone(), v.clone()))),
341 }
342 }
343 Ordering::Greater => {
344 let (k, v) = self.db_cache.take().unwrap();
345 return Ok(Some((k.clone(), v.clone())));
346 }
347 Ordering::Equal => {
348 self.db_cache.take();
349 continue;
350 }
351 },
352 }
353 }
354 }
355}
356
357impl<'a, C, T> Iterator for CacheIterRaw<'a, C, T>
358where
359 C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
360 T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
361{
362 type Item = Result<(Vec<u8>, Vec<u8>)>;
363
364 #[inline]
365 fn next(&mut self) -> Option<Self::Item> {
366 swap_option_result(self.next_inner())
367 }
368}
369
370struct CacheIter<'a> {
371 change_iter: Fuse<Range<'a, Vec<u8>, Option<Vec<u8>>>>,
372 db_iter: Fuse<Range<'a, Vec<u8>, Vec<u8>>>,
373 change_cache: Option<(&'a Vec<u8>, &'a Option<Vec<u8>>)>,
374 db_cache: Option<(&'a Vec<u8>, &'a Vec<u8>)>,
375}
376
377impl CacheIter<'_> {
378 #[inline]
379 fn fill_cache(&mut self) -> Result<()> {
380 if self.change_cache.is_none() {
381 if let Some(kmv) = self.change_iter.next() {
382 self.change_cache = Some(kmv)
383 }
384 }
385
386 if self.db_cache.is_none() {
387 if let Some(kv) = self.db_iter.next() {
388 self.db_cache = Some(kv);
389 }
390 }
391
392 Ok(())
393 }
394
395 #[inline]
396 fn next_inner(&mut self) -> Result<Option<Tuple>> {
397 loop {
398 self.fill_cache()?;
399 match (&self.change_cache, &self.db_cache) {
400 (None, None) => return Ok(None),
401 (Some(_), None) => {
402 let (k, cv) = self.change_cache.take().unwrap();
403 match cv {
404 None => continue,
405 Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
406 }
407 }
408 (None, Some(_)) => {
409 let (k, v) = self.db_cache.take().unwrap();
410 return Ok(Some(decode_tuple_from_kv(k, v, None)));
411 }
412 (Some((ck, _)), Some((dk, _))) => match ck.cmp(dk) {
413 Ordering::Less => {
414 let (k, sv) = self.change_cache.take().unwrap();
415 match sv {
416 None => continue,
417 Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
418 }
419 }
420 Ordering::Greater => {
421 let (k, v) = self.db_cache.take().unwrap();
422 return Ok(Some(decode_tuple_from_kv(k, v, None)));
423 }
424 Ordering::Equal => {
425 self.db_cache.take();
426 continue;
427 }
428 },
429 }
430 }
431 }
432}
433
434impl Iterator for CacheIter<'_> {
435 type Item = Result<Tuple>;
436
437 #[inline]
438 fn next(&mut self) -> Option<Self::Item> {
439 swap_option_result(self.next_inner())
440 }
441}
442
443pub(crate) struct SkipIterator<'a> {
445 pub(crate) inner: &'a BTreeMap<Vec<u8>, Vec<u8>>,
446 pub(crate) upper: Vec<u8>,
447 pub(crate) valid_at: ValidityTs,
448 pub(crate) next_bound: Vec<u8>,
449 pub(crate) size_hint: Option<usize>,
450}
451
452impl<'a> Iterator for SkipIterator<'a> {
453 type Item = Tuple;
454
455 fn next(&mut self) -> Option<Self::Item> {
456 loop {
457 let nxt = self
458 .inner
459 .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
460 Bound::Included(&self.next_bound),
461 Bound::Excluded(&self.upper),
462 ))
463 .next();
464 match nxt {
465 None => return None,
466 Some((candidate_key, candidate_val)) => {
467 let (ret, nxt_bound) =
468 check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
469 self.next_bound = nxt_bound;
470 if let Some(mut nk) = ret {
471 extend_tuple_from_v(&mut nk, candidate_val);
472 return Some(nk);
473 }
474 }
475 }
476 }
477 }
478}
479
480struct SkipDualIterator<'a> {
481 stored: &'a BTreeMap<Vec<u8>, Vec<u8>>,
482 delta: &'a BTreeMap<Vec<u8>, Option<Vec<u8>>>,
483 upper: Vec<u8>,
484 valid_at: ValidityTs,
485 next_bound: Vec<u8>,
486}
487
488impl<'a> Iterator for SkipDualIterator<'a> {
489 type Item = Tuple;
490
491 fn next(&mut self) -> Option<Self::Item> {
492 loop {
493 let stored_nxt = self
494 .stored
495 .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
496 Bound::Included(&self.next_bound),
497 Bound::Excluded(&self.upper),
498 ))
499 .next();
500 let delta_nxt = self
501 .delta
502 .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
503 Bound::Included(&self.next_bound),
504 Bound::Excluded(&self.upper),
505 ))
506 .next();
507 let (candidate_key, candidate_val) = match (stored_nxt, delta_nxt) {
508 (None, None) => return None,
509 (None, Some((delta_key, maybe_delta_val))) => match maybe_delta_val {
510 None => {
511 let (_, nxt_seek) = check_key_for_validity(delta_key, self.valid_at, None);
512 self.next_bound = nxt_seek;
513 continue;
514 }
515 Some(delta_val) => (delta_key, delta_val),
516 },
517 (Some((stored_key, stored_val)), None) => (stored_key, stored_val),
518 (Some((stored_key, stored_val)), Some((delta_key, maybe_delta_val))) => {
519 if stored_key < delta_key {
520 (stored_key, stored_val)
521 } else {
522 match maybe_delta_val {
523 None => {
524 let (_, nxt_seek) =
525 check_key_for_validity(delta_key, self.valid_at, None);
526 self.next_bound = nxt_seek;
527 continue;
528 }
529 Some(delta_val) => (delta_key, delta_val),
530 }
531 }
532 }
533 };
534 let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, None);
535 self.next_bound = nxt_bound;
536 if let Some(mut nk) = ret {
537 extend_tuple_from_v(&mut nk, candidate_val);
538 return Some(nk);
539 }
540 }
541 }
542}