use crate::trace::{TraceStep, TRACER};
use super::error::GpuError;
use super::ledger::VramLedger;
use super::profiler::GpuProfiler;
use super::wait::{self, WaitConfig};
pub struct VramGuard {
ledger: VramLedger,
budget_mb: usize,
}
impl VramGuard {
pub fn acquire(budget_mb: usize, task: &str) -> Result<Self, GpuError> {
TRACER.span(TraceStep::VramQuery, format!("guard_acquire budget={budget_mb}MB"), || {
let mut ledger = super::ledger::auto_ledger();
ledger.try_reserve(budget_mb, task)?;
Ok(Self { ledger, budget_mb })
})
}
pub fn acquire_wait(budget_mb: usize, task: &str, timeout_secs: u64) -> Result<Self, GpuError> {
TRACER.span(
TraceStep::WaitPoll,
format!("guard_wait budget={budget_mb}MB timeout={timeout_secs}s"),
|| {
let mut ledger = super::ledger::auto_ledger();
let config = WaitConfig::with_timeout_secs(timeout_secs);
let mut profiler = GpuProfiler::disabled();
wait::wait_for_vram(&mut ledger, budget_mb, task, &config, &mut profiler)?;
Ok(Self { ledger, budget_mb })
},
)
}
pub fn update_actual(&mut self, actual_mb: usize) -> Result<(), GpuError> {
self.ledger.update_actual(actual_mb)
}
pub fn budget_mb(&self) -> usize {
self.budget_mb
}
pub fn gpu_uuid(&self) -> &str {
&self.ledger.gpu_uuid
}
pub fn status(&self) -> Result<String, GpuError> {
super::ledger::gpu_status_display(&self.ledger)
}
pub fn check_overshoot(&self) -> Option<(usize, usize)> {
let actual = self
.ledger
.read_reservations()
.ok()?
.iter()
.filter(|r| r.pid == std::process::id())
.find_map(|r| r.actual_mb)?;
if actual > self.budget_mb {
Some((actual, self.budget_mb))
} else {
None
}
}
pub fn auto_estimate_budget(param_count: usize) -> usize {
let weight_mb = (param_count * 2) / (1024 * 1024);
let overhead = weight_mb / 5; weight_mb + overhead
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
fn test_guard_ledger(total_mb: usize) -> VramLedger {
let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join("entrenar-guard-test");
std::fs::create_dir_all(&dir).expect("dir creation should succeed");
let path = dir.join(format!("guard-{n}-{}.json", std::process::id()));
VramLedger::new("GPU-test-guard".into(), total_mb, 0.85).with_path(path)
}
#[test]
fn test_guard_direct_acquire() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(5000, "guard-test").expect("should succeed");
assert_eq!(ledger.total_reserved().expect("should succeed"), 5000);
drop(ledger);
}
#[test]
fn test_guard_update_actual() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(8000, "guard-actual").expect("should succeed");
ledger.update_actual(7200).expect("should succeed");
assert_eq!(ledger.total_reserved().expect("should succeed"), 7200);
}
#[test]
fn test_guard_rejects_over_budget() {
let mut ledger = test_guard_ledger(10000);
let result = ledger.try_reserve(9000, "too-big");
assert!(result.is_err());
}
#[test]
fn test_guard_budget_mb() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 8000 };
assert_eq!(guard.budget_mb(), 8000);
}
#[test]
fn test_guard_gpu_uuid() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 5000 };
assert_eq!(guard.gpu_uuid(), "GPU-test-guard");
}
#[test]
fn test_guard_status() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 5000 };
let result = guard.status();
let _ = result;
}
#[test]
fn test_guard_update_actual_without_reservation() {
let ledger = test_guard_ledger(24000);
let mut guard = VramGuard { ledger, budget_mb: 5000 };
let result = guard.update_actual(4000);
assert!(result.is_ok());
}
#[test]
fn test_guard_multiple_reservations_sequential() {
let mut ledger1 = test_guard_ledger(24000);
ledger1.try_reserve(3000, "task-1").expect("should succeed");
let reserved = ledger1.total_reserved().expect("should succeed");
assert_eq!(reserved, 3000);
drop(ledger1);
}
#[test]
fn test_guard_zero_budget() {
let mut ledger = test_guard_ledger(24000);
let result = ledger.try_reserve(0, "zero-budget");
assert!(result.is_ok());
}
#[test]
fn test_guard_exact_budget() {
let mut ledger = test_guard_ledger(10000);
let result = ledger.try_reserve(8000, "near-limit");
assert!(result.is_ok());
}
#[test]
fn test_guard_update_actual_reduces_reserved() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(8000, "actual-test").expect("should succeed");
assert_eq!(ledger.total_reserved().expect("should succeed"), 8000);
ledger.update_actual(6000).expect("should succeed");
assert_eq!(ledger.total_reserved().expect("should succeed"), 6000);
}
#[test]
fn test_guard_struct_fields() {
let ledger = test_guard_ledger(16000);
let guard = VramGuard { ledger, budget_mb: 4000 };
assert_eq!(guard.budget_mb(), 4000);
assert_eq!(guard.gpu_uuid(), "GPU-test-guard");
}
#[test]
fn test_guard_status_returns_string() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(5000, "status-test").expect("should succeed");
let guard = VramGuard { ledger, budget_mb: 5000 };
let status = guard.status();
assert!(status.is_ok());
let status_str = status.unwrap();
assert!(status_str.contains("GPU-test-guard"));
assert!(status_str.contains("5000 MB budget"));
}
#[test]
fn test_guard_status_empty_ledger() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 0 };
let status = guard.status();
assert!(status.is_ok());
let s = status.unwrap();
assert!(s.contains("none") || s.contains("Reservations"));
}
#[test]
fn test_guard_update_actual_with_active_reservation() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(10000, "update-actual").expect("should succeed");
let mut guard = VramGuard { ledger, budget_mb: 10000 };
let result = guard.update_actual(9500);
assert!(result.is_ok());
}
#[test]
fn test_guard_small_gpu() {
let mut ledger = test_guard_ledger(2048);
let result = ledger.try_reserve(1740, "small-gpu");
assert!(result.is_ok());
let result2 = ledger.try_reserve(1, "overflow");
assert!(result2.is_err());
}
#[test]
fn test_guard_capacity_calculation() {
let ledger = test_guard_ledger(10000);
assert_eq!(ledger.capacity_mb(), 8500);
}
#[test]
fn test_guard_available_mb_after_reserve() {
let mut ledger = test_guard_ledger(20000);
ledger.try_reserve(7000, "test").expect("should succeed");
let available = ledger.available_mb().expect("should succeed");
assert_eq!(available, 10000);
}
#[test]
fn test_guard_multiple_sequential_reserve_release() {
let mut ledger = test_guard_ledger(24000);
ledger.try_reserve(5000, "first").expect("ok");
assert_eq!(ledger.total_reserved().expect("ok"), 5000);
ledger.release().expect("ok");
assert_eq!(ledger.total_reserved().expect("ok"), 0);
ledger.try_reserve(8000, "second").expect("ok");
assert_eq!(ledger.total_reserved().expect("ok"), 8000);
}
#[test]
fn test_guard_profiler_report_accessible() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 0 };
let report = guard.ledger.profiler_report();
assert!(report.contains("No operations recorded"));
}
#[test]
fn test_guard_drop_does_not_panic() {
let ledger = test_guard_ledger(24000);
let guard = VramGuard { ledger, budget_mb: 3000 };
drop(guard);
}
#[test]
fn test_guard_drop_with_reservation_releases() {
let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join("entrenar-guard-test");
std::fs::create_dir_all(&dir).expect("dir creation should succeed");
let path = dir.join(format!("guard-drop-{n}-{}.json", std::process::id()));
{
let mut ledger =
VramLedger::new("GPU-test-guard".into(), 24000, 0.85).with_path(path.clone());
ledger.try_reserve(5000, "drop-reserve").expect("ok");
let guard = VramGuard { ledger, budget_mb: 5000 };
drop(guard);
}
let check_ledger = VramLedger::new("GPU-test-guard".into(), 24000, 0.85).with_path(path);
let reserved = check_ledger.total_reserved().expect("ok");
assert_eq!(reserved, 0);
}
}