atuin_client/
kv.rs

1use std::collections::BTreeMap;
2
3use atuin_common::record::{DecryptedData, Host, HostId};
4use eyre::{Result, bail, ensure, eyre};
5use serde::Deserialize;
6
7use crate::record::encryption::PASETO_V4;
8use crate::record::store::Store;
9
10const KV_VERSION: &str = "v1";
11const KV_TAG: &str = "kv";
12const KV_VAL_MAX_LEN: usize = 100 * 1024;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvRecord {
16    pub namespace: String,
17    pub key: String,
18    pub value: Option<String>,
19}
20
21impl KvRecord {
22    pub fn serialize(&self) -> Result<DecryptedData> {
23        use rmp::encode;
24
25        let mut output = vec![];
26
27        // INFO: ensure this is updated when adding new fields
28        encode::write_array_len(&mut output, 4)?;
29
30        encode::write_str(&mut output, &self.namespace)?;
31        encode::write_str(&mut output, &self.key)?;
32        encode::write_bool(&mut output, self.value.is_some())?;
33
34        if let Some(value) = &self.value {
35            encode::write_str(&mut output, value)?;
36        }
37
38        Ok(DecryptedData(output))
39    }
40
41    pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
42        use rmp::decode;
43
44        fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
45            eyre!("{err:?}")
46        }
47
48        match version {
49            "v0" => {
50                let mut bytes = decode::Bytes::new(&data.0);
51
52                let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
53                ensure!(nfields == 3, "too many entries in v0 kv record");
54
55                let bytes = bytes.remaining_slice();
56
57                let (namespace, bytes) =
58                    decode::read_str_from_slice(bytes).map_err(error_report)?;
59                let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
60                let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
61
62                if !bytes.is_empty() {
63                    bail!("trailing bytes in encoded kvrecord. malformed")
64                }
65
66                Ok(KvRecord {
67                    namespace: namespace.to_owned(),
68                    key: key.to_owned(),
69                    value: Some(value.to_owned()),
70                })
71            }
72            KV_VERSION => {
73                let mut bytes = decode::Bytes::new(&data.0);
74
75                let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
76                ensure!(nfields == 4, "too many entries in v1 kv record");
77
78                let bytes = bytes.remaining_slice();
79
80                let (namespace, bytes) =
81                    decode::read_str_from_slice(bytes).map_err(error_report)?;
82                let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
83                let has_value = decode::read_bool(&mut bytes).map_err(error_report)?;
84
85                let (value, bytes) = if has_value {
86                    let (value, bytes) =
87                        decode::read_str_from_slice(bytes).map_err(error_report)?;
88                    (Some(value.to_owned()), bytes)
89                } else {
90                    (None, bytes)
91                };
92
93                if !bytes.is_empty() {
94                    bail!("trailing bytes in encoded kvrecord. malformed")
95                }
96
97                Ok(KvRecord {
98                    namespace: namespace.to_owned(),
99                    key: key.to_owned(),
100                    value,
101                })
102            }
103            _ => {
104                bail!("unknown version {version:?}")
105            }
106        }
107    }
108}
109
110#[derive(Debug, Clone, Deserialize)]
111pub struct KvStore;
112
113impl Default for KvStore {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl KvStore {
120    // will want to init the actual kv store when that is done
121    pub fn new() -> KvStore {
122        KvStore {}
123    }
124
125    pub async fn set(
126        &self,
127        store: &(impl Store + Send + Sync),
128        encryption_key: &[u8; 32],
129        host_id: HostId,
130        namespace: &str,
131        key: &str,
132        value: Option<&str>,
133    ) -> Result<()> {
134        if value.is_some() && value.unwrap().len() > KV_VAL_MAX_LEN {
135            return Err(eyre!(
136                "kv value too large: max len {} bytes",
137                KV_VAL_MAX_LEN
138            ));
139        }
140
141        let record = KvRecord {
142            namespace: namespace.to_string(),
143            key: key.to_string(),
144            value: value.map(|v| v.to_string()),
145        };
146
147        let bytes = record.serialize()?;
148
149        let idx = store
150            .last(host_id, KV_TAG)
151            .await?
152            .map_or(0, |entry| entry.idx + 1);
153
154        let record = atuin_common::record::Record::builder()
155            .host(Host::new(host_id))
156            .version(KV_VERSION.to_string())
157            .tag(KV_TAG.to_string())
158            .idx(idx)
159            .data(bytes)
160            .build();
161
162        store
163            .push(&record.encrypt::<PASETO_V4>(encryption_key))
164            .await?;
165
166        Ok(())
167    }
168
169    // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as
170    // well.
171    pub async fn get(
172        &self,
173        store: &impl Store,
174        encryption_key: &[u8; 32],
175        namespace: &str,
176        key: &str,
177    ) -> Result<Option<KvRecord>> {
178        // TODO: don't rebuild every time...
179        let map = self.build_kv(store, encryption_key).await?;
180
181        let res = map.get(namespace);
182
183        if let Some(ns) = res {
184            let value = ns.get(key);
185
186            Ok(value.cloned())
187        } else {
188            Ok(None)
189        }
190    }
191
192    // Build a kv map out of the linked list kv store
193    // Map is Namespace -> Key -> Value
194    // TODO(ellie): "cache" this into a real kv structure, which we can
195    // use as a write-through cache to avoid constant rebuilds.
196    pub async fn build_kv(
197        &self,
198        store: &impl Store,
199        encryption_key: &[u8; 32],
200    ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
201        let mut map = BTreeMap::new();
202
203        // TODO: maybe don't load the entire tag into memory to build the kv
204        // we can be smart about it and only load values since the last build
205        // or, iterate/paginate
206        let tagged = store.all_tagged(KV_TAG).await?;
207
208        // iterate through all tags and play each KV record at a time
209        // this is "last write wins"
210        // probably good enough for now, but revisit in future
211        for record in tagged {
212            let decrypted = match record.version.as_str() {
213                "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
214                version => bail!("unknown version {version:?}"),
215            };
216
217            let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
218
219            let ns = map
220                .entry(kv.namespace.clone())
221                .or_insert_with(BTreeMap::new);
222
223            ns.insert(kv.key.clone(), kv);
224        }
225
226        Ok(map)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
233    use rand::rngs::OsRng;
234
235    use crate::record::sqlite_store::SqliteStore;
236    use crate::settings::test_local_timeout;
237
238    use super::{DecryptedData, KV_VERSION, KvRecord, KvStore};
239
240    #[test]
241    fn encode_decode_some() {
242        let kv = KvRecord {
243            namespace: "foo".to_owned(),
244            key: "bar".to_owned(),
245            value: Some("baz".to_owned()),
246        };
247        let snapshot = [
248            0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z',
249        ];
250
251        let encoded = kv.serialize().unwrap();
252        let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
253
254        assert_eq!(encoded.0, &snapshot);
255        assert_eq!(decoded, kv);
256    }
257
258    #[test]
259    fn encode_decode_none() {
260        let kv = KvRecord {
261            namespace: "foo".to_owned(),
262            key: "bar".to_owned(),
263            value: None,
264        };
265        let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2];
266
267        let encoded = kv.serialize().unwrap();
268        let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
269
270        assert_eq!(encoded.0, &snapshot);
271        assert_eq!(decoded, kv);
272    }
273
274    #[test]
275    fn decode_v0() {
276        let kv = KvRecord {
277            namespace: "foo".to_owned(),
278            key: "bar".to_owned(),
279            value: Some("baz".to_owned()),
280        };
281
282        let snapshot = vec![
283            0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z',
284        ];
285
286        let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap();
287
288        assert_eq!(decoded, kv);
289    }
290
291    #[tokio::test]
292    async fn build_kv() {
293        let mut store = SqliteStore::new(":memory:", test_local_timeout())
294            .await
295            .unwrap();
296        let kv = KvStore::new();
297        let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
298        let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
299
300        kv.set(&mut store, &key, host_id, "test-kv", "foo", Some("bar"))
301            .await
302            .unwrap();
303
304        kv.set(&mut store, &key, host_id, "test-kv", "1", Some("2"))
305            .await
306            .unwrap();
307
308        kv.set(
309            &mut store,
310            &key,
311            host_id,
312            "test-kv",
313            "deleted",
314            Some("hello"),
315        )
316        .await
317        .unwrap();
318
319        kv.set(&mut store, &key, host_id, "test-kv", "deleted", None)
320            .await
321            .unwrap();
322
323        let map = kv.build_kv(&store, &key).await.unwrap();
324
325        assert_eq!(
326            *map.get("test-kv")
327                .expect("map namespace not set")
328                .get("foo")
329                .expect("map key not set"),
330            KvRecord {
331                namespace: String::from("test-kv"),
332                key: String::from("foo"),
333                value: Some(String::from("bar"))
334            }
335        );
336
337        assert_eq!(
338            *map.get("test-kv")
339                .expect("map namespace not set")
340                .get("1")
341                .expect("map key not set"),
342            KvRecord {
343                namespace: String::from("test-kv"),
344                key: String::from("1"),
345                value: Some(String::from("2"))
346            }
347        );
348
349        assert_eq!(
350            *map.get("test-kv")
351                .expect("map namespace not set")
352                .get("deleted")
353                .expect("map key not set"),
354            KvRecord {
355                namespace: String::from("test-kv"),
356                key: String::from("deleted"),
357                value: None
358            }
359        );
360    }
361}