1use crate::Result;
4use serde::{Deserialize, Serialize};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct WarmupConfig {
10 pub target_cold_start_ms: u64,
12 pub max_warmup_ms: u64,
14 pub preload_weights: bool,
16 pub preallocate_pools: bool,
18 pub precompile_kernels: bool,
20 pub warmup_batch_size: usize,
22 pub lazy_loading: bool,
24}
25
26impl Default for WarmupConfig {
27 fn default() -> Self {
28 Self {
29 target_cold_start_ms: 100,
30 max_warmup_ms: 5000,
31 preload_weights: true,
32 preallocate_pools: true,
33 precompile_kernels: true,
34 warmup_batch_size: 1,
35 lazy_loading: true,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct WarmupStats {
43 pub total_warmup_ms: u64,
45 pub weight_load_ms: u64,
47 pub pool_alloc_ms: u64,
49 pub kernel_compile_ms: u64,
51 pub first_inference_ms: u64,
53 pub warmup_iterations: u32,
55 pub memory_after_warmup: u64,
57}
58
59impl WarmupStats {
60 pub fn met_target(&self, target_ms: u64) -> bool {
62 self.first_inference_ms <= target_ms
63 }
64}
65
66#[derive(Debug)]
68pub struct ColdStartOptimizer {
69 config: WarmupConfig,
71 stats: WarmupStats,
73 warmed_up: bool,
75 warmup_start: Option<Instant>,
77 phases: Vec<WarmupPhase>,
79}
80
81#[derive(Debug, Clone)]
83#[allow(dead_code)]
84struct WarmupPhase {
85 name: String,
86 duration_ms: u64,
87 completed: bool,
88}
89
90impl ColdStartOptimizer {
91 pub fn new(config: WarmupConfig) -> Self {
93 Self {
94 config,
95 stats: WarmupStats::default(),
96 warmed_up: false,
97 warmup_start: None,
98 phases: Vec::new(),
99 }
100 }
101
102 pub async fn warmup<F>(&mut self, init_fn: F) -> Result<()>
104 where
105 F: FnOnce() -> Result<()>,
106 {
107 self.warmup_start = Some(Instant::now());
108
109 if self.config.preload_weights {
111 let start = Instant::now();
112 self.load_weights().await?;
113 self.stats.weight_load_ms = start.elapsed().as_millis() as u64;
114 self.phases.push(WarmupPhase {
115 name: "weight_load".into(),
116 duration_ms: self.stats.weight_load_ms,
117 completed: true,
118 });
119 }
120
121 if self.config.preallocate_pools {
123 let start = Instant::now();
124 self.preallocate_pools().await?;
125 self.stats.pool_alloc_ms = start.elapsed().as_millis() as u64;
126 self.phases.push(WarmupPhase {
127 name: "pool_alloc".into(),
128 duration_ms: self.stats.pool_alloc_ms,
129 completed: true,
130 });
131 }
132
133 if self.config.precompile_kernels {
135 let start = Instant::now();
136 self.compile_kernels().await?;
137 self.stats.kernel_compile_ms = start.elapsed().as_millis() as u64;
138 self.phases.push(WarmupPhase {
139 name: "kernel_compile".into(),
140 duration_ms: self.stats.kernel_compile_ms,
141 completed: true,
142 });
143 }
144
145 let start = Instant::now();
147 init_fn()?;
148 self.phases.push(WarmupPhase {
149 name: "custom_init".into(),
150 duration_ms: start.elapsed().as_millis() as u64,
151 completed: true,
152 });
153
154 let start = Instant::now();
156 self.warmup_inference().await?;
157 self.stats.first_inference_ms = start.elapsed().as_millis() as u64;
158 self.phases.push(WarmupPhase {
159 name: "warmup_inference".into(),
160 duration_ms: self.stats.first_inference_ms,
161 completed: true,
162 });
163
164 self.stats.total_warmup_ms = self.warmup_start.unwrap().elapsed().as_millis() as u64;
165 self.warmed_up = true;
166
167 if self.stats.total_warmup_ms > self.config.max_warmup_ms {
169 tracing::warn!(
170 "Warmup exceeded max time: {}ms > {}ms",
171 self.stats.total_warmup_ms,
172 self.config.max_warmup_ms
173 );
174 }
175
176 Ok(())
177 }
178
179 async fn load_weights(&self) -> Result<()> {
181 Ok(())
186 }
187
188 async fn preallocate_pools(&self) -> Result<()> {
190 Ok(())
196 }
197
198 async fn compile_kernels(&self) -> Result<()> {
200 Ok(())
206 }
207
208 async fn warmup_inference(&self) -> Result<()> {
210 for _ in 0..self.config.warmup_batch_size {
215 tokio::time::sleep(Duration::from_micros(100)).await;
217 }
218 Ok(())
220 }
221
222 pub fn is_warmed_up(&self) -> bool {
224 self.warmed_up
225 }
226
227 pub fn stats(&self) -> &WarmupStats {
229 &self.stats
230 }
231
232 pub fn config(&self) -> &WarmupConfig {
234 &self.config
235 }
236
237 pub fn phases(&self) -> Vec<(String, u64)> {
239 self.phases
240 .iter()
241 .map(|p| (p.name.clone(), p.duration_ms))
242 .collect()
243 }
244}
245
246#[derive(Debug)]
248pub struct WarmupScheduler {
249 schedule: Vec<ScheduledWarmup>,
251 active_count: usize,
253 max_concurrent: usize,
255}
256
257#[derive(Debug, Clone)]
259#[allow(dead_code)]
260struct ScheduledWarmup {
261 instance_id: String,
263 scheduled_at: Instant,
265 priority: u32,
267}
268
269impl WarmupScheduler {
270 pub fn new(max_concurrent: usize) -> Self {
272 Self {
273 schedule: Vec::new(),
274 active_count: 0,
275 max_concurrent,
276 }
277 }
278
279 pub fn schedule(&mut self, instance_id: impl Into<String>, priority: u32) {
281 self.schedule.push(ScheduledWarmup {
282 instance_id: instance_id.into(),
283 scheduled_at: Instant::now(),
284 priority,
285 });
286
287 self.schedule.sort_by(|a, b| b.priority.cmp(&a.priority));
289 }
290
291 pub fn next_warmup(&mut self) -> Option<String> {
293 if self.active_count >= self.max_concurrent {
294 return None;
295 }
296
297 if self.schedule.is_empty() {
298 return None;
299 }
300
301 let s = self.schedule.remove(0);
303 self.active_count += 1;
304 Some(s.instance_id)
305 }
306
307 pub fn complete(&mut self, _instance_id: &str) {
309 self.active_count = self.active_count.saturating_sub(1);
310 }
311
312 pub fn pending_count(&self) -> usize {
314 self.schedule.len()
315 }
316
317 pub fn active_count(&self) -> usize {
319 self.active_count
320 }
321}
322
323#[derive(Debug, Default)]
325pub struct ColdStartMetrics {
326 cold_starts: Vec<u64>,
328 warm_starts: Vec<u64>,
330}
331
332impl ColdStartMetrics {
333 pub fn record_cold_start(&mut self, duration_ms: u64) {
335 self.cold_starts.push(duration_ms);
336 }
337
338 pub fn record_warm_start(&mut self, duration_ms: u64) {
340 self.warm_starts.push(duration_ms);
341 }
342
343 pub fn avg_cold_start_ms(&self) -> f64 {
345 if self.cold_starts.is_empty() {
346 0.0
347 } else {
348 self.cold_starts.iter().sum::<u64>() as f64 / self.cold_starts.len() as f64
349 }
350 }
351
352 pub fn avg_warm_start_ms(&self) -> f64 {
354 if self.warm_starts.is_empty() {
355 0.0
356 } else {
357 self.warm_starts.iter().sum::<u64>() as f64 / self.warm_starts.len() as f64
358 }
359 }
360
361 pub fn cold_warm_ratio(&self) -> f64 {
363 let total = self.cold_starts.len() + self.warm_starts.len();
364 if total == 0 {
365 0.0
366 } else {
367 self.cold_starts.len() as f64 / total as f64
368 }
369 }
370
371 pub fn p95_cold_start_ms(&self) -> Option<u64> {
373 if self.cold_starts.is_empty() {
374 return None;
375 }
376
377 let mut sorted = self.cold_starts.clone();
378 sorted.sort();
379 let idx = (sorted.len() as f64 * 0.95) as usize;
380 Some(sorted[idx.min(sorted.len() - 1)])
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_config_default() {
390 let config = WarmupConfig::default();
391 assert_eq!(config.target_cold_start_ms, 100);
392 assert!(config.preload_weights);
393 assert!(config.preallocate_pools);
394 }
395
396 #[test]
397 fn test_warmup_stats() {
398 let stats = WarmupStats {
399 first_inference_ms: 80,
400 ..Default::default()
401 };
402
403 assert!(stats.met_target(100));
404 assert!(!stats.met_target(50));
405 }
406
407 #[test]
408 fn test_optimizer_creation() {
409 let config = WarmupConfig::default();
410 let optimizer = ColdStartOptimizer::new(config);
411
412 assert!(!optimizer.is_warmed_up());
413 }
414
415 #[test]
416 fn test_warmup_scheduler() {
417 let mut scheduler = WarmupScheduler::new(2);
418
419 scheduler.schedule("instance1", 1);
420 scheduler.schedule("instance2", 2);
421 scheduler.schedule("instance3", 3);
422
423 assert_eq!(scheduler.next_warmup(), Some("instance3".to_string()));
425 assert_eq!(scheduler.next_warmup(), Some("instance2".to_string()));
426 assert_eq!(scheduler.next_warmup(), None); scheduler.complete("instance3");
429 assert_eq!(scheduler.next_warmup(), Some("instance1".to_string()));
430 }
431
432 #[test]
433 fn test_cold_start_metrics() {
434 let mut metrics = ColdStartMetrics::default();
435
436 metrics.record_cold_start(100);
437 metrics.record_cold_start(150);
438 metrics.record_warm_start(10);
439 metrics.record_warm_start(15);
440
441 assert_eq!(metrics.avg_cold_start_ms(), 125.0);
442 assert_eq!(metrics.avg_warm_start_ms(), 12.5);
443 assert_eq!(metrics.cold_warm_ratio(), 0.5);
444 }
445}