use crate::session::policy_engine::PolicyEngine;
use crate::session::{SessionError, RUNTIME};
use antimatter_api::models::{CapsuleOpenRequest, CapsuleOpenResponse};
use chrono::{DateTime, TimeDelta, Utc};
use digest::Digest;
use lru::LruCache;
use sha2::Sha256;
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
const POLICY_CACHE_TTL: TimeDelta = TimeDelta::minutes(10);
fn make_policy_engine(rsp: CapsuleOpenResponse) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
Ok(Arc::new(Mutex::new(RUNTIME.block_on(
PolicyEngine::new(&rsp.read_context_configuration.policy_assembly),
)?)))
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct ReadCacheKey {
domain: String,
capsule_id: String,
read_context: String,
}
#[derive(Clone)]
struct ReadCacheValue {
resp: CapsuleOpenResponse,
engine: Arc<Mutex<PolicyEngine>>,
expires: DateTime<Utc>,
}
#[derive(Clone)]
struct EngineCacheValue {
engine: Arc<Mutex<PolicyEngine>>,
expires: DateTime<Utc>,
}
#[derive(Clone, Debug)]
pub struct ReadCache {
read_cache: Option<LruCache<ReadCacheKey, ReadCacheValue>>,
engine_cache: Option<LruCache<Vec<u8>, EngineCacheValue>>,
read_enabled: bool,
engine_enabled: bool,
read_cache_size: usize,
engine_cache_size: usize,
}
impl ReadCache {
pub fn new(read_cache_size: usize, engine_cache_size: usize) -> Self {
let (read_cache, read_enabled) = match NonZeroUsize::new(read_cache_size) {
None => (None, false),
Some(size) => (Some(LruCache::new(size)), true),
};
let (engine_cache, engine_enabled) = match NonZeroUsize::new(engine_cache_size) {
None => (None, false),
Some(size) => (Some(LruCache::new(size)), true),
};
Self {
read_cache,
engine_cache,
read_enabled,
engine_enabled,
read_cache_size,
engine_cache_size,
}
}
pub fn open_capsule(
&mut self,
domain: &str,
capsule_id: &str,
read_context: &str,
req: CapsuleOpenRequest,
open_capsule: impl FnMut(
&str,
&str,
Option<&str>,
CapsuleOpenRequest,
) -> Result<CapsuleOpenResponse, SessionError>,
) -> Result<Option<(CapsuleOpenResponse, Arc<Mutex<PolicyEngine>>)>, SessionError> {
self.open_capsule_internal(
domain,
capsule_id,
read_context,
req,
&make_policy_engine,
open_capsule,
)
}
fn open_capsule_internal(
&mut self,
domain: &str,
capsule_id: &str,
read_context: &str,
req: CapsuleOpenRequest,
make_engine: &dyn Fn(CapsuleOpenResponse) -> Result<Arc<Mutex<PolicyEngine>>, SessionError>,
mut open_capsule: impl FnMut(
&str,
&str,
Option<&str>,
CapsuleOpenRequest,
) -> Result<CapsuleOpenResponse, SessionError>,
) -> Result<Option<(CapsuleOpenResponse, Arc<Mutex<PolicyEngine>>)>, SessionError> {
let key = ReadCacheKey {
domain: domain.to_string(),
capsule_id: capsule_id.to_string(),
read_context: read_context.to_string(),
};
if self.read_enabled {
if let Some(value) = self.read_cache.as_mut().expect("null cache").get(&key) {
if value.expires > Utc::now() {
return Ok(Some((value.resp.clone(), value.engine.clone())));
} else {
self.read_cache.as_mut().expect("null cache").pop(&key);
}
}
}
let open_resp = match open_capsule(capsule_id, read_context, Some(domain), req) {
Ok(resp) => Ok(Some(resp)),
Err(SessionError::Status401(_)) => Ok(None),
Err(e) => Err(e),
}?;
if open_resp.is_none() {
return Ok(None);
}
let mut hasher = Sha256::new();
hasher.update(
&open_resp
.clone()
.unwrap()
.read_context_configuration
.policy_assembly,
);
let hash = hasher.finalize().to_vec();
let engine = match match self.engine_enabled {
true => {
if let Some(value) = self.engine_cache.as_mut().expect("null cache").get(&hash) {
if value.expires > Utc::now() {
Some(value.engine.clone())
} else {
self.engine_cache.as_mut().expect("null cache").pop(&hash);
None
}
} else {
None
}
}
false => Some(make_engine(open_resp.clone().unwrap())?),
} {
None => {
let engine = make_engine(open_resp.clone().unwrap())?;
let expires = Utc::now() + POLICY_CACHE_TTL;
self.engine_cache.as_mut().unwrap().push(
hash,
EngineCacheValue {
engine: engine.clone(),
expires,
},
);
engine
}
Some(e) => e,
};
let response = Some((open_resp.clone().unwrap(), engine.clone()));
if self.read_enabled {
let r = open_resp.unwrap();
self.read_cache.as_mut().unwrap().push(
key,
ReadCacheValue {
resp: r.clone(),
engine,
expires: Utc::now()
+ TimeDelta::try_seconds(r.read_context_configuration.key_cache_ttl.into())
.unwrap(),
},
);
}
Ok(response)
}
pub fn read_cache_size(&self) -> usize {
self.read_cache_size
}
pub fn engine_cache_size(&self) -> usize {
self.engine_cache_size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::policy_engine::PolicyEngine;
use antimatter_api::models::{CapsuleOpenRequest, CapsuleOpenResponse};
use std::sync::{Arc, Mutex};
fn mock_make_policy_engine(
_rsp: CapsuleOpenResponse,
) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("static/fixtures/allow_all.wasm");
let wasm_bytes = std::fs::read(path).expect("unable to read WASM file");
Ok(Arc::new(Mutex::new(
RUNTIME.block_on(PolicyEngine::new(&wasm_bytes))?,
)))
}
fn mock_make_policy_engine_unused(
_rsp: CapsuleOpenResponse,
) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
panic!("should not be invoked unless cache is not behaving")
}
#[test]
fn test_policy_cache_used_when_enabled() {
let mut read_cache = ReadCache::new(0, 10);
assert_eq!(read_cache.engine_cache.is_some(), true);
assert_eq!(read_cache.read_cache.is_some(), false);
let domain = "test_domain";
let capsule_id = "test_capsule_id";
let read_context = "test_read_context";
let req = CapsuleOpenRequest::default();
assert_eq!(read_cache.engine_cache.clone().unwrap().len(), 0);
let result = read_cache.open_capsule_internal(
domain,
capsule_id,
read_context,
req.clone(),
&mock_make_policy_engine,
move |_capsule_id, _read_context, _domain_id, _open_req| {
Ok(CapsuleOpenResponse {
decryption_key: vec![],
read_context_configuration: Box::new(Default::default()),
open_token: "".to_string(),
capsule_tags: vec![],
})
},
);
assert!(result.is_ok());
assert_eq!(read_cache.engine_cache.clone().unwrap().len(), 1);
let result = read_cache.open_capsule_internal(
domain,
capsule_id,
read_context,
req.clone(),
&mock_make_policy_engine_unused,
move |_capsule_id, _read_context, _domain_id, _open_req| {
Ok(CapsuleOpenResponse {
decryption_key: vec![],
read_context_configuration: Box::new(Default::default()),
open_token: "".to_string(),
capsule_tags: vec![],
})
},
);
assert!(result.is_ok());
assert_eq!(read_cache.engine_cache.unwrap().len(), 1);
}
}