1use crate::{Model, Result};
31use std::time::{Duration, Instant};
32
33#[derive(Debug, Clone)]
35pub struct WarmupConfig {
36 pub iterations: usize,
38 pub sample_texts: Vec<String>,
40 pub verbose: bool,
42 pub max_duration: Option<Duration>,
44}
45
46impl Default for WarmupConfig {
47 fn default() -> Self {
48 Self {
49 iterations: 3,
50 sample_texts: vec![
51 "John Smith".to_string(),
53 "Sophie Wilson designed the ARM processor. She changed computing.".to_string(),
55 "Apple Inc. was founded by Steve Jobs and Steve Wozniak in Cupertino, \
57 California on April 1, 1976. The company went public on December 12, 1980."
58 .to_string(),
59 ],
60 verbose: true,
61 max_duration: Some(Duration::from_secs(30)),
62 }
63 }
64}
65
66impl WarmupConfig {
67 #[must_use]
69 pub fn with_iterations(mut self, n: usize) -> Self {
70 self.iterations = n;
71 self
72 }
73
74 #[must_use]
76 pub fn with_sample(mut self, text: impl Into<String>) -> Self {
77 self.sample_texts.push(text.into());
78 self
79 }
80
81 #[must_use]
83 pub fn with_max_duration(mut self, duration: Duration) -> Self {
84 self.max_duration = Some(duration);
85 self
86 }
87
88 #[must_use]
90 pub fn quiet(mut self) -> Self {
91 self.verbose = false;
92 self
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct WarmupResult {
99 pub total_duration: Duration,
101 pub inference_count: usize,
103 pub first_duration: Duration,
105 pub last_duration: Duration,
107 pub average_warm: Duration,
109 pub speedup: f64,
111}
112
113impl WarmupResult {
114 #[must_use]
116 pub fn is_effective(&self) -> bool {
117 self.speedup > 1.5
118 }
119}
120
121pub fn warmup_model<M: Model>(model: &M, config: WarmupConfig) -> Result<WarmupResult> {
143 let start = Instant::now();
144 let mut durations: Vec<Duration> = Vec::new();
145 let mut first_duration = Duration::ZERO;
146 let mut inference_count = 0;
147
148 if config.verbose {
149 log::info!(
150 "[warmup] Starting warmup: {} iterations, {} sample texts",
151 config.iterations,
152 config.sample_texts.len()
153 );
154 }
155
156 'outer: for iter in 0..config.iterations {
157 for text in &config.sample_texts {
158 if let Some(max) = config.max_duration {
160 if start.elapsed() > max {
161 if config.verbose {
162 log::info!("[warmup] Reached max duration, stopping early");
163 }
164 break 'outer;
165 }
166 }
167
168 let call_start = Instant::now();
169 let _ = model.extract_entities(text, None)?;
170 let call_duration = call_start.elapsed();
171
172 if inference_count == 0 {
173 first_duration = call_duration;
174 }
175 durations.push(call_duration);
176 inference_count += 1;
177
178 if config.verbose && iter == 0 {
179 log::debug!(
180 "[warmup] Sample {}: {:?} (text len: {})",
181 inference_count,
182 call_duration,
183 text.len()
184 );
185 }
186 }
187 }
188
189 let total_duration = start.elapsed();
190 let last_duration = durations.last().copied().unwrap_or(Duration::ZERO);
191
192 let warm_count = durations.len() / 2;
194 let average_warm = if warm_count > 0 {
195 let warm_sum: Duration = durations.iter().skip(durations.len() - warm_count).sum();
196 warm_sum / warm_count as u32
197 } else {
198 last_duration
199 };
200
201 let speedup = if last_duration.as_nanos() > 0 {
202 first_duration.as_secs_f64() / last_duration.as_secs_f64()
203 } else {
204 1.0
205 };
206
207 let result = WarmupResult {
208 total_duration,
209 inference_count,
210 first_duration,
211 last_duration,
212 average_warm,
213 speedup,
214 };
215
216 if config.verbose {
217 log::info!(
218 "[warmup] Complete: {} inferences in {:?}",
219 inference_count,
220 total_duration
221 );
222 log::info!(
223 "[warmup] First: {:?}, Last: {:?}, Speedup: {:.2}x",
224 first_duration,
225 last_duration,
226 speedup
227 );
228 }
229
230 Ok(result)
231}
232
233pub fn warmup_with_callback<M: Model, F>(
237 model: &M,
238 config: WarmupConfig,
239 mut callback: F,
240) -> Result<WarmupResult>
241where
242 F: FnMut(usize, usize, Duration),
243{
244 let start = Instant::now();
245 let total_calls = config.iterations * config.sample_texts.len();
246 let mut durations: Vec<Duration> = Vec::new();
247 let mut first_duration = Duration::ZERO;
248 let mut inference_count = 0;
249
250 'outer: for _iter in 0..config.iterations {
251 for text in &config.sample_texts {
252 if let Some(max) = config.max_duration {
253 if start.elapsed() > max {
254 break 'outer;
255 }
256 }
257
258 let call_start = Instant::now();
259 let _ = model.extract_entities(text, None)?;
260 let call_duration = call_start.elapsed();
261
262 if inference_count == 0 {
263 first_duration = call_duration;
264 }
265 durations.push(call_duration);
266 inference_count += 1;
267
268 callback(inference_count, total_calls, call_duration);
270 }
271 }
272
273 let total_duration = start.elapsed();
274 let last_duration = durations.last().copied().unwrap_or(Duration::ZERO);
275 let warm_count = durations.len() / 2;
276 let average_warm = if warm_count > 0 {
277 let warm_sum: Duration = durations.iter().skip(durations.len() - warm_count).sum();
278 warm_sum / warm_count as u32
279 } else {
280 last_duration
281 };
282
283 let speedup = if last_duration.as_nanos() > 0 {
284 first_duration.as_secs_f64() / last_duration.as_secs_f64()
285 } else {
286 1.0
287 };
288
289 Ok(WarmupResult {
290 total_duration,
291 inference_count,
292 first_duration,
293 last_duration,
294 average_warm,
295 speedup,
296 })
297}
298
299#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_warmup_config_default() {
309 let config = WarmupConfig::default();
310 assert_eq!(config.iterations, 3);
311 assert!(!config.sample_texts.is_empty());
312 assert!(config.verbose);
313 }
314
315 #[test]
316 fn test_warmup_config_builder() {
317 let config = WarmupConfig::default()
318 .with_iterations(5)
319 .with_sample("Custom text")
320 .with_max_duration(Duration::from_secs(10))
321 .quiet();
322
323 assert_eq!(config.iterations, 5);
324 assert!(config.sample_texts.iter().any(|t| t == "Custom text"));
325 assert_eq!(config.max_duration, Some(Duration::from_secs(10)));
326 assert!(!config.verbose);
327 }
328
329 #[test]
330 fn test_warmup_result_effective() {
331 let effective = WarmupResult {
332 total_duration: Duration::from_secs(1),
333 inference_count: 9,
334 first_duration: Duration::from_millis(300),
335 last_duration: Duration::from_millis(100),
336 average_warm: Duration::from_millis(110),
337 speedup: 3.0,
338 };
339 assert!(effective.is_effective());
340
341 let not_effective = WarmupResult {
342 speedup: 1.1,
343 ..effective.clone()
344 };
345 assert!(!not_effective.is_effective());
346 }
347}