use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use utoipa::ToSchema;
use crate::http::service::metrics::{
WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE, WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE,
WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE,
};
use crate::protocols::openai::nvext::WorkerIdInfo;
const NO_WORKER_ID: u64 = 0;
const NO_DP_RANK: u32 = u32::MAX;
pub const WORKER_TYPE_PREFILL: &str = "prefill";
pub const WORKER_TYPE_DECODE: &str = "decode";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RequestPhase {
Prefill,
Decode,
#[default]
Aggregated,
}
impl std::fmt::Display for RequestPhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RequestPhase::Prefill => write!(f, "prefill"),
RequestPhase::Decode => write!(f, "decode"),
RequestPhase::Aggregated => write!(f, "aggregated"),
}
}
}
#[derive(Debug)]
pub struct RequestTracker {
request_received: Instant,
request_received_epoch_ms: u64,
prefill_start_time: OnceLock<Instant>,
first_token_time: OnceLock<Instant>,
request_finish_time: Mutex<Option<Instant>>,
kv_overlap_blocks: OnceLock<u32>,
isl_blocks: OnceLock<usize>,
isl_tokens: OnceLock<usize>,
cached_tokens: OnceLock<usize>,
osl_tokens: AtomicU64,
prefill_worker_id: AtomicU64,
prefill_dp_rank: AtomicU32,
decode_worker_id: AtomicU64,
decode_dp_rank: AtomicU32,
prefill_worker_type: OnceLock<&'static str>,
decode_worker_type: OnceLock<&'static str>,
phase: Mutex<RequestPhase>,
phase_semaphore: Arc<Semaphore>,
tokenize_latency: OnceLock<Duration>,
detokenize_total_ns: AtomicU64,
detokenize_count: AtomicU64,
}
impl RequestTracker {
pub fn new() -> Self {
let now = Instant::now();
let epoch_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
RequestTracker {
request_received: now,
request_received_epoch_ms: epoch_ms,
prefill_start_time: OnceLock::new(),
first_token_time: OnceLock::new(),
request_finish_time: Mutex::new(None),
kv_overlap_blocks: OnceLock::new(),
isl_blocks: OnceLock::new(),
isl_tokens: OnceLock::new(),
cached_tokens: OnceLock::new(),
osl_tokens: AtomicU64::new(0),
prefill_worker_id: AtomicU64::new(NO_WORKER_ID),
prefill_dp_rank: AtomicU32::new(NO_DP_RANK),
decode_worker_id: AtomicU64::new(NO_WORKER_ID),
decode_dp_rank: AtomicU32::new(NO_DP_RANK),
prefill_worker_type: OnceLock::new(),
decode_worker_type: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated),
phase_semaphore: Arc::new(Semaphore::new(1)),
tokenize_latency: OnceLock::new(),
detokenize_total_ns: AtomicU64::new(0),
detokenize_count: AtomicU64::new(0),
}
}
pub fn record_prefill_start(&self) -> bool {
self.prefill_start_time.set(Instant::now()).is_ok()
}
pub fn record_first_token(&self) {
let _ = self.first_token_time.set(Instant::now());
}
pub fn record_finish(&self) {
*self.request_finish_time.lock() = Some(Instant::now());
}
pub fn record_kv_hit(&self, overlap_blocks: u32, isl_blocks: usize) -> bool {
let overlap_set = self.kv_overlap_blocks.set(overlap_blocks).is_ok();
let isl_set = self.isl_blocks.set(isl_blocks).is_ok();
overlap_set && isl_set
}
pub fn record_isl(&self, isl_tokens: usize, cached_tokens: usize) {
let _ = self.isl_tokens.set(isl_tokens);
let _ = self.cached_tokens.set(cached_tokens);
}
pub fn isl_tokens(&self) -> Option<usize> {
self.isl_tokens.get().copied()
}
pub fn cached_tokens(&self) -> Option<usize> {
self.cached_tokens.get().copied()
}
pub fn record_osl(&self, osl: usize) {
self.osl_tokens.store(osl as u64, Ordering::Relaxed);
}
pub fn osl_tokens(&self) -> u64 {
self.osl_tokens.load(Ordering::Relaxed)
}
pub fn prefill_wait_time_ms(&self) -> Option<f64> {
self.prefill_start_time
.get()
.map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn prefill_time_ms(&self) -> Option<f64> {
let prefill_start = self.prefill_start_time.get()?;
let first_token = self.first_token_time.get()?;
Some(first_token.duration_since(*prefill_start).as_secs_f64() * 1000.0)
}
pub fn ttft_ms(&self) -> Option<f64> {
let first_token = self.first_token_time.get()?;
Some(
first_token
.duration_since(self.request_received)
.as_secs_f64()
* 1000.0,
)
}
pub fn total_time_ms(&self) -> Option<f64> {
let finish = (*self.request_finish_time.lock())?;
Some(finish.duration_since(self.request_received).as_secs_f64() * 1000.0)
}
pub fn avg_itl_ms(&self) -> Option<f64> {
let first_token = *self.first_token_time.get()?;
let finish = (*self.request_finish_time.lock())?;
let osl = self.osl_tokens.load(Ordering::Relaxed);
if osl < 2 {
return None;
}
let decode_duration = finish.duration_since(first_token).as_secs_f64() * 1000.0;
Some(decode_duration / (osl - 1) as f64)
}
pub fn request_received_epoch_ms(&self) -> u64 {
self.request_received_epoch_ms
}
pub fn kv_hit_rate(&self) -> Option<f64> {
let overlap = *self.kv_overlap_blocks.get()?;
let isl = *self.isl_blocks.get()?;
if isl == 0 {
return None;
}
Some(overlap as f64 / isl as f64)
}
pub async fn set_phase(&self, phase: RequestPhase) -> OwnedSemaphorePermit {
let permit = self
.phase_semaphore
.clone()
.acquire_owned()
.await
.expect("phase semaphore should never be closed");
*self.phase.lock() = phase;
permit
}
pub fn phase(&self) -> RequestPhase {
*self.phase.lock()
}
pub fn record_worker_full(&self, instance_id: u64, dp_rank: u32, worker_type: &'static str) {
match self.phase() {
RequestPhase::Prefill => {
self.prefill_worker_id.store(instance_id, Ordering::Relaxed);
self.prefill_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.prefill_worker_type.set(worker_type);
}
RequestPhase::Decode => {
self.decode_worker_id.store(instance_id, Ordering::Relaxed);
self.decode_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.decode_worker_type.set(worker_type);
}
RequestPhase::Aggregated => {
self.prefill_worker_id.store(instance_id, Ordering::Relaxed);
self.prefill_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.prefill_worker_type.set(worker_type);
self.decode_worker_id.store(instance_id, Ordering::Relaxed);
self.decode_dp_rank.store(dp_rank, Ordering::Relaxed);
let _ = self.decode_worker_type.set(worker_type);
}
}
}
pub fn record_tokenize_latency(&self, l: Duration) {
let _ = self.tokenize_latency.set(l);
}
pub fn tokenize_latency(&self) -> Option<Duration> {
self.tokenize_latency.get().copied()
}
pub fn record_detokenize_latency(&self, l: Duration) {
let delta_ns = u64::try_from(l.as_nanos()).unwrap_or(u64::MAX);
let _ = self.detokenize_total_ns.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|current| Some(current.saturating_add(delta_ns)),
);
self.detokenize_count.fetch_add(1, Ordering::Relaxed);
}
pub fn detokenize_total_latency(&self) -> Option<Duration> {
let total_ns = self.detokenize_total_ns.load(Ordering::Relaxed);
let count = self.detokenize_count.load(Ordering::Relaxed);
if count == 0 {
None
} else {
Some(Duration::from_nanos(total_ns))
}
}
pub fn detokenize_count(&self) -> u64 {
self.detokenize_count.load(Ordering::Relaxed)
}
pub fn get_worker_info(&self) -> Option<WorkerIdInfo> {
let prefill = self.prefill_worker_id();
let decode = self.decode_worker_id();
if prefill.is_none() && decode.is_none() {
return None;
}
Some(WorkerIdInfo {
prefill_worker_id: prefill,
prefill_dp_rank: self.prefill_dp_rank(),
decode_worker_id: decode,
decode_dp_rank: self.decode_dp_rank(),
})
}
pub fn decode_worker_id(&self) -> Option<u64> {
let id = self.decode_worker_id.load(Ordering::SeqCst);
if id == NO_WORKER_ID { None } else { Some(id) }
}
pub fn decode_dp_rank(&self) -> Option<u32> {
let rank = self.decode_dp_rank.load(Ordering::SeqCst);
if rank == NO_DP_RANK { None } else { Some(rank) }
}
pub fn prefill_worker_id(&self) -> Option<u64> {
let id = self.prefill_worker_id.load(Ordering::SeqCst);
if id == NO_WORKER_ID { None } else { Some(id) }
}
pub fn prefill_dp_rank(&self) -> Option<u32> {
let rank = self.prefill_dp_rank.load(Ordering::SeqCst);
if rank == NO_DP_RANK { None } else { Some(rank) }
}
pub fn prefill_worker_type(&self) -> Option<&'static str> {
self.prefill_worker_type.get().copied()
}
pub fn decode_worker_type(&self) -> Option<&'static str> {
self.decode_worker_type.get().copied()
}
pub fn observe_first_token_gauges(&self) {
let Some(worker_id) = self.prefill_worker_id() else {
return;
};
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.prefill_dp_rank()
.map_or("0".to_string(), |r| r.to_string());
let worker_type = self.prefill_worker_type().unwrap_or(WORKER_TYPE_PREFILL);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
if let Some(ttft) = self.ttft_ms() {
WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE
.with_label_values(labels)
.set(ttft / 1000.0);
}
if let Some(isl) = self.isl_tokens() {
WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE
.with_label_values(labels)
.set(isl as i64);
}
}
pub fn observe_finish_gauges(&self) {
let Some(worker_id) = self.decode_worker_id() else {
return;
};
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.decode_dp_rank()
.map_or("0".to_string(), |r| r.to_string());
let worker_type = self.decode_worker_type().unwrap_or(WORKER_TYPE_DECODE);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
if let Some(avg_itl) = self.avg_itl_ms() {
WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE
.with_label_values(labels)
.set(avg_itl / 1000.0);
}
}
pub fn get_timing_info(&self) -> TimingInfo {
TimingInfo {
request_received_ms: self.request_received_epoch_ms,
prefill_wait_time_ms: self.prefill_wait_time_ms(),
prefill_time_ms: self.prefill_time_ms(),
ttft_ms: self.ttft_ms(),
total_time_ms: self.total_time_ms(),
kv_hit_rate: self.kv_hit_rate(),
}
}
}
impl Default for RequestTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TimingInfo {
pub request_received_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_wait_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttft_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_time_ms: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kv_hit_rate: Option<f64>,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_record_isl_osl() {
let tracker = RequestTracker::new();
tracker.record_isl(512, 256);
assert_eq!(tracker.isl_tokens(), Some(512));
assert_eq!(tracker.cached_tokens(), Some(256));
tracker.record_osl(100);
assert_eq!(tracker.osl_tokens(), 100);
}
#[test]
fn test_ttft_ms() {
let tracker = RequestTracker::new();
thread::sleep(Duration::from_millis(10));
tracker.record_first_token();
let ttft = tracker.ttft_ms().unwrap();
assert!(ttft >= 5.0, "TTFT should be at least 5ms, got {ttft}");
}
#[test]
fn test_ttft_ms_none_before_first_token() {
let tracker = RequestTracker::new();
assert!(tracker.ttft_ms().is_none());
}
#[test]
fn test_avg_itl_ms() {
let tracker = RequestTracker::new();
tracker.record_first_token();
thread::sleep(Duration::from_millis(20));
tracker.record_osl(11); tracker.record_finish();
let itl = tracker.avg_itl_ms().unwrap();
assert!(itl > 0.0, "avg ITL should be positive, got {itl}");
}
#[test]
fn test_avg_itl_ms_none_with_single_token() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_osl(1);
tracker.record_finish();
assert!(
tracker.avg_itl_ms().is_none(),
"avg ITL should be None with < 2 output tokens"
);
}
#[test]
fn test_kv_hit_rate() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(3, 10);
let rate = tracker.kv_hit_rate().unwrap();
assert!(
(rate - 0.3).abs() < f64::EPSILON,
"KV hit rate should be 0.3, got {rate}"
);
}
#[test]
fn test_kv_hit_rate_zero_isl() {
let tracker = RequestTracker::new();
tracker.record_kv_hit(0, 0);
assert!(
tracker.kv_hit_rate().is_none(),
"KV hit rate should be None when isl_blocks is 0"
);
}
#[test]
fn test_total_time_ms() {
let tracker = RequestTracker::new();
thread::sleep(Duration::from_millis(10));
tracker.record_finish();
let total = tracker.total_time_ms().unwrap();
assert!(
total >= 5.0,
"total time should be at least 5ms, got {total}"
);
}
#[test]
fn test_observe_first_token_gauges_no_panic_without_worker() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_isl(100, 50);
tracker.observe_first_token_gauges();
}
#[test]
fn test_observe_finish_gauges_no_panic_without_worker() {
let tracker = RequestTracker::new();
tracker.record_first_token();
tracker.record_osl(10);
tracker.record_finish();
tracker.observe_finish_gauges();
}
#[test]
fn test_observe_first_token_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker_full(42, 0, WORKER_TYPE_PREFILL);
thread::sleep(Duration::from_millis(5));
tracker.record_first_token();
tracker.record_isl(256, 128);
tracker.observe_first_token_gauges();
let labels = &["42", "0", WORKER_TYPE_PREFILL];
let ttft_val = WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE
.with_label_values(labels)
.get();
assert!(
ttft_val > 0.0,
"TTFT gauge should be positive after observe, got {ttft_val}"
);
let isl_val = WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE
.with_label_values(labels)
.get();
assert_eq!(isl_val, 256, "ISL gauge should be 256, got {isl_val}");
}
#[test]
fn test_observe_finish_gauges_with_worker() {
let tracker = RequestTracker::new();
tracker.record_worker_full(99, 1, WORKER_TYPE_DECODE);
tracker.record_first_token();
thread::sleep(Duration::from_millis(10));
tracker.record_osl(5);
tracker.record_finish();
tracker.observe_finish_gauges();
let labels = &["99", "1", WORKER_TYPE_DECODE];
let itl_val = WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE
.with_label_values(labels)
.get();
assert!(
itl_val > 0.0,
"ITL gauge should be positive after observe, got {itl_val}"
);
}
}