use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::error::{Error, Result};
use crate::security::{
MemoryLimits, CpuLimits, IoLimits, TimeLimits, ResourceLimits
};
#[derive(Debug, Clone)]
pub struct MemoryResourceTracker {
pub max_memory_pages: u32,
current_pages: Arc<AtomicU64>,
peak_pages: Arc<AtomicU64>,
growth_tracker: Arc<Mutex<MemoryGrowthTracker>>,
}
#[derive(Debug)]
struct MemoryGrowthTracker {
max_rate: Option<u32>,
last_size: u64,
last_check: Instant,
growth_events: Vec<(Instant, u64)>,
window: Duration,
}
impl MemoryResourceTracker {
pub fn new(limits: &MemoryLimits) -> Self {
let growth_tracker = MemoryGrowthTracker {
max_rate: limits.max_growth_rate,
last_size: 0,
last_check: Instant::now(),
growth_events: Vec::new(),
window: Duration::from_secs(1), };
Self {
max_memory_pages: limits.max_memory_pages,
current_pages: Arc::new(AtomicU64::new(limits.reserved_memory_pages as u64)),
peak_pages: Arc::new(AtomicU64::new(limits.reserved_memory_pages as u64)),
growth_tracker: Arc::new(Mutex::new(growth_tracker)),
}
}
pub fn check_allocation(&self, pages: u32) -> Result<()> {
let current = self.current_pages.load(Ordering::Acquire);
let requested = current + pages as u64;
if requested > self.max_memory_pages as u64 {
return Err(Error::ResourceLimit {
message: format!("Memory allocation of {} pages would exceed limit of {} pages",
pages, self.max_memory_pages)
});
}
if let Some(max_rate) = self.growth_tracker.lock().unwrap().max_rate {
let now = Instant::now();
let mut tracker = self.growth_tracker.lock().unwrap();
let cutoff = now - tracker.window;
tracker.growth_events.retain(|(time, _)| *time >= cutoff);
tracker.growth_events.push((now, pages as u64));
let total_growth: u64 = tracker.growth_events.iter().map(|(_, size)| *size).sum();
if total_growth > max_rate as u64 {
return Err(Error::ResourceLimit {
message: format!("Memory growth rate of {} pages/s exceeds limit of {} pages/s",
total_growth, max_rate)
});
}
tracker.last_size = requested;
tracker.last_check = now;
}
Ok(())
}
pub fn update(&self, pages: u32) {
let current = self.current_pages.fetch_add(pages as u64, Ordering::AcqRel) + pages as u64;
let mut peak = self.peak_pages.load(Ordering::Acquire);
while current > peak {
match self.peak_pages.compare_exchange_weak(
peak,
current,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(actual) => peak = actual,
}
}
}
pub fn current_pages(&self) -> u64 {
self.current_pages.load(Ordering::Acquire)
}
pub fn peak_pages(&self) -> u64 {
self.peak_pages.load(Ordering::Acquire)
}
pub fn reset_peak(&self) {
self.peak_pages.store(
self.current_pages.load(Ordering::Acquire),
Ordering::Release
);
}
}
#[derive(Debug, Clone)]
pub struct CpuResourceTracker {
pub max_execution_time: Duration,
pub cpu_usage_percentage: Option<u8>,
pub max_threads: Option<u32>,
start_time: Arc<Mutex<Option<Instant>>>,
total_time: Arc<AtomicU64>,
active_threads: Arc<AtomicU64>,
}
impl CpuResourceTracker {
pub fn new(limits: &CpuLimits) -> Self {
Self {
max_execution_time: Duration::from_millis(limits.max_execution_time_ms),
cpu_usage_percentage: limits.cpu_usage_percentage,
max_threads: limits.max_threads,
start_time: Arc::new(Mutex::new(None)),
total_time: Arc::new(AtomicU64::new(0)),
active_threads: Arc::new(AtomicU64::new(0)),
}
}
pub fn start_execution(&self) {
let mut start = self.start_time.lock().unwrap();
if start.is_none() {
*start = Some(Instant::now());
}
}
pub fn stop_execution(&self) {
let mut start_lock = self.start_time.lock().unwrap();
if let Some(start) = *start_lock {
let elapsed = start.elapsed();
self.total_time.fetch_add(elapsed.as_millis() as u64, Ordering::AcqRel);
*start_lock = None;
}
}
pub fn check_time_limit(&self) -> Result<()> {
let total = self.total_time.load(Ordering::Acquire);
let mut current_total = total;
let start_lock = self.start_time.lock().unwrap();
if let Some(start) = *start_lock {
current_total += start.elapsed().as_millis() as u64;
}
if current_total > self.max_execution_time.as_millis() as u64 {
return Err(Error::Timeout {
operation: "execution".to_string(),
duration: Duration::from_millis(current_total),
instance_id: None,
});
}
Ok(())
}
pub fn register_thread(&self) -> Result<()> {
if let Some(max) = self.max_threads {
let current = self.active_threads.fetch_add(1, Ordering::AcqRel) + 1;
if current > max as u64 {
self.active_threads.fetch_sub(1, Ordering::AcqRel);
return Err(Error::ResourceLimit {
message: format!("Thread limit of {} exceeded", max)
});
}
} else {
self.active_threads.fetch_add(1, Ordering::AcqRel);
}
Ok(())
}
pub fn unregister_thread(&self) {
self.active_threads.fetch_sub(1, Ordering::AcqRel);
}
pub fn total_time_ms(&self) -> u64 {
let total = self.total_time.load(Ordering::Acquire);
let mut current_total = total;
let start_lock = self.start_time.lock().unwrap();
if let Some(start) = *start_lock {
current_total += start.elapsed().as_millis() as u64;
}
current_total
}
pub fn active_threads(&self) -> u32 {
self.active_threads.load(Ordering::Acquire) as u32
}
pub fn apply_throttling(&self) {
if let Some(percentage) = self.cpu_usage_percentage {
if percentage >= 100 {
return; }
if percentage > 0 {
let sleep_time_ns = (100 - percentage) as u64 * 10_000; std::thread::sleep(Duration::from_nanos(sleep_time_ns));
}
}
}
}
#[derive(Debug, Clone)]
pub struct IoResourceTracker {
pub max_open_files: u32,
pub max_read_bytes_per_second: Option<u64>,
pub max_write_bytes_per_second: Option<u64>,
pub max_total_read_bytes: Option<u64>,
pub max_total_write_bytes: Option<u64>,
open_files: Arc<AtomicU64>,
total_read: Arc<AtomicU64>,
total_write: Arc<AtomicU64>,
rate_tracker: Arc<Mutex<IoRateTracker>>,
}
#[derive(Debug)]
struct IoRateTracker {
read_events: Vec<(Instant, u64)>,
write_events: Vec<(Instant, u64)>,
window: Duration,
}
impl IoResourceTracker {
pub fn new(limits: &IoLimits) -> Self {
let rate_tracker = IoRateTracker {
read_events: Vec::new(),
write_events: Vec::new(),
window: Duration::from_secs(1), };
Self {
max_open_files: limits.max_open_files,
max_read_bytes_per_second: limits.max_read_bytes_per_second,
max_write_bytes_per_second: limits.max_write_bytes_per_second,
max_total_read_bytes: limits.max_total_read_bytes,
max_total_write_bytes: limits.max_total_write_bytes,
open_files: Arc::new(AtomicU64::new(0)),
total_read: Arc::new(AtomicU64::new(0)),
total_write: Arc::new(AtomicU64::new(0)),
rate_tracker: Arc::new(Mutex::new(rate_tracker)),
}
}
pub fn register_open(&self) -> Result<()> {
let current = self.open_files.fetch_add(1, Ordering::AcqRel) + 1;
if current > self.max_open_files as u64 {
self.open_files.fetch_sub(1, Ordering::AcqRel);
return Err(Error::ResourceLimit {
message: format!("Open file limit of {} exceeded", self.max_open_files)
});
}
Ok(())
}
pub fn register_close(&self) {
self.open_files.fetch_sub(1, Ordering::AcqRel);
}
pub fn register_read(&self, bytes: u64) -> Result<()> {
let total = self.total_read.fetch_add(bytes, Ordering::AcqRel) + bytes;
if let Some(limit) = self.max_total_read_bytes {
if total > limit {
return Err(Error::ResourceLimit {
message: format!("Total read limit of {} bytes exceeded", limit)
});
}
}
if let Some(rate_limit) = self.max_read_bytes_per_second {
let now = Instant::now();
let mut tracker = self.rate_tracker.lock().unwrap();
let cutoff = now - tracker.window;
tracker.read_events.retain(|(time, _)| *time >= cutoff);
tracker.read_events.push((now, bytes));
let window_total: u64 = tracker.read_events.iter().map(|(_, size)| *size).sum();
if window_total > rate_limit {
return Err(Error::ResourceLimit {
message: format!("Read rate limit of {} bytes/s exceeded", rate_limit)
});
}
}
Ok(())
}
pub fn register_write(&self, bytes: u64) -> Result<()> {
let total = self.total_write.fetch_add(bytes, Ordering::AcqRel) + bytes;
if let Some(limit) = self.max_total_write_bytes {
if total > limit {
return Err(Error::ResourceLimit {
message: format!("Total write limit of {} bytes exceeded", limit)
});
}
}
if let Some(rate_limit) = self.max_write_bytes_per_second {
let now = Instant::now();
let mut tracker = self.rate_tracker.lock().unwrap();
let cutoff = now - tracker.window;
tracker.write_events.retain(|(time, _)| *time >= cutoff);
tracker.write_events.push((now, bytes));
let window_total: u64 = tracker.write_events.iter().map(|(_, size)| *size).sum();
if window_total > rate_limit {
return Err(Error::ResourceLimit {
message: format!("Write rate limit of {} bytes/s exceeded", rate_limit)
});
}
}
Ok(())
}
pub fn open_files(&self) -> u32 {
self.open_files.load(Ordering::Acquire) as u32
}
pub fn total_read(&self) -> u64 {
self.total_read.load(Ordering::Acquire)
}
pub fn total_write(&self) -> u64 {
self.total_write.load(Ordering::Acquire)
}
pub fn read_rate(&self) -> u64 {
let tracker = self.rate_tracker.lock().unwrap();
let now = Instant::now();
let cutoff = now - tracker.window;
tracker.read_events
.iter()
.filter(|(time, _)| *time >= cutoff)
.map(|(_, size)| *size)
.sum()
}
pub fn write_rate(&self) -> u64 {
let tracker = self.rate_tracker.lock().unwrap();
let now = Instant::now();
let cutoff = now - tracker.window;
tracker.write_events
.iter()
.filter(|(time, _)| *time >= cutoff)
.map(|(_, size)| *size)
.sum()
}
}
#[derive(Debug, Clone)]
pub struct TimeResourceTracker {
pub max_total_time: Duration,
pub max_idle_time: Option<Duration>,
start_time: Arc<Mutex<Instant>>,
last_activity: Arc<Mutex<Instant>>,
}
impl TimeResourceTracker {
pub fn new(limits: &TimeLimits) -> Self {
let now = Instant::now();
Self {
max_total_time: Duration::from_millis(limits.max_total_time_ms),
max_idle_time: limits.max_idle_time_ms.map(Duration::from_millis),
start_time: Arc::new(Mutex::new(now)),
last_activity: Arc::new(Mutex::new(now)),
}
}
pub fn register_activity(&self) {
*self.last_activity.lock().unwrap() = Instant::now();
}
pub fn check_limits(&self) -> Result<()> {
let now = Instant::now();
let elapsed = now.duration_since(*self.start_time.lock().unwrap());
if elapsed > self.max_total_time {
return Err(Error::Timeout {
operation: "total time".to_string(),
duration: elapsed,
instance_id: None,
});
}
if let Some(idle_limit) = self.max_idle_time {
let idle_time = now.duration_since(*self.last_activity.lock().unwrap());
if idle_time > idle_limit {
return Err(Error::ResourceLimit {
message: format!("Idle time limit of {}ms exceeded", idle_limit.as_millis())
});
}
}
Ok(())
}
pub fn elapsed_ms(&self) -> u64 {
let now = Instant::now();
now.duration_since(*self.start_time.lock().unwrap()).as_millis() as u64
}
pub fn idle_ms(&self) -> u64 {
let now = Instant::now();
now.duration_since(*self.last_activity.lock().unwrap()).as_millis() as u64
}
}
#[derive(Debug, Clone)]
pub struct ResourceLimitManager {
pub memory: MemoryResourceTracker,
pub cpu: CpuResourceTracker,
pub io: IoResourceTracker,
pub time: TimeResourceTracker,
pub fuel: Option<Arc<AtomicU64>>,
}
impl ResourceLimitManager {
pub fn new(limits: &ResourceLimits) -> Self {
let fuel = limits.fuel.map(|f| Arc::new(AtomicU64::new(f)));
Self {
memory: MemoryResourceTracker::new(&limits.memory),
cpu: CpuResourceTracker::new(&limits.cpu),
io: IoResourceTracker::new(&limits.io),
time: TimeResourceTracker::new(&limits.time),
fuel,
}
}
pub fn check_all_limits(&self) -> Result<()> {
self.time.check_limits()?;
self.cpu.check_time_limit()?;
if let Some(fuel) = &self.fuel {
if fuel.load(Ordering::Acquire) == 0 {
return Err(Error::ResourceLimit {
message: "Fuel limit exceeded".to_string()
});
}
}
Ok(())
}
pub fn consume_fuel(&self, amount: u64) -> Result<()> {
if let Some(fuel) = &self.fuel {
let current = fuel.load(Ordering::Acquire);
if current < amount {
return Err(Error::ResourceLimit {
message: format!("Not enough fuel: requested {}, available {}", amount, current)
});
}
fuel.fetch_sub(amount, Ordering::AcqRel);
}
Ok(())
}
pub fn add_fuel(&self, amount: u64) -> Result<()> {
if let Some(fuel) = &self.fuel {
fuel.fetch_add(amount, Ordering::AcqRel);
Ok(())
} else {
Err(Error::UnsupportedOperation {
message: "Fuel metering is not enabled".to_string()
})
}
}
pub fn reset_fuel(&self, amount: u64) -> Result<()> {
if let Some(fuel) = &self.fuel {
fuel.store(amount, Ordering::Release);
Ok(())
} else {
Err(Error::UnsupportedOperation {
message: "Fuel metering is not enabled".to_string()
})
}
}
pub fn get_remaining_fuel(&self) -> Option<u64> {
self.fuel.as_ref().map(|f| f.load(Ordering::Acquire))
}
pub fn start_monitor(&self) -> std::thread::JoinHandle<()> {
let cpu_tracker = self.cpu.clone();
let time_tracker = self.time.clone();
std::thread::spawn(move || {
let check_interval = Duration::from_millis(100);
loop {
std::thread::sleep(check_interval);
time_tracker.register_activity();
cpu_tracker.apply_throttling();
let _ = time_tracker.check_limits();
let _ = cpu_tracker.check_time_limit();
}
})
}
}