use std::sync::{Arc, RwLock};
use rayon::{iter::Either, prelude::*};
use crate::{CompositeEncryptable, Decryptable, IdentifyKey, KeySlotId, KeySlotIds};
mod backend;
mod context;
use backend::{StoreBackend, create_store};
use context::GlobalKeys;
pub use context::KeyStoreContext;
mod key_rotation;
pub use key_rotation::*;
pub struct KeyStore<Ids: KeySlotIds> {
inner: Arc<RwLock<KeyStoreInner<Ids>>>,
}
impl<Ids: KeySlotIds> Clone for KeyStore<Ids> {
fn clone(&self) -> Self {
KeyStore {
inner: Arc::clone(&self.inner),
}
}
}
impl<Ids: KeySlotIds> std::fmt::Debug for KeyStore<Ids> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KeyStore").finish()
}
}
struct KeyStoreInner<Ids: KeySlotIds> {
symmetric_keys: Box<dyn StoreBackend<Ids::Symmetric>>,
private_keys: Box<dyn StoreBackend<Ids::Private>>,
signing_keys: Box<dyn StoreBackend<Ids::Signing>>,
security_state_version: u64,
}
impl<Ids: KeySlotIds> Default for KeyStore<Ids> {
fn default() -> Self {
Self {
inner: Arc::new(RwLock::new(KeyStoreInner {
symmetric_keys: create_store(),
private_keys: create_store(),
signing_keys: create_store(),
security_state_version: 1,
})),
}
}
}
impl<Ids: KeySlotIds> KeyStore<Ids> {
pub fn clear(&self) {
let mut keys = self.inner.write().expect("RwLock is poisoned");
keys.symmetric_keys.clear();
keys.private_keys.clear();
keys.signing_keys.clear();
}
pub fn set_security_state_version(&self, version: u64) {
let mut data = self.inner.write().expect("RwLock is poisoned");
data.security_state_version = version;
}
pub fn context(&'_ self) -> KeyStoreContext<'_, Ids> {
let data = self.inner.read().expect("RwLock is poisoned");
let security_state_version = data.security_state_version;
KeyStoreContext {
global_keys: GlobalKeys::ReadOnly(data),
local_symmetric_keys: create_store(),
local_private_keys: create_store(),
local_signing_keys: create_store(),
security_state_version,
_phantom: std::marker::PhantomData,
}
}
pub fn context_mut(&'_ self) -> KeyStoreContext<'_, Ids> {
let inner = self.inner.write().expect("RwLock is poisoned");
let security_state_version = inner.security_state_version;
KeyStoreContext {
global_keys: GlobalKeys::ReadWrite(inner),
local_symmetric_keys: create_store(),
local_private_keys: create_store(),
local_signing_keys: create_store(),
security_state_version,
_phantom: std::marker::PhantomData,
}
}
pub fn decrypt<
Key: KeySlotId,
Data: Decryptable<Ids, Key, Output> + IdentifyKey<Key>,
Output,
>(
&self,
data: &Data,
) -> Result<Output, crate::CryptoError> {
let key = data.key_identifier();
data.decrypt(&mut self.context(), key)
}
pub fn encrypt<
Key: KeySlotId,
Data: CompositeEncryptable<Ids, Key, Output> + IdentifyKey<Key>,
Output,
>(
&self,
data: Data,
) -> Result<Output, crate::CryptoError> {
let key = data.key_identifier();
data.encrypt_composite(&mut self.context(), key)
}
pub fn decrypt_list<
Key: KeySlotId,
Data: Decryptable<Ids, Key, Output> + IdentifyKey<Key> + Send + Sync,
Output: Send + Sync,
>(
&self,
data: &[Data],
) -> Result<Vec<Output>, crate::CryptoError> {
let res: Result<Vec<_>, _> = data
.par_chunks(batch_chunk_size(data.len()))
.map(|chunk| {
let mut ctx = self.context();
let mut result = Vec::with_capacity(chunk.len());
for item in chunk {
let key = item.key_identifier();
result.push(item.decrypt(&mut ctx, key));
ctx.clear_local();
}
result
})
.flatten()
.collect();
res
}
pub fn decrypt_list_with_failures<
'a,
Key: KeySlotId,
Data: Decryptable<Ids, Key, Output> + IdentifyKey<Key> + Send + Sync + 'a,
Output: Send + Sync,
>(
&self,
data: &'a [Data],
) -> (Vec<Output>, Vec<&'a Data>) {
let results: (Vec<_>, Vec<_>) = data
.par_chunks(batch_chunk_size(data.len()))
.flat_map(|chunk| {
let mut ctx = self.context();
chunk
.iter()
.map(|item| {
let result = item
.decrypt(&mut ctx, item.key_identifier())
.map_err(|_| item);
ctx.clear_local();
result
})
.collect::<Vec<_>>()
})
.partition_map(|result| match result {
Ok(output) => Either::Left(output),
Err(original_item) => Either::Right(original_item),
});
results
}
pub fn encrypt_list<
Key: KeySlotId,
Data: CompositeEncryptable<Ids, Key, Output> + IdentifyKey<Key> + Send + Sync,
Output: Send + Sync,
>(
&self,
data: &[Data],
) -> Result<Vec<Output>, crate::CryptoError> {
let res: Result<Vec<_>, _> = data
.par_chunks(batch_chunk_size(data.len()))
.map(|chunk| {
let mut ctx = self.context();
let mut result = Vec::with_capacity(chunk.len());
for item in chunk {
let key = item.key_identifier();
result.push(item.encrypt_composite(&mut ctx, key));
ctx.clear_local();
}
result
})
.flatten()
.collect();
res
}
}
fn batch_chunk_size(len: usize) -> usize {
let items_per_thread = usize::div_ceil(len, rayon::current_num_threads());
const MINIMUM_CHUNK_SIZE: usize = 50;
usize::max(items_per_thread, MINIMUM_CHUNK_SIZE)
}
#[cfg(test)]
pub(crate) mod tests {
use crate::{
EncString, PrimitiveEncryptable, SymmetricKeyAlgorithm,
store::{KeyStore, KeyStoreContext},
traits::tests::{TestIds, TestSymmKey},
};
pub struct DataView(pub String, pub TestSymmKey);
pub struct Data(pub EncString, pub TestSymmKey);
impl crate::IdentifyKey<TestSymmKey> for DataView {
fn key_identifier(&self) -> TestSymmKey {
self.1
}
}
impl crate::IdentifyKey<TestSymmKey> for Data {
fn key_identifier(&self) -> TestSymmKey {
self.1
}
}
impl crate::CompositeEncryptable<TestIds, TestSymmKey, Data> for DataView {
fn encrypt_composite(
&self,
ctx: &mut KeyStoreContext<TestIds>,
key: TestSymmKey,
) -> Result<Data, crate::CryptoError> {
Ok(Data(self.0.encrypt(ctx, key)?, key))
}
}
impl crate::Decryptable<TestIds, TestSymmKey, DataView> for Data {
fn decrypt(
&self,
ctx: &mut KeyStoreContext<TestIds>,
key: TestSymmKey,
) -> Result<DataView, crate::CryptoError> {
Ok(DataView(self.0.decrypt(ctx, key)?, key))
}
}
#[test]
fn test_multithread_decrypt_keeps_order() {
let store: KeyStore<TestIds> = KeyStore::default();
for n in 0..15 {
let mut ctx = store.context_mut();
let local_key_id = ctx.make_symmetric_key(SymmetricKeyAlgorithm::Aes256CbcHmac);
ctx.persist_symmetric_key(local_key_id, TestSymmKey::A(n))
.unwrap();
}
let data: Vec<_> = (0..300usize)
.map(|n| DataView(format!("Test {n}"), TestSymmKey::A((n % 15) as u8)))
.collect();
let encrypted: Vec<_> = store.encrypt_list(&data).unwrap();
let decrypted: Vec<_> = store.decrypt_list(&encrypted).unwrap();
for (orig, dec) in data.iter().zip(decrypted.iter()) {
assert_eq!(orig.0, dec.0);
assert_eq!(orig.1, dec.1);
}
}
}