use std::cell::Cell;
use powdb_storage::types::Value;
use crate::result::QueryError;
pub const DEFAULT_QUERY_MEMORY_LIMIT: usize = 256 * 1024 * 1024;
thread_local! {
static USED_BYTES: Cell<usize> = const { Cell::new(0) };
static DEPTH: Cell<u32> = const { Cell::new(0) };
}
#[inline]
#[must_use = "the guard must be held for the duration of the statement"]
pub fn enter() -> EnterGuard {
DEPTH.with(|d| {
let depth = d.get();
if depth == 0 {
USED_BYTES.with(|u| u.set(0));
}
d.set(depth + 1);
});
EnterGuard
}
pub struct EnterGuard;
impl Drop for EnterGuard {
#[inline]
fn drop(&mut self) {
DEPTH.with(|d| d.set(d.get().saturating_sub(1)));
}
}
#[cfg(test)]
pub fn reset() {
USED_BYTES.with(|u| u.set(0));
DEPTH.with(|d| d.set(0));
}
#[inline]
pub fn charge(bytes: usize, limit_bytes: usize) -> Result<(), QueryError> {
USED_BYTES.with(|u| {
let requested = u.get().saturating_add(bytes);
if requested > limit_bytes {
return Err(QueryError::MemoryLimitExceeded {
limit_bytes,
requested_bytes: requested,
});
}
u.set(requested);
Ok(())
})
}
#[cfg(test)]
pub fn used() -> usize {
USED_BYTES.with(|u| u.get())
}
#[inline]
pub fn estimate_value_size(v: &Value) -> usize {
let base = std::mem::size_of::<Value>();
let heap = match v {
Value::Str(s) => s.capacity(),
Value::Bytes(b) => b.capacity(),
_ => 0,
};
base + heap
}
#[inline]
pub fn estimate_row_size(row: &[Value]) -> usize {
let mut total = std::mem::size_of::<Vec<Value>>();
for v in row {
total += estimate_value_size(v);
}
total
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn charge_under_limit_succeeds() {
reset();
assert!(charge(512, 1024).is_ok());
assert!(charge(512, 1024).is_ok());
assert_eq!(used(), 1024);
}
#[test]
fn charge_over_limit_errors_without_charging() {
reset();
assert!(charge(512, 1024).is_ok());
let err = charge(1024, 1024).unwrap_err();
match err {
QueryError::MemoryLimitExceeded {
limit_bytes,
requested_bytes,
} => {
assert_eq!(limit_bytes, 1024);
assert_eq!(requested_bytes, 1536);
}
other => panic!("expected MemoryLimitExceeded, got {other:?}"),
}
assert_eq!(used(), 512);
}
#[test]
fn nested_enter_preserves_outer_accounting() {
reset();
let outer = enter();
assert_eq!(used(), 0, "outermost enter zeroes the accumulator");
charge(1000, 1_000_000).unwrap();
assert_eq!(used(), 1000);
{
let _inner = enter();
assert_eq!(used(), 1000, "nested enter must not discard outer bytes");
charge(500, 1_000_000).unwrap();
assert_eq!(used(), 1500);
}
assert_eq!(used(), 1500, "outer accounting includes nested charges");
drop(outer);
let _next = enter();
assert_eq!(used(), 0, "next outermost statement resets");
}
#[test]
fn string_value_counts_heap_bytes() {
let small = estimate_value_size(&Value::Int(1));
let big = estimate_value_size(&Value::Str("x".repeat(10_000)));
assert!(big >= small + 10_000);
}
}