cozo_ce/storage/
mem.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use crossbeam::sync::{ShardedLock, ShardedLockReadGuard, ShardedLockWriteGuard};
10use std::cmp::Ordering;
11use std::collections::btree_map::Range;
12use std::collections::BTreeMap;
13use std::default::Default;
14use std::iter::Fuse;
15use std::mem;
16use std::ops::Bound;
17use std::sync::Arc;
18
19use itertools::Itertools;
20use miette::{bail, Result};
21
22use crate::data::tuple::{check_key_for_validity, Tuple};
23use crate::data::value::ValidityTs;
24use crate::runtime::relation::{decode_tuple_from_kv, extend_tuple_from_v};
25use crate::storage::{Storage, StoreTx};
26use crate::utils::swap_option_result;
27
28/// Create a database backed by memory.
29/// This is the fastest storage, but non-persistent.
30/// Supports concurrent readers but only a single writer.
31pub fn new_cozo_mem() -> Result<crate::Db<MemStorage>> {
32    let ret = crate::Db::new(MemStorage::default())?;
33
34    ret.initialize()?;
35    Ok(ret)
36}
37
38/// The non-persistent storage
39#[derive(Default, Clone)]
40pub struct MemStorage {
41    store: Arc<ShardedLock<BTreeMap<Vec<u8>, Vec<u8>>>>,
42}
43
44impl<'s> Storage<'s> for MemStorage {
45    type Tx = MemTx<'s>;
46
47    fn storage_kind(&self) -> &'static str {
48        "mem"
49    }
50
51    fn transact(&'s self, write: bool) -> Result<Self::Tx> {
52        Ok(if write {
53            let wtr = self.store.write().unwrap();
54            MemTx::Writer(wtr, Default::default())
55        } else {
56            let rdr = self.store.read().unwrap();
57            MemTx::Reader(rdr)
58        })
59    }
60
61    fn range_compact(&'s self, _lower: &[u8], _upper: &[u8]) -> Result<()> {
62        Ok(())
63    }
64
65    fn batch_put<'a>(
66        &'a self,
67        data: Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>,
68    ) -> Result<()> {
69        let mut store = self.store.write().unwrap();
70        for pair in data {
71            let (k, v) = pair?;
72            store.insert(k, v);
73        }
74        Ok(())
75    }
76}
77
78pub enum MemTx<'s> {
79    Reader(ShardedLockReadGuard<'s, BTreeMap<Vec<u8>, Vec<u8>>>),
80    Writer(
81        ShardedLockWriteGuard<'s, BTreeMap<Vec<u8>, Vec<u8>>>,
82        BTreeMap<Vec<u8>, Option<Vec<u8>>>,
83    ),
84}
85
86impl<'s> StoreTx<'s> for MemTx<'s> {
87    fn get(&self, key: &[u8], _for_update: bool) -> Result<Option<Vec<u8>>> {
88        Ok(match self {
89            MemTx::Reader(rdr) => rdr.get(key).cloned(),
90            MemTx::Writer(wtr, cache) => match cache.get(key) {
91                Some(r) => r.clone(),
92                None => wtr.get(key).cloned(),
93            },
94        })
95    }
96
97    fn put(&mut self, key: &[u8], val: &[u8]) -> Result<()> {
98        match self {
99            MemTx::Reader(_) => {
100                bail!("write in read transaction")
101            }
102            MemTx::Writer(_, cache) => {
103                cache.insert(key.to_vec(), Some(val.to_vec()));
104                Ok(())
105            }
106        }
107    }
108
109    fn supports_par_put(&self) -> bool {
110        false
111    }
112
113    fn par_put(&self, _key: &[u8], _val: &[u8]) -> Result<()> {
114        panic!()
115    }
116
117    fn del(&mut self, key: &[u8]) -> Result<()> {
118        match self {
119            MemTx::Reader(_) => {
120                bail!("write in read transaction")
121            }
122            MemTx::Writer(_, cache) => {
123                cache.insert(key.to_vec(), None);
124                Ok(())
125            }
126        }
127    }
128
129    fn del_range_from_persisted(&mut self, lower: &[u8], upper: &[u8]) -> Result<()> {
130        match self {
131            MemTx::Reader(_) => {
132                bail!("write in read transaction")
133            }
134            MemTx::Writer(ref mut wtr, _) => {
135                let keys = wtr
136                    .range(lower.to_vec()..upper.to_vec())
137                    .map(|kv| kv.0.clone())
138                    .collect_vec();
139                for k in keys.iter() {
140                    wtr.remove(k);
141                }
142            }
143        }
144
145        Ok(())
146    }
147
148    fn exists(&self, key: &[u8], _for_update: bool) -> Result<bool> {
149        Ok(match self {
150            MemTx::Reader(rdr) => rdr.contains_key(key),
151            MemTx::Writer(wtr, cache) => match cache.get(key) {
152                Some(r) => r.is_some(),
153                None => wtr.contains_key(key),
154            },
155        })
156    }
157
158    fn commit(&mut self) -> Result<()> {
159        match self {
160            MemTx::Reader(_) => Ok(()),
161            MemTx::Writer(wtr, cached) => {
162                let mut cache = BTreeMap::default();
163                mem::swap(&mut cache, cached);
164                for (k, mv) in cache {
165                    match mv {
166                        None => {
167                            wtr.remove(&k);
168                        }
169                        Some(v) => {
170                            wtr.insert(k, v);
171                        }
172                    }
173                }
174                Ok(())
175            }
176        }
177    }
178
179    fn range_scan_tuple<'a>(
180        &'a self,
181        lower: &[u8],
182        upper: &[u8],
183    ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a>
184    where
185        's: 'a,
186    {
187        match self {
188            MemTx::Reader(rdr) => Box::new(
189                rdr.range(lower.to_vec()..upper.to_vec())
190                    .map(|(k, v)| Ok(decode_tuple_from_kv(k, v, None))),
191            ),
192            MemTx::Writer(wtr, cache) => Box::new(CacheIter {
193                change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
194                db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
195                change_cache: None,
196                db_cache: None,
197            }),
198        }
199    }
200
201    fn range_skip_scan_tuple<'a>(
202        &'a self,
203        lower: &[u8],
204        upper: &[u8],
205        valid_at: ValidityTs,
206    ) -> Box<dyn Iterator<Item = Result<Tuple>> + 'a> {
207        match self {
208            MemTx::Reader(stored) => Box::new(
209                SkipIterator {
210                    inner: stored,
211                    upper: upper.to_vec(),
212                    valid_at,
213                    next_bound: lower.to_vec(),
214                    size_hint: None,
215                }
216                .map(Ok),
217            ),
218            MemTx::Writer(stored, delta) => Box::new(
219                SkipDualIterator {
220                    stored,
221                    delta,
222                    upper: upper.to_vec(),
223                    valid_at,
224                    next_bound: lower.to_vec(),
225                }
226                .map(Ok),
227            ),
228        }
229    }
230
231    fn range_scan<'a>(
232        &'a self,
233        lower: &[u8],
234        upper: &[u8],
235    ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
236    where
237        's: 'a,
238    {
239        match self {
240            MemTx::Reader(rdr) => Box::new(
241                rdr.range(lower.to_vec()..upper.to_vec())
242                    .map(|(k, v)| Ok((k.clone(), v.clone()))),
243            ),
244            MemTx::Writer(wtr, cache) => Box::new(CacheIterRaw {
245                change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
246                db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
247                change_cache: None,
248                db_cache: None,
249            }),
250        }
251    }
252
253    fn range_count<'a>(&'a self, lower: &[u8], upper: &[u8]) -> Result<usize>
254    where
255        's: 'a,
256    {
257        Ok(match self {
258            MemTx::Reader(rdr) => rdr.range(lower.to_vec()..upper.to_vec()).count(),
259            MemTx::Writer(wtr, cache) => (CacheIterRaw {
260                change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
261                db_iter: wtr.range(lower.to_vec()..upper.to_vec()).fuse(),
262                change_cache: None,
263                db_cache: None,
264            })
265            .count(),
266        })
267    }
268
269    fn total_scan<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + 'a>
270    where
271        's: 'a,
272    {
273        match self {
274            MemTx::Reader(rdr) => Box::new(rdr.iter().map(|(k, v)| Ok((k.clone(), v.clone())))),
275            MemTx::Writer(wtr, cache) => Box::new(CacheIterRaw {
276                change_iter: cache.iter().fuse(),
277                db_iter: wtr.iter().fuse(),
278                change_cache: None,
279                db_cache: None,
280            }),
281        }
282    }
283}
284
285struct CacheIterRaw<'a, C, T>
286where
287    C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
288    T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
289{
290    change_iter: C,
291    db_iter: T,
292    change_cache: Option<(&'a Vec<u8>, &'a Option<Vec<u8>>)>,
293    db_cache: Option<(&'a Vec<u8>, &'a Vec<u8>)>,
294}
295
296impl<'a, C, T> CacheIterRaw<'a, C, T>
297where
298    C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
299    T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
300{
301    #[inline]
302    fn fill_cache(&mut self) -> Result<()> {
303        if self.change_cache.is_none() {
304            if let Some(kmv) = self.change_iter.next() {
305                self.change_cache = Some(kmv)
306            }
307        }
308
309        if self.db_cache.is_none() {
310            if let Some(kv) = self.db_iter.next() {
311                self.db_cache = Some(kv);
312            }
313        }
314
315        Ok(())
316    }
317
318    #[inline]
319    fn next_inner(&mut self) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
320        loop {
321            self.fill_cache()?;
322            match (&self.change_cache, &self.db_cache) {
323                (None, None) => return Ok(None),
324                (Some(_), None) => {
325                    let (k, cv) = self.change_cache.take().unwrap();
326                    match cv {
327                        None => continue,
328                        Some(v) => return Ok(Some((k.clone(), v.clone()))),
329                    }
330                }
331                (None, Some(_)) => {
332                    let (k, v) = self.db_cache.take().unwrap();
333                    return Ok(Some((k.clone(), v.clone())));
334                }
335                (Some((ck, _)), Some((dk, _))) => match ck.cmp(dk) {
336                    Ordering::Less => {
337                        let (k, sv) = self.change_cache.take().unwrap();
338                        match sv {
339                            None => continue,
340                            Some(v) => return Ok(Some((k.clone(), v.clone()))),
341                        }
342                    }
343                    Ordering::Greater => {
344                        let (k, v) = self.db_cache.take().unwrap();
345                        return Ok(Some((k.clone(), v.clone())));
346                    }
347                    Ordering::Equal => {
348                        self.db_cache.take();
349                        continue;
350                    }
351                },
352            }
353        }
354    }
355}
356
357impl<'a, C, T> Iterator for CacheIterRaw<'a, C, T>
358where
359    C: Iterator<Item = (&'a Vec<u8>, &'a Option<Vec<u8>>)> + 'a,
360    T: Iterator<Item = (&'a Vec<u8>, &'a Vec<u8>)>,
361{
362    type Item = Result<(Vec<u8>, Vec<u8>)>;
363
364    #[inline]
365    fn next(&mut self) -> Option<Self::Item> {
366        swap_option_result(self.next_inner())
367    }
368}
369
370struct CacheIter<'a> {
371    change_iter: Fuse<Range<'a, Vec<u8>, Option<Vec<u8>>>>,
372    db_iter: Fuse<Range<'a, Vec<u8>, Vec<u8>>>,
373    change_cache: Option<(&'a Vec<u8>, &'a Option<Vec<u8>>)>,
374    db_cache: Option<(&'a Vec<u8>, &'a Vec<u8>)>,
375}
376
377impl CacheIter<'_> {
378    #[inline]
379    fn fill_cache(&mut self) -> Result<()> {
380        if self.change_cache.is_none() {
381            if let Some(kmv) = self.change_iter.next() {
382                self.change_cache = Some(kmv)
383            }
384        }
385
386        if self.db_cache.is_none() {
387            if let Some(kv) = self.db_iter.next() {
388                self.db_cache = Some(kv);
389            }
390        }
391
392        Ok(())
393    }
394
395    #[inline]
396    fn next_inner(&mut self) -> Result<Option<Tuple>> {
397        loop {
398            self.fill_cache()?;
399            match (&self.change_cache, &self.db_cache) {
400                (None, None) => return Ok(None),
401                (Some(_), None) => {
402                    let (k, cv) = self.change_cache.take().unwrap();
403                    match cv {
404                        None => continue,
405                        Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
406                    }
407                }
408                (None, Some(_)) => {
409                    let (k, v) = self.db_cache.take().unwrap();
410                    return Ok(Some(decode_tuple_from_kv(k, v, None)));
411                }
412                (Some((ck, _)), Some((dk, _))) => match ck.cmp(dk) {
413                    Ordering::Less => {
414                        let (k, sv) = self.change_cache.take().unwrap();
415                        match sv {
416                            None => continue,
417                            Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
418                        }
419                    }
420                    Ordering::Greater => {
421                        let (k, v) = self.db_cache.take().unwrap();
422                        return Ok(Some(decode_tuple_from_kv(k, v, None)));
423                    }
424                    Ordering::Equal => {
425                        self.db_cache.take();
426                        continue;
427                    }
428                },
429            }
430        }
431    }
432}
433
434impl Iterator for CacheIter<'_> {
435    type Item = Result<Tuple>;
436
437    #[inline]
438    fn next(&mut self) -> Option<Self::Item> {
439        swap_option_result(self.next_inner())
440    }
441}
442
443/// Keep an eye on https://github.com/rust-lang/rust/issues/49638
444pub(crate) struct SkipIterator<'a> {
445    pub(crate) inner: &'a BTreeMap<Vec<u8>, Vec<u8>>,
446    pub(crate) upper: Vec<u8>,
447    pub(crate) valid_at: ValidityTs,
448    pub(crate) next_bound: Vec<u8>,
449    pub(crate) size_hint: Option<usize>,
450}
451
452impl<'a> Iterator for SkipIterator<'a> {
453    type Item = Tuple;
454
455    fn next(&mut self) -> Option<Self::Item> {
456        loop {
457            let nxt = self
458                .inner
459                .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
460                    Bound::Included(&self.next_bound),
461                    Bound::Excluded(&self.upper),
462                ))
463                .next();
464            match nxt {
465                None => return None,
466                Some((candidate_key, candidate_val)) => {
467                    let (ret, nxt_bound) =
468                        check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
469                    self.next_bound = nxt_bound;
470                    if let Some(mut nk) = ret {
471                        extend_tuple_from_v(&mut nk, candidate_val);
472                        return Some(nk);
473                    }
474                }
475            }
476        }
477    }
478}
479
480struct SkipDualIterator<'a> {
481    stored: &'a BTreeMap<Vec<u8>, Vec<u8>>,
482    delta: &'a BTreeMap<Vec<u8>, Option<Vec<u8>>>,
483    upper: Vec<u8>,
484    valid_at: ValidityTs,
485    next_bound: Vec<u8>,
486}
487
488impl<'a> Iterator for SkipDualIterator<'a> {
489    type Item = Tuple;
490
491    fn next(&mut self) -> Option<Self::Item> {
492        loop {
493            let stored_nxt = self
494                .stored
495                .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
496                    Bound::Included(&self.next_bound),
497                    Bound::Excluded(&self.upper),
498                ))
499                .next();
500            let delta_nxt = self
501                .delta
502                .range::<Vec<u8>, (Bound<&Vec<u8>>, Bound<&Vec<u8>>)>((
503                    Bound::Included(&self.next_bound),
504                    Bound::Excluded(&self.upper),
505                ))
506                .next();
507            let (candidate_key, candidate_val) = match (stored_nxt, delta_nxt) {
508                (None, None) => return None,
509                (None, Some((delta_key, maybe_delta_val))) => match maybe_delta_val {
510                    None => {
511                        let (_, nxt_seek) = check_key_for_validity(delta_key, self.valid_at, None);
512                        self.next_bound = nxt_seek;
513                        continue;
514                    }
515                    Some(delta_val) => (delta_key, delta_val),
516                },
517                (Some((stored_key, stored_val)), None) => (stored_key, stored_val),
518                (Some((stored_key, stored_val)), Some((delta_key, maybe_delta_val))) => {
519                    if stored_key < delta_key {
520                        (stored_key, stored_val)
521                    } else {
522                        match maybe_delta_val {
523                            None => {
524                                let (_, nxt_seek) =
525                                    check_key_for_validity(delta_key, self.valid_at, None);
526                                self.next_bound = nxt_seek;
527                                continue;
528                            }
529                            Some(delta_val) => (delta_key, delta_val),
530                        }
531                    }
532                }
533            };
534            let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, None);
535            self.next_bound = nxt_bound;
536            if let Some(mut nk) = ret {
537                extend_tuple_from_v(&mut nk, candidate_val);
538                return Some(nk);
539            }
540        }
541    }
542}