use crate::Result;
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmupConfig {
pub target_cold_start_ms: u64,
pub max_warmup_ms: u64,
pub preload_weights: bool,
pub preallocate_pools: bool,
pub precompile_kernels: bool,
pub warmup_batch_size: usize,
pub lazy_loading: bool,
}
impl Default for WarmupConfig {
fn default() -> Self {
Self {
target_cold_start_ms: 100,
max_warmup_ms: 5000,
preload_weights: true,
preallocate_pools: true,
precompile_kernels: true,
warmup_batch_size: 1,
lazy_loading: true,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WarmupStats {
pub total_warmup_ms: u64,
pub weight_load_ms: u64,
pub pool_alloc_ms: u64,
pub kernel_compile_ms: u64,
pub first_inference_ms: u64,
pub warmup_iterations: u32,
pub memory_after_warmup: u64,
}
impl WarmupStats {
pub fn met_target(&self, target_ms: u64) -> bool {
self.first_inference_ms <= target_ms
}
}
#[derive(Debug)]
pub struct ColdStartOptimizer {
config: WarmupConfig,
stats: WarmupStats,
warmed_up: bool,
warmup_start: Option<Instant>,
phases: Vec<WarmupPhase>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct WarmupPhase {
name: String,
duration_ms: u64,
completed: bool,
}
impl ColdStartOptimizer {
pub fn new(config: WarmupConfig) -> Self {
Self {
config,
stats: WarmupStats::default(),
warmed_up: false,
warmup_start: None,
phases: Vec::new(),
}
}
pub async fn warmup<F>(&mut self, init_fn: F) -> Result<()>
where
F: FnOnce() -> Result<()>,
{
self.warmup_start = Some(Instant::now());
if self.config.preload_weights {
let start = Instant::now();
self.load_weights().await?;
self.stats.weight_load_ms = start.elapsed().as_millis() as u64;
self.phases.push(WarmupPhase {
name: "weight_load".into(),
duration_ms: self.stats.weight_load_ms,
completed: true,
});
}
if self.config.preallocate_pools {
let start = Instant::now();
self.preallocate_pools().await?;
self.stats.pool_alloc_ms = start.elapsed().as_millis() as u64;
self.phases.push(WarmupPhase {
name: "pool_alloc".into(),
duration_ms: self.stats.pool_alloc_ms,
completed: true,
});
}
if self.config.precompile_kernels {
let start = Instant::now();
self.compile_kernels().await?;
self.stats.kernel_compile_ms = start.elapsed().as_millis() as u64;
self.phases.push(WarmupPhase {
name: "kernel_compile".into(),
duration_ms: self.stats.kernel_compile_ms,
completed: true,
});
}
let start = Instant::now();
init_fn()?;
self.phases.push(WarmupPhase {
name: "custom_init".into(),
duration_ms: start.elapsed().as_millis() as u64,
completed: true,
});
let start = Instant::now();
self.warmup_inference().await?;
self.stats.first_inference_ms = start.elapsed().as_millis() as u64;
self.phases.push(WarmupPhase {
name: "warmup_inference".into(),
duration_ms: self.stats.first_inference_ms,
completed: true,
});
self.stats.total_warmup_ms = self.warmup_start.unwrap().elapsed().as_millis() as u64;
self.warmed_up = true;
if self.stats.total_warmup_ms > self.config.max_warmup_ms {
tracing::warn!(
"Warmup exceeded max time: {}ms > {}ms",
self.stats.total_warmup_ms,
self.config.max_warmup_ms
);
}
Ok(())
}
async fn load_weights(&self) -> Result<()> {
Ok(())
}
async fn preallocate_pools(&self) -> Result<()> {
Ok(())
}
async fn compile_kernels(&self) -> Result<()> {
Ok(())
}
async fn warmup_inference(&self) -> Result<()> {
for _ in 0..self.config.warmup_batch_size {
tokio::time::sleep(Duration::from_micros(100)).await;
}
Ok(())
}
pub fn is_warmed_up(&self) -> bool {
self.warmed_up
}
pub fn stats(&self) -> &WarmupStats {
&self.stats
}
pub fn config(&self) -> &WarmupConfig {
&self.config
}
pub fn phases(&self) -> Vec<(String, u64)> {
self.phases
.iter()
.map(|p| (p.name.clone(), p.duration_ms))
.collect()
}
}
#[derive(Debug)]
pub struct WarmupScheduler {
schedule: Vec<ScheduledWarmup>,
active_count: usize,
max_concurrent: usize,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ScheduledWarmup {
instance_id: String,
scheduled_at: Instant,
priority: u32,
}
impl WarmupScheduler {
pub fn new(max_concurrent: usize) -> Self {
Self {
schedule: Vec::new(),
active_count: 0,
max_concurrent,
}
}
pub fn schedule(&mut self, instance_id: impl Into<String>, priority: u32) {
self.schedule.push(ScheduledWarmup {
instance_id: instance_id.into(),
scheduled_at: Instant::now(),
priority,
});
self.schedule.sort_by(|a, b| b.priority.cmp(&a.priority));
}
pub fn next_warmup(&mut self) -> Option<String> {
if self.active_count >= self.max_concurrent {
return None;
}
if self.schedule.is_empty() {
return None;
}
let s = self.schedule.remove(0);
self.active_count += 1;
Some(s.instance_id)
}
pub fn complete(&mut self, _instance_id: &str) {
self.active_count = self.active_count.saturating_sub(1);
}
pub fn pending_count(&self) -> usize {
self.schedule.len()
}
pub fn active_count(&self) -> usize {
self.active_count
}
}
#[derive(Debug, Default)]
pub struct ColdStartMetrics {
cold_starts: Vec<u64>,
warm_starts: Vec<u64>,
}
impl ColdStartMetrics {
pub fn record_cold_start(&mut self, duration_ms: u64) {
self.cold_starts.push(duration_ms);
}
pub fn record_warm_start(&mut self, duration_ms: u64) {
self.warm_starts.push(duration_ms);
}
pub fn avg_cold_start_ms(&self) -> f64 {
if self.cold_starts.is_empty() {
0.0
} else {
self.cold_starts.iter().sum::<u64>() as f64 / self.cold_starts.len() as f64
}
}
pub fn avg_warm_start_ms(&self) -> f64 {
if self.warm_starts.is_empty() {
0.0
} else {
self.warm_starts.iter().sum::<u64>() as f64 / self.warm_starts.len() as f64
}
}
pub fn cold_warm_ratio(&self) -> f64 {
let total = self.cold_starts.len() + self.warm_starts.len();
if total == 0 {
0.0
} else {
self.cold_starts.len() as f64 / total as f64
}
}
pub fn p95_cold_start_ms(&self) -> Option<u64> {
if self.cold_starts.is_empty() {
return None;
}
let mut sorted = self.cold_starts.clone();
sorted.sort();
let idx = (sorted.len() as f64 * 0.95) as usize;
Some(sorted[idx.min(sorted.len() - 1)])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = WarmupConfig::default();
assert_eq!(config.target_cold_start_ms, 100);
assert!(config.preload_weights);
assert!(config.preallocate_pools);
}
#[test]
fn test_warmup_stats() {
let stats = WarmupStats {
first_inference_ms: 80,
..Default::default()
};
assert!(stats.met_target(100));
assert!(!stats.met_target(50));
}
#[test]
fn test_optimizer_creation() {
let config = WarmupConfig::default();
let optimizer = ColdStartOptimizer::new(config);
assert!(!optimizer.is_warmed_up());
}
#[test]
fn test_warmup_scheduler() {
let mut scheduler = WarmupScheduler::new(2);
scheduler.schedule("instance1", 1);
scheduler.schedule("instance2", 2);
scheduler.schedule("instance3", 3);
assert_eq!(scheduler.next_warmup(), Some("instance3".to_string()));
assert_eq!(scheduler.next_warmup(), Some("instance2".to_string()));
assert_eq!(scheduler.next_warmup(), None);
scheduler.complete("instance3");
assert_eq!(scheduler.next_warmup(), Some("instance1".to_string()));
}
#[test]
fn test_cold_start_metrics() {
let mut metrics = ColdStartMetrics::default();
metrics.record_cold_start(100);
metrics.record_cold_start(150);
metrics.record_warm_start(10);
metrics.record_warm_start(15);
assert_eq!(metrics.avg_cold_start_ms(), 125.0);
assert_eq!(metrics.avg_warm_start_ms(), 12.5);
assert_eq!(metrics.cold_warm_ratio(), 0.5);
}
}