use axum::{
Router,
extract::State,
http::StatusCode,
response::{IntoResponse, sse::Event},
routing::get,
};
use dynamo_runtime::{
config::environment_names::llm::metrics as env_metrics,
metrics::prometheus_names::{
frontend_service, name_prefix, sanitize_frontend_prometheus_prefix,
},
};
use prometheus::{
Encoder, GaugeVec, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts,
};
use serde::Serialize;
use std::{
sync::{Arc, LazyLock},
time::{Duration, Instant},
};
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::metrics::prometheus_names::clamp_u64_to_i64;
pub use prometheus::Registry;
use super::RouteDoc;
pub use crate::discovery::{WORKER_TYPE_DECODE, WORKER_TYPE_PREFILL};
pub static WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE: LazyLock<GaugeVec> = LazyLock::new(|| {
GaugeVec::new(
Opts::new(
format!(
"{}_{}",
name_prefix::FRONTEND,
frontend_service::WORKER_LAST_TIME_TO_FIRST_TOKEN_SECONDS
),
"Last observed time to first token per worker (seconds)",
),
&["worker_id", "dp_rank", "worker_type"],
)
.expect("Failed to create worker_last_time_to_first_token gauge")
});
pub static WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE: LazyLock<IntGaugeVec> = LazyLock::new(|| {
IntGaugeVec::new(
Opts::new(
format!(
"{}_{}",
name_prefix::FRONTEND,
frontend_service::WORKER_LAST_INPUT_SEQUENCE_TOKENS
),
"Last observed input sequence tokens per worker",
),
&["worker_id", "dp_rank", "worker_type"],
)
.expect("Failed to create worker_last_input_sequence_tokens gauge")
});
pub static WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE: LazyLock<GaugeVec> = LazyLock::new(|| {
GaugeVec::new(
Opts::new(
format!(
"{}_{}",
name_prefix::FRONTEND,
frontend_service::WORKER_LAST_INTER_TOKEN_LATENCY_SECONDS
),
"Last observed inter-token latency per worker (seconds)",
),
&["worker_id", "dp_rank", "worker_type"],
)
.expect("Failed to create worker_last_inter_token_latency gauge")
});
pub fn register_worker_timing_metrics(registry: &Registry) -> Result<(), prometheus::Error> {
registry.register(Box::new(WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE.clone()))?;
registry.register(Box::new(WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE.clone()))?;
registry.register(Box::new(WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE.clone()))?;
Ok(())
}
pub fn generate_log_buckets(min: f64, max: f64, count: usize) -> Vec<f64> {
if count == 0 {
return vec![];
}
if count == 1 {
return vec![0.0];
}
let requested_count = count;
let mut buckets = Vec::with_capacity(count);
buckets.push(0.0);
for i in 1..count {
let log_min = min.ln();
let log_max = max.ln();
let log_value = log_min + (log_max - log_min) * (i as f64) / ((count - 1) as f64);
let value = log_value.exp();
buckets.push(round_to_sig_figs(value, 2));
}
let original_len = buckets.len();
buckets.dedup();
if buckets.len() < original_len && (original_len - buckets.len()) > original_len / 10 {
tracing::warn!(
requested = requested_count,
unique = buckets.len(),
duplicates = original_len - buckets.len(),
min = min,
max = max,
"Histogram bucket generation: Significant duplicate values after rounding to 2 sig figs. \
Consider reducing bucket count or increasing range."
);
}
buckets
}
pub fn round_to_sig_figs(value: f64, sig_figs: u32) -> f64 {
if value == 0.0 {
return 0.0;
}
let magnitude = value.abs().log10().floor();
let scale = 10_f64.powf(sig_figs as f64 - 1.0 - magnitude);
(value * scale).round() / scale
}
const MAX_BUCKET_COUNT: usize = 512;
fn validate_bucket_config(min: f64, max: f64, count: usize) -> bool {
min.is_finite()
&& max.is_finite()
&& min > 0.0
&& min < max
&& count > 0
&& count <= MAX_BUCKET_COUNT
}
fn parse_bucket_config(
env_prefix: &str,
default_min: f64,
default_max: f64,
default_count: usize,
) -> (f64, f64, usize) {
if !validate_bucket_config(default_min, default_max, default_count) {
tracing::error!(
default_min,
default_max,
default_count,
"Invalid default histogram configuration"
);
return (1.0, 10.0, 10);
}
let env_prefix = format!("{}{}", env_metrics::HISTOGRAM_PREFIX, env_prefix);
let mut min = std::env::var(format!("{env_prefix}_MIN"))
.ok()
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(default_min);
let mut max = std::env::var(format!("{env_prefix}_MAX"))
.ok()
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(default_max);
let mut count = std::env::var(format!("{env_prefix}_COUNT"))
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(default_count);
if !validate_bucket_config(min, max, count) {
tracing::warn!(
min=%min,
max=%max,
count=%count,
"Invalid histogram configuration given, using defaults"
);
min = default_min;
max = default_max;
count = default_count;
}
(min, max, count)
}
struct MetricsHandlerState {
registry: Arc<Registry>,
drt_metrics: Option<dynamo_runtime::metrics::MetricsRegistry>,
}
pub struct Metrics {
request_counter: IntCounterVec,
inflight_gauge: IntGaugeVec,
client_disconnect_gauge: prometheus::IntGauge,
http_queue_gauge: IntGaugeVec,
request_duration: HistogramVec,
input_sequence_length: HistogramVec,
output_sequence_length: HistogramVec,
cached_tokens: HistogramVec,
tokenizer_latency: HistogramVec,
output_tokens_counter: IntCounterVec,
time_to_first_token: HistogramVec,
inter_token_latency: HistogramVec,
model_total_kv_blocks: IntGaugeVec,
model_max_num_seqs: IntGaugeVec,
model_max_num_batched_tokens: IntGaugeVec,
model_context_length: IntGaugeVec,
model_kv_cache_block_size: IntGaugeVec,
model_migration_limit: IntGaugeVec,
model_migration_total: IntCounterVec,
}
pub struct HttpQueueGuard {
metrics: Arc<Metrics>,
model: String,
}
pub struct InflightGuard {
metrics: Arc<Metrics>,
model: String,
endpoint: Endpoint,
request_type: RequestType,
status: Status,
error_type: ErrorType,
timer: Instant,
}
#[derive(Clone, Copy)]
pub enum Endpoint {
Completions,
ChatCompletions,
Embeddings,
Images,
Videos,
Responses,
AnthropicMessages,
Tensor,
}
pub enum RequestType {
Unary,
Stream,
}
#[derive(PartialEq)]
pub enum Status {
Success,
Error,
}
#[derive(PartialEq, Clone, Debug)]
pub enum ErrorType {
None,
Validation,
NotFound,
Overload,
Cancelled,
Internal,
NotImplemented,
}
pub struct ResponseMetricCollector {
metrics: Arc<Metrics>,
model: String,
start_time: Instant,
is_first_token: bool,
last_response_time: Option<Duration>,
osl: usize,
cached_tokens_observed: bool,
tokenize_latency_observed: bool,
detokenize_latency_total: Duration,
detokenize_count_total: u64,
prefill_worker_id: Option<u64>,
prefill_dp_rank: Option<u32>,
prefill_worker_type: Option<String>,
decode_worker_id: Option<u64>,
decode_dp_rank: Option<u32>,
decode_worker_type: Option<String>,
}
impl Default for Metrics {
fn default() -> Self {
Self::new()
}
}
impl Metrics {
pub fn new() -> Self {
let raw_prefix = std::env::var(env_metrics::DYN_METRICS_PREFIX)
.unwrap_or_else(|_| name_prefix::FRONTEND.to_string());
let prefix = sanitize_frontend_prometheus_prefix(&raw_prefix);
if prefix != raw_prefix {
tracing::warn!(
raw=%raw_prefix,
sanitized=%prefix,
env=%frontend_service::METRICS_PREFIX_ENV,
"Sanitized HTTP metrics prefix"
);
}
let frontend_metric_name = |suffix: &str| format!("{}_{}", &prefix, suffix);
let request_counter = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::REQUESTS_TOTAL),
"Total number of LLM requests processed",
),
&["model", "endpoint", "request_type", "status", "error_type"],
)
.unwrap();
let inflight_gauge = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::INFLIGHT_REQUESTS),
"Number of inflight requests",
),
&["model"],
)
.unwrap();
let client_disconnect_gauge = prometheus::IntGauge::new(
frontend_metric_name(frontend_service::DISCONNECTED_CLIENTS),
"Number of disconnected clients",
)
.unwrap();
let http_queue_gauge = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::QUEUED_REQUESTS),
"Number of requests in HTTP processing queue",
),
&["model"],
)
.unwrap();
let (req_dur_min, req_dur_max, req_dur_count) =
parse_bucket_config("DYN_METRICS_REQUEST_DURATION", 1.0, 256.0, 10);
let request_duration_buckets =
generate_log_buckets(req_dur_min, req_dur_max, req_dur_count);
let request_duration = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::REQUEST_DURATION_SECONDS),
"Duration of LLM requests",
)
.buckets(request_duration_buckets),
&["model"],
)
.unwrap();
let (isl_min, isl_max, isl_count) =
parse_bucket_config("DYN_METRICS_INPUT_SEQUENCE", 50.0, 128000.0, 12);
let input_sequence_buckets = generate_log_buckets(isl_min, isl_max, isl_count);
let input_sequence_length = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::INPUT_SEQUENCE_TOKENS),
"Input sequence length in tokens",
)
.buckets(input_sequence_buckets.clone()),
&["model"],
)
.unwrap();
let (osl_min, osl_max, osl_count) =
parse_bucket_config("DYN_METRICS_OUTPUT_SEQUENCE", 50.0, 32000.0, 10);
let output_sequence_buckets = generate_log_buckets(osl_min, osl_max, osl_count);
let output_sequence_length = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::OUTPUT_SEQUENCE_TOKENS),
"Output sequence length in tokens",
)
.buckets(output_sequence_buckets),
&["model"],
)
.unwrap();
let output_tokens_counter = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::OUTPUT_TOKENS_TOTAL),
"Total number of output tokens generated (updates in real-time)",
),
&["model"],
)
.unwrap();
let (ttft_min, ttft_max, ttft_count) =
parse_bucket_config("DYN_METRICS_TTFT", 0.001, 480.0, 18);
let time_to_first_token_buckets = generate_log_buckets(ttft_min, ttft_max, ttft_count);
let time_to_first_token = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::TIME_TO_FIRST_TOKEN_SECONDS),
"Time to first token in seconds",
)
.buckets(time_to_first_token_buckets),
&["model"],
)
.unwrap();
let (itl_min, itl_max, itl_count) = parse_bucket_config("DYN_METRICS_ITL", 0.001, 2.0, 13);
let inter_token_latency_buckets = generate_log_buckets(itl_min, itl_max, itl_count);
let inter_token_latency = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::INTER_TOKEN_LATENCY_SECONDS),
"Inter-token latency in seconds",
)
.buckets(inter_token_latency_buckets),
&["model"],
)
.unwrap();
let cached_tokens = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::CACHED_TOKENS),
"Number of cached tokens (prefix cache hits) per request",
)
.buckets(input_sequence_buckets.clone()),
&["model"],
)
.unwrap();
let tokenizer_latency = HistogramVec::new(
HistogramOpts::new(
frontend_metric_name(frontend_service::TOKENIZER_LATENCY_MS),
"Tokenizer latency in milliseconds",
)
.buckets(vec![
0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0,
]),
&[frontend_service::OPERATION_LABEL],
)
.unwrap();
let model_total_kv_blocks = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_TOTAL_KV_BLOCKS),
"Total KV cache blocks available for a worker serving the model",
),
&["model"],
)
.unwrap();
let model_max_num_seqs = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_MAX_NUM_SEQS),
"Maximum number of sequences for a worker serving the model",
),
&["model"],
)
.unwrap();
let model_max_num_batched_tokens = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_MAX_NUM_BATCHED_TOKENS),
"Maximum number of batched tokens for a worker serving the model",
),
&["model"],
)
.unwrap();
let model_context_length = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_CONTEXT_LENGTH),
"Maximum context length in tokens for a worker serving the model",
),
&["model"],
)
.unwrap();
let model_kv_cache_block_size = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_KV_CACHE_BLOCK_SIZE),
"KV cache block size in tokens for a worker serving the model",
),
&["model"],
)
.unwrap();
let model_migration_limit = IntGaugeVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_MIGRATION_LIMIT),
"Maximum number of request migrations allowed for the model",
),
&["model"],
)
.unwrap();
let model_migration_total = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_MIGRATION_TOTAL),
"Total number of request migrations due to worker unavailability",
),
&["model", frontend_service::MIGRATION_TYPE_LABEL],
)
.unwrap();
Metrics {
request_counter,
inflight_gauge,
client_disconnect_gauge,
http_queue_gauge,
request_duration,
input_sequence_length,
output_sequence_length,
cached_tokens,
tokenizer_latency,
output_tokens_counter,
time_to_first_token,
inter_token_latency,
model_total_kv_blocks,
model_max_num_seqs,
model_max_num_batched_tokens,
model_context_length,
model_kv_cache_block_size,
model_migration_limit,
model_migration_total,
}
}
pub fn get_request_counter(
&self,
model: &str,
endpoint: &Endpoint,
request_type: &RequestType,
status: &Status,
error_type: &ErrorType,
) -> u64 {
self.request_counter
.with_label_values(&[
model,
endpoint.as_str(),
request_type.as_str(),
status.as_str(),
error_type.as_str(),
])
.get()
}
fn inc_request_counter(
&self,
model: &str,
endpoint: &Endpoint,
request_type: &RequestType,
status: &Status,
error_type: &ErrorType,
) {
self.request_counter
.with_label_values(&[
model,
endpoint.as_str(),
request_type.as_str(),
status.as_str(),
error_type.as_str(),
])
.inc()
}
pub fn get_inflight_count(&self, model: &str) -> i64 {
self.inflight_gauge.with_label_values(&[model]).get()
}
fn inc_inflight_gauge(&self, model: &str) {
self.inflight_gauge.with_label_values(&[model]).inc()
}
fn dec_inflight_gauge(&self, model: &str) {
self.inflight_gauge.with_label_values(&[model]).dec()
}
pub fn inc_client_disconnect(&self) {
self.client_disconnect_gauge.inc();
}
pub fn get_client_disconnect_count(&self) -> i64 {
self.client_disconnect_gauge.get()
}
fn inc_http_queue_gauge(&self, model: &str) {
self.http_queue_gauge.with_label_values(&[model]).inc()
}
fn dec_http_queue_gauge(&self, model: &str) {
self.http_queue_gauge.with_label_values(&[model]).dec()
}
pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
registry.register(Box::new(self.request_counter.clone()))?;
registry.register(Box::new(self.inflight_gauge.clone()))?;
registry.register(Box::new(self.client_disconnect_gauge.clone()))?;
registry.register(Box::new(self.http_queue_gauge.clone()))?;
registry.register(Box::new(self.request_duration.clone()))?;
registry.register(Box::new(self.input_sequence_length.clone()))?;
registry.register(Box::new(self.output_sequence_length.clone()))?;
registry.register(Box::new(self.cached_tokens.clone()))?;
registry.register(Box::new(self.tokenizer_latency.clone()))?;
registry.register(Box::new(self.output_tokens_counter.clone()))?;
registry.register(Box::new(self.time_to_first_token.clone()))?;
registry.register(Box::new(self.inter_token_latency.clone()))?;
registry.register(Box::new(self.model_total_kv_blocks.clone()))?;
registry.register(Box::new(self.model_max_num_seqs.clone()))?;
registry.register(Box::new(self.model_max_num_batched_tokens.clone()))?;
registry.register(Box::new(self.model_context_length.clone()))?;
registry.register(Box::new(self.model_kv_cache_block_size.clone()))?;
registry.register(Box::new(self.model_migration_limit.clone()))?;
registry.register(Box::new(self.model_migration_total.clone()))?;
Ok(())
}
pub fn update_runtime_config_metrics(
&self,
model_name: &str,
runtime_config: &ModelRuntimeConfig,
) {
if let Some(total_kv_blocks) = runtime_config.total_kv_blocks {
self.model_total_kv_blocks
.with_label_values(&[model_name])
.set(clamp_u64_to_i64(total_kv_blocks));
}
if let Some(max_num_seqs) = runtime_config.max_num_seqs {
self.model_max_num_seqs
.with_label_values(&[model_name])
.set(clamp_u64_to_i64(max_num_seqs));
}
if let Some(max_batched_tokens) = runtime_config.max_num_batched_tokens {
self.model_max_num_batched_tokens
.with_label_values(&[model_name])
.set(clamp_u64_to_i64(max_batched_tokens));
}
}
pub fn update_metrics_from_mdc(&self, card: &ModelDeploymentCard) -> anyhow::Result<()> {
self.update_runtime_config_metrics(&card.display_name, &card.runtime_config);
self.model_context_length
.with_label_values(&[&card.display_name])
.set(card.context_length as i64);
self.model_kv_cache_block_size
.with_label_values(&[&card.display_name])
.set(card.kv_cache_block_size as i64);
self.model_migration_limit
.with_label_values(&[&card.display_name])
.set(card.migration_limit as i64);
tracing::debug!(
model = %card.display_name,
"Successfully updated MDC metrics"
);
Ok(())
}
pub fn inc_migration_new_request(&self, model: &str) {
self.model_migration_total
.with_label_values(&[model, frontend_service::migration_type::NEW_REQUEST])
.inc();
}
pub fn inc_migration_ongoing_request(&self, model: &str) {
self.model_migration_total
.with_label_values(&[model, frontend_service::migration_type::ONGOING_REQUEST])
.inc();
}
pub fn get_migration_new_request_count(&self, model: &str) -> u64 {
self.model_migration_total
.with_label_values(&[model, frontend_service::migration_type::NEW_REQUEST])
.get()
}
pub fn get_migration_ongoing_request_count(&self, model: &str) -> u64 {
self.model_migration_total
.with_label_values(&[model, frontend_service::migration_type::ONGOING_REQUEST])
.get()
}
pub fn create_inflight_guard(
self: Arc<Self>,
model: &str,
endpoint: Endpoint,
streaming: bool,
) -> InflightGuard {
let request_type = if streaming {
RequestType::Stream
} else {
RequestType::Unary
};
InflightGuard::new(
self.clone(),
model.to_string().to_lowercase(),
endpoint,
request_type,
)
}
pub fn create_response_collector(self: Arc<Self>, model: &str) -> ResponseMetricCollector {
ResponseMetricCollector::new(self, model.to_string().to_lowercase())
}
pub fn create_http_queue_guard(self: Arc<Self>, model: &str) -> HttpQueueGuard {
HttpQueueGuard::new(self, model.to_string().to_lowercase())
}
}
impl HttpQueueGuard {
fn new(metrics: Arc<Metrics>, model: String) -> Self {
metrics.inc_http_queue_gauge(&model);
HttpQueueGuard { metrics, model }
}
}
impl Drop for HttpQueueGuard {
fn drop(&mut self) {
self.metrics.dec_http_queue_gauge(&self.model);
}
}
impl InflightGuard {
fn new(
metrics: Arc<Metrics>,
model: String,
endpoint: Endpoint,
request_type: RequestType,
) -> Self {
let timer = Instant::now();
metrics.inc_inflight_gauge(&model);
InflightGuard {
metrics,
model,
endpoint,
request_type,
status: Status::Error,
error_type: ErrorType::Internal,
timer,
}
}
pub(crate) fn mark_ok(&mut self) {
self.status = Status::Success;
self.error_type = ErrorType::None;
}
pub(crate) fn mark_error(&mut self, error_type: ErrorType) {
self.status = Status::Error;
self.error_type = error_type;
}
}
impl Drop for InflightGuard {
fn drop(&mut self) {
let duration = self.timer.elapsed().as_secs_f64();
self.metrics.dec_inflight_gauge(&self.model);
self.metrics.inc_request_counter(
&self.model,
&self.endpoint,
&self.request_type,
&self.status,
&self.error_type,
);
self.metrics
.request_duration
.with_label_values(&[&self.model])
.observe(duration);
}
}
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Endpoint::Completions => write!(f, "completions"),
Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"),
Endpoint::Videos => write!(f, "videos"),
Endpoint::Responses => write!(f, "responses"),
Endpoint::AnthropicMessages => write!(f, "anthropic_messages"),
Endpoint::Tensor => write!(f, "tensor"),
}
}
}
impl Endpoint {
pub fn as_str(&self) -> &'static str {
match self {
Endpoint::Completions => "completions",
Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images",
Endpoint::Videos => "videos",
Endpoint::Responses => "responses",
Endpoint::AnthropicMessages => "anthropic_messages",
Endpoint::Tensor => "tensor",
}
}
}
impl RequestType {
pub fn as_str(&self) -> &'static str {
match self {
RequestType::Unary => frontend_service::request_type::UNARY,
RequestType::Stream => frontend_service::request_type::STREAM,
}
}
}
impl Status {
pub fn as_str(&self) -> &'static str {
match self {
Status::Success => frontend_service::status::SUCCESS,
Status::Error => frontend_service::status::ERROR,
}
}
}
impl ErrorType {
pub fn as_str(&self) -> &'static str {
match self {
ErrorType::None => frontend_service::error_type::NONE,
ErrorType::Validation => frontend_service::error_type::VALIDATION,
ErrorType::NotFound => frontend_service::error_type::NOT_FOUND,
ErrorType::Overload => frontend_service::error_type::OVERLOAD,
ErrorType::Cancelled => frontend_service::error_type::CANCELLED,
ErrorType::Internal => frontend_service::error_type::INTERNAL,
ErrorType::NotImplemented => frontend_service::error_type::NOT_IMPLEMENTED,
}
}
}
impl ResponseMetricCollector {
fn new(metrics: Arc<Metrics>, model: String) -> Self {
ResponseMetricCollector {
metrics,
model,
is_first_token: true,
last_response_time: None,
start_time: Instant::now(),
osl: 0,
cached_tokens_observed: false,
tokenize_latency_observed: false,
detokenize_latency_total: Duration::ZERO,
detokenize_count_total: 0,
prefill_worker_id: None,
prefill_dp_rank: None,
prefill_worker_type: None,
decode_worker_id: None,
decode_dp_rank: None,
decode_worker_type: None,
}
}
pub fn set_worker_info(
&mut self,
prefill_worker_id: Option<u64>,
prefill_dp_rank: Option<u32>,
prefill_worker_type: Option<String>,
decode_worker_id: Option<u64>,
decode_dp_rank: Option<u32>,
decode_worker_type: Option<String>,
) {
if self.prefill_worker_id.is_none() {
self.prefill_worker_id = prefill_worker_id;
}
if self.prefill_dp_rank.is_none() {
self.prefill_dp_rank = prefill_dp_rank;
}
if self.prefill_worker_type.is_none() {
self.prefill_worker_type = prefill_worker_type;
}
if self.decode_worker_id.is_none() {
self.decode_worker_id = decode_worker_id;
}
if self.decode_dp_rank.is_none() {
self.decode_dp_rank = decode_dp_rank;
}
if self.decode_worker_type.is_none() {
self.decode_worker_type = decode_worker_type;
}
}
pub fn observe_current_osl(&mut self, osl: usize) {
self.osl = osl;
}
pub fn is_first_token(&self) -> bool {
self.is_first_token
}
pub fn observe_cached_tokens(&mut self, cached_tokens: Option<usize>) {
if let Some(tokens) = cached_tokens
&& !self.cached_tokens_observed
{
self.cached_tokens_observed = true;
self.metrics
.cached_tokens
.with_label_values(&[&self.model])
.observe(tokens as f64);
}
}
pub fn observe_tokenize_latencies(
&mut self,
tokenize_latency: Option<Duration>,
detokenize_latency: Option<Duration>,
detokenize_count: Option<u64>,
) {
if let Some(latency) = tokenize_latency
&& !self.tokenize_latency_observed
{
self.tokenize_latency_observed = true;
self.metrics
.tokenizer_latency
.with_label_values(&[frontend_service::operation::TOKENIZE])
.observe(latency.as_secs_f64() * 1000.0);
}
if let Some(latency) = detokenize_latency {
self.detokenize_latency_total = latency;
}
if let Some(count) = detokenize_count {
self.detokenize_count_total = count;
}
}
pub fn observe_response(&mut self, isl: usize, num_tokens: usize) {
if num_tokens == 0 {
return;
}
self.metrics
.output_tokens_counter
.with_label_values(&[&self.model])
.inc_by(num_tokens as u64);
if self.is_first_token {
self.is_first_token = false;
let ttft = self.start_time.elapsed().as_secs_f64();
self.metrics
.time_to_first_token
.with_label_values(&[&self.model])
.observe(ttft);
if let Some(worker_id) = self.prefill_worker_id {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.prefill_dp_rank
.map_or("0".to_string(), |r| r.to_string());
let worker_type = self
.prefill_worker_type
.as_deref()
.unwrap_or(WORKER_TYPE_PREFILL);
let labels = &[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type];
WORKER_LAST_TIME_TO_FIRST_TOKEN_GAUGE
.with_label_values(labels)
.set(ttft);
WORKER_LAST_INPUT_SEQUENCE_TOKENS_GAUGE
.with_label_values(labels)
.set(isl as i64);
}
self.metrics
.input_sequence_length
.with_label_values(&[&self.model])
.observe(isl as f64);
}
let current_duration = self.start_time.elapsed();
if let Some(last_response_time) = self.last_response_time {
let response_duration = current_duration - last_response_time;
let itl = response_duration.as_secs_f64() / num_tokens as f64;
for _ in 0..num_tokens {
self.metrics
.inter_token_latency
.with_label_values(&[&self.model])
.observe(itl);
}
if let Some(worker_id) = self.decode_worker_id {
let worker_id_str = worker_id.to_string();
let dp_rank_str = self
.decode_dp_rank
.map_or("0".to_string(), |r| r.to_string());
let worker_type = self
.decode_worker_type
.as_deref()
.unwrap_or(WORKER_TYPE_DECODE);
WORKER_LAST_INTER_TOKEN_LATENCY_GAUGE
.with_label_values(&[worker_id_str.as_str(), dp_rank_str.as_str(), worker_type])
.set(itl);
}
}
self.last_response_time = Some(current_duration);
}
}
impl Drop for ResponseMetricCollector {
fn drop(&mut self) {
if !self.detokenize_latency_total.is_zero() && self.detokenize_count_total > 0 {
let avg_detokenize_latency_ms = (self.detokenize_latency_total.as_secs_f64() * 1000.0)
/ self.detokenize_count_total as f64;
self.metrics
.tokenizer_latency
.with_label_values(&[frontend_service::operation::DETOKENIZE])
.observe(avg_detokenize_latency_ms);
}
self.metrics
.output_sequence_length
.with_label_values(&[&self.model])
.observe(self.osl as f64);
}
}
pub fn process_response_and_observe_metrics<T>(
annotated: &crate::types::Annotated<T>,
response_collector: &mut ResponseMetricCollector,
http_queue_guard: &mut Option<HttpQueueGuard>,
) {
use crate::preprocessor::LLMMetricAnnotation;
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_cached_tokens(metrics.cached_tokens);
response_collector.observe_tokenize_latencies(
metrics.tokenize_latency,
metrics.detokenize_total_latency,
metrics.detokenize_count,
);
response_collector.set_worker_info(
metrics.prefill_worker_id,
metrics.prefill_dp_rank,
metrics.prefill_worker_type,
metrics.decode_worker_id,
metrics.decode_dp_rank,
metrics.decode_worker_type,
);
if response_collector.is_first_token()
&& metrics.chunk_tokens > 0
&& let Some(guard) = http_queue_guard.take()
{
drop(guard);
}
response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
}
}
pub struct EventConverter<T>(pub crate::types::Annotated<T>);
impl<T> From<crate::types::Annotated<T>> for EventConverter<T> {
fn from(annotated: crate::types::Annotated<T>) -> Self {
EventConverter(annotated)
}
}
pub fn process_response_using_event_converter_and_observe_metrics<T: Serialize>(
annotated: EventConverter<T>,
response_collector: &mut ResponseMetricCollector,
http_queue_guard: &mut Option<HttpQueueGuard>,
) -> Result<Option<Event>, axum::Error> {
use crate::preprocessor::LLMMetricAnnotation;
let mut annotated = annotated.0;
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) {
response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_cached_tokens(metrics.cached_tokens);
response_collector.observe_tokenize_latencies(
metrics.tokenize_latency,
metrics.detokenize_total_latency,
metrics.detokenize_count,
);
response_collector.set_worker_info(
metrics.prefill_worker_id,
metrics.prefill_dp_rank,
metrics.prefill_worker_type,
metrics.decode_worker_id,
metrics.decode_dp_rank,
metrics.decode_worker_type,
);
if response_collector.is_first_token()
&& metrics.chunk_tokens > 0
&& let Some(guard) = http_queue_guard.take()
{
drop(guard);
}
response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
if annotated.event.as_deref() == Some(crate::preprocessor::ANNOTATION_LLM_METRICS) {
annotated.event = None;
annotated.comment = None;
}
}
let mut event = Event::default();
if let Some(ref data) = annotated.data {
event = event.json_data(data)?;
}
if let Some(ref msg) = annotated.event {
if msg == "error" {
let msgs = annotated
.comment
.unwrap_or_else(|| vec!["unspecified error".to_string()]);
return Err(axum::Error::new(msgs.join(" -- ")));
}
event = event.event(msg);
}
if let Some(comments) = annotated.comment {
for comment in comments {
event = event.comment(comment);
}
}
if annotated.data.is_none() && annotated.event.is_none() {
Ok(None)
} else {
Ok(Some(event))
}
}
pub fn router(
registry: Registry,
path: Option<String>,
drt_metrics: Option<dynamo_runtime::metrics::MetricsRegistry>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or_else(|| "/metrics".to_string());
let doc = RouteDoc::new(axum::http::Method::GET, &path);
let metrics_state = MetricsHandlerState {
registry: Arc::new(registry),
drt_metrics,
};
let route = Router::new()
.route(&path, get(handler_metrics))
.with_state(Arc::new(metrics_state));
(vec![doc], route)
}
async fn handler_metrics(State(state): State<Arc<MetricsHandlerState>>) -> impl IntoResponse {
let encoder = prometheus::TextEncoder::new();
let metric_families = state.registry.gather();
let mut buffer = vec![];
if encoder.encode(&metric_families, &mut buffer).is_err() {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to encode metrics",
)
.into_response();
}
let mut metrics = match String::from_utf8(buffer) {
Ok(metrics) => metrics,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to encode metrics",
)
.into_response();
}
};
if let Some(ref drt_metrics) = state.drt_metrics {
match drt_metrics.prometheus_expfmt_combined() {
Ok(drt_text) => {
if !drt_text.is_empty() {
if !metrics.is_empty() && !metrics.ends_with('\n') {
metrics.push('\n');
}
metrics.push_str(&drt_text);
}
}
Err(e) => {
tracing::warn!("Failed to gather DRT metrics: {}", e);
}
}
}
(StatusCode::OK, metrics).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_round_to_sig_figs() {
assert_eq!(round_to_sig_figs(0.0026, 2), 0.0026);
assert_eq!(round_to_sig_figs(0.26, 2), 0.26);
assert_eq!(round_to_sig_figs(0.2356, 2), 0.24);
assert_eq!(round_to_sig_figs(1.234, 2), 1.2);
assert_eq!(round_to_sig_figs(12.34, 2), 12.0);
assert_eq!(round_to_sig_figs(123.4, 2), 120.0);
assert_eq!(round_to_sig_figs(1234.0, 2), 1200.0);
assert_eq!(round_to_sig_figs(0.0, 2), 0.0);
assert_eq!(round_to_sig_figs(0.999, 2), 1.0);
assert_eq!(round_to_sig_figs(9.99, 2), 10.0);
assert_eq!(round_to_sig_figs(99.9, 2), 100.0);
}
#[test]
fn test_generate_log_buckets_basic() {
let buckets = generate_log_buckets(1.0, 100.0, 5);
assert_eq!(buckets.len(), 5);
assert_eq!(buckets[0], 0.0);
assert_eq!(buckets[buckets.len() - 1], 100.0);
for i in 1..buckets.len() {
assert!(
buckets[i] > buckets[i - 1],
"Bucket values should be increasing: {} <= {}",
buckets[i - 1],
buckets[i]
);
}
}
#[test]
fn test_generate_log_buckets_edge_cases() {
let buckets = generate_log_buckets(1.0, 100.0, 0);
assert_eq!(buckets.len(), 0);
let buckets = generate_log_buckets(1.0, 100.0, 1);
assert_eq!(buckets.len(), 1);
assert_eq!(buckets[0], 0.0);
let buckets = generate_log_buckets(1.0, 100.0, 2);
assert_eq!(buckets.len(), 2);
assert_eq!(buckets[0], 0.0);
assert_eq!(buckets[1], 100.0);
}
#[test]
fn test_generate_log_buckets_always_includes_zero() {
for count in 1..=20 {
let buckets = generate_log_buckets(0.1, 1000.0, count);
assert_eq!(
buckets[0], 0.0,
"First bucket should always be 0.0 for count={}",
count
);
}
}
#[test]
fn test_all_buckets_are_two_sig_figs() {
let test_cases = vec![
(1.0, 256.0, 10),
(50.0, 128000.0, 12),
(50.0, 32000.0, 10),
(0.001, 480.0, 18),
(0.001, 2.0, 13),
];
for (min, max, count) in test_cases {
let buckets = generate_log_buckets(min, max, count);
for &value in buckets.iter().skip(1) {
let rounded = round_to_sig_figs(value, 2);
assert_eq!(
value, rounded,
"Value {} should be rounded to 2 sig figs (min={}, max={}, count={})",
value, min, max, count
);
}
}
}
#[test]
fn test_sig_fig_limitation_with_many_buckets() {
let buckets = generate_log_buckets(0.0001, 1.0, 1000);
println!(
"Requested 1000 buckets, got {} total values (including 0.0)",
buckets.len()
);
assert!(
buckets.len() < 500,
"Expected fewer than 500 unique buckets due to 2 sig fig limitation, got {}",
buckets.len()
);
let mut sorted_buckets = buckets.clone();
sorted_buckets.sort_by(|a, b| a.partial_cmp(b).unwrap());
sorted_buckets.dedup();
assert_eq!(
buckets.len(),
sorted_buckets.len(),
"All buckets should be unique after deduplication"
);
assert_eq!(buckets[0], 0.0);
for i in 1..buckets.len() {
assert!(
buckets[i] > buckets[i - 1],
"Buckets should be in increasing order"
);
}
}
#[test]
fn test_deduplication_preserves_order() {
let buckets = generate_log_buckets(0.01, 1.0, 50);
let mut unique_check = std::collections::HashSet::new();
for &bucket in &buckets {
assert!(
unique_check.insert(bucket.to_bits()),
"Duplicate value {} found after deduplication",
bucket
);
}
for i in 1..buckets.len() {
assert!(
buckets[i] > buckets[i - 1],
"Bucket values should be in increasing order after deduplication"
);
}
}
#[test]
fn test_output_tokens_counter_increments() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let mut collector = metrics.clone().create_response_collector(model);
collector.observe_response(100, 5);
let counter_value = metrics
.output_tokens_counter
.with_label_values(&[model])
.get();
assert_eq!(counter_value, 5);
collector.observe_response(100, 10);
let counter_value = metrics
.output_tokens_counter
.with_label_values(&[model])
.get();
assert_eq!(counter_value, 15);
collector.observe_response(100, 7);
let counter_value = metrics
.output_tokens_counter
.with_label_values(&[model])
.get();
assert_eq!(counter_value, 22);
}
#[test]
fn test_output_tokens_counter_zero_tokens() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let mut collector = metrics.clone().create_response_collector(model);
collector.observe_response(100, 0);
let counter_value = metrics
.output_tokens_counter
.with_label_values(&[model])
.get();
assert_eq!(counter_value, 0);
collector.observe_response(100, 5);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model])
.get(),
5
);
collector.observe_response(100, 0);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model])
.get(),
5
);
}
#[test]
fn test_output_tokens_counter_multiple_models() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model1 = "model-1";
let model2 = "model-2";
let mut collector1 = metrics.clone().create_response_collector(model1);
let mut collector2 = metrics.clone().create_response_collector(model2);
collector1.observe_response(100, 10);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model1])
.get(),
10
);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model2])
.get(),
0
);
collector2.observe_response(200, 20);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model1])
.get(),
10
);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model2])
.get(),
20
);
collector1.observe_response(100, 5);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model1])
.get(),
15
);
assert_eq!(
metrics
.output_tokens_counter
.with_label_values(&[model2])
.get(),
20
);
}
#[test]
fn test_cached_tokens_once_per_request() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let expected_metric_name = "dynamo_frontend_cached_tokens";
let mut collector = metrics.clone().create_response_collector(model);
let _histogram = metrics.cached_tokens.with_label_values(&[model]);
collector.observe_cached_tokens(Some(100));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
collector.observe_cached_tokens(Some(50));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
collector.observe_cached_tokens(Some(75));
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
}
#[test]
fn test_metrics_annotation_event_handling() {
use crate::preprocessor::LLMMetricAnnotation;
use crate::types::Annotated;
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let expected_metric_name = "dynamo_frontend_cached_tokens";
let expected_tokenizer_metric_name = "dynamo_frontend_tokenizer_latency_ms";
let mut collector = metrics.clone().create_response_collector(model);
let mut annotated = Annotated::<
crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
> {
id: None,
data: None,
event: Some(crate::preprocessor::ANNOTATION_LLM_METRICS.to_string()),
comment: None,
error: None,
};
let llm_metrics = LLMMetricAnnotation {
input_tokens: 10,
output_tokens: 20,
chunk_tokens: 5,
cached_tokens: Some(15),
prefill_worker_id: None,
prefill_dp_rank: None,
prefill_worker_type: None,
decode_worker_id: None,
decode_dp_rank: None,
decode_worker_type: None,
tokenize_latency: Some(Duration::from_millis(8)),
detokenize_total_latency: Some(Duration::from_micros(100)),
detokenize_count: Some(2),
};
let annotation = llm_metrics.to_annotation::<()>().unwrap();
annotated.event = annotation.event;
annotated.comment = annotation.comment;
let mut http_queue_guard = None;
let result = process_response_using_event_converter_and_observe_metrics(
EventConverter::from(annotated),
&mut collector,
&mut http_queue_guard,
);
assert!(matches!(result, Ok(None)));
drop(collector);
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_tokenizer_metric_name)
.expect("histogram should be registered");
let tokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "tokenize"))
.expect("tokenize metric should exist");
assert_eq!(tokenize_metric.get_histogram().get_sample_count(), 1);
assert!(
(tokenize_metric.get_histogram().get_sample_sum() - 8.0).abs() < 0.001,
"tokenize latency should be 8.0ms"
);
let detokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "detokenize"))
.expect("detokenize metric should exist");
assert_eq!(detokenize_metric.get_histogram().get_sample_count(), 1);
assert!(
(detokenize_metric.get_histogram().get_sample_sum() - 0.05).abs() < 0.001,
"detokenize average latency should be 0.05ms, got {}",
detokenize_metric.get_histogram().get_sample_sum()
);
}
#[test]
fn test_non_streaming_path_observes_cached_tokens() {
use crate::preprocessor::LLMMetricAnnotation;
use crate::types::Annotated;
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let expected_metric_name = "dynamo_frontend_cached_tokens";
let expected_tokenizer_metric_name = "dynamo_frontend_tokenizer_latency_ms";
let mut collector = metrics.clone().create_response_collector(model);
let mut annotated = Annotated::<
crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
> {
id: None,
data: None,
event: Some(crate::preprocessor::ANNOTATION_LLM_METRICS.to_string()),
comment: None,
error: None,
};
let llm_metrics = LLMMetricAnnotation {
input_tokens: 10,
output_tokens: 20,
chunk_tokens: 5,
cached_tokens: Some(15),
prefill_worker_id: None,
prefill_dp_rank: None,
prefill_worker_type: None,
decode_worker_id: None,
decode_dp_rank: None,
decode_worker_type: None,
tokenize_latency: Some(Duration::from_millis(8)),
detokenize_total_latency: Some(Duration::from_micros(100)),
detokenize_count: Some(2),
};
let annotation = llm_metrics.to_annotation::<()>().unwrap();
annotated.event = annotation.event;
annotated.comment = annotation.comment;
let mut http_queue_guard = None;
process_response_and_observe_metrics(&annotated, &mut collector, &mut http_queue_guard);
drop(collector);
let metric_families = registry.gather();
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_metric_name)
.expect("histogram should be registered");
assert_eq!(
histogram_family.get_metric()[0]
.get_histogram()
.get_sample_count(),
1
);
let histogram_family = metric_families
.iter()
.find(|mf| mf.name() == expected_tokenizer_metric_name)
.expect("histogram should be registered");
let tokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "tokenize"))
.expect("tokenize metric should exist");
assert_eq!(tokenize_metric.get_histogram().get_sample_count(), 1);
let detokenize_metric = histogram_family
.get_metric()
.iter()
.find(|m| m.get_label().iter().any(|l| l.value() == "detokenize"))
.expect("detokenize metric should exist");
assert_eq!(detokenize_metric.get_histogram().get_sample_count(), 1);
assert!(
(detokenize_metric.get_histogram().get_sample_sum() - 0.05).abs() < 0.001,
"detokenize average latency should be 0.05ms, got {}",
detokenize_metric.get_histogram().get_sample_sum()
);
}
#[test]
fn test_error_type_as_str() {
assert_eq!(ErrorType::None.as_str(), "");
assert_eq!(ErrorType::Validation.as_str(), "validation");
assert_eq!(ErrorType::NotFound.as_str(), "not_found");
assert_eq!(ErrorType::Overload.as_str(), "overload");
assert_eq!(ErrorType::Cancelled.as_str(), "cancelled");
assert_eq!(ErrorType::Internal.as_str(), "internal");
assert_eq!(ErrorType::NotImplemented.as_str(), "not_implemented");
}
#[test]
fn test_inflight_guard_marks_success_with_correct_error_type() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
{
let mut guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, false);
guard.mark_ok();
}
let counter_value = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::ChatCompletions.as_str(),
RequestType::Unary.as_str(),
Status::Success.as_str(),
ErrorType::None.as_str(),
])
.get();
assert_eq!(counter_value, 1);
}
#[test]
fn test_inflight_guard_marks_validation_error() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
{
let mut guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, false);
guard.mark_error(ErrorType::Validation);
}
let counter_value = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::ChatCompletions.as_str(),
RequestType::Unary.as_str(),
Status::Error.as_str(),
ErrorType::Validation.as_str(),
])
.get();
assert_eq!(counter_value, 1);
}
#[test]
fn test_inflight_guard_defaults_to_internal_error_on_drop() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
{
let _guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, false);
}
let counter_value = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::ChatCompletions.as_str(),
RequestType::Unary.as_str(),
Status::Error.as_str(),
ErrorType::Internal.as_str(),
])
.get();
assert_eq!(counter_value, 1);
}
#[test]
fn test_all_error_types_recorded_correctly() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
let endpoint = Endpoint::ChatCompletions;
let error_types = vec![
ErrorType::Validation,
ErrorType::NotFound,
ErrorType::Overload,
ErrorType::Cancelled,
ErrorType::Internal,
ErrorType::NotImplemented,
];
for error_type in &error_types {
let mut guard = metrics
.clone()
.create_inflight_guard(model, endpoint, false);
guard.mark_error(error_type.clone());
drop(guard);
}
for error_type in &error_types {
let counter_value = metrics
.request_counter
.with_label_values(&[
model,
endpoint.as_str(),
RequestType::Unary.as_str(),
Status::Error.as_str(),
error_type.as_str(),
])
.get();
assert_eq!(
counter_value,
1,
"Should have 1 request for error_type={}",
error_type.as_str()
);
}
}
#[test]
fn test_multiple_requests_different_error_types() {
let metrics = Arc::new(Metrics::new());
let registry = prometheus::Registry::new();
metrics.register(®istry).unwrap();
let model = "test-model";
for _ in 0..2 {
let mut guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, false);
guard.mark_error(ErrorType::Validation);
drop(guard);
}
for _ in 0..3 {
let mut guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::Completions, false);
guard.mark_error(ErrorType::Internal);
drop(guard);
}
{
let mut guard =
metrics
.clone()
.create_inflight_guard(model, Endpoint::Embeddings, false);
guard.mark_ok();
drop(guard);
}
let validation_count = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::ChatCompletions.as_str(),
RequestType::Unary.as_str(),
Status::Error.as_str(),
ErrorType::Validation.as_str(),
])
.get();
assert_eq!(validation_count, 2);
let internal_count = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::Completions.as_str(),
RequestType::Unary.as_str(),
Status::Error.as_str(),
ErrorType::Internal.as_str(),
])
.get();
assert_eq!(internal_count, 3);
let success_count = metrics
.request_counter
.with_label_values(&[
model,
Endpoint::Embeddings.as_str(),
RequestType::Unary.as_str(),
Status::Success.as_str(),
ErrorType::None.as_str(),
])
.get();
assert_eq!(success_count, 1);
}
}