1use std::time::{Duration, Instant};
13
14use super::error::GpuError;
15use super::ledger::VramLedger;
16use super::profiler::GpuProfiler;
17
18pub struct WaitConfig {
20 pub timeout: Duration,
22 pub base_interval: Duration,
24 pub max_interval: Duration,
26}
27
28impl Default for WaitConfig {
29 fn default() -> Self {
30 Self {
31 timeout: Duration::from_secs(3600), base_interval: Duration::from_secs(30), max_interval: Duration::from_secs(300), }
35 }
36}
37
38impl WaitConfig {
39 pub fn with_timeout_secs(secs: u64) -> Self {
41 Self { timeout: Duration::from_secs(secs), ..Default::default() }
42 }
43
44 fn interval_for_attempt(&self, attempt: u32) -> Duration {
46 let multiplier = 2u64.saturating_pow(attempt);
47 let interval_secs = self.base_interval.as_secs().saturating_mul(multiplier);
48 Duration::from_secs(interval_secs.min(self.max_interval.as_secs()))
49 }
50}
51
52pub fn wait_for_vram(
57 ledger: &mut VramLedger,
58 budget_mb: usize,
59 task: &str,
60 config: &WaitConfig,
61 profiler: &mut GpuProfiler,
62) -> Result<u64, GpuError> {
63 let start = Instant::now();
64 let mut attempt: u32 = 0;
65
66 loop {
67 if start.elapsed() > config.timeout {
69 return Err(GpuError::Timeout { budget_mb, timeout_secs: config.timeout.as_secs() });
70 }
71
72 profiler.begin(GpuProfiler::WAIT_POLL);
74 let result = ledger.try_reserve(budget_mb, task);
75 profiler.end(GpuProfiler::WAIT_POLL);
76
77 match result {
78 Ok(reservation_id) => {
79 profiler.finish_op();
80 return Ok(reservation_id);
81 }
82 Err(GpuError::InsufficientMemory { available_mb, reserved_mb, .. }) => {
83 let elapsed = start.elapsed();
84 let remaining = config.timeout.saturating_sub(elapsed);
85
86 eprintln!(
87 "[GPU] Waiting for {} MB VRAM ({} MB available, {} MB reserved) \
88 [{:.0}s elapsed, {:.0}s remaining]",
89 budget_mb,
90 available_mb,
91 reserved_mb,
92 elapsed.as_secs_f64(),
93 remaining.as_secs_f64(),
94 );
95
96 let interval = config.interval_for_attempt(attempt);
97 let sleep_time = interval.min(remaining);
99 std::thread::sleep(sleep_time);
100 attempt = attempt.saturating_add(1);
101 }
102 Err(e) => return Err(e),
103 }
104 }
105}
106
107pub fn timeout_bound(config: &WaitConfig) -> Duration {
114 config.timeout + config.max_interval
115}
116
117pub fn fairness_via_expiry(ledger: &mut VramLedger) -> Vec<u32> {
124 ledger
125 .read_reservations()
126 .unwrap_or_default()
127 .iter()
128 .filter(|r| r.is_expired())
129 .map(|r| r.pid)
130 .collect()
131}
132
133pub struct WaitProgress {
137 pub attempt: u32,
139 pub elapsed: Duration,
141 pub remaining: Duration,
143 pub budget_mb: usize,
145 pub available_mb: usize,
147 pub reserved_mb: usize,
149}
150
151pub fn progress_report(
155 config: &WaitConfig,
156 start: Instant,
157 attempt: u32,
158 budget_mb: usize,
159 available_mb: usize,
160 reserved_mb: usize,
161) -> WaitProgress {
162 let elapsed = start.elapsed();
163 let remaining = config.timeout.saturating_sub(elapsed);
164 WaitProgress { attempt, elapsed, remaining, budget_mb, available_mb, reserved_mb }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::path::PathBuf;
171 use std::sync::atomic::{AtomicU32, Ordering};
172
173 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
174
175 fn test_ledger_path() -> PathBuf {
176 let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
177 let dir = std::env::temp_dir().join("entrenar-wait-test");
178 std::fs::create_dir_all(&dir).unwrap();
179 dir.join(format!("wait-ledger-{n}-{}.json", std::process::id()))
180 }
181
182 fn cleanup(path: &std::path::Path) {
183 let _ = std::fs::remove_file(path);
184 let _ = std::fs::remove_file(path.with_extension("tmp"));
185 }
186
187 #[test]
188 fn test_immediate_success() {
189 let path = test_ledger_path();
190 let mut ledger = VramLedger::new("GPU-test".into(), 24000, 0.85).with_path(path.clone());
191 let mut profiler = GpuProfiler::disabled();
192 let config = WaitConfig::with_timeout_secs(5);
193
194 let id = wait_for_vram(&mut ledger, 5000, "test", &config, &mut profiler).unwrap();
195 assert!(id != 0);
196
197 cleanup(&path);
198 }
199
200 #[test]
201 fn test_timeout_when_full() {
202 let path = test_ledger_path();
203 let mut ledger = VramLedger::new("GPU-test".into(), 10000, 0.85).with_path(path.clone());
204
205 ledger.try_reserve(8000, "blocker").unwrap();
207
208 let mut profiler = GpuProfiler::disabled();
210 let config = WaitConfig {
211 timeout: Duration::from_millis(100),
212 base_interval: Duration::from_millis(50),
213 max_interval: Duration::from_millis(100),
214 };
215
216 let result = wait_for_vram(&mut ledger, 5000, "waiter", &config, &mut profiler);
217 assert!(result.is_err());
218 match result.unwrap_err() {
219 GpuError::Timeout { budget_mb, .. } => assert_eq!(budget_mb, 5000),
220 other => panic!("expected Timeout, got {other}"),
221 }
222
223 cleanup(&path);
224 }
225
226 #[test]
227 fn test_interval_exponential_backoff() {
228 let config = WaitConfig {
229 base_interval: Duration::from_secs(30),
230 max_interval: Duration::from_secs(300),
231 ..Default::default()
232 };
233
234 assert_eq!(config.interval_for_attempt(0), Duration::from_secs(30));
235 assert_eq!(config.interval_for_attempt(1), Duration::from_secs(60));
236 assert_eq!(config.interval_for_attempt(2), Duration::from_secs(120));
237 assert_eq!(config.interval_for_attempt(3), Duration::from_secs(240));
238 assert_eq!(config.interval_for_attempt(4), Duration::from_secs(300)); assert_eq!(config.interval_for_attempt(10), Duration::from_secs(300)); }
241
242 #[test]
243 fn test_expired_lease_unblocks_waiter() {
244 let path = test_ledger_path();
245 let mut blocker = VramLedger::new("GPU-test".into(), 10000, 0.85)
246 .with_path(path.clone())
247 .with_lease_hours(0); blocker.try_reserve(8000, "expiring").unwrap();
251 blocker.our_reservation_id = None;
253
254 std::thread::sleep(Duration::from_millis(10));
256
257 let mut waiter = VramLedger::new("GPU-test".into(), 10000, 0.85).with_path(path.clone());
259 let mut profiler = GpuProfiler::disabled();
260 let config = WaitConfig::with_timeout_secs(5);
261
262 let id = wait_for_vram(&mut waiter, 5000, "waiter", &config, &mut profiler).unwrap();
263 assert!(id != 0);
264
265 cleanup(&path);
266 }
267}