use super::prometheus::ParsedMetrics;
use std::collections::HashMap;
const TRIGGER_COUNT_KEY: &str = "vllm_time_to_first_token_seconds_count";
#[derive(Clone, Debug, Default)]
pub struct HistogramBaseline {
pub counters: HashMap<String, f64>,
pub histograms: HashMap<String, Vec<(f64, f64)>>,
}
#[derive(Clone, Debug)]
enum WarmupState {
Warming { initial_total: Option<u64> },
Active { baseline: HistogramBaseline },
}
pub struct WarmupTracker {
skip_requests: u64,
state: WarmupState,
}
pub struct WarmupOutput {
pub warming_up: bool,
pub just_transitioned: bool,
pub adjusted: ParsedMetrics,
}
impl WarmupTracker {
pub fn new(skip_requests: u64) -> Self {
Self {
skip_requests,
state: WarmupState::Warming {
initial_total: None,
},
}
}
pub fn observe(&mut self, parsed: &ParsedMetrics) -> WarmupOutput {
let current_total = parsed
.counters
.get(TRIGGER_COUNT_KEY)
.map(|v| *v as u64)
.unwrap_or(0);
match &self.state {
WarmupState::Warming { initial_total } => {
let initial = match initial_total {
Some(i) => *i,
None => {
self.state = WarmupState::Warming {
initial_total: Some(current_total),
};
current_total
}
};
if current_total < initial {
self.state = WarmupState::Warming {
initial_total: Some(current_total),
};
return WarmupOutput {
warming_up: true,
just_transitioned: false,
adjusted: clone_passthrough(parsed),
};
}
if current_total - initial >= self.skip_requests {
let baseline = HistogramBaseline {
counters: parsed.counters.clone(),
histograms: parsed.histograms.clone(),
};
let adjusted = subtract_baseline(parsed, &baseline);
self.state = WarmupState::Active { baseline };
return WarmupOutput {
warming_up: false,
just_transitioned: true,
adjusted,
};
}
WarmupOutput {
warming_up: true,
just_transitioned: false,
adjusted: clone_passthrough(parsed),
}
}
WarmupState::Active { baseline } => {
if counter_regression(parsed, baseline) {
self.state = WarmupState::Warming {
initial_total: Some(current_total),
};
return WarmupOutput {
warming_up: true,
just_transitioned: false,
adjusted: clone_passthrough(parsed),
};
}
let adjusted = subtract_baseline(parsed, baseline);
WarmupOutput {
warming_up: false,
just_transitioned: false,
adjusted,
}
}
}
}
}
fn clone_passthrough(parsed: &ParsedMetrics) -> ParsedMetrics {
ParsedMetrics {
gauges: parsed.gauges.clone(),
counters: parsed.counters.clone(),
histograms: parsed.histograms.clone(),
}
}
fn counter_regression(parsed: &ParsedMetrics, baseline: &HistogramBaseline) -> bool {
for (key, baseline_val) in &baseline.counters {
if let Some(current) = parsed.counters.get(key) {
if *current < *baseline_val {
return true;
}
}
}
false
}
fn subtract_baseline(parsed: &ParsedMetrics, baseline: &HistogramBaseline) -> ParsedMetrics {
let mut counters = HashMap::with_capacity(parsed.counters.len());
for (key, current) in &parsed.counters {
let delta = match baseline.counters.get(key) {
Some(b) => (current - b).max(0.0),
None => *current,
};
counters.insert(key.clone(), delta);
}
let mut histograms = HashMap::with_capacity(parsed.histograms.len());
for (key, current_buckets) in &parsed.histograms {
match baseline.histograms.get(key) {
Some(baseline_buckets) if same_schema(current_buckets, baseline_buckets) => {
let delta = current_buckets
.iter()
.zip(baseline_buckets.iter())
.map(|((le, c), (_, b))| (*le, (c - b).max(0.0)))
.collect::<Vec<_>>();
histograms.insert(key.clone(), delta);
}
Some(_) => {
}
None => {
histograms.insert(key.clone(), current_buckets.clone());
}
}
}
ParsedMetrics {
gauges: parsed.gauges.clone(),
counters,
histograms,
}
}
fn same_schema(a: &[(f64, f64)], b: &[(f64, f64)]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|((le_a, _), (le_b, _))| {
if le_a.is_infinite() && le_b.is_infinite() {
le_a.is_sign_positive() == le_b.is_sign_positive()
} else {
le_a == le_b
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn metrics(
counters: &[(&str, f64)],
histograms: &[(&str, &[(f64, f64)])],
gauges: &[(&str, f64)],
) -> ParsedMetrics {
ParsedMetrics {
counters: counters
.iter()
.map(|(k, v)| ((*k).to_string(), *v))
.collect(),
histograms: histograms
.iter()
.map(|(k, v)| ((*k).to_string(), v.to_vec()))
.collect(),
gauges: gauges.iter().map(|(k, v)| ((*k).to_string(), *v)).collect(),
}
}
#[test]
fn first_poll_captures_initial_total_and_warms_up() {
let mut t = WarmupTracker::new(1);
let m = metrics(&[(TRIGGER_COUNT_KEY, 0.0)], &[], &[]);
let out = t.observe(&m);
assert!(out.warming_up);
assert!(!out.just_transitioned);
assert!(matches!(
t.state,
WarmupState::Warming {
initial_total: Some(0)
}
));
}
#[test]
fn transitions_to_active_when_skip_threshold_reached() {
let mut t = WarmupTracker::new(1);
let _ = t.observe(&metrics(&[(TRIGGER_COUNT_KEY, 0.0)], &[], &[]));
let m2 = metrics(
&[
(TRIGGER_COUNT_KEY, 1.0),
("vllm_time_to_first_token_seconds_sum", 2.5),
("vllm_generation_tokens_total", 100.0),
],
&[(
"vllm_time_to_first_token_seconds",
&[(0.5, 0.0), (1.0, 0.0), (5.0, 1.0), (f64::INFINITY, 1.0)],
)],
&[],
);
let out = t.observe(&m2);
assert!(!out.warming_up);
assert!(out.just_transitioned);
let buckets = out
.adjusted
.histograms
.get("vllm_time_to_first_token_seconds")
.expect("histogram present");
assert!(buckets.iter().all(|(_, c)| *c == 0.0));
assert_eq!(
out.adjusted
.counters
.get("vllm_time_to_first_token_seconds_sum"),
Some(&0.0)
);
}
#[test]
fn post_baseline_yields_correct_deltas() {
let mut t = WarmupTracker::new(1);
let _ = t.observe(&metrics(&[(TRIGGER_COUNT_KEY, 0.0)], &[], &[]));
let _ = t.observe(&metrics(
&[
(TRIGGER_COUNT_KEY, 1.0),
("vllm_time_to_first_token_seconds_sum", 4.0),
("vllm_generation_tokens_total", 50.0),
],
&[(
"vllm_time_to_first_token_seconds",
&[(0.05, 0.0), (1.0, 0.0), (5.0, 1.0), (f64::INFINITY, 1.0)],
)],
&[("vllm_num_requests_running", 1.0)],
));
let m = metrics(
&[
(TRIGGER_COUNT_KEY, 101.0),
("vllm_time_to_first_token_seconds_sum", 5.0), ("vllm_generation_tokens_total", 10050.0),
],
&[(
"vllm_time_to_first_token_seconds",
&[
(0.05, 100.0),
(1.0, 100.0),
(5.0, 101.0),
(f64::INFINITY, 101.0),
],
)],
&[("vllm_num_requests_running", 2.0)],
);
let out = t.observe(&m);
assert!(!out.warming_up);
assert!(!out.just_transitioned);
let buckets = out
.adjusted
.histograms
.get("vllm_time_to_first_token_seconds")
.expect("histogram present");
assert_eq!(buckets[0], (0.05, 100.0));
assert_eq!(buckets[1], (1.0, 100.0));
assert_eq!(buckets[2], (5.0, 100.0));
assert!(buckets[3].0.is_infinite());
assert_eq!(buckets[3].1, 100.0);
assert_eq!(
out.adjusted.counters.get("vllm_generation_tokens_total"),
Some(&10000.0)
);
let sum = out
.adjusted
.counters
.get("vllm_time_to_first_token_seconds_sum")
.copied()
.expect("sum");
assert!((sum - 1.0).abs() < 1e-9, "sum delta {sum}");
assert_eq!(
out.adjusted.gauges.get("vllm_num_requests_running"),
Some(&2.0)
);
}
#[test]
fn warming_regression_resets_initial_cursor() {
let mut t = WarmupTracker::new(3);
let _ = t.observe(&metrics(&[(TRIGGER_COUNT_KEY, 10.0)], &[], &[]));
let _ = t.observe(&metrics(&[(TRIGGER_COUNT_KEY, 0.0)], &[], &[]));
match &t.state {
WarmupState::Warming { initial_total } => assert_eq!(*initial_total, Some(0)),
_ => panic!("should still be warming after regression"),
}
}
#[test]
fn active_regression_resets_to_warming() {
let mut t = WarmupTracker::new(0);
let _ = t.observe(&metrics(
&[
(TRIGGER_COUNT_KEY, 5.0),
("vllm_generation_tokens_total", 500.0),
],
&[],
&[],
));
assert!(matches!(t.state, WarmupState::Active { .. }));
let out = t.observe(&metrics(
&[
(TRIGGER_COUNT_KEY, 1.0),
("vllm_generation_tokens_total", 10.0),
],
&[],
&[],
));
assert!(out.warming_up);
match &t.state {
WarmupState::Warming { initial_total } => assert_eq!(*initial_total, Some(1)),
_ => panic!("should be warming again"),
}
}
#[test]
fn schema_drift_drops_only_affected_histogram() {
let mut t = WarmupTracker::new(1);
let _ = t.observe(&metrics(&[(TRIGGER_COUNT_KEY, 0.0)], &[], &[]));
let _ = t.observe(&metrics(
&[(TRIGGER_COUNT_KEY, 1.0)],
&[
("histA", &[(0.1, 0.0), (1.0, 1.0), (f64::INFINITY, 1.0)]),
("histB", &[(0.5, 1.0), (f64::INFINITY, 1.0)]),
],
&[],
));
let out = t.observe(&metrics(
&[(TRIGGER_COUNT_KEY, 2.0)],
&[
(
"histA",
&[(0.1, 0.0), (0.5, 1.0), (1.0, 2.0), (f64::INFINITY, 2.0)],
),
("histB", &[(0.5, 2.0), (f64::INFINITY, 2.0)]),
],
&[],
));
assert!(!out.adjusted.histograms.contains_key("histA"));
let b = out.adjusted.histograms.get("histB").expect("histB kept");
assert_eq!(b[0], (0.5, 1.0));
}
#[test]
fn skip_zero_baselines_on_first_poll() {
let mut t = WarmupTracker::new(0);
let m = metrics(
&[
(TRIGGER_COUNT_KEY, 5.0),
("vllm_generation_tokens_total", 500.0),
],
&[],
&[],
);
let out = t.observe(&m);
assert!(!out.warming_up);
assert!(out.just_transitioned);
assert!(matches!(t.state, WarmupState::Active { .. }));
}
}