antimatter 2.0.13

antimatter.io Rust library for data control
Documentation
use crate::session::{SessionError, API_TIMEOUT};
use base64::{engine::general_purpose, Engine as _};
use http::Extensions;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::RwLock;

const CACHE_INVALIDATION_HEADER: &str = "antimatter-cache-invalidation";

// HTTPClient should be used as the client with the generated API client.
// Its purpose is to implement the cache invalidation scheme: the relevant
// returned headers are caught and cached internally, and the valid headers
// are attached to outbound requests.
#[derive(Clone)]
pub struct HTTPClient {
    client: reqwest_middleware::ClientWithMiddleware,
}

impl HTTPClient {
    pub fn new() -> Result<Self, SessionError> {
        Ok(Self {
            client: reqwest_middleware::ClientBuilder::new(
                reqwest::ClientBuilder::new()
                    .timeout(API_TIMEOUT)
                    .build()
                    .map_err(|e| {
                        SessionError::Error(format!("failed creating reqwest client: {}", e))
                    })?,
            )
            .with(CacheInvalidationMiddleware {
                cached_tokens: RwLock::new(HashMap::new()),
                handle_err: |e: SessionError| {
                    eprintln!("HTTP client middleware error: {}", e);
                },
            })
            .build(),
        })
    }

    pub fn client(&self) -> reqwest_middleware::ClientWithMiddleware {
        self.client.clone()
    }
}

// EvictionRecord is a translation of github.com/antimatterhq/antimatter/
// srv/pkg/l2arc.EvictionRecord.
#[derive(Serialize, Deserialize, Debug)]
struct EvictionRecord {
    #[serde(rename = "TTLSeconds")]
    ttl_seconds: u16,
    #[serde(rename = "CGEpoch")]
    cg_epoch: i64,
    #[serde(rename = "CacheName")]
    cache_name: String,
    #[serde(rename = "KeyValues")]
    key_values: Vec<String>,
}

// EvictionToken is a translation of github.com/antimatterhq/antimatter/
// srv/pkg/l2arc.EvictionToken. Note that the signature is missing; this
// is because decoding a typed object would mean maintaining multiple
// typed object registries, but it's possible to implement the cache
// invalidation scheme without ever decoding the token.
#[derive(Serialize, Deserialize, Debug)]
struct EvictionToken {
    #[serde(rename = "Record")]
    record: EvictionRecord,
}

// CacheInvalidationKey is the key into the cached tokens stored
// on the CacheInvalidationMiddleware.
#[derive(Hash, Eq, PartialEq, Debug)]
struct CacheInvalidationKey {
    cache_name: String,
    key_values: Vec<String>,
}

// CacheInvalidationKey is the value of the cached tokens stored
// on the CacheInvalidationMiddleware.
struct CacheInvalidationValue {
    expiry: i64,
    cg_epoch: i64,
    token_bytes: Vec<u8>,
}

impl EvictionToken {
    // cache_key returns the cache key associated with this EvictionToken.
    fn cache_key(&self) -> CacheInvalidationKey {
        CacheInvalidationKey {
            cache_name: self.record.cache_name.clone(),
            key_values: self.record.key_values.clone(),
        }
    }

    // Expiry returns the record expiry as i64 timestamp in epoch seconds.
    fn expiry(&self) -> i64 {
        chrono::Local::now().timestamp() + self.record.ttl_seconds as i64
    }

    // cache_value returns the cache value associated with the EvictionToken.
    fn cache_value(&self, token_bytes: Vec<u8>) -> CacheInvalidationValue {
        CacheInvalidationValue {
            expiry: self.expiry(),
            cg_epoch: self.record.cg_epoch,
            token_bytes,
        }
    }

    // update_entry updates the referenced CacheInvalidationValue with
    // values from the EvictionToken, if necessary.
    fn update_entry(&self, entry: &mut CacheInvalidationValue, token_bytes: Vec<u8>) {
        if entry.cg_epoch <= self.record.cg_epoch {
            entry.expiry = self.expiry();
            entry.cg_epoch = self.record.cg_epoch;
            entry.token_bytes = token_bytes;
        }
    }
}

// CacheInvalidationMiddleware implements the cache invalidation scheme
// as a HTTP client middleware.
struct CacheInvalidationMiddleware {
    // cached_tokens stores previously received tokens and is used to
    // populate the outbound request header.
    cached_tokens: RwLock<HashMap<CacheInvalidationKey, CacheInvalidationValue>>,
    // handle_err is called whenever an error is encountered in processing
    // the cache invalidation headers.
    handle_err: fn(SessionError),
}

impl CacheInvalidationMiddleware {
    // purge_expired_tokens removes all expired tokens from the cache.
    // It should be called before populating the outbound request headers.
    fn purge_expired_tokens(&self) {
        let mut map = self.cached_tokens.write().unwrap();
        let now = chrono::Local::now().timestamp();
        map.retain(|_, v| v.expiry > now);
    }

    // header_value returns the ';'-delimited concatenation of the
    // base64 encoding of the cbor encoding of all the currently cached
    // tokens. This is the format expected by the server in the cache
    // invalidation header.
    fn header_value(&self) -> String {
        let map = self.cached_tokens.write().unwrap();
        let mut result = String::new();
        let mut first = true;
        for val in map.values() {
            if first {
                first = false
            } else {
                result.push(';');
            }

            // base64 encode first
            result.push_str(&general_purpose::STANDARD.encode(val.token_bytes.clone()));
        }
        result
    }

