use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use super::config::QuerySecurityConfig;
pub struct QueryGuard {
config: QuerySecurityConfig,
start_time: Instant,
result_count: AtomicUsize,
memory_usage: AtomicUsize,
check_interval: usize,
checks_performed: AtomicUsize,
}
impl QueryGuard {
#[must_use]
pub fn new(config: QuerySecurityConfig) -> Self {
Self {
config,
start_time: Instant::now(),
result_count: AtomicUsize::new(0),
memory_usage: AtomicUsize::new(0),
check_interval: 100, checks_performed: AtomicUsize::new(0),
}
}
#[must_use]
pub fn with_check_interval(config: QuerySecurityConfig, interval: usize) -> Self {
Self {
check_interval: interval.max(1), ..Self::new(config)
}
}
pub fn should_continue(&self) -> Result<(), QuerySecurityError> {
let elapsed = self.start_time.elapsed();
let timeout_limit = self.config.timeout();
if elapsed > timeout_limit {
return Err(QuerySecurityError::Timeout {
elapsed,
limit: timeout_limit,
});
}
let count = self.result_count.load(Ordering::Relaxed);
let result_limit = self.config.result_cap();
if count >= result_limit {
return Err(QuerySecurityError::ResultCapExceeded {
count,
limit: result_limit,
});
}
let checks = self.checks_performed.fetch_add(1, Ordering::Relaxed);
if checks.is_multiple_of(self.check_interval) {
let usage = self.memory_usage.load(Ordering::Relaxed);
let memory_limit = self.config.memory_limit();
if usage >= memory_limit {
return Err(QuerySecurityError::MemoryLimitExceeded {
usage,
limit: memory_limit,
});
}
}
Ok(())
}
pub fn record_result(&self, estimated_size: usize) {
self.result_count.fetch_add(1, Ordering::Relaxed);
self.memory_usage
.fetch_add(estimated_size, Ordering::Relaxed);
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
#[must_use]
pub fn result_count(&self) -> usize {
self.result_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.memory_usage.load(Ordering::Relaxed)
}
#[must_use]
pub fn config(&self) -> &QuerySecurityConfig {
&self.config
}
}
#[derive(Debug, thiserror::Error)]
pub enum QuerySecurityError {
#[error("Query timeout: {elapsed:?} exceeded {limit:?}")]
Timeout {
elapsed: Duration,
limit: Duration,
},
#[error("Result cap exceeded: {count} >= {limit}")]
ResultCapExceeded {
count: usize,
limit: usize,
},
#[error("Memory limit exceeded: {usage} bytes >= {limit} bytes")]
MemoryLimitExceeded {
usage: usize,
limit: usize,
},
#[error("Query cost exceeds limit: {estimated} > {limit}")]
CostLimitExceeded {
estimated: usize,
limit: usize,
},
}
impl QuerySecurityError {
#[must_use]
pub fn into_completion_status(self) -> QueryCompletionStatus {
match self {
Self::Timeout { elapsed, limit } => QueryCompletionStatus::TimedOut { elapsed, limit },
Self::ResultCapExceeded { count, limit } => {
QueryCompletionStatus::ResultCapReached { count, limit }
}
Self::MemoryLimitExceeded { usage, limit } => {
QueryCompletionStatus::MemoryLimitReached {
usage_bytes: usage,
limit_bytes: limit,
}
}
Self::CostLimitExceeded { .. } =>
{
QueryCompletionStatus::Complete
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryCompletionStatus {
Complete,
ResultCapReached {
count: usize,
limit: usize,
},
MemoryLimitReached {
usage_bytes: usize,
limit_bytes: usize,
},
TimedOut {
elapsed: Duration,
limit: Duration,
},
}
impl QueryCompletionStatus {
#[must_use]
pub fn is_complete(&self) -> bool {
matches!(self, Self::Complete)
}
#[must_use]
pub fn message(&self) -> String {
match self {
Self::Complete => "Query completed successfully".to_string(),
Self::ResultCapReached { count, limit } => {
format!(
"Results truncated: showing {count} of {limit}+ matches (result cap reached)"
)
}
Self::MemoryLimitReached {
usage_bytes,
limit_bytes,
} => {
format!(
"Results truncated: memory limit reached ({} of {} MB)",
usage_bytes / (1024 * 1024),
limit_bytes / (1024 * 1024)
)
}
Self::TimedOut { elapsed, limit } => {
format!(
"Results truncated: query timed out after {:.1}s (limit: {}s)",
elapsed.as_secs_f64(),
limit.as_secs()
)
}
}
}
#[must_use]
pub fn status_field(&self) -> &'static str {
match self {
Self::Complete => "complete",
Self::ResultCapReached { .. } => "result_cap_reached",
Self::MemoryLimitReached { .. } => "memory_limit_reached",
Self::TimedOut { .. } => "timed_out",
}
}
#[must_use]
pub fn exit_code(&self) -> i32 {
match self {
Self::Complete => 0,
_ => 2, }
}
}
#[derive(Debug)]
pub struct QueryResultSet<T> {
pub results: Vec<T>,
pub status: QueryCompletionStatus,
pub memory_usage_bytes: usize,
pub elapsed: Duration,
}
impl<T> QueryResultSet<T> {
#[must_use]
pub fn complete(results: Vec<T>, memory_usage_bytes: usize, elapsed: Duration) -> Self {
Self {
results,
status: QueryCompletionStatus::Complete,
memory_usage_bytes,
elapsed,
}
}
#[must_use]
pub fn truncated(
results: Vec<T>,
status: QueryCompletionStatus,
memory_usage_bytes: usize,
elapsed: Duration,
) -> Self {
Self {
results,
status,
memory_usage_bytes,
elapsed,
}
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.status.is_complete()
}
#[must_use]
pub fn len(&self) -> usize {
self.results.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_guard_initial_state() {
let guard = QueryGuard::new(QuerySecurityConfig::default());
assert_eq!(guard.result_count(), 0);
assert_eq!(guard.memory_usage(), 0);
assert!(guard.should_continue().is_ok());
}
#[test]
fn test_guard_record_result() {
let guard = QueryGuard::new(QuerySecurityConfig::default());
guard.record_result(1024);
assert_eq!(guard.result_count(), 1);
assert_eq!(guard.memory_usage(), 1024);
}
#[test]
fn test_guard_result_cap() {
let config = QuerySecurityConfig::default().with_result_cap(5);
let guard = QueryGuard::new(config);
for _ in 0..5 {
guard.record_result(100);
}
let err = guard.should_continue().unwrap_err();
assert!(matches!(
err,
QuerySecurityError::ResultCapExceeded { count: 5, limit: 5 }
));
}
#[test]
fn test_guard_memory_limit() {
let config = QuerySecurityConfig::default().with_memory_limit(1000);
let guard = QueryGuard::with_check_interval(config, 1);
guard.record_result(500);
assert!(guard.should_continue().is_ok());
guard.record_result(600);
let err = guard.should_continue().unwrap_err();
assert!(matches!(
err,
QuerySecurityError::MemoryLimitExceeded { .. }
));
}
#[test]
fn test_completion_status_messages() {
assert_eq!(
QueryCompletionStatus::Complete.message(),
"Query completed successfully"
);
let cap_status = QueryCompletionStatus::ResultCapReached {
count: 100,
limit: 100,
};
assert!(cap_status.message().contains("100"));
let mem_status = QueryCompletionStatus::MemoryLimitReached {
usage_bytes: 10 * 1024 * 1024,
limit_bytes: 10 * 1024 * 1024,
};
assert!(mem_status.message().contains("MB"));
let timeout_status = QueryCompletionStatus::TimedOut {
elapsed: Duration::from_secs(5),
limit: Duration::from_secs(5),
};
assert!(timeout_status.message().contains("timed out"));
}
#[test]
fn test_completion_status_is_complete() {
assert!(QueryCompletionStatus::Complete.is_complete());
assert!(
!QueryCompletionStatus::ResultCapReached {
count: 10,
limit: 10
}
.is_complete()
);
}
#[test]
fn test_error_to_status_conversion() {
let timeout_err = QuerySecurityError::Timeout {
elapsed: Duration::from_secs(10),
limit: Duration::from_secs(5),
};
assert!(matches!(
timeout_err.into_completion_status(),
QueryCompletionStatus::TimedOut { .. }
));
let cap_err = QuerySecurityError::ResultCapExceeded {
count: 100,
limit: 50,
};
assert!(matches!(
cap_err.into_completion_status(),
QueryCompletionStatus::ResultCapReached { .. }
));
}
#[test]
fn test_result_set_complete() {
let results = vec![1, 2, 3];
let set = QueryResultSet::complete(results, 100, Duration::from_millis(10));
assert!(set.is_complete());
assert_eq!(set.len(), 3);
assert!(!set.is_empty());
}
#[test]
fn test_result_set_truncated() {
let results = vec![1, 2];
let status = QueryCompletionStatus::ResultCapReached { count: 2, limit: 2 };
let set = QueryResultSet::truncated(results, status, 50, Duration::from_millis(5));
assert!(!set.is_complete());
assert_eq!(set.len(), 2);
}
#[test]
fn test_exit_codes() {
assert_eq!(QueryCompletionStatus::Complete.exit_code(), 0);
assert_eq!(
QueryCompletionStatus::ResultCapReached {
count: 10,
limit: 10
}
.exit_code(),
2
);
assert_eq!(
QueryCompletionStatus::TimedOut {
elapsed: Duration::from_secs(5),
limit: Duration::from_secs(5)
}
.exit_code(),
2
);
}
}