use crate::session::{SessionError, API_TIMEOUT};
use base64::{engine::general_purpose, Engine as _};
use http::Extensions;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::RwLock;
const CACHE_INVALIDATION_HEADER: &str = "antimatter-cache-invalidation";
#[derive(Clone)]
pub struct HTTPClient {
client: reqwest_middleware::ClientWithMiddleware,
}
impl HTTPClient {
pub fn new() -> Result<Self, SessionError> {
Ok(Self {
client: reqwest_middleware::ClientBuilder::new(
reqwest::ClientBuilder::new()
.timeout(API_TIMEOUT)
.build()
.map_err(|e| {
SessionError::Error(format!("failed creating reqwest client: {}", e))
})?,
)
.with(CacheInvalidationMiddleware {
cached_tokens: RwLock::new(HashMap::new()),
handle_err: |e: SessionError| {
eprintln!("HTTP client middleware error: {}", e);
},
})
.build(),
})
}
pub fn client(&self) -> reqwest_middleware::ClientWithMiddleware {
self.client.clone()
}
}
#[derive(Serialize, Deserialize, Debug)]
struct EvictionRecord {
#[serde(rename = "TTLSeconds")]
ttl_seconds: u16,
#[serde(rename = "CGEpoch")]
cg_epoch: i64,
#[serde(rename = "CacheName")]
cache_name: String,
#[serde(rename = "KeyValues")]
key_values: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct EvictionToken {
#[serde(rename = "Record")]
record: EvictionRecord,
}
#[derive(Hash, Eq, PartialEq, Debug)]
struct CacheInvalidationKey {
cache_name: String,
key_values: Vec<String>,
}
struct CacheInvalidationValue {
expiry: i64,
cg_epoch: i64,
token_bytes: Vec<u8>,
}
impl EvictionToken {
fn cache_key(&self) -> CacheInvalidationKey {
CacheInvalidationKey {
cache_name: self.record.cache_name.clone(),
key_values: self.record.key_values.clone(),
}
}
fn expiry(&self) -> i64 {
chrono::Local::now().timestamp() + self.record.ttl_seconds as i64
}
fn cache_value(&self, token_bytes: Vec<u8>) -> CacheInvalidationValue {
CacheInvalidationValue {
expiry: self.expiry(),
cg_epoch: self.record.cg_epoch,
token_bytes,
}
}
fn update_entry(&self, entry: &mut CacheInvalidationValue, token_bytes: Vec<u8>) {
if entry.cg_epoch <= self.record.cg_epoch {
entry.expiry = self.expiry();
entry.cg_epoch = self.record.cg_epoch;
entry.token_bytes = token_bytes;
}
}
}
struct CacheInvalidationMiddleware {
cached_tokens: RwLock<HashMap<CacheInvalidationKey, CacheInvalidationValue>>,
handle_err: fn(SessionError),
}
impl CacheInvalidationMiddleware {
fn purge_expired_tokens(&self) {
let mut map = self.cached_tokens.write().unwrap();
let now = chrono::Local::now().timestamp();
map.retain(|_, v| v.expiry > now);
}
fn header_value(&self) -> String {
let map = self.cached_tokens.write().unwrap();
let mut result = String::new();
let mut first = true;
for val in map.values() {
if first {
first = false
} else {
result.push(';');
}
result.push_str(&general_purpose::STANDARD.encode(val.token_bytes.clone()));
}
result
}
fn update_from_response(&self, resp: &reqwest::Response) -> Result<(), SessionError> {
let header = match resp.headers().get(CACHE_INVALIDATION_HEADER) {
Some(value) => value.to_str().map_err(|e| {
SessionError::Error(format!("converting header value to string: {}", e))
})?,
None => return Ok(()),
};
for token_b64 in header.split(';') {
let decoded = general_purpose::STANDARD
.decode(token_b64.as_bytes())
.map_err(|e| SessionError::Error(format!("decoding token base64: {}", e)))?;
let token = ciborium::from_reader::<EvictionToken, &mut std::io::Cursor<Vec<u8>>>(
&mut std::io::Cursor::new(decoded.clone()),
)
.map_err(|e| SessionError::Error(format!("decoding EvictionToken CBOR: {}", e)))?;
let key = token.cache_key();
let mut map = self.cached_tokens.write().unwrap();
if let Some(entry) = map.get_mut(&key) {
token.update_entry(entry, decoded);
} else {
map.insert(key, token.cache_value(decoded));
}
}
Ok(())
}
}
#[async_trait::async_trait]
impl reqwest_middleware::Middleware for CacheInvalidationMiddleware {
async fn handle(
&self,
mut req: reqwest::Request,
extensions: &mut Extensions,
next: reqwest_middleware::Next<'_>,
) -> reqwest_middleware::Result<reqwest::Response> {
self.purge_expired_tokens();
match reqwest::header::HeaderValue::from_str(self.header_value().as_str())
.map_err(|e| SessionError::Error(format!("creating header value from string: {}", e)))
{
Ok(header) => req.headers_mut().insert(CACHE_INVALIDATION_HEADER, header),
Err(e) => {
(self.handle_err)(e);
None
}
};
let res = next.run(req, extensions).await;
if let Ok(ref resp) = res {
if let Err(e) = self.update_from_response(resp) {
(self.handle_err)(e);
}
}
res
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_header_value() {
let mw = CacheInvalidationMiddleware {
cached_tokens: RwLock::new(HashMap::new()),
handle_err: |_: SessionError| {},
};
{
let mut map = mw.cached_tokens.write().unwrap();
let mut token = EvictionToken {
record: EvictionRecord {
ttl_seconds: 0,
cg_epoch: 123,
cache_name: "test".to_string(),
key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
},
};
map.insert(
token.cache_key(),
token.cache_value("test".to_string().into_bytes()),
);
token = EvictionToken {
record: EvictionRecord {
ttl_seconds: 100,
cg_epoch: 125,
cache_name: "test2".to_string(),
key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
},
};
map.insert(
token.cache_key(),
token.cache_value("test2".to_string().into_bytes()),
);
}
let header = mw.header_value();
assert!(header.contains(general_purpose::STANDARD.encode("test").as_str()));
assert!(header.contains(general_purpose::STANDARD.encode("test2").as_str()));
assert_eq!(header.split(";").collect::<Vec<&str>>().len(), 2);
}
#[test]
fn test_purge_expired_tokens() {
let mw = CacheInvalidationMiddleware {
cached_tokens: RwLock::new(HashMap::new()),
handle_err: |_: SessionError| {},
};
{
let mut map = mw.cached_tokens.write().unwrap();
let mut token = EvictionToken {
record: EvictionRecord {
ttl_seconds: 0,
cg_epoch: 123,
cache_name: "test".to_string(),
key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
},
};
map.insert(
token.cache_key(),
token.cache_value("test".to_string().into_bytes()),
);
token = EvictionToken {
record: EvictionRecord {
ttl_seconds: 100,
cg_epoch: 125,
cache_name: "test2".to_string(),
key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
},
};
map.insert(
token.cache_key(),
token.cache_value("test2".to_string().into_bytes()),
);
}
mw.purge_expired_tokens();
{
let map = mw.cached_tokens.read().unwrap();
assert_eq!(map.len(), 1);
assert!(map.contains_key(&CacheInvalidationKey {
cache_name: "test2".to_string(),
key_values: vec!["domain".to_string(), "dm-domain4test".to_string()],
}));
}
}
}