1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use super::Error;
use crate::model::Value;
use async_trait::async_trait;
use sled::transaction::{abort, TransactionError};
use std::path::Path;

impl From<TransactionError<Error>> for Error {
    fn from(e: TransactionError<Error>) -> Self {
        match e {
            TransactionError::Abort(e) => e,
            TransactionError::Storage(e) => Error::Sled(e),
        }
    }
}

/// A versioned key-value store
pub struct SledDatabase {
    db: sled::Db,
}

impl SledDatabase {
    /// Open a database at the given path.
    pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, sled::Error> {
        let db = sled::open(path.as_ref())?;
        Ok(Self { db })
    }

    /// Open a database at the given path and clear it.
    pub async fn new_and_clear<P: AsRef<Path>>(path: P) -> Result<Self, sled::Error> {
        let db = sled::open(path.as_ref())?;
        db.clear()?;
        Ok(Self { db })
    }
}

#[async_trait]
impl super::Database for SledDatabase {
    async fn put(&self, client_id: &[u8], kvs: &Vec<(String, Value)>) -> Result<(), Error> {
        let client_id_prefix = hex::encode(client_id);
        self.db.transaction(|tx| {
            let mut conflicts = Vec::new();
            for (key_suffix, value) in kvs.iter() {
                let key = format!("{}/{}", client_id_prefix, key_suffix);
                let res_o = tx.get(key).unwrap();
                let (next_version, existing) = if let Some(res) = res_o {
                    let existing: Value = serde_cbor::from_reader(&res[..]).unwrap();
                    (existing.version + 1, Some(existing))
                } else {
                    (0, None)
                };
                if value.version != next_version {
                    conflicts.push((key_suffix.clone(), existing))
                }
            }
            if !conflicts.is_empty() {
                abort(Error::Conflict(conflicts))?;
            }
            for (key_suffix, value) in kvs.iter() {
                let key = format!("{}/{}", client_id_prefix, key_suffix);
                let mut value_vec = Vec::new();
                serde_cbor::to_writer(&mut value_vec, value).unwrap();
                tx.insert(key.as_str(), value_vec).unwrap();
            }
            Ok(())
        })?;
        Ok(())
    }

    /// Get all keys matching a prefix from the database
    async fn get_with_prefix(
        &self,
        client_id: &[u8],
        key_prefix: String,
    ) -> Result<Vec<(String, Value)>, Error> {
        let prefix = format!("{}/{}", hex::encode(client_id), key_prefix);
        let mut res = Vec::new();
        let prefix_bytes = prefix.as_bytes();
        for item in self.db.scan_prefix(prefix_bytes) {
            let (key, value) = item?;
            let value: Value = serde_cbor::from_reader(&value[..]).unwrap();
            let key_s = String::from_utf8(key.to_vec())
                .expect("keys must be utf-8")
                .split_off(client_id.len() * 2 + 1);
            res.push((key_s, value));
        }
        Ok(res)
    }
}