use std::collections::HashMap;
use std::sync::Mutex;
pub struct DmlCounter {
counts: Mutex<HashMap<(u64, String), u64>>,
}
impl DmlCounter {
pub fn new() -> Self {
Self {
counts: Mutex::new(HashMap::new()),
}
}
pub fn record_dml(&self, tenant_id: u64, collection: &str) {
let mut map = self.counts.lock().unwrap_or_else(|p| p.into_inner());
*map.entry((tenant_id, collection.to_string())).or_insert(0) += 1;
}
pub fn should_analyze(&self, tenant_id: u64, collection: &str, last_row_count: u64) -> bool {
let threshold = (last_row_count / 10).max(1000);
let map = self.counts.lock().unwrap_or_else(|p| p.into_inner());
map.get(&(tenant_id, collection.to_string()))
.copied()
.unwrap_or(0)
>= threshold
}
pub fn reset(&self, tenant_id: u64, collection: &str) {
let mut map = self.counts.lock().unwrap_or_else(|p| p.into_inner());
map.remove(&(tenant_id, collection.to_string()));
}
}
impl Default for DmlCounter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_counting() {
let counter = DmlCounter::new();
counter.record_dml(1, "users");
counter.record_dml(1, "users");
counter.record_dml(1, "users");
assert!(!counter.should_analyze(1, "users", 0));
}
#[test]
fn threshold_exceeded() {
let counter = DmlCounter::new();
for _ in 0..1001 {
counter.record_dml(1, "users");
}
assert!(counter.should_analyze(1, "users", 0));
}
#[test]
fn percentage_threshold() {
let counter = DmlCounter::new();
for _ in 0..10_001 {
counter.record_dml(1, "big_table");
}
assert!(counter.should_analyze(1, "big_table", 100_000));
}
#[test]
fn reset_clears() {
let counter = DmlCounter::new();
for _ in 0..2000 {
counter.record_dml(1, "users");
}
assert!(counter.should_analyze(1, "users", 0));
counter.reset(1, "users");
assert!(!counter.should_analyze(1, "users", 0));
}
}