use std::time::{Duration, Instant};
use super::error::GpuError;
use super::ledger::VramLedger;
use super::profiler::GpuProfiler;
pub struct WaitConfig {
pub timeout: Duration,
pub base_interval: Duration,
pub max_interval: Duration,
}
impl Default for WaitConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(3600), base_interval: Duration::from_secs(30), max_interval: Duration::from_secs(300), }
}
}
impl WaitConfig {
pub fn with_timeout_secs(secs: u64) -> Self {
Self { timeout: Duration::from_secs(secs), ..Default::default() }
}
fn interval_for_attempt(&self, attempt: u32) -> Duration {
let multiplier = 2u64.saturating_pow(attempt);
let interval_secs = self.base_interval.as_secs().saturating_mul(multiplier);
Duration::from_secs(interval_secs.min(self.max_interval.as_secs()))
}
}
pub fn wait_for_vram(
ledger: &mut VramLedger,
budget_mb: usize,
task: &str,
config: &WaitConfig,
profiler: &mut GpuProfiler,
) -> Result<u64, GpuError> {
let start = Instant::now();
let mut attempt: u32 = 0;
loop {
if start.elapsed() > config.timeout {
return Err(GpuError::Timeout { budget_mb, timeout_secs: config.timeout.as_secs() });
}
profiler.begin(GpuProfiler::WAIT_POLL);
let result = ledger.try_reserve(budget_mb, task);
profiler.end(GpuProfiler::WAIT_POLL);
match result {
Ok(reservation_id) => {
profiler.finish_op();
return Ok(reservation_id);
}
Err(GpuError::InsufficientMemory { available_mb, reserved_mb, .. }) => {
let elapsed = start.elapsed();
let remaining = config.timeout.saturating_sub(elapsed);
eprintln!(
"[GPU] Waiting for {} MB VRAM ({} MB available, {} MB reserved) \
[{:.0}s elapsed, {:.0}s remaining]",
budget_mb,
available_mb,
reserved_mb,
elapsed.as_secs_f64(),
remaining.as_secs_f64(),
);
let interval = config.interval_for_attempt(attempt);
let sleep_time = interval.min(remaining);
std::thread::sleep(sleep_time);
attempt = attempt.saturating_add(1);
}
Err(e) => return Err(e),
}
}
}
pub fn timeout_bound(config: &WaitConfig) -> Duration {
config.timeout + config.max_interval
}
pub fn fairness_via_expiry(ledger: &mut VramLedger) -> Vec<u32> {
ledger
.read_reservations()
.unwrap_or_default()
.iter()
.filter(|r| r.is_expired())
.map(|r| r.pid)
.collect()
}
pub struct WaitProgress {
pub attempt: u32,
pub elapsed: Duration,
pub remaining: Duration,
pub budget_mb: usize,
pub available_mb: usize,
pub reserved_mb: usize,
}
pub fn progress_report(
config: &WaitConfig,
start: Instant,
attempt: u32,
budget_mb: usize,
available_mb: usize,
reserved_mb: usize,
) -> WaitProgress {
let elapsed = start.elapsed();
let remaining = config.timeout.saturating_sub(elapsed);
WaitProgress { attempt, elapsed, remaining, budget_mb, available_mb, reserved_mb }
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
fn test_ledger_path() -> PathBuf {
let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join("entrenar-wait-test");
std::fs::create_dir_all(&dir).unwrap();
dir.join(format!("wait-ledger-{n}-{}.json", std::process::id()))
}
fn cleanup(path: &std::path::Path) {
let _ = std::fs::remove_file(path);
let _ = std::fs::remove_file(path.with_extension("tmp"));
}
#[test]
fn test_immediate_success() {
let path = test_ledger_path();
let mut ledger = VramLedger::new("GPU-test".into(), 24000, 0.85).with_path(path.clone());
let mut profiler = GpuProfiler::disabled();
let config = WaitConfig::with_timeout_secs(5);
let id = wait_for_vram(&mut ledger, 5000, "test", &config, &mut profiler).unwrap();
assert!(id != 0);
cleanup(&path);
}
#[test]
fn test_timeout_when_full() {
let path = test_ledger_path();
let mut ledger = VramLedger::new("GPU-test".into(), 10000, 0.85).with_path(path.clone());
ledger.try_reserve(8000, "blocker").unwrap();
let mut profiler = GpuProfiler::disabled();
let config = WaitConfig {
timeout: Duration::from_millis(100),
base_interval: Duration::from_millis(50),
max_interval: Duration::from_millis(100),
};
let result = wait_for_vram(&mut ledger, 5000, "waiter", &config, &mut profiler);
assert!(result.is_err());
match result.unwrap_err() {
GpuError::Timeout { budget_mb, .. } => assert_eq!(budget_mb, 5000),
other => panic!("expected Timeout, got {other}"),
}
cleanup(&path);
}
#[test]
fn test_interval_exponential_backoff() {
let config = WaitConfig {
base_interval: Duration::from_secs(30),
max_interval: Duration::from_secs(300),
..Default::default()
};
assert_eq!(config.interval_for_attempt(0), Duration::from_secs(30));
assert_eq!(config.interval_for_attempt(1), Duration::from_secs(60));
assert_eq!(config.interval_for_attempt(2), Duration::from_secs(120));
assert_eq!(config.interval_for_attempt(3), Duration::from_secs(240));
assert_eq!(config.interval_for_attempt(4), Duration::from_secs(300)); assert_eq!(config.interval_for_attempt(10), Duration::from_secs(300)); }
#[test]
fn test_expired_lease_unblocks_waiter() {
let path = test_ledger_path();
let mut blocker = VramLedger::new("GPU-test".into(), 10000, 0.85)
.with_path(path.clone())
.with_lease_hours(0);
blocker.try_reserve(8000, "expiring").unwrap();
blocker.our_reservation_id = None;
std::thread::sleep(Duration::from_millis(10));
let mut waiter = VramLedger::new("GPU-test".into(), 10000, 0.85).with_path(path.clone());
let mut profiler = GpuProfiler::disabled();
let config = WaitConfig::with_timeout_secs(5);
let id = wait_for_vram(&mut waiter, 5000, "waiter", &config, &mut profiler).unwrap();
assert!(id != 0);
cleanup(&path);
}
}