Skip to main content

haagenti_serverless/
cold_start.rs

1//! Cold start optimization for serverless functions
2
3use crate::Result;
4use serde::{Deserialize, Serialize};
5use std::time::{Duration, Instant};
6
7/// Warmup configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct WarmupConfig {
10    /// Target cold start time in ms
11    pub target_cold_start_ms: u64,
12    /// Maximum warmup time in ms
13    pub max_warmup_ms: u64,
14    /// Pre-load model weights
15    pub preload_weights: bool,
16    /// Pre-allocate memory pools
17    pub preallocate_pools: bool,
18    /// Pre-compile shaders/kernels
19    pub precompile_kernels: bool,
20    /// Warmup batch size
21    pub warmup_batch_size: usize,
22    /// Enable lazy loading after warmup
23    pub lazy_loading: bool,
24}
25
26impl Default for WarmupConfig {
27    fn default() -> Self {
28        Self {
29            target_cold_start_ms: 100,
30            max_warmup_ms: 5000,
31            preload_weights: true,
32            preallocate_pools: true,
33            precompile_kernels: true,
34            warmup_batch_size: 1,
35            lazy_loading: true,
36        }
37    }
38}
39
40/// Warmup statistics
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct WarmupStats {
43    /// Total warmup time in ms
44    pub total_warmup_ms: u64,
45    /// Weight loading time in ms
46    pub weight_load_ms: u64,
47    /// Pool allocation time in ms
48    pub pool_alloc_ms: u64,
49    /// Kernel compilation time in ms
50    pub kernel_compile_ms: u64,
51    /// First inference time in ms
52    pub first_inference_ms: u64,
53    /// Number of warmup iterations
54    pub warmup_iterations: u32,
55    /// Memory used after warmup in bytes
56    pub memory_after_warmup: u64,
57}
58
59impl WarmupStats {
60    /// Check if warmup met target
61    pub fn met_target(&self, target_ms: u64) -> bool {
62        self.first_inference_ms <= target_ms
63    }
64}
65
66/// Cold start optimizer
67#[derive(Debug)]
68pub struct ColdStartOptimizer {
69    /// Configuration
70    config: WarmupConfig,
71    /// Warmup stats
72    stats: WarmupStats,
73    /// Whether warmup is complete
74    warmed_up: bool,
75    /// Warmup start time
76    warmup_start: Option<Instant>,
77    /// Warmup phases
78    phases: Vec<WarmupPhase>,
79}
80
81/// Warmup phase (internal tracking)
82#[derive(Debug, Clone)]
83#[allow(dead_code)]
84struct WarmupPhase {
85    name: String,
86    duration_ms: u64,
87    completed: bool,
88}
89
90impl ColdStartOptimizer {
91    /// Create new optimizer
92    pub fn new(config: WarmupConfig) -> Self {
93        Self {
94            config,
95            stats: WarmupStats::default(),
96            warmed_up: false,
97            warmup_start: None,
98            phases: Vec::new(),
99        }
100    }
101
102    /// Start warmup process
103    pub async fn warmup<F>(&mut self, init_fn: F) -> Result<()>
104    where
105        F: FnOnce() -> Result<()>,
106    {
107        self.warmup_start = Some(Instant::now());
108
109        // Phase 1: Weight loading
110        if self.config.preload_weights {
111            let start = Instant::now();
112            self.load_weights().await?;
113            self.stats.weight_load_ms = start.elapsed().as_millis() as u64;
114            self.phases.push(WarmupPhase {
115                name: "weight_load".into(),
116                duration_ms: self.stats.weight_load_ms,
117                completed: true,
118            });
119        }
120
121        // Phase 2: Pool allocation
122        if self.config.preallocate_pools {
123            let start = Instant::now();
124            self.preallocate_pools().await?;
125            self.stats.pool_alloc_ms = start.elapsed().as_millis() as u64;
126            self.phases.push(WarmupPhase {
127                name: "pool_alloc".into(),
128                duration_ms: self.stats.pool_alloc_ms,
129                completed: true,
130            });
131        }
132
133        // Phase 3: Kernel compilation
134        if self.config.precompile_kernels {
135            let start = Instant::now();
136            self.compile_kernels().await?;
137            self.stats.kernel_compile_ms = start.elapsed().as_millis() as u64;
138            self.phases.push(WarmupPhase {
139                name: "kernel_compile".into(),
140                duration_ms: self.stats.kernel_compile_ms,
141                completed: true,
142            });
143        }
144
145        // Phase 4: Custom initialization
146        let start = Instant::now();
147        init_fn()?;
148        self.phases.push(WarmupPhase {
149            name: "custom_init".into(),
150            duration_ms: start.elapsed().as_millis() as u64,
151            completed: true,
152        });
153
154        // Phase 5: Warmup inference
155        let start = Instant::now();
156        self.warmup_inference().await?;
157        self.stats.first_inference_ms = start.elapsed().as_millis() as u64;
158        self.phases.push(WarmupPhase {
159            name: "warmup_inference".into(),
160            duration_ms: self.stats.first_inference_ms,
161            completed: true,
162        });
163
164        self.stats.total_warmup_ms = self.warmup_start.unwrap().elapsed().as_millis() as u64;
165        self.warmed_up = true;
166
167        // Check if we exceeded max warmup time
168        if self.stats.total_warmup_ms > self.config.max_warmup_ms {
169            tracing::warn!(
170                "Warmup exceeded max time: {}ms > {}ms",
171                self.stats.total_warmup_ms,
172                self.config.max_warmup_ms
173            );
174        }
175
176        Ok(())
177    }
178
179    /// Load model weights
180    async fn load_weights(&self) -> Result<()> {
181        // In a real implementation, this would:
182        // 1. Load weights from pre-warmed cache or storage
183        // 2. Initialize model tensors
184        // 3. Transfer to GPU if available
185        Ok(())
186    }
187
188    /// Pre-allocate memory pools
189    async fn preallocate_pools(&self) -> Result<()> {
190        // Pre-allocate:
191        // - Input buffer pools
192        // - Output buffer pools
193        // - Intermediate activation buffers
194        // - KV cache for transformers
195        Ok(())
196    }
197
198    /// Pre-compile compute kernels
199    async fn compile_kernels(&self) -> Result<()> {
200        // Pre-compile:
201        // - Matrix multiplication kernels
202        // - Attention kernels
203        // - Activation functions
204        // - Normalization layers
205        Ok(())
206    }
207
208    /// Run warmup inference
209    async fn warmup_inference(&self) -> Result<()> {
210        // Run inference with dummy data to:
211        // - Warm up JIT compilation
212        // - Populate caches
213        // - Trigger GPU memory allocation
214        for _ in 0..self.config.warmup_batch_size {
215            // Simulate inference
216            tokio::time::sleep(Duration::from_micros(100)).await;
217        }
218        // Warmup iterations tracked in caller (warmup method)
219        Ok(())
220    }
221
222    /// Check if warmed up
223    pub fn is_warmed_up(&self) -> bool {
224        self.warmed_up
225    }
226
227    /// Get warmup stats
228    pub fn stats(&self) -> &WarmupStats {
229        &self.stats
230    }
231
232    /// Get configuration
233    pub fn config(&self) -> &WarmupConfig {
234        &self.config
235    }
236
237    /// Get warmup phases
238    pub fn phases(&self) -> Vec<(String, u64)> {
239        self.phases
240            .iter()
241            .map(|p| (p.name.clone(), p.duration_ms))
242            .collect()
243    }
244}
245
246/// Warmup scheduler for pre-warming instances
247#[derive(Debug)]
248pub struct WarmupScheduler {
249    /// Scheduled warmups
250    schedule: Vec<ScheduledWarmup>,
251    /// Active warmup count
252    active_count: usize,
253    /// Maximum concurrent warmups
254    max_concurrent: usize,
255}
256
257/// Scheduled warmup entry (internal to WarmupScheduler)
258#[derive(Debug, Clone)]
259#[allow(dead_code)]
260struct ScheduledWarmup {
261    /// Instance ID
262    instance_id: String,
263    /// Scheduled time
264    scheduled_at: Instant,
265    /// Priority
266    priority: u32,
267}
268
269impl WarmupScheduler {
270    /// Create new scheduler
271    pub fn new(max_concurrent: usize) -> Self {
272        Self {
273            schedule: Vec::new(),
274            active_count: 0,
275            max_concurrent,
276        }
277    }
278
279    /// Schedule a warmup
280    pub fn schedule(&mut self, instance_id: impl Into<String>, priority: u32) {
281        self.schedule.push(ScheduledWarmup {
282            instance_id: instance_id.into(),
283            scheduled_at: Instant::now(),
284            priority,
285        });
286
287        // Sort by priority (higher first)
288        self.schedule.sort_by(|a, b| b.priority.cmp(&a.priority));
289    }
290
291    /// Get next instance to warm up
292    pub fn next_warmup(&mut self) -> Option<String> {
293        if self.active_count >= self.max_concurrent {
294            return None;
295        }
296
297        if self.schedule.is_empty() {
298            return None;
299        }
300
301        // Remove from front (highest priority after sorting)
302        let s = self.schedule.remove(0);
303        self.active_count += 1;
304        Some(s.instance_id)
305    }
306
307    /// Mark warmup complete
308    pub fn complete(&mut self, _instance_id: &str) {
309        self.active_count = self.active_count.saturating_sub(1);
310    }
311
312    /// Pending count
313    pub fn pending_count(&self) -> usize {
314        self.schedule.len()
315    }
316
317    /// Active count
318    pub fn active_count(&self) -> usize {
319        self.active_count
320    }
321}
322
323/// Cold start metrics collector
324#[derive(Debug, Default)]
325pub struct ColdStartMetrics {
326    /// Cold start times
327    cold_starts: Vec<u64>,
328    /// Warm start times
329    warm_starts: Vec<u64>,
330}
331
332impl ColdStartMetrics {
333    /// Record cold start
334    pub fn record_cold_start(&mut self, duration_ms: u64) {
335        self.cold_starts.push(duration_ms);
336    }
337
338    /// Record warm start
339    pub fn record_warm_start(&mut self, duration_ms: u64) {
340        self.warm_starts.push(duration_ms);
341    }
342
343    /// Average cold start time
344    pub fn avg_cold_start_ms(&self) -> f64 {
345        if self.cold_starts.is_empty() {
346            0.0
347        } else {
348            self.cold_starts.iter().sum::<u64>() as f64 / self.cold_starts.len() as f64
349        }
350    }
351
352    /// Average warm start time
353    pub fn avg_warm_start_ms(&self) -> f64 {
354        if self.warm_starts.is_empty() {
355            0.0
356        } else {
357            self.warm_starts.iter().sum::<u64>() as f64 / self.warm_starts.len() as f64
358        }
359    }
360
361    /// Cold to warm ratio
362    pub fn cold_warm_ratio(&self) -> f64 {
363        let total = self.cold_starts.len() + self.warm_starts.len();
364        if total == 0 {
365            0.0
366        } else {
367            self.cold_starts.len() as f64 / total as f64
368        }
369    }
370
371    /// P95 cold start
372    pub fn p95_cold_start_ms(&self) -> Option<u64> {
373        if self.cold_starts.is_empty() {
374            return None;
375        }
376
377        let mut sorted = self.cold_starts.clone();
378        sorted.sort();
379        let idx = (sorted.len() as f64 * 0.95) as usize;
380        Some(sorted[idx.min(sorted.len() - 1)])
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_config_default() {
390        let config = WarmupConfig::default();
391        assert_eq!(config.target_cold_start_ms, 100);
392        assert!(config.preload_weights);
393        assert!(config.preallocate_pools);
394    }
395
396    #[test]
397    fn test_warmup_stats() {
398        let stats = WarmupStats {
399            first_inference_ms: 80,
400            ..Default::default()
401        };
402
403        assert!(stats.met_target(100));
404        assert!(!stats.met_target(50));
405    }
406
407    #[test]
408    fn test_optimizer_creation() {
409        let config = WarmupConfig::default();
410        let optimizer = ColdStartOptimizer::new(config);
411
412        assert!(!optimizer.is_warmed_up());
413    }
414
415    #[test]
416    fn test_warmup_scheduler() {
417        let mut scheduler = WarmupScheduler::new(2);
418
419        scheduler.schedule("instance1", 1);
420        scheduler.schedule("instance2", 2);
421        scheduler.schedule("instance3", 3);
422
423        // Highest priority first
424        assert_eq!(scheduler.next_warmup(), Some("instance3".to_string()));
425        assert_eq!(scheduler.next_warmup(), Some("instance2".to_string()));
426        assert_eq!(scheduler.next_warmup(), None); // At max concurrent
427
428        scheduler.complete("instance3");
429        assert_eq!(scheduler.next_warmup(), Some("instance1".to_string()));
430    }
431
432    #[test]
433    fn test_cold_start_metrics() {
434        let mut metrics = ColdStartMetrics::default();
435
436        metrics.record_cold_start(100);
437        metrics.record_cold_start(150);
438        metrics.record_warm_start(10);
439        metrics.record_warm_start(15);
440
441        assert_eq!(metrics.avg_cold_start_ms(), 125.0);
442        assert_eq!(metrics.avg_warm_start_ms(), 12.5);
443        assert_eq!(metrics.cold_warm_ratio(), 0.5);
444    }
445}