Skip to main content

anno/backends/
warmup.rs

1//! Model warmup utilities for cold-start mitigation.
2//!
3//! In serverless environments (AWS Lambda, Cloud Functions), the first
4//! inference call is significantly slower due to:
5//!
6//! 1. ONNX graph optimization (done on first inference)
7//! 2. Memory allocation and page faults
8//! 3. CPU cache warming
9//!
10//! This module provides utilities to "warm up" models before serving traffic.
11//!
12//! # Usage
13//!
14//! ```rust,ignore
15//! use anno::backends::warmup::{warmup_model, WarmupConfig};
16//!
17//! // During initialization (before serving traffic)
18//! let model = GLiNEROnnx::new("onnx-community/gliner_small-v2.1")?;
19//! warmup_model(&model, WarmupConfig::default())?;
20//!
21//! // Now ready for production traffic
22//! ```
23//!
24//! # Notes
25//!
26//! Warmup shifts “one-time” costs (graph optimization, allocations, cache effects) from the first
27//! user-facing request into an explicit initialization step. Use [`WarmupResult`] to measure the
28//! effect in your environment; do not treat static numbers as portable.
29
30use crate::{Model, Result};
31use std::time::{Duration, Instant};
32
33/// Configuration for model warmup.
34#[derive(Debug, Clone)]
35pub struct WarmupConfig {
36    /// Number of warmup inference calls.
37    pub iterations: usize,
38    /// Sample texts for warmup (various lengths).
39    pub sample_texts: Vec<String>,
40    /// Whether to log warmup progress.
41    pub verbose: bool,
42    /// Target warmup duration (stops early if reached).
43    pub max_duration: Option<Duration>,
44}
45
46impl Default for WarmupConfig {
47    fn default() -> Self {
48        Self {
49            iterations: 3,
50            sample_texts: vec![
51                // Short text
52                "John Smith".to_string(),
53                // Medium text
54                "Sophie Wilson designed the ARM processor. She changed computing.".to_string(),
55                // Longer text with multiple entities
56                "Apple Inc. was founded by Steve Jobs and Steve Wozniak in Cupertino, \
57                 California on April 1, 1976. The company went public on December 12, 1980."
58                    .to_string(),
59            ],
60            verbose: true,
61            max_duration: Some(Duration::from_secs(30)),
62        }
63    }
64}
65
66impl WarmupConfig {
67    /// Create config with specific iteration count.
68    #[must_use]
69    pub fn with_iterations(mut self, n: usize) -> Self {
70        self.iterations = n;
71        self
72    }
73
74    /// Add a custom sample text.
75    #[must_use]
76    pub fn with_sample(mut self, text: impl Into<String>) -> Self {
77        self.sample_texts.push(text.into());
78        self
79    }
80
81    /// Set maximum warmup duration.
82    #[must_use]
83    pub fn with_max_duration(mut self, duration: Duration) -> Self {
84        self.max_duration = Some(duration);
85        self
86    }
87
88    /// Disable verbose logging.
89    #[must_use]
90    pub fn quiet(mut self) -> Self {
91        self.verbose = false;
92        self
93    }
94}
95
96/// Warmup result with timing information.
97#[derive(Debug, Clone)]
98pub struct WarmupResult {
99    /// Total warmup duration.
100    pub total_duration: Duration,
101    /// Number of inference calls made.
102    pub inference_count: usize,
103    /// First inference duration (coldest).
104    pub first_duration: Duration,
105    /// Last inference duration (warmest).
106    pub last_duration: Duration,
107    /// Average duration after warmup.
108    pub average_warm: Duration,
109    /// Speedup ratio (first / last).
110    pub speedup: f64,
111}
112
113impl WarmupResult {
114    /// Check if warmup achieved significant speedup.
115    #[must_use]
116    pub fn is_effective(&self) -> bool {
117        self.speedup > 1.5
118    }
119}
120
121/// Warm up a model by running sample inferences.
122///
123/// # Arguments
124///
125/// * `model` - The model to warm up
126/// * `config` - Warmup configuration
127///
128/// # Returns
129///
130/// `WarmupResult` with timing information.
131///
132/// # Example
133///
134/// ```rust,ignore
135/// use anno::{GLiNEROnnx, backends::warmup::{warmup_model, WarmupConfig}};
136///
137/// let model = GLiNEROnnx::new("onnx-community/gliner_small-v2.1")?;
138///
139/// let result = warmup_model(&model, WarmupConfig::default())?;
140/// println!("Warmup speedup: {:.2}x", result.speedup);
141/// ```
142pub fn warmup_model<M: Model>(model: &M, config: WarmupConfig) -> Result<WarmupResult> {
143    let start = Instant::now();
144    let mut durations: Vec<Duration> = Vec::new();
145    let mut first_duration = Duration::ZERO;
146    let mut inference_count = 0;
147
148    if config.verbose {
149        log::info!(
150            "[warmup] Starting warmup: {} iterations, {} sample texts",
151            config.iterations,
152            config.sample_texts.len()
153        );
154    }
155
156    'outer: for iter in 0..config.iterations {
157        for text in &config.sample_texts {
158            // Check timeout
159            if let Some(max) = config.max_duration {
160                if start.elapsed() > max {
161                    if config.verbose {
162                        log::info!("[warmup] Reached max duration, stopping early");
163                    }
164                    break 'outer;
165                }
166            }
167
168            let call_start = Instant::now();
169            let _ = model.extract_entities(text, None)?;
170            let call_duration = call_start.elapsed();
171
172            if inference_count == 0 {
173                first_duration = call_duration;
174            }
175            durations.push(call_duration);
176            inference_count += 1;
177
178            if config.verbose && iter == 0 {
179                log::debug!(
180                    "[warmup] Sample {}: {:?} (text len: {})",
181                    inference_count,
182                    call_duration,
183                    text.len()
184                );
185            }
186        }
187    }
188
189    let total_duration = start.elapsed();
190    let last_duration = durations.last().copied().unwrap_or(Duration::ZERO);
191
192    // Calculate average of last half (warmed up)
193    let warm_count = durations.len() / 2;
194    let average_warm = if warm_count > 0 {
195        let warm_sum: Duration = durations.iter().skip(durations.len() - warm_count).sum();
196        warm_sum / warm_count as u32
197    } else {
198        last_duration
199    };
200
201    let speedup = if last_duration.as_nanos() > 0 {
202        first_duration.as_secs_f64() / last_duration.as_secs_f64()
203    } else {
204        1.0
205    };
206
207    let result = WarmupResult {
208        total_duration,
209        inference_count,
210        first_duration,
211        last_duration,
212        average_warm,
213        speedup,
214    };
215
216    if config.verbose {
217        log::info!(
218            "[warmup] Complete: {} inferences in {:?}",
219            inference_count,
220            total_duration
221        );
222        log::info!(
223            "[warmup] First: {:?}, Last: {:?}, Speedup: {:.2}x",
224            first_duration,
225            last_duration,
226            speedup
227        );
228    }
229
230    Ok(result)
231}
232
233/// Warmup with progress callback.
234///
235/// Useful for showing progress in CLI tools or updating health checks.
236pub fn warmup_with_callback<M: Model, F>(
237    model: &M,
238    config: WarmupConfig,
239    mut callback: F,
240) -> Result<WarmupResult>
241where
242    F: FnMut(usize, usize, Duration),
243{
244    let start = Instant::now();
245    let total_calls = config.iterations * config.sample_texts.len();
246    let mut durations: Vec<Duration> = Vec::new();
247    let mut first_duration = Duration::ZERO;
248    let mut inference_count = 0;
249
250    'outer: for _iter in 0..config.iterations {
251        for text in &config.sample_texts {
252            if let Some(max) = config.max_duration {
253                if start.elapsed() > max {
254                    break 'outer;
255                }
256            }
257
258            let call_start = Instant::now();
259            let _ = model.extract_entities(text, None)?;
260            let call_duration = call_start.elapsed();
261
262            if inference_count == 0 {
263                first_duration = call_duration;
264            }
265            durations.push(call_duration);
266            inference_count += 1;
267
268            // Call progress callback
269            callback(inference_count, total_calls, call_duration);
270        }
271    }
272
273    let total_duration = start.elapsed();
274    let last_duration = durations.last().copied().unwrap_or(Duration::ZERO);
275    let warm_count = durations.len() / 2;
276    let average_warm = if warm_count > 0 {
277        let warm_sum: Duration = durations.iter().skip(durations.len() - warm_count).sum();
278        warm_sum / warm_count as u32
279    } else {
280        last_duration
281    };
282
283    let speedup = if last_duration.as_nanos() > 0 {
284        first_duration.as_secs_f64() / last_duration.as_secs_f64()
285    } else {
286        1.0
287    };
288
289    Ok(WarmupResult {
290        total_duration,
291        inference_count,
292        first_duration,
293        last_duration,
294        average_warm,
295        speedup,
296    })
297}
298
299// =============================================================================
300// Tests
301// =============================================================================
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_warmup_config_default() {
309        let config = WarmupConfig::default();
310        assert_eq!(config.iterations, 3);
311        assert!(!config.sample_texts.is_empty());
312        assert!(config.verbose);
313    }
314
315    #[test]
316    fn test_warmup_config_builder() {
317        let config = WarmupConfig::default()
318            .with_iterations(5)
319            .with_sample("Custom text")
320            .with_max_duration(Duration::from_secs(10))
321            .quiet();
322
323        assert_eq!(config.iterations, 5);
324        assert!(config.sample_texts.iter().any(|t| t == "Custom text"));
325        assert_eq!(config.max_duration, Some(Duration::from_secs(10)));
326        assert!(!config.verbose);
327    }
328
329    #[test]
330    fn test_warmup_result_effective() {
331        let effective = WarmupResult {
332            total_duration: Duration::from_secs(1),
333            inference_count: 9,
334            first_duration: Duration::from_millis(300),
335            last_duration: Duration::from_millis(100),
336            average_warm: Duration::from_millis(110),
337            speedup: 3.0,
338        };
339        assert!(effective.is_effective());
340
341        let not_effective = WarmupResult {
342            speedup: 1.1,
343            ..effective.clone()
344        };
345        assert!(!not_effective.is_effective());
346    }
347}