use super::histogram::{fraction_le, percentile};
use super::prometheus::parse_prometheus_text;
use super::warmup::WarmupTracker;
use super::{
EngineAdapter, EngineMetrics, EngineStatus, EngineType, LatencyPercentiles, ModelInfo,
E2E_SLO_MS, ITL_SLO_MS, TTFT_SLO_MS,
};
use async_trait::async_trait;
use serde::Deserialize;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
const DEFAULT_WARMUP_SKIP_REQUESTS: u64 = 1;
fn warmup_skip_from_env() -> u64 {
std::env::var("SPARK_WARMUP_SKIP_REQUESTS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_WARMUP_SKIP_REQUESTS)
}
fn format_param_size(count: u64) -> String {
if count >= 1_000_000_000 {
format!("{:.1}B params", count as f64 / 1_000_000_000.0)
} else if count >= 1_000_000 {
format!("{:.1}M params", count as f64 / 1_000_000.0)
} else {
format!("{} params", count)
}
}
fn format_quant_method(method: &str) -> String {
match method {
"auto-round" => "AutoRound".into(),
"gptq" => "GPTQ".into(),
"awq" => "AWQ".into(),
"bitsandbytes" => "BitsAndBytes".into(),
"fp8" => "FP8".into(),
other => other.to_string(),
}
}
fn format_precision(bits: u64) -> String {
format!("{}-bit precision", bits)
}
fn format_tensor_type(params: &std::collections::HashMap<String, u64>) -> Option<String> {
let float_keys: [&str; 5] = ["BF16", "F16", "F32", "F64", "FP8"];
for key in &float_keys {
if params.contains_key(*key) {
return Some(key.to_string());
}
}
params.keys().next().cloned()
}
pub struct VllmAdapter {
client: reqwest::Client,
endpoint: String,
served_model: Option<String>,
prev_gen_tokens: Mutex<Option<(f64, Instant)>>,
prev_prompt_tokens: Mutex<Option<(f64, Instant)>>,
avg_accum: Mutex<(f64, u64)>,
avg_prompt_accum: Mutex<(f64, u64)>,
warmup: Mutex<WarmupTracker>,
hf_model_cache: Mutex<Option<ModelInfo>>,
last_hf_error: Mutex<Option<Instant>>,
}
impl VllmAdapter {
pub fn new(client: reqwest::Client, endpoint: String, served_model: Option<String>) -> Self {
Self {
client,
endpoint,
served_model,
prev_gen_tokens: Mutex::new(None),
prev_prompt_tokens: Mutex::new(None),
avg_accum: Mutex::new((0.0, 0)),
avg_prompt_accum: Mutex::new((0.0, 0)),
warmup: Mutex::new(WarmupTracker::new(warmup_skip_from_env())),
hf_model_cache: Mutex::new(None),
last_hf_error: Mutex::new(None),
}
}
const HF_RETRY_COOLDOWN: Duration = Duration::from_secs(60);
async fn fetch_hf_model_info(&self, model_id: &str) -> Option<ModelInfo> {
{
let cache = self.hf_model_cache.lock().await;
if let Some(cached) = cache.as_ref() {
return Some(cached.clone());
}
}
{
let last = self.last_hf_error.lock().await;
if let Some(when) = *last {
if when.elapsed() < Self::HF_RETRY_COOLDOWN {
return None;
}
}
}
if !model_id.contains('/') {
return None;
}
let url = format!("https://huggingface.co/api/models/{}", model_id);
let hf_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.ok()?;
let resp = match hf_client.get(&url).send().await {
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::warn!(
endpoint = %self.endpoint,
model_id = %model_id,
status = %r.status(),
"HF model info API returned non-success",
);
*self.last_hf_error.lock().await = Some(Instant::now());
return None;
}
Err(e) => {
tracing::warn!(
endpoint = %self.endpoint,
model_id = %model_id,
error = %e,
"HF model info API request failed",
);
*self.last_hf_error.lock().await = Some(Instant::now());
return None;
}
};
let hf: HfModelResponse = match resp.json().await {
Ok(v) => v,
Err(e) => {
tracing::warn!(
endpoint = %self.endpoint,
model_id = %model_id,
error = %e,
"HF model info response deserialization failed",
);
*self.last_hf_error.lock().await = Some(Instant::now());
return None;
}
};
let parameter_size = hf
.safetensors
.as_ref()
.and_then(|s| s.total)
.map(format_param_size);
let quant_config = hf
.config
.as_ref()
.and_then(|c| c.quantization_config.as_ref());
let quantization = quant_config
.and_then(|q| q.quant_method.as_deref())
.map(format_quant_method);
let precision = quant_config.and_then(|q| q.bits).map(format_precision);
let tensor_type = hf
.safetensors
.as_ref()
.and_then(|s| s.parameters.as_ref())
.and_then(format_tensor_type);
let model_type = hf.config.as_ref().and_then(|c| c.model_type.clone());
let result = ModelInfo {
name: model_id.to_string(),
parameter_size,
quantization,
precision,
tensor_type,
model_type,
pipeline_tag: hf.pipeline_tag,
};
*self.hf_model_cache.lock().await = Some(result.clone());
tracing::info!(
endpoint = %self.endpoint,
model_id = %model_id,
parameter_size = ?result.parameter_size,
quantization = ?result.quantization,
precision = ?result.precision,
tensor_type = ?result.tensor_type,
model_type = ?result.model_type,
pipeline_tag = ?result.pipeline_tag,
"Fetched model info from HuggingFace",
);
Some(result)
}
}
#[derive(Deserialize)]
struct OpenAIModelsResponse {
#[serde(default)]
data: Vec<OpenAIModel>,
}
#[derive(Deserialize)]
struct OpenAIModel {
id: String,
}
#[derive(Deserialize)]
struct HfModelResponse {
pipeline_tag: Option<String>,
safetensors: Option<HfSafetensors>,
#[serde(default)]
config: Option<HfConfig>,
}
#[derive(Deserialize)]
struct HfSafetensors {
total: Option<u64>,
#[serde(default)]
parameters: Option<std::collections::HashMap<String, u64>>,
}
#[derive(Deserialize)]
struct HfConfig {
#[serde(default)]
model_type: Option<String>,
quantization_config: Option<HfQuantizationConfig>,
}
#[derive(Deserialize)]
struct HfQuantizationConfig {
bits: Option<u64>,
quant_method: Option<String>,
}
#[async_trait]
impl EngineAdapter for VllmAdapter {
fn engine_type(&self) -> EngineType {
EngineType::Vllm
}
fn endpoint(&self) -> &str {
&self.endpoint
}
async fn health_check(&self) -> EngineStatus {
match self
.client
.get(format!("{}/health", self.endpoint))
.timeout(Duration::from_secs(2))
.send()
.await
{
Ok(r) if r.status().is_success() => EngineStatus::Running,
Ok(r) => EngineStatus::Error(format!("HTTP {}", r.status())),
Err(e) => EngineStatus::Error(e.to_string()),
}
}
async fn get_model_info(&self) -> Option<ModelInfo> {
let api_id: Option<String> = async {
let resp = self
.client
.get(format!("{}/v1/models", self.endpoint))
.timeout(Duration::from_secs(2))
.send()
.await
.ok()?;
let models: OpenAIModelsResponse = resp.json().await.ok()?;
models.data.first().map(|m| m.id.clone())
}
.await;
let name = match (&api_id, &self.served_model) {
(Some(id), _) if id.contains('/') => Some(id.clone()),
(_, Some(hint)) => Some(hint.clone()),
(Some(id), None) => Some(id.clone()),
(None, None) => None,
}?;
if let Some(enriched) = self.fetch_hf_model_info(&name).await {
return Some(enriched);
}
Some(ModelInfo {
name,
parameter_size: None,
quantization: None,
precision: None,
tensor_type: None,
model_type: None,
pipeline_tag: None,
})
}
async fn get_metrics(&self) -> Option<EngineMetrics> {
let body = self
.client
.get(format!("{}/metrics", self.endpoint))
.timeout(Duration::from_secs(2))
.send()
.await
.ok()?
.text()
.await
.ok()?;
let raw = parse_prometheus_text(&body)?;
let warmup_out = {
let mut tracker = self.warmup.lock().await;
tracker.observe(&raw)
};
if warmup_out.just_transitioned {
*self.prev_gen_tokens.lock().await = None;
*self.prev_prompt_tokens.lock().await = None;
*self.avg_accum.lock().await = (0.0, 0);
*self.avg_prompt_accum.lock().await = (0.0, 0);
tracing::info!(
endpoint = %self.endpoint,
"warmup complete — baseline captured, steady-state metrics now reported"
);
}
let parsed = &warmup_out.adjusted;
let warming_up = warmup_out.warming_up;
let active_requests = parsed
.gauges
.get("vllm_num_requests_running")
.map(|v| *v as u64);
let queued_requests = parsed
.gauges
.get("vllm_num_requests_waiting")
.map(|v| *v as u64);
let kv_cache_percent = parsed
.gauges
.get("vllm_kv_cache_usage_perc")
.or_else(|| parsed.gauges.get("vllm_gpu_cache_usage_perc"))
.map(|v| v * 100.0);
let ttft_count = parsed
.counters
.get("vllm_time_to_first_token_seconds_count");
let ttft_ms = {
let sum = parsed.counters.get("vllm_time_to_first_token_seconds_sum");
match (sum, ttft_count) {
(Some(&s), Some(&c)) if c > 0.0 => Some((s / c) * 1000.0),
_ => None,
}
};
let total_requests = raw
.counters
.get("vllm_time_to_first_token_seconds_count")
.map(|&c| c as u64);
let per_request_tps = {
let sum = parsed
.counters
.get("vllm_request_time_per_output_token_seconds_sum")
.or_else(|| {
parsed
.counters
.get("vllm_time_per_output_token_seconds_sum")
});
let count = parsed
.counters
.get("vllm_request_time_per_output_token_seconds_count")
.or_else(|| {
parsed
.counters
.get("vllm_time_per_output_token_seconds_count")
});
match (sum, count) {
(Some(&s), Some(&c)) if c > 0.0 && s > 0.0 => Some(c / s),
_ => None,
}
};
let current_gen = parsed.counters.get("vllm_generation_tokens_total").copied();
let now = Instant::now();
let tokens_per_sec = {
let mut prev_lock = self.prev_gen_tokens.lock().await;
let tps = match (current_gen, prev_lock.as_ref()) {
(Some(current), Some(&(prev_val, prev_time))) => {
let elapsed = now.duration_since(prev_time).as_secs_f64();
if elapsed > 0.0 {
Some((current - prev_val) / elapsed)
} else {
None
}
}
_ => None,
};
if let Some(val) = current_gen {
*prev_lock = Some((val, now));
}
tps
};
let current_prompt = parsed.counters.get("vllm_prompt_tokens_total").copied();
let prompt_tokens_per_sec = {
let mut prev_lock = self.prev_prompt_tokens.lock().await;
let tps = match (current_prompt, prev_lock.as_ref()) {
(Some(current), Some(&(prev_val, prev_time))) => {
let elapsed = now.duration_since(prev_time).as_secs_f64();
if elapsed > 0.0 {
Some((current - prev_val) / elapsed)
} else {
None
}
}
_ => None,
};
if let Some(val) = current_prompt {
*prev_lock = Some((val, now));
}
tps
};
let avg_tokens_per_sec = {
let mut accum = self.avg_accum.lock().await;
if let Some(tps) = tokens_per_sec {
if tps > 0.0 {
accum.0 += tps;
accum.1 += 1;
}
}
if accum.1 > 0 {
Some(accum.0 / accum.1 as f64)
} else {
None
}
};
let avg_prompt_tokens_per_sec = {
let mut accum = self.avg_prompt_accum.lock().await;
if let Some(tps) = prompt_tokens_per_sec {
if tps > 0.0 {
accum.0 += tps;
accum.1 += 1;
}
}
if accum.1 > 0 {
Some(accum.0 / accum.1 as f64)
} else {
None
}
};
let per_request_prompt_tps = {
let prompt_total = parsed.counters.get("vllm_prompt_tokens_total");
let ttft_sum = parsed.counters.get("vllm_time_to_first_token_seconds_sum");
match (prompt_total, ttft_sum) {
(Some(&p), Some(&t)) if t > 0.0 => Some(p / t),
_ => None,
}
};
let e2e_latency_ms = {
let sum = parsed.counters.get("vllm_e2e_request_latency_seconds_sum");
let count = parsed
.counters
.get("vllm_e2e_request_latency_seconds_count");
match (sum, count) {
(Some(&s), Some(&c)) if c > 0.0 => Some((s / c) * 1000.0),
_ => None,
}
};
let swapped_requests = parsed
.gauges
.get("vllm_num_requests_swapped")
.map(|v| *v as u64);
let prefix_cache_hit_rate = {
let hits = parsed.counters.get("vllm_prefix_cache_hits_total");
let queries = parsed.counters.get("vllm_prefix_cache_queries_total");
match (hits, queries) {
(Some(&h), Some(&q)) if q > 0.0 => Some((h / q) * 100.0),
_ => None,
}
};
let queue_time_ms = {
let sum = parsed.counters.get("vllm_request_queue_time_seconds_sum");
let count = parsed.counters.get("vllm_request_queue_time_seconds_count");
match (sum, count) {
(Some(&s), Some(&c)) if c > 0.0 => Some((s / c) * 1000.0),
_ => None,
}
};
let preemptions_total = raw
.counters
.get("vllm_num_preemptions_total")
.map(|v| *v as u64);
let avg_batch_size = {
let sum = parsed.counters.get("vllm_iteration_tokens_total_sum");
let count = parsed.counters.get("vllm_iteration_tokens_total_count");
match (sum, count) {
(Some(&s), Some(&c)) if c > 0.0 => Some(s / c),
_ => None,
}
};
let inter_token_latency_ms = {
let sum = parsed.counters.get("vllm_inter_token_latency_seconds_sum");
let count = parsed
.counters
.get("vllm_inter_token_latency_seconds_count");
match (sum, count) {
(Some(&s), Some(&c)) if c > 0.0 => Some((s / c) * 1000.0),
_ => None,
}
};
let percentiles_ms = |metric: &str| -> Option<LatencyPercentiles> {
let buckets = parsed.histograms.get(metric)?;
let to_ms = |q: f64| percentile(buckets, q).map(|s| s * 1000.0);
let p = LatencyPercentiles {
p50_ms: to_ms(0.50),
p95_ms: to_ms(0.95),
p99_ms: to_ms(0.99),
};
if p.p50_ms.is_none() && p.p95_ms.is_none() && p.p99_ms.is_none() {
None
} else {
Some(p)
}
};
let ttft_percentiles = percentiles_ms("vllm_time_to_first_token_seconds");
let itl_percentiles = percentiles_ms("vllm_inter_token_latency_seconds");
let e2e_percentiles = percentiles_ms("vllm_e2e_request_latency_seconds");
let goodput_pct = |metric: &str, slo_ms: f64| -> Option<f64> {
let buckets = parsed.histograms.get(metric)?;
fraction_le(buckets, slo_ms / 1000.0).map(|f| f * 100.0)
};
let ttft_goodput_pct = goodput_pct("vllm_time_to_first_token_seconds", TTFT_SLO_MS);
let itl_goodput_pct = goodput_pct("vllm_inter_token_latency_seconds", ITL_SLO_MS);
let e2e_goodput_pct = goodput_pct("vllm_e2e_request_latency_seconds", E2E_SLO_MS);
let blank = warming_up;
Some(EngineMetrics {
tokens_per_sec: if blank { None } else { tokens_per_sec },
avg_tokens_per_sec: if blank { None } else { avg_tokens_per_sec },
per_request_tps: if blank { None } else { per_request_tps },
ttft_ms: if blank { None } else { ttft_ms },
active_requests,
queued_requests,
kv_cache_percent,
kv_cache_is_estimated: false,
total_requests,
e2e_latency_ms: if blank { None } else { e2e_latency_ms },
prompt_tokens_per_sec: if blank { None } else { prompt_tokens_per_sec },
avg_prompt_tokens_per_sec: if blank {
None
} else {
avg_prompt_tokens_per_sec
},
per_request_prompt_tps: if blank { None } else { per_request_prompt_tps },
swapped_requests,
prefix_cache_hit_rate,
queue_time_ms: if blank { None } else { queue_time_ms },
inter_token_latency_ms: if blank { None } else { inter_token_latency_ms },
preemptions_total,
avg_batch_size: if blank { None } else { avg_batch_size },
ttft_percentiles: if blank { None } else { ttft_percentiles },
itl_percentiles: if blank { None } else { itl_percentiles },
e2e_percentiles: if blank { None } else { e2e_percentiles },
ttft_goodput_pct: if blank { None } else { ttft_goodput_pct },
itl_goodput_pct: if blank { None } else { itl_goodput_pct },
e2e_goodput_pct: if blank { None } else { e2e_goodput_pct },
warming_up,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ttft_percentiles_roundtrip_from_metrics_body() {
let body = "\
# HELP vllm:time_to_first_token_seconds TTFT histogram.
# TYPE vllm:time_to_first_token_seconds histogram
vllm:time_to_first_token_seconds_bucket{le=\"0.05\"} 50
vllm:time_to_first_token_seconds_bucket{le=\"0.1\"} 80
vllm:time_to_first_token_seconds_bucket{le=\"0.5\"} 95
vllm:time_to_first_token_seconds_bucket{le=\"1.0\"} 99
vllm:time_to_first_token_seconds_bucket{le=\"+Inf\"} 100
vllm:time_to_first_token_seconds_sum 12.0
vllm:time_to_first_token_seconds_count 100.0
";
let parsed = parse_prometheus_text(body).expect("parse");
let buckets = parsed
.histograms
.get("vllm_time_to_first_token_seconds")
.expect("histogram");
let p50 = percentile(buckets, 0.5).expect("p50") * 1000.0;
let p95 = percentile(buckets, 0.95).expect("p95") * 1000.0;
let p99 = percentile(buckets, 0.99).expect("p99") * 1000.0;
assert!(p50 < p95, "p50 {p50} < p95 {p95}");
assert!(p95 < p99, "p95 {p95} < p99 {p99}");
assert!((40.0..=60.0).contains(&p50), "p50 {p50} near 50ms");
assert!(p99 > 500.0 && p99 <= 1000.0, "p99 {p99} in (500, 1000]");
}
#[test]
fn warmup_tracker_excludes_first_observation_from_percentiles() {
use super::super::warmup::WarmupTracker;
let body_warmup = "\
# HELP vllm:time_to_first_token_seconds TTFT histogram.
# TYPE vllm:time_to_first_token_seconds histogram
vllm:time_to_first_token_seconds_bucket{le=\"0.05\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"0.1\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"0.5\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"1.0\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"+Inf\"} 1
vllm:time_to_first_token_seconds_sum 8.0
vllm:time_to_first_token_seconds_count 1.0
";
let body_steady = "\
# HELP vllm:time_to_first_token_seconds TTFT histogram.
# TYPE vllm:time_to_first_token_seconds histogram
vllm:time_to_first_token_seconds_bucket{le=\"0.05\"} 100
vllm:time_to_first_token_seconds_bucket{le=\"0.1\"} 100
vllm:time_to_first_token_seconds_bucket{le=\"0.5\"} 100
vllm:time_to_first_token_seconds_bucket{le=\"1.0\"} 100
vllm:time_to_first_token_seconds_bucket{le=\"+Inf\"} 101
vllm:time_to_first_token_seconds_sum 9.0
vllm:time_to_first_token_seconds_count 101.0
";
let body_idle = "\
# HELP vllm:time_to_first_token_seconds TTFT histogram.
# TYPE vllm:time_to_first_token_seconds histogram
vllm:time_to_first_token_seconds_bucket{le=\"0.05\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"0.1\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"0.5\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"1.0\"} 0
vllm:time_to_first_token_seconds_bucket{le=\"+Inf\"} 0
vllm:time_to_first_token_seconds_sum 0.0
vllm:time_to_first_token_seconds_count 0.0
";
let mut tracker = WarmupTracker::new(1);
let parsed_idle = parse_prometheus_text(body_idle).expect("parse idle");
let out_idle = tracker.observe(&parsed_idle);
assert!(out_idle.warming_up);
assert!(!out_idle.just_transitioned);
let parsed_warmup = parse_prometheus_text(body_warmup).expect("parse warmup");
let out_warmup = tracker.observe(&parsed_warmup);
assert!(!out_warmup.warming_up);
assert!(out_warmup.just_transitioned);
let parsed_steady = parse_prometheus_text(body_steady).expect("parse steady");
let out_steady = tracker.observe(&parsed_steady);
assert!(!out_steady.warming_up);
assert!(!out_steady.just_transitioned);
let buckets = out_steady
.adjusted
.histograms
.get("vllm_time_to_first_token_seconds")
.expect("histogram");
let p50 = percentile(buckets, 0.5).expect("p50") * 1000.0;
let p95 = percentile(buckets, 0.95).expect("p95") * 1000.0;
assert!(p50 <= 50.0, "p50 {p50} should be in fast bucket (<=50ms)");
assert!(p95 <= 50.0, "p95 {p95} should be in fast bucket (<=50ms)");
let sum = out_steady
.adjusted
.counters
.get("vllm_time_to_first_token_seconds_sum")
.copied()
.expect("sum delta");
let count = out_steady
.adjusted
.counters
.get("vllm_time_to_first_token_seconds_count")
.copied()
.expect("count delta");
assert!((sum - 1.0).abs() < 1e-9, "sum delta {sum}");
assert!((count - 100.0).abs() < 1e-9, "count delta {count}");
}
#[test]
fn format_param_size_formats_billions() {
assert_eq!(format_param_size(11_823_991_872), "11.8B params");
assert_eq!(format_param_size(7_000_000_000), "7.0B params");
}
#[test]
fn format_param_size_formats_millions() {
assert_eq!(format_param_size(500_000_000), "500.0M params");
assert_eq!(format_param_size(1_000_000), "1.0M params");
}
#[test]
fn format_quant_method_formats_known_methods() {
assert_eq!(format_quant_method("auto-round"), "AutoRound");
assert_eq!(format_quant_method("gptq"), "GPTQ");
assert_eq!(format_quant_method("awq"), "AWQ");
assert_eq!(format_quant_method("bitsandbytes"), "BitsAndBytes");
assert_eq!(format_quant_method("fp8"), "FP8");
}
#[test]
fn format_quant_method_passes_through_unknown() {
assert_eq!(format_quant_method("some-new-method"), "some-new-method");
}
#[test]
fn format_precision_produces_label() {
assert_eq!(format_precision(4), "4-bit precision");
assert_eq!(format_precision(8), "8-bit precision");
}
#[test]
fn format_tensor_type_prefers_float_dtypes() {
let mut params = std::collections::HashMap::new();
params.insert("BF16".into(), 1000);
params.insert("I32".into(), 5000);
assert_eq!(format_tensor_type(¶ms), Some("BF16".into()));
}
#[test]
fn format_tensor_type_falls_back_to_first_key() {
let mut params = std::collections::HashMap::new();
params.insert("I32".into(), 5000);
assert_eq!(format_tensor_type(¶ms), Some("I32".into()));
}
#[test]
fn format_tensor_type_returns_none_for_empty() {
let params = std::collections::HashMap::new();
assert_eq!(format_tensor_type(¶ms), None);
}
}