1use std::collections::BTreeMap;
2
3use atuin_common::record::{DecryptedData, Host, HostId};
4use eyre::{Result, bail, ensure, eyre};
5use serde::Deserialize;
6
7use crate::record::encryption::PASETO_V4;
8use crate::record::store::Store;
9
10const KV_VERSION: &str = "v1";
11const KV_TAG: &str = "kv";
12const KV_VAL_MAX_LEN: usize = 100 * 1024;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvRecord {
16 pub namespace: String,
17 pub key: String,
18 pub value: Option<String>,
19}
20
21impl KvRecord {
22 pub fn serialize(&self) -> Result<DecryptedData> {
23 use rmp::encode;
24
25 let mut output = vec![];
26
27 encode::write_array_len(&mut output, 4)?;
29
30 encode::write_str(&mut output, &self.namespace)?;
31 encode::write_str(&mut output, &self.key)?;
32 encode::write_bool(&mut output, self.value.is_some())?;
33
34 if let Some(value) = &self.value {
35 encode::write_str(&mut output, value)?;
36 }
37
38 Ok(DecryptedData(output))
39 }
40
41 pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
42 use rmp::decode;
43
44 fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
45 eyre!("{err:?}")
46 }
47
48 match version {
49 "v0" => {
50 let mut bytes = decode::Bytes::new(&data.0);
51
52 let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
53 ensure!(nfields == 3, "too many entries in v0 kv record");
54
55 let bytes = bytes.remaining_slice();
56
57 let (namespace, bytes) =
58 decode::read_str_from_slice(bytes).map_err(error_report)?;
59 let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
60 let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
61
62 if !bytes.is_empty() {
63 bail!("trailing bytes in encoded kvrecord. malformed")
64 }
65
66 Ok(KvRecord {
67 namespace: namespace.to_owned(),
68 key: key.to_owned(),
69 value: Some(value.to_owned()),
70 })
71 }
72 KV_VERSION => {
73 let mut bytes = decode::Bytes::new(&data.0);
74
75 let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
76 ensure!(nfields == 4, "too many entries in v1 kv record");
77
78 let bytes = bytes.remaining_slice();
79
80 let (namespace, bytes) =
81 decode::read_str_from_slice(bytes).map_err(error_report)?;
82 let (key, mut bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
83 let has_value = decode::read_bool(&mut bytes).map_err(error_report)?;
84
85 let (value, bytes) = if has_value {
86 let (value, bytes) =
87 decode::read_str_from_slice(bytes).map_err(error_report)?;
88 (Some(value.to_owned()), bytes)
89 } else {
90 (None, bytes)
91 };
92
93 if !bytes.is_empty() {
94 bail!("trailing bytes in encoded kvrecord. malformed")
95 }
96
97 Ok(KvRecord {
98 namespace: namespace.to_owned(),
99 key: key.to_owned(),
100 value,
101 })
102 }
103 _ => {
104 bail!("unknown version {version:?}")
105 }
106 }
107 }
108}
109
110#[derive(Debug, Clone, Deserialize)]
111pub struct KvStore;
112
113impl Default for KvStore {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl KvStore {
120 pub fn new() -> KvStore {
122 KvStore {}
123 }
124
125 pub async fn set(
126 &self,
127 store: &(impl Store + Send + Sync),
128 encryption_key: &[u8; 32],
129 host_id: HostId,
130 namespace: &str,
131 key: &str,
132 value: Option<&str>,
133 ) -> Result<()> {
134 if value.is_some() && value.unwrap().len() > KV_VAL_MAX_LEN {
135 return Err(eyre!(
136 "kv value too large: max len {} bytes",
137 KV_VAL_MAX_LEN
138 ));
139 }
140
141 let record = KvRecord {
142 namespace: namespace.to_string(),
143 key: key.to_string(),
144 value: value.map(|v| v.to_string()),
145 };
146
147 let bytes = record.serialize()?;
148
149 let idx = store
150 .last(host_id, KV_TAG)
151 .await?
152 .map_or(0, |entry| entry.idx + 1);
153
154 let record = atuin_common::record::Record::builder()
155 .host(Host::new(host_id))
156 .version(KV_VERSION.to_string())
157 .tag(KV_TAG.to_string())
158 .idx(idx)
159 .data(bytes)
160 .build();
161
162 store
163 .push(&record.encrypt::<PASETO_V4>(encryption_key))
164 .await?;
165
166 Ok(())
167 }
168
169 pub async fn get(
172 &self,
173 store: &impl Store,
174 encryption_key: &[u8; 32],
175 namespace: &str,
176 key: &str,
177 ) -> Result<Option<KvRecord>> {
178 let map = self.build_kv(store, encryption_key).await?;
180
181 let res = map.get(namespace);
182
183 if let Some(ns) = res {
184 let value = ns.get(key);
185
186 Ok(value.cloned())
187 } else {
188 Ok(None)
189 }
190 }
191
192 pub async fn build_kv(
197 &self,
198 store: &impl Store,
199 encryption_key: &[u8; 32],
200 ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
201 let mut map = BTreeMap::new();
202
203 let tagged = store.all_tagged(KV_TAG).await?;
207
208 for record in tagged {
212 let decrypted = match record.version.as_str() {
213 "v0" | KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
214 version => bail!("unknown version {version:?}"),
215 };
216
217 let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
218
219 let ns = map
220 .entry(kv.namespace.clone())
221 .or_insert_with(BTreeMap::new);
222
223 ns.insert(kv.key.clone(), kv);
224 }
225
226 Ok(map)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
233 use rand::rngs::OsRng;
234
235 use crate::record::sqlite_store::SqliteStore;
236 use crate::settings::test_local_timeout;
237
238 use super::{DecryptedData, KV_VERSION, KvRecord, KvStore};
239
240 #[test]
241 fn encode_decode_some() {
242 let kv = KvRecord {
243 namespace: "foo".to_owned(),
244 key: "bar".to_owned(),
245 value: Some("baz".to_owned()),
246 };
247 let snapshot = [
248 0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc3, 0xa3, b'b', b'a', b'z',
249 ];
250
251 let encoded = kv.serialize().unwrap();
252 let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
253
254 assert_eq!(encoded.0, &snapshot);
255 assert_eq!(decoded, kv);
256 }
257
258 #[test]
259 fn encode_decode_none() {
260 let kv = KvRecord {
261 namespace: "foo".to_owned(),
262 key: "bar".to_owned(),
263 value: None,
264 };
265 let snapshot = [0x94, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xc2];
266
267 let encoded = kv.serialize().unwrap();
268 let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
269
270 assert_eq!(encoded.0, &snapshot);
271 assert_eq!(decoded, kv);
272 }
273
274 #[test]
275 fn decode_v0() {
276 let kv = KvRecord {
277 namespace: "foo".to_owned(),
278 key: "bar".to_owned(),
279 value: Some("baz".to_owned()),
280 };
281
282 let snapshot = vec![
283 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z',
284 ];
285
286 let decoded = KvRecord::deserialize(&DecryptedData(snapshot), "v0").unwrap();
287
288 assert_eq!(decoded, kv);
289 }
290
291 #[tokio::test]
292 async fn build_kv() {
293 let mut store = SqliteStore::new(":memory:", test_local_timeout())
294 .await
295 .unwrap();
296 let kv = KvStore::new();
297 let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
298 let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
299
300 kv.set(&mut store, &key, host_id, "test-kv", "foo", Some("bar"))
301 .await
302 .unwrap();
303
304 kv.set(&mut store, &key, host_id, "test-kv", "1", Some("2"))
305 .await
306 .unwrap();
307
308 kv.set(
309 &mut store,
310 &key,
311 host_id,
312 "test-kv",
313 "deleted",
314 Some("hello"),
315 )
316 .await
317 .unwrap();
318
319 kv.set(&mut store, &key, host_id, "test-kv", "deleted", None)
320 .await
321 .unwrap();
322
323 let map = kv.build_kv(&store, &key).await.unwrap();
324
325 assert_eq!(
326 *map.get("test-kv")
327 .expect("map namespace not set")
328 .get("foo")
329 .expect("map key not set"),
330 KvRecord {
331 namespace: String::from("test-kv"),
332 key: String::from("foo"),
333 value: Some(String::from("bar"))
334 }
335 );
336
337 assert_eq!(
338 *map.get("test-kv")
339 .expect("map namespace not set")
340 .get("1")
341 .expect("map key not set"),
342 KvRecord {
343 namespace: String::from("test-kv"),
344 key: String::from("1"),
345 value: Some(String::from("2"))
346 }
347 );
348
349 assert_eq!(
350 *map.get("test-kv")
351 .expect("map namespace not set")
352 .get("deleted")
353 .expect("map key not set"),
354 KvRecord {
355 namespace: String::from("test-kv"),
356 key: String::from("deleted"),
357 value: None
358 }
359 );
360 }
361}