use crate::AppState;
use actix_web::web::Data;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use crate::config_validation::runtime_env_settings;
#[derive(Debug, Clone, Copy, Default)]
pub struct CacheInvalidationSummary {
pub did_run: bool,
pub invalidated_entries: usize,
pub remaining_entries: usize,
pub debounce_delay_ms: Option<u64>,
}
#[derive(Debug, Clone, Copy)]
struct DebounceState {
last_run: Instant,
scheduled: bool,
}
static CACHE_INVALIDATION_DEBOUNCE: Lazy<Mutex<HashMap<String, DebounceState>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
fn invalidation_window_ms() -> u64 {
runtime_env_settings().cache_invalidation_window_ms
}
fn debounce_key(client_name: &str, table_name: &str) -> String {
format!(
"{}|{}",
client_name.trim().to_ascii_lowercase(),
table_name.trim().to_ascii_lowercase()
)
}
pub(crate) fn gateway_cache_entry_matches_table_invalidation(key: &str, table_name: &str) -> bool {
let table_lc = table_name.trim().to_ascii_lowercase();
if table_lc.is_empty() {
return false;
}
if let Some(base) = key.strip_suffix(":__raw_json") {
return gateway_cache_entry_matches_table_invalidation(base, table_name);
}
let key_lc: String = key.to_ascii_lowercase();
if key_lc.starts_with(&format!("{table_lc}:")) {
return true;
}
if key_lc.contains(&format!("get_data_route:{table_lc}:")) {
return true;
}
key.starts_with("query_count:")
}
async fn run_invalidation_now(
app_state: Data<AppState>,
_client_name: &str,
table_name: &str,
) -> CacheInvalidationSummary {
let before: usize = app_state.cache.entry_count() as usize;
if before == 0 {
return CacheInvalidationSummary {
did_run: false,
invalidated_entries: 0,
remaining_entries: 0,
debounce_delay_ms: None,
};
}
let table_for_predicate: String = table_name.trim().to_string();
let _ = app_state.cache.invalidate_entries_if(move |key, _value| {
gateway_cache_entry_matches_table_invalidation(key.as_str(), &table_for_predicate)
});
app_state.cache.run_pending_tasks().await;
let after: usize = app_state.cache.entry_count() as usize;
CacheInvalidationSummary {
did_run: true,
invalidated_entries: before.saturating_sub(after),
remaining_entries: after,
debounce_delay_ms: None,
}
}
pub async fn invalidate_scoped_gateway_cache(
app_state: Data<AppState>,
client_name: &str,
table_name: &str,
) -> CacheInvalidationSummary {
if app_state.cache.entry_count() == 0 {
return CacheInvalidationSummary {
did_run: false,
invalidated_entries: 0,
remaining_entries: 0,
debounce_delay_ms: None,
};
}
let window_ms: u64 = invalidation_window_ms();
if window_ms == 0 {
return run_invalidation_now(app_state, client_name, table_name).await;
}
let key: String = debounce_key(client_name, table_name);
let now: Instant = Instant::now();
let window: Duration = Duration::from_millis(window_ms);
let mut schedule_trailing: bool = false;
let mut trailing_delay_ms: u64 = 0_u64;
let should_run_now: bool = if let Ok(mut guard) = CACHE_INVALIDATION_DEBOUNCE.lock() {
let state: &mut DebounceState = guard.entry(key.clone()).or_insert(DebounceState {
last_run: now.checked_sub(window).unwrap_or(now),
scheduled: false,
});
let elapsed: Duration = now.saturating_duration_since(state.last_run);
if elapsed < window {
trailing_delay_ms = (window - elapsed).as_millis() as u64;
if !state.scheduled {
state.scheduled = true;
schedule_trailing = true;
}
false
} else {
state.last_run = now;
true
}
} else {
true
};
if should_run_now {
return run_invalidation_now(app_state, client_name, table_name).await;
}
if schedule_trailing {
let app_state_for_task: Data<AppState> = app_state.clone();
let key_for_task: String = key.clone();
let client_for_task: String = client_name.to_string();
let table_for_task: String = table_name.to_string();
tokio::spawn(async move {
sleep(Duration::from_millis(trailing_delay_ms)).await;
let _ =
run_invalidation_now(app_state_for_task, &client_for_task, &table_for_task).await;
if let Ok(mut guard) = CACHE_INVALIDATION_DEBOUNCE.lock() {
guard.insert(
key_for_task,
DebounceState {
last_run: Instant::now(),
scheduled: false,
},
);
}
});
}
CacheInvalidationSummary {
did_run: false,
invalidated_entries: 0,
remaining_entries: app_state.cache.entry_count() as usize,
debounce_delay_ms: Some(trailing_delay_ms),
}
}
#[cfg(test)]
mod tests {
use super::gateway_cache_entry_matches_table_invalidation;
#[test]
fn gateway_cache_match_post_fetch_style_key() {
assert!(gateway_cache_entry_matches_table_invalidation(
"http_request_log:id,col:limit:10:false:deadbeef",
"http_request_log"
));
}
#[test]
fn gateway_cache_match_post_fetch_case_insensitive_table() {
assert!(gateway_cache_entry_matches_table_invalidation(
"HTTP_REQUEST_LOG:id:col:1:false:abc",
"http_request_log"
));
}
#[test]
fn gateway_cache_match_get_data_route_key() {
assert!(gateway_cache_entry_matches_table_invalidation(
"get_data_route:http_request_log:eq:val:::::::::client",
"http_request_log"
));
}
#[test]
fn gateway_cache_match_get_data_hashed_suffix() {
let key: &str =
"get_data_route:http_request_log:a:b:c:d:e:f:g:h:i:j:k:l:client-hdeadbeef12345678";
assert!(gateway_cache_entry_matches_table_invalidation(
key,
"http_request_log"
));
}
#[test]
fn gateway_cache_match_raw_json_suffix() {
let base: &str = "http_request_log:id,col:1:false:ab";
let raw: String = format!("{base}:__raw_json");
assert!(gateway_cache_entry_matches_table_invalidation(
&raw,
"http_request_log"
));
}
#[test]
fn gateway_cache_match_query_count_family() {
assert!(gateway_cache_entry_matches_table_invalidation(
"query_count:deadbeef",
"any_table"
));
}
#[test]
fn gateway_cache_no_match_other_table_post_key() {
assert!(!gateway_cache_entry_matches_table_invalidation(
"other_table:id:col:1:false:ab",
"http_request_log"
));
}
#[test]
fn gateway_cache_no_match_client_substring_without_table() {
assert!(!gateway_cache_entry_matches_table_invalidation(
"unrelated:neon_dexter_sentinel:extra",
"http_request_log"
));
}
}