antimatter 2.0.13

antimatter.io Rust library for data control
Documentation
use crate::session::policy_engine::PolicyEngine;
use crate::session::{SessionError, RUNTIME};
use antimatter_api::models::{CapsuleOpenRequest, CapsuleOpenResponse};
use chrono::{DateTime, TimeDelta, Utc};
use digest::Digest;
use lru::LruCache;
use sha2::Sha256;
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};

const POLICY_CACHE_TTL: TimeDelta = TimeDelta::minutes(10);

fn make_policy_engine(rsp: CapsuleOpenResponse) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
    Ok(Arc::new(Mutex::new(RUNTIME.block_on(
        PolicyEngine::new(&rsp.read_context_configuration.policy_assembly),
    )?)))
}

#[derive(Hash, Eq, PartialEq, Clone, Debug)]
struct ReadCacheKey {
    domain: String,
    capsule_id: String,
    read_context: String,
}

#[derive(Clone)]
struct ReadCacheValue {
    resp: CapsuleOpenResponse,
    engine: Arc<Mutex<PolicyEngine>>,
    expires: DateTime<Utc>,
}

#[derive(Clone)]
struct EngineCacheValue {
    engine: Arc<Mutex<PolicyEngine>>,
    expires: DateTime<Utc>,
}

#[derive(Clone, Debug)]
pub struct ReadCache {
    read_cache: Option<LruCache<ReadCacheKey, ReadCacheValue>>,
    engine_cache: Option<LruCache<Vec<u8>, EngineCacheValue>>,
    read_enabled: bool,
    engine_enabled: bool,
    read_cache_size: usize,
    engine_cache_size: usize,
}

/// ReadCache is a cache for OpenCapsule responses to be used when reading
/// capsules. For a given (domain, capsule id, read context), we cache the
/// CapsuleOpenResponse for the duration specified in the CapsuleOpenResponse.
/// This allows for potentially much faster decryption when there is a series
/// of capsules which were created with key reuse enabled.
impl ReadCache {
    /// new creates a new ReadCache. Pass a size of 0 to disable caching.
    pub fn new(read_cache_size: usize, engine_cache_size: usize) -> Self {
        let (read_cache, read_enabled) = match NonZeroUsize::new(read_cache_size) {
            None => (None, false),
            Some(size) => (Some(LruCache::new(size)), true),
        };
        let (engine_cache, engine_enabled) = match NonZeroUsize::new(engine_cache_size) {
            None => (None, false),
            Some(size) => (Some(LruCache::new(size)), true),
        };

        Self {
            read_cache,
            engine_cache,
            read_enabled,
            engine_enabled,
            read_cache_size,
            engine_cache_size,
        }
    }

    /// open_capsule either returns the cached CapsuleOpenResponse, if
    /// the cache is enabled and one exists and is valid for the argument
    /// (domain, capsule_id, read_context), or makes the request to the
    /// server and updates the cache accordingly if not.
    pub fn open_capsule(
        &mut self,
        domain: &str,
        capsule_id: &str,
        read_context: &str,
        req: CapsuleOpenRequest,
        open_capsule: impl FnMut(
            &str,
            &str,
            Option<&str>,
            CapsuleOpenRequest,
        ) -> Result<CapsuleOpenResponse, SessionError>,
    ) -> Result<Option<(CapsuleOpenResponse, Arc<Mutex<PolicyEngine>>)>, SessionError> {
        self.open_capsule_internal(
            domain,
            capsule_id,
            read_context,
            req,
            &make_policy_engine,
            open_capsule,
        )
    }

