#![cfg(feature = "memstore")]
#![cfg_attr(docsrs, doc(cfg(feature = "memstore")))]
#![forbid(unsafe_code)]
extern crate alloc;
use alloc::{
boxed::Box,
collections::btree_map::{self, BTreeMap},
vec::Vec,
};
use core::marker::PhantomData;
use super::{Entry, ErrorKind, KeyStore, Occupied, Vacant};
use crate::{engine::WrappedKey, id::BaseId, util::cbor};
#[derive(Clone, Default, Debug)]
pub struct MemStore {
keys: BTreeMap<BaseId, StoredKey>,
}
impl MemStore {
#[inline]
pub const fn new() -> Self {
Self {
keys: BTreeMap::new(),
}
}
}
impl KeyStore for MemStore {
type Error = Error;
type Vacant<'a, T: WrappedKey> = VacantEntry<'a, T>;
type Occupied<'a, T: WrappedKey> = OccupiedEntry<'a, T>;
fn entry<T: WrappedKey>(&mut self, id: BaseId) -> Result<Entry<'_, Self, T>, Self::Error> {
match self.keys.entry(id) {
btree_map::Entry::Vacant(entry) => Ok(Entry::Vacant(VacantEntry {
entry,
_t: PhantomData,
})),
btree_map::Entry::Occupied(entry) => Ok(Entry::Occupied(OccupiedEntry {
entry,
_t: PhantomData,
})),
}
}
fn get<T: WrappedKey>(&self, id: BaseId) -> Result<Option<T>, Self::Error> {
match self.keys.get(&id) {
Some(v) => Ok(Some(v.to_wrapped()?)),
None => Ok(None),
}
}
}
#[derive(Clone, Debug)]
struct StoredKey(Vec<u8>);
impl StoredKey {
fn new<T: WrappedKey>(key: T) -> Result<Self, Error> {
let data =
cbor::to_allocvec(&key).map_err(|_| <Error as super::Error>::other(EncodingError))?;
Ok(Self(data))
}
fn to_wrapped<T: WrappedKey>(&self) -> Result<T, Error> {
cbor::from_bytes(&self.0).map_err(|_| <Error as super::Error>::other(DecodingError))
}
}
pub struct VacantEntry<'a, T> {
entry: btree_map::VacantEntry<'a, BaseId, StoredKey>,
_t: PhantomData<T>,
}
impl<T: WrappedKey> Vacant<T> for VacantEntry<'_, T> {
type Error = Error;
fn insert(self, key: T) -> Result<(), Self::Error> {
self.entry.insert(StoredKey::new(key)?);
Ok(())
}
}
pub struct OccupiedEntry<'a, T> {
entry: btree_map::OccupiedEntry<'a, BaseId, StoredKey>,
_t: PhantomData<T>,
}
impl<T: WrappedKey> Occupied<T> for OccupiedEntry<'_, T> {
type Error = Error;
fn get(&self) -> Result<T, Self::Error> {
self.entry.get().to_wrapped()
}
fn remove(self) -> Result<T, Self::Error> {
self.entry.remove().to_wrapped()
}
}
#[derive(Debug, thiserror::Error)]
#[error("{err}")]
pub struct Error {
kind: ErrorKind,
#[source]
err: Box<dyn core::error::Error + Send + Sync + 'static>,
}
impl Error {
#[inline]
pub fn downcast_ref<T: core::error::Error + 'static>(&self) -> Option<&T> {
self.err.downcast_ref::<T>()
}
}
impl super::Error for Error {
fn new<E>(kind: ErrorKind, err: E) -> Self
where
E: core::error::Error + Send + Sync + 'static,
{
Self {
kind,
err: Box::new(err),
}
}
#[inline]
fn kind(&self) -> ErrorKind {
self.kind
}
}
#[derive(Debug, thiserror::Error)]
#[error("unable to encode key")]
struct EncodingError;
#[derive(Debug, thiserror::Error)]
#[error("unable to decode key")]
struct DecodingError;
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
use crate::{
default::DefaultCipherSuite,
id::{BaseId, Identified},
};
macro_rules! id {
($id:expr) => {{
let data = ($id as u64).to_le_bytes();
$crate::id::IdExt::new::<DefaultCipherSuite>(
b"TestKey",
::core::iter::once(data.as_slice()),
)
}};
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
struct TestKey64(u64);
impl WrappedKey for TestKey64 {}
impl Identified for TestKey64 {
type Id = BaseId;
fn id(&self) -> Result<Self::Id, crate::id::IdError> {
Ok(id!(self.0))
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
struct TestKeyId(BaseId);
impl WrappedKey for TestKeyId {}
impl Identified for TestKeyId {
type Id = BaseId;
fn id(&self) -> Result<Self::Id, crate::id::IdError> {
Ok(self.0)
}
}
#[test]
fn test_get() {
let mut store = MemStore::new();
let want = TestKey64(1);
store
.try_insert(id!(1), want)
.expect("should be able to store key");
let got = store
.get::<TestKey64>(id!(1))
.expect("`get` should not fail")
.expect("should be able to find key");
assert_eq!(got, want);
}
#[test]
fn test_get_wrong_key_type() {
let mut store = MemStore::new();
let want = TestKey64(1);
store
.try_insert(id!(1), want)
.expect("should be able to store key");
store
.get::<TestKeyId>(id!(1))
.expect_err("should not be able to get key");
}
#[test]
fn test_remove() {
let mut store = MemStore::new();
store
.try_insert(id!(1), TestKey64(1))
.expect("should be able to store key");
store
.try_insert(id!(2), TestKey64(2))
.expect("should be able to store key");
let got = store
.remove::<TestKey64>(id!(1))
.expect("`remove` should not fail")
.expect("should be able to find key");
assert_eq!(got, TestKey64(1));
let got = store
.get::<TestKey64>(id!(2))
.expect("`get` should not fail")
.expect("should be able to find key");
assert_eq!(got, TestKey64(2));
assert!(
store
.get::<TestKey64>(id!(1))
.expect("`get` should not fail")
.is_none()
);
}
}