use std::{
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmupConfig {
pub warmup_iterations: usize,
pub timeout: Duration,
pub sample_prompt: String,
pub sample_max_tokens: usize,
pub validate_output: bool,
pub gc_after_warmup: bool,
pub verbose: bool,
}
impl Default for WarmupConfig {
fn default() -> Self {
Self {
warmup_iterations: 3,
timeout: Duration::from_secs(60),
sample_prompt: "Hello, world!".to_string(),
sample_max_tokens: 10,
validate_output: true,
gc_after_warmup: true,
verbose: false,
}
}
}
impl WarmupConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_warmup_iterations(mut self, n: usize) -> Self {
self.warmup_iterations = n.max(1);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn with_sample_prompt(mut self, prompt: impl Into<String>) -> Self {
self.sample_prompt = prompt.into();
self
}
#[must_use]
pub fn with_sample_max_tokens(mut self, n: usize) -> Self {
self.sample_max_tokens = n;
self
}
#[must_use]
pub fn with_validate_output(mut self, validate: bool) -> Self {
self.validate_output = validate;
self
}
#[must_use]
pub fn with_gc_after_warmup(mut self, gc: bool) -> Self {
self.gc_after_warmup = gc;
self
}
#[must_use]
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WarmupStatus {
NotStarted,
InProgress,
Ready,
Failed,
TimedOut,
}
impl WarmupStatus {
#[must_use]
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready)
}
#[must_use]
pub fn is_in_progress(&self) -> bool {
matches!(self, Self::InProgress)
}
#[must_use]
pub fn has_failed(&self) -> bool {
matches!(self, Self::Failed | Self::TimedOut)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WarmupResult {
pub status: WarmupStatus,
pub iterations_completed: usize,
pub total_duration: Duration,
pub avg_latency: Duration,
pub first_latency: Duration,
pub last_latency: Duration,
pub speedup_factor: f64,
pub error: Option<String>,
}
impl WarmupResult {
#[must_use]
pub fn success(iterations: usize, duration: Duration, latencies: &[Duration]) -> Self {
let first = latencies.first().copied().unwrap_or(Duration::ZERO);
let last = latencies.last().copied().unwrap_or(Duration::ZERO);
let avg = if latencies.is_empty() {
Duration::ZERO
} else {
Duration::from_nanos(
latencies.iter().map(|d| d.as_nanos() as u64).sum::<u64>() / latencies.len() as u64,
)
};
let speedup = if last.as_nanos() > 0 {
first.as_nanos() as f64 / last.as_nanos() as f64
} else {
1.0
};
Self {
status: WarmupStatus::Ready,
iterations_completed: iterations,
total_duration: duration,
avg_latency: avg,
first_latency: first,
last_latency: last,
speedup_factor: speedup,
error: None,
}
}
#[must_use]
pub fn failed(error: impl Into<String>, iterations: usize, duration: Duration) -> Self {
Self {
status: WarmupStatus::Failed,
iterations_completed: iterations,
total_duration: duration,
avg_latency: Duration::ZERO,
first_latency: Duration::ZERO,
last_latency: Duration::ZERO,
speedup_factor: 0.0,
error: Some(error.into()),
}
}
#[must_use]
pub fn timed_out(iterations: usize, duration: Duration) -> Self {
Self {
status: WarmupStatus::TimedOut,
iterations_completed: iterations,
total_duration: duration,
avg_latency: Duration::ZERO,
first_latency: Duration::ZERO,
last_latency: Duration::ZERO,
speedup_factor: 0.0,
error: Some("Warm-up timed out".to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct ModelHealth {
ready: Arc<AtomicBool>,
status: Arc<std::sync::RwLock<WarmupStatus>>,
requests_served: Arc<AtomicU64>,
requests_failed: Arc<AtomicU64>,
last_health_check: Arc<std::sync::RwLock<Instant>>,
loaded_at: Instant,
}
impl Default for ModelHealth {
fn default() -> Self {
Self::new()
}
}
impl ModelHealth {
#[must_use]
pub fn new() -> Self {
Self {
ready: Arc::new(AtomicBool::new(false)),
status: Arc::new(std::sync::RwLock::new(WarmupStatus::NotStarted)),
requests_served: Arc::new(AtomicU64::new(0)),
requests_failed: Arc::new(AtomicU64::new(0)),
last_health_check: Arc::new(std::sync::RwLock::new(Instant::now())),
loaded_at: Instant::now(),
}
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
pub fn set_ready(&self, ready: bool) {
self.ready.store(ready, Ordering::Release);
}
#[must_use]
pub fn status(&self) -> WarmupStatus {
*self.status.read().unwrap()
}
pub fn set_status(&self, status: WarmupStatus) {
*self.status.write().unwrap() = status;
if status == WarmupStatus::Ready {
self.set_ready(true);
}
}
pub fn record_success(&self) {
self.requests_served.fetch_add(1, Ordering::Relaxed);
}
pub fn record_failure(&self) {
self.requests_failed.fetch_add(1, Ordering::Relaxed);
}
#[must_use]
pub fn total_requests(&self) -> u64 {
self.requests_served.load(Ordering::Relaxed)
}
#[must_use]
pub fn failed_requests(&self) -> u64 {
self.requests_failed.load(Ordering::Relaxed)
}
#[must_use]
pub fn error_rate(&self) -> f64 {
let total = self.total_requests();
let failed = self.failed_requests();
if total == 0 {
0.0
} else {
failed as f64 / total as f64
}
}
pub fn touch(&self) {
*self.last_health_check.write().unwrap() = Instant::now();
}
#[must_use]
pub fn uptime(&self) -> Duration {
self.loaded_at.elapsed()
}
#[must_use]
pub fn time_since_last_check(&self) -> Duration {
self.last_health_check.read().unwrap().elapsed()
}
#[must_use]
pub fn report(&self) -> HealthReport {
HealthReport {
ready: self.is_ready(),
status: self.status(),
uptime_secs: self.uptime().as_secs_f64(),
total_requests: self.total_requests(),
failed_requests: self.failed_requests(),
error_rate: self.error_rate(),
time_since_last_check_secs: self.time_since_last_check().as_secs_f64(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthReport {
pub ready: bool,
pub status: WarmupStatus,
pub uptime_secs: f64,
pub total_requests: u64,
pub failed_requests: u64,
pub error_rate: f64,
pub time_since_last_check_secs: f64,
}
#[derive(Debug, Clone)]
pub struct WarmupExecutor {
config: WarmupConfig,
}
impl WarmupExecutor {
#[must_use]
pub fn new(config: WarmupConfig) -> Self {
Self { config }
}
#[must_use]
pub fn config(&self) -> &WarmupConfig {
&self.config
}
#[must_use]
pub fn simulate_warmup(&self) -> WarmupResult {
let start = Instant::now();
let mut latencies = Vec::with_capacity(self.config.warmup_iterations);
for i in 0..self.config.warmup_iterations {
let base_latency_us = if i == 0 { 1000 } else { 100 };
let jitter = (i * 10) as u64;
let latency = Duration::from_micros(base_latency_us - jitter.min(50));
latencies.push(latency);
}
WarmupResult::success(self.config.warmup_iterations, start.elapsed(), &latencies)
}
#[allow(dead_code)]
fn check_timeout(&self, start: Instant, iterations: usize) -> Option<WarmupResult> {
if start.elapsed() > self.config.timeout {
Some(WarmupResult::timed_out(iterations, start.elapsed()))
} else {
None
}
}
}
impl Default for WarmupExecutor {
fn default() -> Self {
Self::new(WarmupConfig::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreloadConfig {
pub models: Vec<PreloadModelConfig>,
pub parallel_loading: bool,
pub max_concurrent: usize,
pub fail_fast: bool,
}
impl Default for PreloadConfig {
fn default() -> Self {
Self {
models: Vec::new(),
parallel_loading: true,
max_concurrent: 4,
fail_fast: false,
}
}
}
impl PreloadConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_model(mut self, model: PreloadModelConfig) -> Self {
self.models.push(model);
self
}
#[must_use]
pub fn with_parallel_loading(mut self, parallel: bool) -> Self {
self.parallel_loading = parallel;
self
}
#[must_use]
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max.max(1);
self
}
#[must_use]
pub fn with_fail_fast(mut self, fail_fast: bool) -> Self {
self.fail_fast = fail_fast;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreloadModelConfig {
pub model_id: String,
pub uri: String,
pub priority: u32,
pub warmup: bool,
pub warmup_config: Option<WarmupConfig>,
}
impl PreloadModelConfig {
#[must_use]
pub fn new(model_id: impl Into<String>, uri: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
uri: uri.into(),
priority: 100,
warmup: true,
warmup_config: None,
}
}
#[must_use]
pub fn with_priority(mut self, priority: u32) -> Self {
self.priority = priority;
self
}
#[must_use]
pub fn with_warmup(mut self, warmup: bool) -> Self {
self.warmup = warmup;
self
}
#[must_use]
pub fn with_warmup_config(mut self, config: WarmupConfig) -> Self {
self.warmup_config = Some(config);
self.warmup = true;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_warmup_config_default() {
let config = WarmupConfig::default();
assert_eq!(config.warmup_iterations, 3);
assert_eq!(config.timeout, Duration::from_secs(60));
assert_eq!(config.sample_max_tokens, 10);
assert!(config.validate_output);
}
#[test]
fn test_warmup_config_builder() {
let config = WarmupConfig::new()
.with_warmup_iterations(5)
.with_timeout(Duration::from_secs(120))
.with_sample_prompt("Test prompt")
.with_sample_max_tokens(20)
.with_validate_output(false)
.with_verbose(true);
assert_eq!(config.warmup_iterations, 5);
assert_eq!(config.timeout, Duration::from_secs(120));
assert_eq!(config.sample_prompt, "Test prompt");
assert_eq!(config.sample_max_tokens, 20);
assert!(!config.validate_output);
assert!(config.verbose);
}
#[test]
fn test_warmup_config_min_iterations() {
let config = WarmupConfig::new().with_warmup_iterations(0);
assert_eq!(config.warmup_iterations, 1);
}
#[test]
fn test_warmup_status_is_ready() {
assert!(!WarmupStatus::NotStarted.is_ready());
assert!(!WarmupStatus::InProgress.is_ready());
assert!(WarmupStatus::Ready.is_ready());
assert!(!WarmupStatus::Failed.is_ready());
assert!(!WarmupStatus::TimedOut.is_ready());
}
#[test]
fn test_warmup_status_is_in_progress() {
assert!(!WarmupStatus::NotStarted.is_in_progress());
assert!(WarmupStatus::InProgress.is_in_progress());
assert!(!WarmupStatus::Ready.is_in_progress());
}
#[test]
fn test_warmup_status_has_failed() {
assert!(!WarmupStatus::NotStarted.has_failed());
assert!(!WarmupStatus::InProgress.has_failed());
assert!(!WarmupStatus::Ready.has_failed());
assert!(WarmupStatus::Failed.has_failed());
assert!(WarmupStatus::TimedOut.has_failed());
}
#[test]
fn test_warmup_result_success() {
let latencies = vec![
Duration::from_millis(100),
Duration::from_millis(50),
Duration::from_millis(25),
];
let result = WarmupResult::success(3, Duration::from_millis(200), &latencies);
assert_eq!(result.status, WarmupStatus::Ready);
assert_eq!(result.iterations_completed, 3);
assert_eq!(result.first_latency, Duration::from_millis(100));
assert_eq!(result.last_latency, Duration::from_millis(25));
assert!(result.speedup_factor > 1.0);
assert!(result.error.is_none());
}
#[test]
fn test_warmup_result_failed() {
let result = WarmupResult::failed("Test error", 2, Duration::from_secs(5));
assert_eq!(result.status, WarmupStatus::Failed);
assert_eq!(result.iterations_completed, 2);
assert_eq!(result.error, Some("Test error".to_string()));
}
#[test]
fn test_warmup_result_timed_out() {
let result = WarmupResult::timed_out(1, Duration::from_secs(60));
assert_eq!(result.status, WarmupStatus::TimedOut);
assert!(result.error.is_some());
assert!(result.error.unwrap().contains("timed out"));
}
#[test]
fn test_warmup_result_empty_latencies() {
let result = WarmupResult::success(0, Duration::ZERO, &[]);
assert_eq!(result.first_latency, Duration::ZERO);
assert_eq!(result.last_latency, Duration::ZERO);
assert!((result.speedup_factor - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_model_health_new() {
let health = ModelHealth::new();
assert!(!health.is_ready());
assert_eq!(health.status(), WarmupStatus::NotStarted);
assert_eq!(health.total_requests(), 0);
}
#[test]
fn test_model_health_set_ready() {
let health = ModelHealth::new();
health.set_ready(true);
assert!(health.is_ready());
health.set_ready(false);
assert!(!health.is_ready());
}
#[test]
fn test_model_health_set_status() {
let health = ModelHealth::new();
health.set_status(WarmupStatus::InProgress);
assert_eq!(health.status(), WarmupStatus::InProgress);
assert!(!health.is_ready());
health.set_status(WarmupStatus::Ready);
assert_eq!(health.status(), WarmupStatus::Ready);
assert!(health.is_ready());
}
#[test]
fn test_model_health_record_requests() {
let health = ModelHealth::new();
health.record_success();
health.record_success();
health.record_failure();
assert_eq!(health.total_requests(), 2);
assert_eq!(health.failed_requests(), 1);
}
#[test]
fn test_model_health_error_rate() {
let health = ModelHealth::new();
assert!((health.error_rate() - 0.0).abs() < f64::EPSILON);
health.record_success();
health.record_success();
health.record_failure();
assert!((health.error_rate() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_model_health_touch() {
let health = ModelHealth::new();
std::thread::sleep(Duration::from_millis(10));
let before = health.time_since_last_check();
health.touch();
let after = health.time_since_last_check();
assert!(after < before);
}
#[test]
fn test_model_health_report() {
let health = ModelHealth::new();
health.set_status(WarmupStatus::Ready);
health.record_success();
let report = health.report();
assert!(report.ready);
assert_eq!(report.status, WarmupStatus::Ready);
assert_eq!(report.total_requests, 1);
assert_eq!(report.failed_requests, 0);
assert!(report.uptime_secs >= 0.0);
}
#[test]
fn test_warmup_executor_new() {
let config = WarmupConfig::new().with_warmup_iterations(5);
let executor = WarmupExecutor::new(config.clone());
assert_eq!(executor.config().warmup_iterations, 5);
}
#[test]
fn test_warmup_executor_simulate() {
let config = WarmupConfig::new().with_warmup_iterations(3);
let executor = WarmupExecutor::new(config);
let result = executor.simulate_warmup();
assert_eq!(result.status, WarmupStatus::Ready);
assert_eq!(result.iterations_completed, 3);
assert!(result.first_latency > Duration::ZERO);
assert!(result.last_latency > Duration::ZERO);
assert!(result.first_latency > result.last_latency);
assert!(result.speedup_factor > 1.0);
}
#[test]
fn test_warmup_executor_check_timeout() {
let config = WarmupConfig::new().with_timeout(Duration::from_millis(1));
let executor = WarmupExecutor::new(config);
let start = Instant::now();
std::thread::sleep(Duration::from_millis(10));
let result = executor.check_timeout(start, 0);
assert!(result.is_some());
assert_eq!(result.unwrap().status, WarmupStatus::TimedOut);
}
#[test]
fn test_warmup_executor_default() {
let executor = WarmupExecutor::default();
assert_eq!(executor.config().warmup_iterations, 3);
}
#[test]
fn test_preload_config_default() {
let config = PreloadConfig::default();
assert!(config.models.is_empty());
assert!(config.parallel_loading);
assert_eq!(config.max_concurrent, 4);
assert!(!config.fail_fast);
}
#[test]
fn test_preload_config_builder() {
let model = PreloadModelConfig::new("llama", "pacha://llama:7b");
let config = PreloadConfig::new()
.with_model(model)
.with_parallel_loading(false)
.with_max_concurrent(2)
.with_fail_fast(true);
assert_eq!(config.models.len(), 1);
assert!(!config.parallel_loading);
assert_eq!(config.max_concurrent, 2);
assert!(config.fail_fast);
}
#[test]
fn test_preload_config_min_concurrent() {
let config = PreloadConfig::new().with_max_concurrent(0);
assert_eq!(config.max_concurrent, 1);
}
#[test]
fn test_preload_model_config_new() {
let model = PreloadModelConfig::new("gpt2", "hf://gpt2");
assert_eq!(model.model_id, "gpt2");
assert_eq!(model.uri, "hf://gpt2");
assert_eq!(model.priority, 100);
assert!(model.warmup);
assert!(model.warmup_config.is_none());
}
#[test]
fn test_preload_model_config_builder() {
let warmup_config = WarmupConfig::new().with_warmup_iterations(5);
let model = PreloadModelConfig::new("llama", "file://model.gguf")
.with_priority(10)
.with_warmup(true)
.with_warmup_config(warmup_config);
assert_eq!(model.priority, 10);
assert!(model.warmup);
assert!(model.warmup_config.is_some());
assert_eq!(model.warmup_config.unwrap().warmup_iterations, 5);
}
#[test]
fn test_preload_model_config_with_warmup_enables_warmup() {
let model = PreloadModelConfig::new("test", "file://test.gguf")
.with_warmup(false)
.with_warmup_config(WarmupConfig::new());
assert!(model.warmup);
}
#[test]
fn test_warmup_config_serialization() {
let config = WarmupConfig::new()
.with_warmup_iterations(5)
.with_sample_prompt("Hello");
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("5"));
assert!(json.contains("Hello"));
let deserialized: WarmupConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.warmup_iterations, 5);
}
#[test]
fn test_warmup_status_serialization() {
let status = WarmupStatus::Ready;
let json = serde_json::to_string(&status).unwrap();
let deserialized: WarmupStatus = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, WarmupStatus::Ready);
}
#[test]
fn test_warmup_result_serialization() {
let result =
WarmupResult::success(3, Duration::from_millis(100), &[Duration::from_millis(50)]);
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Ready"));
assert!(json.contains('3'));
}
#[test]
fn test_health_report_serialization() {
let report = HealthReport {
ready: true,
status: WarmupStatus::Ready,
uptime_secs: 100.0,
total_requests: 1000,
failed_requests: 5,
error_rate: 0.005,
time_since_last_check_secs: 1.5,
};
let json = serde_json::to_string(&report).unwrap();
assert!(json.contains("true"));
assert!(json.contains("1000"));
assert!(json.contains("0.005"));
}
}