    fn open_capsule_internal(
        &mut self,
        domain: &str,
        capsule_id: &str,
        read_context: &str,
        req: CapsuleOpenRequest,
        make_engine: &dyn Fn(CapsuleOpenResponse) -> Result<Arc<Mutex<PolicyEngine>>, SessionError>,
        mut open_capsule: impl FnMut(
            &str,
            &str,
            Option<&str>,
            CapsuleOpenRequest,
        ) -> Result<CapsuleOpenResponse, SessionError>,
    ) -> Result<Option<(CapsuleOpenResponse, Arc<Mutex<PolicyEngine>>)>, SessionError> {
        let key = ReadCacheKey {
            domain: domain.to_string(),
            capsule_id: capsule_id.to_string(),
            read_context: read_context.to_string(),
        };

        if self.read_enabled {
            if let Some(value) = self.read_cache.as_mut().expect("null cache").get(&key) {
                if value.expires > Utc::now() {
                    return Ok(Some((value.resp.clone(), value.engine.clone())));
                } else {
                    self.read_cache.as_mut().expect("null cache").pop(&key);
                }
            }
        }

        let open_resp = match open_capsule(capsule_id, read_context, Some(domain), req) {
            Ok(resp) => Ok(Some(resp)),
            Err(SessionError::Status401(_)) => Ok(None),
            Err(e) => Err(e),
        }?;

        if open_resp.is_none() {
            return Ok(None);
        }

        let mut hasher = Sha256::new();
        hasher.update(
            &open_resp
                .clone()
                .unwrap()
                .read_context_configuration
                .policy_assembly,
        );
        let hash = hasher.finalize().to_vec();

        let engine = match match self.engine_enabled {
            true => {
                if let Some(value) = self.engine_cache.as_mut().expect("null cache").get(&hash) {
                    if value.expires > Utc::now() {
                        Some(value.engine.clone())
                    } else {
                        self.engine_cache.as_mut().expect("null cache").pop(&hash);
                        None
                    }
                } else {
                    None
                }
            }
            false => Some(make_engine(open_resp.clone().unwrap())?),
        } {
            None => {
                let engine = make_engine(open_resp.clone().unwrap())?;
                let expires = Utc::now() + POLICY_CACHE_TTL;
                self.engine_cache.as_mut().unwrap().push(
                    hash,
                    EngineCacheValue {
                        engine: engine.clone(),
                        expires,
                    },
                );
                engine
            }
            Some(e) => e,
        };

        let response = Some((open_resp.clone().unwrap(), engine.clone()));

        if self.read_enabled {
            let r = open_resp.unwrap();
            self.read_cache.as_mut().unwrap().push(
                key,
                ReadCacheValue {
                    resp: r.clone(),
                    engine,
                    expires: Utc::now()
                        + TimeDelta::try_seconds(r.read_context_configuration.key_cache_ttl.into())
                            .unwrap(),
                },
            );
        }
        Ok(response)
    }

    pub fn read_cache_size(&self) -> usize {
        self.read_cache_size
    }

    pub fn engine_cache_size(&self) -> usize {
        self.engine_cache_size
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::session::policy_engine::PolicyEngine;
    use antimatter_api::models::{CapsuleOpenRequest, CapsuleOpenResponse};
    use std::sync::{Arc, Mutex};

    fn mock_make_policy_engine(
        _rsp: CapsuleOpenResponse,
    ) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
        let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
        path.push("static/fixtures/allow_all.wasm");
        let wasm_bytes = std::fs::read(path).expect("unable to read WASM file");
        Ok(Arc::new(Mutex::new(
            RUNTIME.block_on(PolicyEngine::new(&wasm_bytes))?,
        )))
    }

    fn mock_make_policy_engine_unused(
        _rsp: CapsuleOpenResponse,
    ) -> Result<Arc<Mutex<PolicyEngine>>, SessionError> {
        panic!("should not be invoked unless cache is not behaving")
    }

    #[test]
    fn test_policy_cache_used_when_enabled() {
        let mut read_cache = ReadCache::new(0, 10);
        assert_eq!(read_cache.engine_cache.is_some(), true);
        assert_eq!(read_cache.read_cache.is_some(), false);

        let domain = "test_domain";
        let capsule_id = "test_capsule_id";
        let read_context = "test_read_context";
        let req = CapsuleOpenRequest::default();
        assert_eq!(read_cache.engine_cache.clone().unwrap().len(), 0);

        // First pass, we expect a cache miss to the policy engine builder will be invoked.
        let result = read_cache.open_capsule_internal(
            domain,
            capsule_id,
            read_context,
            req.clone(),
            &mock_make_policy_engine,
            move |_capsule_id, _read_context, _domain_id, _open_req| {
                Ok(CapsuleOpenResponse {
                    decryption_key: vec![],
                    read_context_configuration: Box::new(Default::default()),
                    open_token: "".to_string(),
                    capsule_tags: vec![],
                })
            },
        );
        assert!(result.is_ok());
        assert_eq!(read_cache.engine_cache.clone().unwrap().len(), 1);

        // Second pass, we expect a cache hit so the policy engine builder will not be invoked.
        let result = read_cache.open_capsule_internal(
            domain,
            capsule_id,
            read_context,
            req.clone(),
            &mock_make_policy_engine_unused,
            move |_capsule_id, _read_context, _domain_id, _open_req| {
                Ok(CapsuleOpenResponse {
                    decryption_key: vec![],
                    read_context_configuration: Box::new(Default::default()),
                    open_token: "".to_string(),
                    capsule_tags: vec![],
                })
            },
        );
        assert!(result.is_ok());
        assert_eq!(read_cache.engine_cache.unwrap().len(), 1);
    }
}