use std::collections::HashMap;
use std::sync::RwLock;
use crate::primitives::private_key::PrivateKey;
use crate::primitives::public_key::PublicKey;
use crate::primitives::symmetric_key::SymmetricKey;
use crate::wallet::error::WalletError;
use crate::wallet::key_deriver::KeyDeriver;
use crate::wallet::types::{Counterparty, Protocol};
const DEFAULT_MAX_CACHE_SIZE: usize = 1000;
enum CachedValue {
Private(PrivateKey),
Public(PublicKey),
Symmetric(Vec<u8>), }
pub struct CachedKeyDeriver {
key_deriver: KeyDeriver,
cache: RwLock<HashMap<String, CachedValue>>,
max_cache_size: usize,
}
impl CachedKeyDeriver {
pub fn new(private_key: PrivateKey, max_cache_size: Option<usize>) -> Self {
let size = match max_cache_size {
Some(s) if s > 0 => s,
_ => DEFAULT_MAX_CACHE_SIZE,
};
CachedKeyDeriver {
key_deriver: KeyDeriver::new(private_key),
cache: RwLock::new(HashMap::new()),
max_cache_size: size,
}
}
pub fn root_key(&self) -> &PrivateKey {
self.key_deriver.root_key()
}
pub fn identity_key(&self) -> PublicKey {
self.key_deriver.identity_key()
}
pub fn identity_key_hex(&self) -> String {
self.key_deriver.identity_key_hex()
}
pub fn derive_private_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<PrivateKey, WalletError> {
let cache_key =
Self::make_cache_key("derivePrivateKey", protocol, key_id, counterparty, false);
{
let cache = self.cache.read().unwrap();
if let Some(CachedValue::Private(pk)) = cache.get(&cache_key) {
return Ok(pk.clone());
}
}
let result = self
.key_deriver
.derive_private_key(protocol, key_id, counterparty)?;
self.cache_set(cache_key, CachedValue::Private(result.clone()));
Ok(result)
}
pub fn derive_public_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
for_self: bool,
) -> Result<PublicKey, WalletError> {
let cache_key =
Self::make_cache_key("derivePublicKey", protocol, key_id, counterparty, for_self);
{
let cache = self.cache.read().unwrap();
if let Some(CachedValue::Public(pk)) = cache.get(&cache_key) {
return Ok(pk.clone());
}
}
let result =
self.key_deriver
.derive_public_key(protocol, key_id, counterparty, for_self)?;
self.cache_set(cache_key, CachedValue::Public(result.clone()));
Ok(result)
}
pub fn derive_symmetric_key(
&self,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
) -> Result<SymmetricKey, WalletError> {
let cache_key =
Self::make_cache_key("deriveSymmetricKey", protocol, key_id, counterparty, false);
{
let cache = self.cache.read().unwrap();
if let Some(CachedValue::Symmetric(bytes)) = cache.get(&cache_key) {
return SymmetricKey::from_bytes(bytes).map_err(WalletError::from);
}
}
let result = self
.key_deriver
.derive_symmetric_key(protocol, key_id, counterparty)?;
let bytes = result.to_bytes();
self.cache_set(cache_key, CachedValue::Symmetric(bytes));
self.key_deriver
.derive_symmetric_key(protocol, key_id, counterparty)
}
#[cfg(test)]
pub(crate) fn cache_len(&self) -> usize {
self.cache.read().unwrap().len()
}
fn make_cache_key(
method: &str,
protocol: &Protocol,
key_id: &str,
counterparty: &Counterparty,
for_self: bool,
) -> String {
let counterparty_hex = match &counterparty.public_key {
Some(pk) => pk.to_der_hex(),
None => format!("{:?}", counterparty.counterparty_type),
};
format!(
"{}:{}:{}:{}:{}:{}",
method, protocol.security_level, protocol.protocol, key_id, counterparty_hex, for_self
)
}
fn cache_set(&self, key: String, value: CachedValue) {
let mut cache = self.cache.write().unwrap();
if cache.len() >= self.max_cache_size {
cache.clear();
}
cache.insert(key, value);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wallet::types::CounterpartyType;
#[test]
fn test_cached_matches_uncached() {
let priv_key = PrivateKey::from_hex("abcd").unwrap();
let priv_key2 = PrivateKey::from_hex("abcd").unwrap();
let kd = KeyDeriver::new(priv_key);
let ckd = CachedKeyDeriver::new(priv_key2, None);
let protocol = Protocol {
security_level: 2,
protocol: "test caching".to_string(),
};
let counterparty = Counterparty {
counterparty_type: CounterpartyType::Self_,
public_key: None,
};
let pk_uncached = kd
.derive_private_key(&protocol, "1", &counterparty)
.unwrap();
let pk_cached = ckd
.derive_private_key(&protocol, "1", &counterparty)
.unwrap();
assert_eq!(pk_uncached.to_hex(), pk_cached.to_hex());
let pk_cached2 = ckd
.derive_private_key(&protocol, "1", &counterparty)
.unwrap();
assert_eq!(pk_uncached.to_hex(), pk_cached2.to_hex());
let pub_uncached = kd
.derive_public_key(&protocol, "1", &counterparty, true)
.unwrap();
let pub_cached = ckd
.derive_public_key(&protocol, "1", &counterparty, true)
.unwrap();
assert_eq!(pub_uncached.to_der_hex(), pub_cached.to_der_hex());
}
#[test]
fn test_cache_eviction() {
let priv_key = PrivateKey::from_hex("abcd").unwrap();
let ckd = CachedKeyDeriver::new(priv_key, Some(2));
let protocol = Protocol {
security_level: 0,
protocol: "evict test".to_string(),
};
let counterparty = Counterparty {
counterparty_type: CounterpartyType::Self_,
public_key: None,
};
let _ = ckd
.derive_private_key(&protocol, "1", &counterparty)
.unwrap();
let _ = ckd
.derive_private_key(&protocol, "2", &counterparty)
.unwrap();
let _ = ckd
.derive_private_key(&protocol, "3", &counterparty)
.unwrap();
assert_eq!(ckd.cache_len(), 1);
}
#[test]
fn test_identity_key_delegates() {
let priv_key = PrivateKey::from_hex("ff").unwrap();
let priv_key2 = PrivateKey::from_hex("ff").unwrap();
let kd = KeyDeriver::new(priv_key);
let ckd = CachedKeyDeriver::new(priv_key2, None);
assert_eq!(kd.identity_key_hex(), ckd.identity_key_hex());
}
#[test]
fn test_root_key_accessor() {
let priv_key = PrivateKey::from_hex("abcd").unwrap();
let expected_hex = priv_key.to_hex();
let ckd = CachedKeyDeriver::new(priv_key, None);
assert_eq!(ckd.root_key().to_hex(), expected_hex);
}
}