    // update_from_response updates the locally cached tokens from the
    // response. This includes inserting new tokens and updating existing
    // tokens if the coherence group epoch has advanced.
    fn update_from_response(&self, resp: &reqwest::Response) -> Result<(), SessionError> {
        let header = match resp.headers().get(CACHE_INVALIDATION_HEADER) {
            Some(value) => value.to_str().map_err(|e| {
                SessionError::Error(format!("converting header value to string: {}", e))
            })?,
            None => return Ok(()),
        };

        for token_b64 in header.split(';') {
            let decoded = general_purpose::STANDARD
                .decode(token_b64.as_bytes())
                .map_err(|e| SessionError::Error(format!("decoding token base64: {}", e)))?;

            let token = ciborium::from_reader::<EvictionToken, &mut std::io::Cursor<Vec<u8>>>(
                &mut std::io::Cursor::new(decoded.clone()),
            )
            .map_err(|e| SessionError::Error(format!("decoding EvictionToken CBOR: {}", e)))?;

            let key = token.cache_key();
            let mut map = self.cached_tokens.write().unwrap();
            if let Some(entry) = map.get_mut(&key) {
                token.update_entry(entry, decoded);
            } else {
                map.insert(key, token.cache_value(decoded));
            }
        }
        Ok(())
    }
}

#[async_trait::async_trait]
impl reqwest_middleware::Middleware for CacheInvalidationMiddleware {
    // handle implements the Middleware trait. Before the outbound request
    // is sent, it purges old tokens, then encodes the remainder of the
    // tokens into the request header. After receiving the response, it
    // reads the cache invalidation header and updates the local cache.
    async fn handle(
        &self,
        mut req: reqwest::Request,
        extensions: &mut Extensions,
        next: reqwest_middleware::Next<'_>,
    ) -> reqwest_middleware::Result<reqwest::Response> {
        // purge expired tokens
        self.purge_expired_tokens();

        // attach remaining tokens to request
        match reqwest::header::HeaderValue::from_str(self.header_value().as_str())
            .map_err(|e| SessionError::Error(format!("creating header value from string: {}", e)))
        {
            Ok(header) => req.headers_mut().insert(CACHE_INVALIDATION_HEADER, header),
            Err(e) => {
                (self.handle_err)(e);
                None
            }
        };

        let res = next.run(req, extensions).await;

        // update cached tokens from response
        if let Ok(ref resp) = res {
            if let Err(e) = self.update_from_response(resp) {
                (self.handle_err)(e);
            }
        }

        res
    }
}

#[cfg(test)]
pub mod tests {
    use super::*;

    #[test]
    fn test_header_value() {
        let mw = CacheInvalidationMiddleware {
            cached_tokens: RwLock::new(HashMap::new()),
            handle_err: |_: SessionError| {},
        };

        // insert 2 records
        {
            let mut map = mw.cached_tokens.write().unwrap();

            let mut token = EvictionToken {
                record: EvictionRecord {
                    ttl_seconds: 0,
                    cg_epoch: 123,
                    cache_name: "test".to_string(),
                    key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
                },
            };
            map.insert(
                token.cache_key(),
                token.cache_value("test".to_string().into_bytes()),
            );

            token = EvictionToken {
                record: EvictionRecord {
                    ttl_seconds: 100,
                    cg_epoch: 125,
                    cache_name: "test2".to_string(),
                    key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
                },
            };
            map.insert(
                token.cache_key(),
                token.cache_value("test2".to_string().into_bytes()),
            );
        }

        // get header value and verify
        let header = mw.header_value();
        assert!(header.contains(general_purpose::STANDARD.encode("test").as_str()));
        assert!(header.contains(general_purpose::STANDARD.encode("test2").as_str()));
        assert_eq!(header.split(";").collect::<Vec<&str>>().len(), 2);
    }

    #[test]
    fn test_purge_expired_tokens() {
        let mw = CacheInvalidationMiddleware {
            cached_tokens: RwLock::new(HashMap::new()),
            handle_err: |_: SessionError| {},
        };

        // insert a record about to expire plus a record that won't
        // expire for a while
        {
            let mut map = mw.cached_tokens.write().unwrap();

            let mut token = EvictionToken {
                record: EvictionRecord {
                    ttl_seconds: 0,
                    cg_epoch: 123,
                    cache_name: "test".to_string(),
                    key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
                },
            };
            map.insert(
                token.cache_key(),
                token.cache_value("test".to_string().into_bytes()),
            );

            token = EvictionToken {
                record: EvictionRecord {
                    ttl_seconds: 100,
                    cg_epoch: 125,
                    cache_name: "test2".to_string(),
                    key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
                },
            };
            map.insert(
                token.cache_key(),
                token.cache_value("test2".to_string().into_bytes()),
            );
        }

        // purge expired
        mw.purge_expired_tokens();

        // check that only the second record still exists
        {
            let map = mw.cached_tokens.read().unwrap();
            assert_eq!(map.len(), 1);
            assert!(map.contains_key(&CacheInvalidationKey {
                cache_name: "test2".to_string(),
                key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
            }));
        }
    }
}