use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use thiserror::Error;
pub const DEFAULT_BUDGET_ROWS: u64 = 5_000_000;
pub const ENV_TOOL_BUDGET_ROWS: &str = "SQRY_TOOL_BUDGET_ROWS";
pub const DEFAULT_CHECK_STRIDE: u64 = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CancellationSource {
None = 0,
Budget = 1,
External = 2,
}
impl CancellationSource {
#[inline]
#[must_use]
pub fn from_u8(v: u8) -> Self {
match v {
1 => CancellationSource::Budget,
2 => CancellationSource::External,
_ => CancellationSource::None,
}
}
}
#[derive(Debug, Clone)]
pub struct QueryBudget {
pub max_rows: u64,
pub examined: Arc<AtomicU64>,
pub cancel: crate::query::cancellation::CancellationToken,
pub state: Arc<AtomicU8>,
pub check_stride: u64,
}
impl QueryBudget {
#[must_use]
pub fn new(max_rows: u64, cancel: crate::query::cancellation::CancellationToken) -> Self {
Self {
max_rows,
examined: Arc::new(AtomicU64::new(0)),
cancel,
state: Arc::new(AtomicU8::new(CancellationSource::None as u8)),
check_stride: DEFAULT_CHECK_STRIDE,
}
}
#[must_use]
pub fn unbounded(cancel: crate::query::cancellation::CancellationToken) -> Self {
Self::new(u64::MAX, cancel)
}
#[must_use]
pub fn from_per_call_or_env(
per_call_budget: Option<u64>,
cancel: crate::query::cancellation::CancellationToken,
) -> Self {
if let Some(rows) = per_call_budget
&& rows > 0
{
return Self::new(rows, cancel);
}
let env_rows = std::env::var(ENV_TOOL_BUDGET_ROWS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|n| *n > 0);
let max_rows = env_rows.unwrap_or(DEFAULT_BUDGET_ROWS);
Self::new(max_rows, cancel)
}
#[inline]
#[must_use]
pub fn exceeded(&self) -> bool {
self.examined.load(Ordering::Relaxed) >= self.max_rows
}
#[inline]
#[must_use]
pub fn source(&self) -> CancellationSource {
CancellationSource::from_u8(self.state.load(Ordering::Acquire))
}
#[inline]
pub fn mark_external_cancel(&self) -> bool {
self.state
.compare_exchange(
CancellationSource::None as u8,
CancellationSource::External as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
#[inline]
fn mark_budget_cancel(&self) -> bool {
self.state
.compare_exchange(
CancellationSource::None as u8,
CancellationSource::Budget as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
#[inline]
pub fn tick(&self) -> Result<(), BudgetExceeded> {
let prev = self.examined.fetch_add(1, Ordering::Relaxed);
if prev + 1 >= self.max_rows {
self.mark_budget_cancel();
self.cancel.cancel();
return Err(BudgetExceeded {
examined: prev + 1,
limit: self.max_rows,
predicate_shape: None,
});
}
Ok(())
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[error("query exceeded row budget: examined {examined} rows, limit {limit}")]
pub struct BudgetExceeded {
pub examined: u64,
pub limit: u64,
pub predicate_shape: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::cancellation::CancellationToken;
#[test]
fn tick_below_max_returns_ok() {
let token = CancellationToken::new();
let budget = QueryBudget::new(10, token.clone());
for _ in 0..9 {
budget.tick().expect("first 9 ticks must succeed");
}
assert!(!budget.exceeded(), "9 ticks must not exceed budget of 10");
assert!(!token.is_cancelled(), "token must remain uncancelled");
assert_eq!(budget.source(), CancellationSource::None);
}
#[test]
fn tick_at_max_trips_cancel_and_stamps_budget_source() {
let token = CancellationToken::new();
let budget = QueryBudget::new(3, token.clone());
budget.tick().expect("tick 1 ok");
budget.tick().expect("tick 2 ok");
let err = budget.tick().expect_err("tick 3 must trip");
assert_eq!(err.examined, 3);
assert_eq!(err.limit, 3);
assert!(token.is_cancelled(), "tick must flip the token");
assert_eq!(
budget.source(),
CancellationSource::Budget,
"budget overflow must stamp source = Budget"
);
}
#[test]
fn external_cancel_first_blocks_budget_tag() {
let token = CancellationToken::new();
let budget = QueryBudget::new(3, token.clone());
assert!(budget.mark_external_cancel(), "External wins CAS");
budget.tick().expect("first tick ok");
budget.tick().expect("second tick ok");
let _ = budget.tick();
assert_eq!(
budget.source(),
CancellationSource::External,
"external-first must keep tag = External even after budget overflow"
);
}
#[test]
fn budget_cancel_first_blocks_external_tag() {
let token = CancellationToken::new();
let budget = QueryBudget::new(2, token.clone());
budget.tick().expect("first tick ok");
let _ = budget.tick(); assert_eq!(budget.source(), CancellationSource::Budget);
assert!(
!budget.mark_external_cancel(),
"external-second CAS must fail"
);
assert_eq!(budget.source(), CancellationSource::Budget);
}
#[test]
fn from_per_call_prefers_per_call_value_over_env() {
unsafe {
std::env::set_var(ENV_TOOL_BUDGET_ROWS, "999");
}
let token = CancellationToken::new();
let budget = QueryBudget::from_per_call_or_env(Some(42), token);
assert_eq!(budget.max_rows, 42, "per-call value must override env var");
unsafe {
std::env::remove_var(ENV_TOOL_BUDGET_ROWS);
}
}
#[test]
fn from_per_call_zero_falls_back_to_default() {
unsafe {
std::env::remove_var(ENV_TOOL_BUDGET_ROWS);
}
let token = CancellationToken::new();
let budget = QueryBudget::from_per_call_or_env(Some(0), token);
assert_eq!(
budget.max_rows, DEFAULT_BUDGET_ROWS,
"per-call zero must map to the default rather than trip immediately"
);
}
#[test]
fn from_per_call_none_uses_default_when_env_unset() {
unsafe {
std::env::remove_var(ENV_TOOL_BUDGET_ROWS);
}
let token = CancellationToken::new();
let budget = QueryBudget::from_per_call_or_env(None, token);
assert_eq!(budget.max_rows, DEFAULT_BUDGET_ROWS);
}
#[test]
fn unbounded_budget_never_trips_on_realistic_iteration_count() {
let token = CancellationToken::new();
let budget = QueryBudget::unbounded(token.clone());
for _ in 0..1_000 {
budget.tick().expect("unbounded must not trip");
}
assert!(!token.is_cancelled());
}
}