1use std::collections::HashMap;
11use std::fmt::Debug;
12use std::fs;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, Mutex};
16
17use tempfile::NamedTempFile;
18
19use aes_gcm_siv::Aes256GcmSiv;
20use elements::hashes::hex::DisplayHex;
21
22use crate::encrypt::{
23 cipher_from_key_bytes, decrypt_with_nonce_prefix, encrypt_with_deterministic_nonce,
24 encrypt_with_random_nonce, EncryptError,
25};
26
27pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
29
30pub trait Store: Send + Sync + Debug {
42 type Error: std::error::Error + Send + Sync + 'static;
44
45 fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error>;
49
50 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error>;
52
53 fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error>; fn is_persisted(&self) -> bool {
64 false
65 }
66}
67
68pub trait DynStore: Send + Sync + Debug {
77 fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BoxError>;
79 fn put(&self, key: &str, value: &[u8]) -> Result<(), BoxError>;
81 fn remove(&self, key: &str) -> Result<(), BoxError>;
83 fn is_persisted(&self) -> bool {
87 false
88 }
89}
90
91#[derive(Debug)]
94pub struct ArcDynStoreError(BoxError);
95
96impl std::fmt::Display for ArcDynStoreError {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 std::fmt::Display::fmt(&self.0, f)
99 }
100}
101
102impl std::error::Error for ArcDynStoreError {
103 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
104 self.0.source()
105 }
106}
107
108impl Store for Arc<dyn DynStore> {
112 type Error = ArcDynStoreError;
113
114 fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, ArcDynStoreError> {
115 let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
116 DynStore::get(self.as_ref(), key).map_err(ArcDynStoreError)
117 }
118
119 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(
120 &self,
121 key: K,
122 value: V,
123 ) -> Result<(), ArcDynStoreError> {
124 let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
125 DynStore::put(self.as_ref(), key, value.as_ref()).map_err(ArcDynStoreError)
126 }
127
128 fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), ArcDynStoreError> {
129 let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
130 DynStore::remove(self.as_ref(), key).map_err(ArcDynStoreError)
131 }
132
133 fn is_persisted(&self) -> bool {
134 DynStore::is_persisted(self.as_ref())
135 }
136}
137
138impl<S: Store> DynStore for S {
140 fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BoxError> {
141 Store::get(self, key).map_err(|e| Box::new(e) as BoxError)
142 }
143
144 fn put(&self, key: &str, value: &[u8]) -> Result<(), BoxError> {
145 Store::put(self, key, value).map_err(|e| Box::new(e) as BoxError)
146 }
147
148 fn remove(&self, key: &str) -> Result<(), BoxError> {
149 Store::remove(self, key).map_err(|e| Box::new(e) as BoxError)
150 }
151
152 fn is_persisted(&self) -> bool {
153 Store::is_persisted(self)
154 }
155}
156
157#[derive(Debug, Default)]
161pub struct MemoryStore {
162 data: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
163}
164
165impl MemoryStore {
166 pub fn new() -> Self {
168 Self::default()
169 }
170}
171
172impl Store for MemoryStore {
173 type Error = std::convert::Infallible;
174
175 fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
176 Ok(self
177 .data
178 .lock()
179 .expect("lock poisoned")
180 .get(key.as_ref())
181 .cloned())
182 }
183
184 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
185 self.data
186 .lock()
187 .expect("lock poisoned")
188 .insert(key.as_ref().to_vec(), value.as_ref().to_vec());
189 Ok(())
190 }
191
192 fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
193 self.data
194 .lock()
195 .expect("lock poisoned")
196 .remove(key.as_ref());
197 Ok(())
198 }
199}
200
201#[derive(Debug, Default, Clone, Copy)]
205pub struct FakeStore;
206
207impl FakeStore {
208 pub fn new() -> Self {
210 Self
211 }
212}
213
214impl Store for FakeStore {
215 type Error = std::convert::Infallible;
216
217 fn get<K: AsRef<[u8]>>(&self, _key: K) -> Result<Option<Vec<u8>>, Self::Error> {
218 Ok(None)
219 }
220
221 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, _key: K, _value: V) -> Result<(), Self::Error> {
222 Ok(())
223 }
224
225 fn remove<K: AsRef<[u8]>>(&self, _key: K) -> Result<(), Self::Error> {
226 Ok(())
227 }
228}
229
230#[derive(Debug)]
235pub struct FileStore {
236 root: Mutex<PathBuf>,
242}
243impl FileStore {
244 pub fn new(path: PathBuf) -> Result<Self, std::io::Error> {
248 if path.is_file() {
249 return Err(std::io::Error::new(
250 std::io::ErrorKind::InvalidInput,
251 "FileStore root path is a file",
252 ));
253 }
254 if !path.exists() {
255 fs::create_dir_all(&path)?;
256 }
257 Ok(Self {
258 root: Mutex::new(path),
259 })
260 }
261
262 fn file_path(root: &Path, key: &[u8]) -> Result<PathBuf, std::io::Error> {
263 let name = std::str::from_utf8(key).map_err(|_| {
267 std::io::Error::new(
268 std::io::ErrorKind::InvalidInput,
269 "store key is not valid UTF-8",
270 )
271 })?;
272
273 if name.is_empty() {
274 return Err(std::io::Error::new(
275 std::io::ErrorKind::InvalidInput,
276 "store key is empty",
277 ));
278 }
279
280 if name.len() > 255 {
281 return Err(std::io::Error::new(
282 std::io::ErrorKind::InvalidInput,
283 "store key exceeds maximum file name length (255 bytes)",
284 ));
285 }
286
287 if name == "."
289 || name == ".."
290 || name.contains('/')
291 || name.contains('\\')
292 || name.contains('\0')
293 || name.contains(':')
294 || name.contains('*')
295 || name.contains('?')
296 || name.contains('<')
297 || name.contains('>')
298 || name.contains('|')
299 {
300 return Err(std::io::Error::new(
301 std::io::ErrorKind::InvalidInput,
302 "store key contains invalid file name characters",
303 ));
304 }
305
306 Ok(root.join(name))
307 }
308
309 #[cfg(not(target_os = "windows"))]
310 fn sync_dir(path: &Path) -> Result<(), std::io::Error> {
311 fs::File::open(path)?.sync_all()
312 }
313
314 #[cfg(target_os = "windows")]
315 fn sync_dir(_path: &Path) -> Result<(), std::io::Error> {
316 Ok(())
318 }
319}
320impl Store for FileStore {
321 type Error = std::io::Error;
322
323 fn is_persisted(&self) -> bool {
324 true
325 }
326
327 fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
328 let root = self.root.lock().expect("lock poisoned");
329 let path = Self::file_path(&root, key.as_ref())?;
330 match fs::read(path) {
331 Ok(bytes) => Ok(Some(bytes)),
332 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
333 Err(e) => Err(e),
334 }
335 }
336
337 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
338 let root = self.root.lock().expect("lock poisoned");
339 let path = Self::file_path(&root, key.as_ref())?;
340
341 let mut tmp = NamedTempFile::new_in(&*root)?;
343 tmp.write_all(value.as_ref())?;
344 tmp.as_file().sync_all()?;
345
346 match tmp.persist(&path) {
347 Ok(_) => {}
348 Err(e) if e.error.kind() == std::io::ErrorKind::AlreadyExists => {
349 match fs::remove_file(&path) {
352 Ok(()) => {}
353 Err(remove_err) if remove_err.kind() == std::io::ErrorKind::NotFound => {}
354 Err(remove_err) => return Err(remove_err),
355 }
356
357 e.file
358 .persist(&path)
359 .map_err(|persist_err| persist_err.error)?;
360 }
361 Err(e) => return Err(e.error),
362 }
363
364 Self::sync_dir(root.as_path())?;
366
367 Ok(())
368 }
369
370 fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
371 let root = self.root.lock().expect("lock poisoned");
372 let path = Self::file_path(&root, key.as_ref())?;
373 match fs::remove_file(path) {
374 Ok(()) => Ok(()),
375 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
376 Err(e) => Err(e),
377 }
378 }
379}
380
381#[derive(Debug)]
385pub enum EncryptedStoreError<E: std::error::Error + Send + Sync + 'static> {
386 Store(E),
388 Encrypt(EncryptError),
390}
391
392impl<E: std::error::Error + Send + Sync + 'static> std::fmt::Display for EncryptedStoreError<E> {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 match self {
395 EncryptedStoreError::Store(e) => write!(f, "store error: {e}"),
396 EncryptedStoreError::Encrypt(e) => write!(f, "encryption error: {e}"),
397 }
398 }
399}
400
401impl<E: std::error::Error + Send + Sync + 'static> std::error::Error for EncryptedStoreError<E> {
402 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
403 match self {
404 EncryptedStoreError::Store(e) => Some(e),
405 EncryptedStoreError::Encrypt(e) => Some(e),
406 }
407 }
408}
409
410#[derive(Debug)]
417pub struct EncryptedStore<S> {
418 inner: S,
419 key_bytes: [u8; 32],
420 encrypt_keys: bool,
421}
422
423impl<S> EncryptedStore<S> {
424 pub fn new(inner: S, key_bytes: [u8; 32]) -> Self {
429 Self {
430 inner,
431 key_bytes,
432 encrypt_keys: false,
433 }
434 }
435
436 pub fn new_with_key_encryption(inner: S, key_bytes: [u8; 32]) -> Self {
438 Self {
439 inner,
440 key_bytes,
441 encrypt_keys: true,
442 }
443 }
444
445 pub fn inner(&self) -> &S {
447 &self.inner
448 }
449
450 pub fn into_inner(self) -> S {
452 self.inner
453 }
454
455 pub fn cipher(&self) -> Aes256GcmSiv {
457 cipher_from_key_bytes(self.key_bytes)
458 }
459}
460
461impl<S: Store> EncryptedStore<S> {
462 fn effective_key<K: AsRef<[u8]>>(
463 &self,
464 key: K,
465 ) -> Result<Vec<u8>, EncryptedStoreError<S::Error>> {
466 if self.encrypt_keys {
467 let mut cipher = cipher_from_key_bytes(self.key_bytes);
468 let encrypted = encrypt_with_deterministic_nonce(&mut cipher, key.as_ref())
469 .map_err(EncryptedStoreError::Encrypt)?;
470 Ok(encrypted.to_lower_hex_string().into_bytes())
471 } else {
472 Ok(key.as_ref().to_vec())
473 }
474 }
475}
476
477impl<S: Store> Store for EncryptedStore<S> {
478 type Error = EncryptedStoreError<S::Error>;
479
480 fn is_persisted(&self) -> bool {
481 self.inner.is_persisted()
482 }
483
484 fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
485 let key = self.effective_key(key)?;
486 match self.inner.get(&key).map_err(EncryptedStoreError::Store)? {
487 Some(ciphertext) => {
488 let mut cipher = cipher_from_key_bytes(self.key_bytes);
489 let plaintext = decrypt_with_nonce_prefix(&mut cipher, &ciphertext)
490 .map_err(EncryptedStoreError::Encrypt)?;
491 Ok(Some(plaintext))
492 }
493 None => Ok(None),
494 }
495 }
496
497 fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
498 let key = self.effective_key(key)?;
499 let mut cipher = cipher_from_key_bytes(self.key_bytes);
500 let ciphertext = encrypt_with_random_nonce(&mut cipher, value.as_ref())
501 .map_err(EncryptedStoreError::Encrypt)?;
502 self.inner
503 .put(&key, ciphertext)
504 .map_err(EncryptedStoreError::Store)?;
505 Ok(())
506 }
507
508 fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
509 let key = self.effective_key(key)?;
510 self.inner
511 .remove(&key)
512 .map_err(EncryptedStoreError::Store)?;
513 Ok(())
514 }
515}
516
517#[cfg(test)]
518mod test {
519 use super::{EncryptedStore, FakeStore, FileStore, MemoryStore, Store};
520
521 #[test]
522 fn memory_store() {
523 let store = MemoryStore::new();
524
525 assert_eq!(store.get("key").unwrap(), None);
527
528 store.put("key", b"value").unwrap();
530 assert_eq!(store.get("key").unwrap(), Some(b"value".to_vec()));
531
532 store.put("key", b"new_value").unwrap();
534 assert_eq!(store.get("key").unwrap(), Some(b"new_value".to_vec()));
535
536 store.remove("key").unwrap();
538 assert_eq!(store.get("key").unwrap(), None);
539
540 store.remove("key").unwrap();
542 }
543
544 #[test]
545 fn file_store_roundtrip() {
546 let dir = tempfile::tempdir().unwrap();
547 let store = FileStore::new(dir.path().to_path_buf()).unwrap();
548
549 assert_eq!(store.get("key").unwrap(), None);
551
552 store.put("key", b"value").unwrap();
554 assert_eq!(store.get("key").unwrap(), Some(b"value".to_vec()));
555
556 store.put("key2", b"value2").unwrap();
557 assert_eq!(store.get("key2").unwrap(), Some(b"value2".to_vec()));
558
559 store.put("key", b"new_value").unwrap();
561 assert_eq!(store.get("key").unwrap(), Some(b"new_value".to_vec()));
562
563 let non_utf8_key = [0u8, 255u8, 1u8];
565 assert!(store.put(non_utf8_key, b"bin").is_err());
566
567 store.remove("key").unwrap();
569 assert_eq!(store.get("key").unwrap(), None);
570
571 store.remove("key").unwrap();
573
574 drop(store);
575 let store = FileStore::new(dir.path().to_path_buf()).unwrap();
577
578 assert_eq!(store.get("key").unwrap(), None);
579 assert_eq!(store.get("key2").unwrap(), Some(b"value2".to_vec()));
580 }
581
582 #[test]
583 fn fake_store() {
584 let store = FakeStore::new();
585
586 assert_eq!(store.get("key").unwrap(), None);
587 store.put("key", b"value").unwrap();
588 assert_eq!(store.get("key").unwrap(), None);
589 store.remove("key").unwrap();
590 }
591
592 #[test]
593 fn encrypted_store_memory() {
594 let key_bytes = [7u8; 32];
595 let inner = MemoryStore::new();
596 let store = EncryptedStore::new(inner, key_bytes);
597
598 assert_eq!(store.get("key").unwrap(), None);
600
601 store.put("key", b"secret value").unwrap();
603 assert_eq!(store.get("key").unwrap(), Some(b"secret value".to_vec()));
604
605 let raw = store.inner().get("key").unwrap().unwrap();
607 assert_ne!(raw, b"secret value".to_vec());
608
609 store.put("key", b"new secret").unwrap();
611 assert_eq!(store.get("key").unwrap(), Some(b"new secret".to_vec()));
612
613 store.remove("key").unwrap();
615 assert_eq!(store.get("key").unwrap(), None);
616 }
617
618 #[test]
619 fn encrypted_store_file() {
620 let key_bytes = [42u8; 32];
621 let dir = tempfile::tempdir().unwrap();
622 let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
623 let store = EncryptedStore::new(inner, key_bytes);
624
625 store.put("000000000000", b"update data").unwrap();
627 assert_eq!(
628 store.get("000000000000").unwrap(),
629 Some(b"update data".to_vec())
630 );
631
632 let file_path = dir.path().join("000000000000");
634 let raw_bytes = std::fs::read(&file_path).unwrap();
635 assert_ne!(raw_bytes, b"update data".to_vec());
636
637 drop(store);
639 let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
640 let store = EncryptedStore::new(inner, key_bytes);
641 assert_eq!(
642 store.get("000000000000").unwrap(),
643 Some(b"update data".to_vec())
644 );
645
646 let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
648 let wrong_store = EncryptedStore::new(inner, [0u8; 32]);
649 assert!(wrong_store.get("000000000000").is_err());
650 }
651}