use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use astrid_core::PrincipalId;
use dashmap::DashMap;
const MAX_PRINCIPALS: usize = 4096;
#[derive(Clone, Default)]
pub struct MemoryLedger {
inner: Arc<DashMap<PrincipalId, AtomicU64>>,
}
impl MemoryLedger {
pub fn record_peak(&self, principal: &PrincipalId, bytes: u64) {
if bytes == 0 {
return;
}
if let Some(counter) = self.inner.get(principal) {
Self::raise_to(&counter, bytes);
return;
}
if self.inner.len() >= MAX_PRINCIPALS && !self.evict_lowest_if_below(bytes) {
return;
}
Self::raise_to(&self.inner.entry(principal.clone()).or_default(), bytes);
}
fn evict_lowest_if_below(&self, threshold: u64) -> bool {
let mut victim: Option<PrincipalId> = None;
let mut lowest = u64::MAX;
for entry in &*self.inner {
let peak = entry.value().load(Ordering::Relaxed);
if peak <= lowest {
lowest = peak;
victim = Some(entry.key().clone());
}
}
let Some(key) = victim else {
return true;
};
if threshold <= lowest {
return false;
}
self.inner.remove(&key);
true
}
fn raise_to(counter: &AtomicU64, bytes: u64) {
let _ = counter.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
(bytes > v).then_some(bytes)
});
}
#[must_use]
pub fn peak(&self, principal: &PrincipalId) -> u64 {
self.inner
.get(principal)
.map_or(0, |counter| counter.load(Ordering::Relaxed))
}
}
pub struct StoreMemoryMeter {
max_memory_bytes: usize,
principal: PrincipalId,
ledger: MemoryLedger,
}
impl StoreMemoryMeter {
#[must_use]
pub fn new(max_memory_bytes: usize, principal: PrincipalId, ledger: MemoryLedger) -> Self {
Self {
max_memory_bytes,
principal,
ledger,
}
}
pub fn set(&mut self, max_memory_bytes: usize, principal: PrincipalId) {
self.max_memory_bytes = max_memory_bytes;
self.principal = principal;
}
}
impl wasmtime::ResourceLimiter for StoreMemoryMeter {
fn memory_growing(
&mut self,
_current: usize,
desired: usize,
maximum: Option<usize>,
) -> wasmtime::Result<bool> {
if desired > self.max_memory_bytes {
return Ok(false);
}
if let Some(max) = maximum
&& desired > max
{
return Ok(false);
}
self.ledger
.record_peak(&self.principal, u64::try_from(desired).unwrap_or(u64::MAX));
Ok(true)
}
fn table_growing(
&mut self,
_current: usize,
_desired: usize,
_maximum: Option<usize>,
) -> wasmtime::Result<bool> {
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn record_peak_keeps_the_high_water_mark() {
let ledger = MemoryLedger::default();
let p = PrincipalId::default();
assert_eq!(ledger.peak(&p), 0);
ledger.record_peak(&p, 1000);
assert_eq!(ledger.peak(&p), 1000);
ledger.record_peak(&p, 500);
assert_eq!(ledger.peak(&p), 1000);
ledger.record_peak(&p, 4096);
assert_eq!(ledger.peak(&p), 4096);
ledger.record_peak(&p, 0);
assert_eq!(ledger.peak(&p), 4096);
}
#[test]
fn ledger_is_per_principal_and_shared_across_clones() {
let ledger = MemoryLedger::default();
let a = PrincipalId::new("alice").unwrap();
let b = PrincipalId::new("bob").unwrap();
ledger.record_peak(&a, 2048);
let clone = ledger.clone();
clone.record_peak(&b, 8192);
assert_eq!(ledger.peak(&a), 2048);
assert_eq!(ledger.peak(&b), 8192);
assert_eq!(clone.peak(&a), 2048);
}
#[test]
fn ledger_is_bounded_and_evicts_the_lowest_peak() {
let ledger = MemoryLedger::default();
for i in 0..MAX_PRINCIPALS {
let p = PrincipalId::new(format!("p{i}")).unwrap();
ledger.record_peak(&p, (i as u64) + 1);
}
assert_eq!(ledger.inner.len(), MAX_PRINCIPALS);
let lowest = PrincipalId::new("p0").unwrap();
assert_eq!(ledger.peak(&lowest), 1);
let newcomer = PrincipalId::new("newcomer").unwrap();
ledger.record_peak(&newcomer, 1_000_000);
assert!(ledger.inner.len() <= MAX_PRINCIPALS, "stays bounded");
assert_eq!(ledger.peak(&newcomer), 1_000_000, "newcomer recorded");
assert_eq!(ledger.peak(&lowest), 0, "lowest-peak principal evicted");
let p1 = PrincipalId::new("p1").unwrap();
assert_eq!(ledger.peak(&p1), 2, "p1 is now the lowest retained user");
let smaller = PrincipalId::new("smaller").unwrap();
ledger.record_peak(&smaller, 2);
assert_eq!(
ledger.peak(&smaller),
0,
"smaller newcomer dropped, not recorded"
);
assert_eq!(ledger.peak(&p1), 2, "existing bigger user retained");
}
#[test]
fn meter_enforces_ceiling_and_records_peak() {
use wasmtime::ResourceLimiter;
let ledger = MemoryLedger::default();
let p = PrincipalId::new("carol").unwrap();
let mut meter = StoreMemoryMeter::new(64 * 1024, p.clone(), ledger.clone());
assert!(meter.memory_growing(0, 16 * 1024, None).unwrap());
assert_eq!(ledger.peak(&p), 16 * 1024);
assert!(meter.memory_growing(16 * 1024, 48 * 1024, None).unwrap());
assert_eq!(ledger.peak(&p), 48 * 1024);
assert!(!meter.memory_growing(48 * 1024, 128 * 1024, None).unwrap());
assert_eq!(ledger.peak(&p), 48 * 1024);
let q = PrincipalId::new("dave").unwrap();
meter.set(256 * 1024, q.clone());
assert!(meter.memory_growing(0, 200 * 1024, None).unwrap());
assert_eq!(ledger.peak(&q), 200 * 1024);
assert_eq!(ledger.peak(&p), 48 * 1024);
}
}