use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use utoipa::ToSchema;
use crate::http::service::metrics::{
WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE, WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE,
WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE,
};
use crate::protocols::openai::nvext::WorkerIdInfo;
pub const WORKER_TYPE_PREFILL: &str = "prefill";
pub const WORKER_TYPE_DECODE: &str = "decode";
const UNSET_DP_RANK_LABEL: &str = "none";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RequestPhase {
Prefill,
Decode,
#[default]
Aggregated,
}
impl std::fmt::Display for RequestPhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RequestPhase::Prefill => write!(f, "prefill"),
RequestPhase::Decode => write!(f, "decode"),
RequestPhase::Aggregated => write!(f, "aggregated"),
}
}
}
#[derive(Debug)]
pub struct RequestTracker {
request_received: Instant,
request_received_epoch_ms: u64,
prefill_start_time: OnceLock<Instant>,
first_token_time: OnceLock<Instant>,
decode_first_token_time: OnceLock<Instant>,
request_finish_time: Mutex<Option<Instant>>,
kv_overlap_blocks: OnceLock<u32>,
isl_blocks: OnceLock<usize>,
isl_tokens: OnceLock<usize>,
cached_tokens: OnceLock<usize>,
osl_tokens: AtomicU64,
prefill_worker_id: OnceLock<u64>,
prefill_dp_rank: OnceLock<u32>,
decode_worker_id: OnceLock<u64>,
decode_dp_rank: OnceLock<u32>,
prefill_worker_type: OnceLock<&'static str>,
decode_worker_type: OnceLock<&'static str>,
phase: Mutex<RequestPhase>,
phase_semaphore: Arc<Semaphore>,
tokenize_latency: OnceLock<Duration>,
detokenize_total_ns: AtomicU64,
detokenize_count: AtomicU64,
router_queue_depth: OnceLock<usize>,
prefill_complete_time: OnceLock<Instant>,
}
impl RequestTracker {
pub fn new() -> Self {
let now = Instant::now();
let epoch_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
RequestTracker {
request_received: now,
request_received_epoch_ms: epoch_ms,
prefill_start_time: OnceLock::new(),
first_token_time: OnceLock::new(),
decode_first_token_time: OnceLock::new(),
request_finish_time: Mutex::new(None),
kv_overlap_blocks: OnceLock::new(),
isl_blocks: OnceLock::new(),
isl_tokens: OnceLock::new(),
cached_tokens: OnceLock::new(),
osl_tokens: AtomicU64::new(0),
prefill_worker_id: OnceLock::new(),
prefill_dp_rank: OnceLock::new(),
decode_worker_id: OnceLock::new(),
decode_dp_rank: OnceLock::new(),
prefill_worker_type: OnceLock::new(),
decode_worker_type: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated),
phase_semaphore: Arc::new(Semaphore::new(1)),
tokenize_latency: OnceLock::new(),
detokenize_total_ns: AtomicU64::new(0),
detokenize_count: AtomicU64::new(0),
router_queue_depth: OnceLock::new(),
prefill_complete_time: OnceLock::new(),
}
}
pub fn record_prefill_start(&self) -> bool {
self.prefill_start_time.set(Instant::now()).is_ok()
}
pub fn record_first_token(&self) {
let _ = self.first_token_time.set(Instant::now());
}
pub fn record_decode_first_token(&self) {
let _ = self.decode_first_token_time.set(Instant::now());
}
pub fn record_finish(&self) {
*self.request_finish_time.lock() = Some(Instant::now());
}
pub fn record_kv_hit(&self, overlap_blocks: u32, isl_blocks: usize) -> bool {
let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok();
let isl_set = self.isl_blocks.set(isl_blocks).is_ok();
overlap_set && isl_set
}
pub fn record_isl(&self, isl_tokens: usize, cached_tokens: Option<usize>) {
let _ = self.isl_tokens.set(isl_tokens);
if let Some(cached_tokens) = cached_tokens {
let _ = self.cached_tokens.set(cached_tokens);
}
}
pub fn isl_tokens(&self) -> Option<usize> {
self.isl_tokens.get().copied()
}
pub fn cached_tokens(&self) -> Option<usize> {
self.cached_tokens.get().copied()
}
pub fn record_osl(&self, osl: usize) {
self.osl_tokens.store(osl as u64, Ordering::Relaxed);
}
pub fn osl_tokens(&self) -> u64 {
self.osl_tokens.load(Ordering::Relaxed)
}
pub fn prefill_wait_time_ms(&self) -> Option<f64> {
self.prefill_start_time
.get()
.map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn prefill_time_ms(&self) -> Option<f64> {
let prefill_start = self.prefill_start_time.get()?;
let first_token = self.first_token_time.get()?;
Some(first_token.duration_since(*prefill_start).as_secs_f64() * 1000.0)
}
pub fn ttft_ms(&self) -> Option<f64> {
let first_token = self.first_token_time.get()?;
Some(
first_token
.duration_since(self.request_received)
.as_secs_f64()
* 1000.0,
)
}
pub fn total_time_ms(&self) -> Option<f64> {
let finish = (*self.request_finish_time.lock())?;
Some(finish.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn avg_itl_ms(&self) -> Option<f64> {
let first_token = *self.first_token_time.get()?;
let finish = (*self.request_finish_time.lock())?;
let osl = self.osl_tokens.load(Ordering::Relaxed);
if osl < 2 {
return None;
}
let decode_duration = finish.duration_since(first_token).as_secs_f64() * 1000.0;
Some(decode_duration / (osl - 1) as f64)
}
pub fn request_received_epoch_ms(&self) -> u64 {
self.request_received_epoch_ms
}
pub fn kv_hit_rate(&self) -> Option<f64> {
let overlap = *self.kv_overlap_blocks.get()?;
let isl = *self.isl_blocks.get()?;
if isl == 0 {
return None;
}
Some(overlap as f64 / isl as f64)
}
pub async fn set_phase(&self, phase: RequestPhase) -> OwnedSemaphorePermit {
let permit = self
.phase_semaphore
.clone()
.acquire_owned()
.await
.expect("phase semaphore should never be closed");
*self.phase.lock() = phase;
permit
}
pub fn phase(&self) -> RequestPhase {
*self.phase.lock()
}
fn record_once_u64(slot: &OnceLock<u64>, value: u64, field_name: &'static str) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
let _ = slot.set(value);
}
fn record_once_u32(slot: &OnceLock<u32>, value: u32, field_name: &'static str) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
let _ = slot.set(value);
}
fn record_once_worker_type(
slot: &OnceLock<&'static str>,
value: &'static str,
field_name: &'static str,
) {
if let Some(existing) = slot.get() {
if *existing != value {
tracing::error!(
field = field_name,
existing = *existing,
new = value,
"Conflicting request tracker write"
);
}
return;
}
let _ = slot.set(value);
}
fn record_prefill_worker(
&self,
instance_id: u64,
dp_rank: Option<u32>,
worker_type: &'static str,
) {
Self::record_once_u64(&self.prefill_worker_id, instance_id, "prefill_worker_id");
if let Some(rank) = dp_rank {
Self::record_once_u32(&self.prefill_dp_rank, rank, "prefill_dp_rank");
}
Self::record_once_worker_type(
&self.prefill_worker_type,
worker_type,
"prefill_worker_type",
);
}
fn record_decode_worker(
&self,
instance_id: u64,
dp_rank: Option<u32>,
worker_type: &'static str,
) {
Self::record_once_u64(&self.decode_worker_id, instance_id, "decode_worker_id");
if let Some(rank) = dp_rank {
Self::record_once_u32(&self.decode_dp_rank, rank, "decode_dp_rank");
}
Self::record_once_worker_type(&self.decode_worker_type, worker_type, "decode_worker_type");
}
pub fn record_worker(&self, instance_id: u64, dp_rank: Option<u32>, worker_type: &'static str) {
match self.phase() {
RequestPhase::Prefill => self.record_prefill_worker(instance_id, dp_rank, worker_type),
RequestPhase::Decode => self.record_decode_worker(instance_id, dp_rank, worker_type),
RequestPhase::Aggregated => {
self.record_prefill_worker(instance_id, dp_rank, worker_type);
self.record_decode_worker(instance_id, dp_rank, worker_type);
}
}
}
pub fn record_tokenize_latency(&self, l: Duration) {
let _ = self.tokenize_latency.set(l);
}
pub fn tokenize_latency(&self) -> Option<Duration> {
self.tokenize_latency.get().copied()
}
pub fn record_detokenize_latency(&self, l: Duration) {
let delta_ns = u64::try_from(l.as_nanos()).unwrap_or(u64::MAX);
let _ = self.detokenize_total_ns.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|current| Some(current.saturating_add(delta_ns)),
);
self.detokenize_count.fetch_add(1, Ordering::Relaxed);
}
pub fn detokenize_total_latency(&self) -> Option<Duration> {
let total_ns = self.detokenize_total_ns.load(Ordering::Relaxed);
let count = self.detokenize_count.load(Ordering::Relaxed);
if count == 0 {
None
} else {
Some(Duration::from_nanos(total_ns))
}
}
pub fn detokenize_count(&self) -> u64 {
self.detokenize_count.load(Ordering::Relaxed)
}
pub fn record_router_queue_depth(&self, depth: usize) {
let _ = self.router_queue_depth.set(depth);
}
pub fn router_queue_depth(&self) -> Option<usize> {
self.router_queue_depth.get().copied()
}
pub fn record_prefill_complete(&self) -> bool {
self.prefill_complete_time.set(Instant::now()).is_ok()
}
pub fn kv_transfer_estimated_latency_secs(&self) -> Option<f64> {
let complete = *self.prefill_complete_time.get()?;
let first_tok = *self.decode_first_token_time.get()?;
Some(first_tok.saturating_duration_since(complete).as_secs_f64())
}
pub fn get_worker_info(&self) -> Option<WorkerIdInfo> {
let prefill = self.prefill_worker_id();
let decode = self.decode_worker_id();
if prefill.is_none() && decode.is_none() {
return None;
}
Some(WorkerIdInfo {
prefill_worker_id: prefill,
prefill_dp_rank: self.prefill_dp_rank(),
decode_worker_id: decode,
decode_dp_rank: self.decode_dp_rank(),
})
}
pub fn decode_worker_id(&self) -> Option<u64> {
self.decode_worker_id.get().copied()
}
pub fn decode_dp_rank(&self) -> Option<u32> {
self.decode_dp_rank.get().copied()
}
pub fn prefill_worker_id(&self) -> Option<u64> {
self.prefill_worker_id.get().copied()
}
pub fn prefill_dp_rank(&self) -> Option<u32> {
self.prefill_dp_rank.get().copied()
}
pub fn prefill_worker_type(&self) -> Option<&'static str> {
self.prefill_worker_type.get().copied()
}
pub fn decode_worker_type(&self) -> Option<&'static str> {
self.decode_worker_type.get().copied()
}
pub fn observe_first_token_gauges(&self) {
let Some(worker_id) = self.prefill_worker_id() else {
return;
};
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.prefill_dp_rank()
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self.prefill_worker_type().unwrap_or(WORKER_TYPE_PREFILL);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
if let Some(ttft) = self.ttft_ms() {
WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE
.with_label_values(labels)
.set(ttft / 1000.0);
}
if let Some(isl) = self.isl_tokens() {
WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE
.with_label_values(labels)
.set(isl as i64);
}
}
pub fn observe_finish_gauges(&self) {
let Some(worker_id) = self.decode_worker_id() else {
return;
};
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.decode_dp_rank()
.map_or(UNSET_DP_RANK_LABEL.to_string(), |r| r.to_string());
let worker_type = self.decode_worker_type().unwrap_or(WORKER_TYPE_DECODE);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
if let Some(avg_itl) = self.avg_itl_ms() {
WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE
.with_label_values(labels)
.set(avg_itl / 1000.0);
}
}
pub fn get_timing_info(&self) -> TimingInfo {
TimingInfo {
request_received_ms: self.request_received_epoch_ms,
prefill_wait_time_ms: self.prefill_wait_time_ms(),
prefill_time_ms: self.prefill_time_ms(),
ttft_ms: self.ttft_ms(),
total_time_ms: self.total_time_ms(),
kv_hit_rate: self.kv_hit_rate(),
router_queue_depth: self.router_queue_depth(),
kv_transfer_estimated_latency_ms: self
.kv_transfer_estimated_latency_secs()
.map(|s| s * 1000.0),
}
}
}
impl Default for RequestTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TimingInfo {
pub request_received_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_wait_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttft_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kv_hit_rate: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub router_queue_depth: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kv_transfer_estimated_latency_ms: Option<f64>,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_record_isl_osl() {
let tracker = RequestTracker::new();
tracker.record_isl(512, Some(256));
assert_eq!(tracker.isl_tokens(), Some(512));
assert_eq!(tracker.cached_tokens(), Some(256));
tracker.record_osl(100);
assert_eq!(tracker.osl_tokens(), 100);
}
#[test]
fn test_ttft_ms() {
let tracker = RequestTracker::new();
thread::sleep(Duration::from_millis(10));
tracker.record_first_token();
let ttft = tracker.ttft_ms().unwrap();
assert!(ttft >= 5.0, "TTFT should be at least 5ms, got {ttft}");
}
#[test]
fn test_ttft_ms_none_before_first_token() {
let tracker = RequestTracker::new();
assert!(tracker.ttft_ms().is_none());
}
#[test]
fn test_avg_itl_ms() {
let tracker = RequestTracker::new();
tracker.record_first_token();
thread::sleep(Duration::from_millis(20));
tracker.record_osl(11); tracker.record_finish();
let itl = tracker.avg_itl_ms().unwrap();
assert!(itl > 0.0, "avg ITL should be positive, got {itl}");
}
#[test]
fn test_avg_itl_ms_none_with_single_token() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_osl(1);
tracker.record_finish();
assert!(
tracker.avg_itl_ms().is_none(),
"avg ITL should be None with < 2 output tokens"
);
}
#[test]
fn test_kv_hit_rate() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(3, 10);
let rate = tracker.kv_hit_rate().unwrap();
assert!(
(rate - 0.3).abs() < f64::EPSILON,
"KV hit rate should be 0.3, got {rate}"
);
}
#[test]
fn test_kv_hit_rate_zero_isl() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(0, 0);
assert!(
tracker.kv_hit_rate().is_none(),
"KV hit rate should be None when isl_blocks is 0"
);
}
#[test]
fn test_total_time_ms() {
let tracker = RequestTracker::new();
thread::sleep(Duration::from_millis(10));
tracker.record_finish();
let total = tracker.total_time_ms().unwrap();
assert!(
total >= 5.0,
"total time should be at least 5ms, got {total}"
);
}
#[test]
fn test_router_queue_depth() {
let tracker = RequestTracker::new();
assert!(tracker.router_queue_depth().is_none());
tracker.record_router_queue_depth(42);
assert_eq!(tracker.router_queue_depth(), Some(42));
tracker.record_router_queue_depth(99);
assert_eq!(tracker.router_queue_depth(), Some(42));
let timing = tracker.get_timing_info();
assert_eq!(timing.router_queue_depth, Some(42));
}
#[test]
fn test_observe_first_token_gauges_no_panic_without_worker() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_isl(100, Some(50));
tracker.observe_first_token_gauges();
}
#[test]
fn test_observe_finish_gauges_no_panic_without_worker() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_osl(10);
tracker.record_finish();
tracker.observe_finish_gauges();
}
#[test]
fn test_observe_first_token_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
thread::sleep(Duration::from_millis(5));
tracker.record_first_token();
tracker.record_isl(256, Some(128));
tracker.observe_first_token_gauges();
let labels = &["42", "0", WORKER_TYPE_PREFILL];
let ttft_val = WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE
.with_label_values(labels)
.get();
assert!(
ttft_val > 0.0,
"TTFT gauge should be positive after observe, got {ttft_val}"
);
let isl_val = WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE
.with_label_values(labels)
.get();
assert_eq!(isl_val, 256, "ISL gauge should be 256, got {isl_val}");
}
#[test]
fn test_observe_finish_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker(99, Some(1), WORKER_TYPE_DECODE);
tracker.record_first_token();
thread::sleep(Duration::from_millis(10));
tracker.record_osl(5);
tracker.record_finish();
tracker.observe_finish_gauges();
let labels = &["99", "1", WORKER_TYPE_DECODE];
let itl_val = WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE
.with_label_values(labels)
.get();
assert!(
itl_val > 0.0,
"ITL gauge should be positive after observe, got {itl_val}"
);
}
#[test]
fn test_kv_transfer_estimated_latency() {
let tracker = RequestTracker::new();
assert!(tracker.kv_transfer_estimated_latency_secs().is_none());
tracker.record_prefill_complete();
thread::sleep(Duration::from_millis(10));
tracker.record_decode_first_token();
let latency = tracker.kv_transfer_estimated_latency_secs().unwrap();
assert!(
latency >= 0.005,
"latency should be at least 5ms, got {latency}"
);
}
#[test]
fn test_kv_transfer_estimated_latency_none_without_first_token() {
let tracker = RequestTracker::new();
tracker.record_prefill_complete();
assert!(
tracker.kv_transfer_estimated_latency_secs().is_none(),
"Should return None when decode_first_token_time is not set"
);
}
#[test]
fn test_kv_transfer_estimated_latency_none_without_prefill_complete() {
let tracker = RequestTracker::new();
tracker.record_decode_first_token();
assert!(
tracker.kv_transfer_estimated_latency_secs().is_none(),
"Should return None when prefill_complete_time is not set"
);
}
#[test]
fn test_kv_transfer_estimated_latency_oncelock_first_write_wins() {
let tracker = RequestTracker::new();
assert!(tracker.record_prefill_complete()); assert!(!tracker.record_prefill_complete()); }
#[test]
fn test_timing_info_includes_kv_transfer_estimated_latency() {
let tracker = RequestTracker::new();
tracker.record_prefill_complete();
thread::sleep(Duration::from_millis(10));
tracker.record_decode_first_token();
let info = tracker.get_timing_info();
let latency_ms = info
.kv_transfer_estimated_latency_ms
.expect("should be Some");
assert!(
latency_ms >= 5.0,
"latency should be at least 5ms, got {latency_ms}"
);
}
#[test]
fn test_timing_info_kv_transfer_estimated_latency_none_in_aggregated() {
let tracker = RequestTracker::new();
let info = tracker.get_timing_info();
assert!(
info.kv_transfer_estimated_latency_ms.is_none(),
"Should be None in aggregated mode (no timestamps recorded)"
);
}
#[test]
fn test_kv_transfer_latency_bug_prefill_timestamps_are_zero() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_prefill_complete();
let first_tok = *tracker.first_token_time.get().unwrap();
let complete = *tracker.prefill_complete_time.get().unwrap();
let old_latency = first_tok.saturating_duration_since(complete).as_secs_f64();
assert_eq!(
old_latency, 0.0,
"Old computation should produce exactly 0.0 (the bug), got {old_latency}"
);
assert!(
tracker.kv_transfer_estimated_latency_secs().is_none(),
"Fixed metric should be None when decode hasn't started"
);
thread::sleep(Duration::from_millis(10));
tracker.record_decode_first_token();
let fixed_latency = tracker.kv_transfer_estimated_latency_secs().unwrap();
assert!(
fixed_latency >= 0.005,
"Fixed latency should be >= 5ms (actual KV transfer time), got {fixed_latency}"
);
}
#[test]
fn test_decode_first_token_not_blocked_by_prefill_oncelock() {
let tracker = RequestTracker::new();
tracker.record_first_token();
let prefill_first_tok = *tracker.first_token_time.get().unwrap();
thread::sleep(Duration::from_millis(5));
tracker.record_first_token();
let still_prefill_tok = *tracker.first_token_time.get().unwrap();
assert_eq!(
prefill_first_tok, still_prefill_tok,
"first_token_time should be unchanged (OnceLock rejected decode's write)"
);
tracker.record_decode_first_token();
let decode_tok = *tracker.decode_first_token_time.get().unwrap();
assert!(
decode_tok > prefill_first_tok,
"decode_first_token_time should be later than first_token_time"
);
}
#[test]
fn test_timing_info_kv_transfer_estimated_latency_serialization() {
let tracker = RequestTracker::new();
let info = tracker.get_timing_info();
let json = serde_json::to_string(&info).unwrap();
assert!(
!json.contains("kv_transfer_estimated_latency_ms"),
"None field should be omitted from JSON, got: {json}"
);
let tracker2 = RequestTracker::new();
tracker2.record_prefill_complete();
tracker2.record_decode_first_token();
let info2 = tracker2.get_timing_info();
let json2 = serde_json::to_string(&info2).unwrap();
assert!(
json2.contains("kv_transfer_estimated_latency_ms"),
"Set field should appear in JSON, got: {json2}"
);
}
}