use std::collections::HashMap;
#[cfg(feature = "axum")]
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use schemars::JsonSchema;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "axum")]
use serde_json::{json, Value};
use super::model_card::ModelCard;
pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
pub const DEFAULT_WORKER_COST: f32 = 1.0;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
)]
#[serde(rename_all = "lowercase")]
pub enum WorkerType {
#[default]
Regular,
Prefill,
Decode,
}
impl std::fmt::Display for WorkerType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WorkerType::Regular => write!(f, "regular"),
WorkerType::Prefill => write!(f, "prefill"),
WorkerType::Decode => write!(f, "decode"),
}
}
}
impl std::str::FromStr for WorkerType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("regular") {
Ok(WorkerType::Regular)
} else if s.eq_ignore_ascii_case("prefill") {
Ok(WorkerType::Prefill)
} else if s.eq_ignore_ascii_case("decode") {
Ok(WorkerType::Decode)
} else {
Err(format!("Unknown worker type: {s}"))
}
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
)]
#[serde(rename_all = "lowercase")]
pub enum ConnectionMode {
#[default]
Http,
Grpc,
}
impl std::fmt::Display for ConnectionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionMode::Http => write!(f, "http"),
ConnectionMode::Grpc => write!(f, "grpc"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WorkerGroupKey {
pub model_id: String,
pub worker_type: WorkerType,
pub connection_mode: ConnectionMode,
}
impl std::fmt::Display for WorkerGroupKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}:{}:{}",
self.model_id, self.worker_type, self.connection_mode
)
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
)]
#[serde(rename_all = "lowercase")]
pub enum RuntimeType {
#[default]
Unspecified,
Sglang,
Vllm,
Trtllm,
External,
}
impl RuntimeType {
pub fn is_specified(self) -> bool {
!matches!(self, RuntimeType::Unspecified)
}
}
impl std::fmt::Display for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RuntimeType::Unspecified => write!(f, "unspecified"),
RuntimeType::Sglang => write!(f, "sglang"),
RuntimeType::Vllm => write!(f, "vllm"),
RuntimeType::Trtllm => write!(f, "trtllm"),
RuntimeType::External => write!(f, "external"),
}
}
}
impl std::str::FromStr for RuntimeType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("unspecified") {
Ok(RuntimeType::Unspecified)
} else if s.eq_ignore_ascii_case("sglang") {
Ok(RuntimeType::Sglang)
} else if s.eq_ignore_ascii_case("vllm") {
Ok(RuntimeType::Vllm)
} else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
Ok(RuntimeType::Trtllm)
} else if s.eq_ignore_ascii_case("external") {
Ok(RuntimeType::External)
} else {
Err(format!("Unknown runtime type: {s}"))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ProviderType {
#[serde(alias = "openai")]
OpenAI,
#[serde(alias = "xai", alias = "grok")]
#[expect(
clippy::upper_case_acronyms,
reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
)]
XAI,
#[serde(alias = "anthropic", alias = "claude")]
Anthropic,
#[serde(alias = "gemini", alias = "google")]
Gemini,
#[serde(untagged)]
Custom(String),
}
impl ProviderType {
pub fn as_str(&self) -> &str {
match self {
Self::OpenAI => "openai",
Self::XAI => "xai",
Self::Anthropic => "anthropic",
Self::Gemini => "gemini",
Self::Custom(s) => s.as_str(),
}
}
pub fn from_url(url: &str) -> Option<Self> {
let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
if host.ends_with("openai.com") {
Some(Self::OpenAI)
} else if host.ends_with("x.ai") {
Some(Self::XAI)
} else if host.ends_with("anthropic.com") {
Some(Self::Anthropic)
} else if host.ends_with("googleapis.com") {
Some(Self::Gemini)
} else {
None
}
}
pub fn admin_key_env_var(&self) -> Option<&'static str> {
match self {
Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
Self::XAI => Some("XAI_ADMIN_KEY"),
Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
Self::Gemini => Some("GEMINI_ADMIN_KEY"),
Self::Custom(_) => None,
}
}
pub fn uses_x_api_key(&self) -> bool {
matches!(self, Self::Anthropic)
}
pub fn from_model_name(model: &str) -> Option<Self> {
let model_lower = model.to_lowercase();
if model_lower.starts_with("grok") {
Some(Self::XAI)
} else if model_lower.starts_with("gemini") {
Some(Self::Gemini)
} else if model_lower.starts_with("claude") {
Some(Self::Anthropic)
} else if model_lower.starts_with("gpt")
|| model_lower.starts_with("o1")
|| model_lower.starts_with("o3")
{
Some(Self::OpenAI)
} else {
None
}
}
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
fn default_priority() -> u32 {
DEFAULT_WORKER_PRIORITY
}
fn default_cost() -> f32 {
DEFAULT_WORKER_COST
}
fn default_health_check_timeout() -> u64 {
30
}
fn default_health_check_interval() -> u64 {
60
}
fn default_health_success_threshold() -> u32 {
2
}
fn default_health_failure_threshold() -> u32 {
3
}
fn default_max_connection_attempts() -> u32 {
20
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct HealthCheckConfig {
#[serde(default = "default_health_check_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_health_check_interval")]
pub check_interval_secs: u64,
#[serde(default = "default_health_success_threshold")]
pub success_threshold: u32,
#[serde(default = "default_health_failure_threshold")]
pub failure_threshold: u32,
#[serde(default)]
pub disable_health_check: bool,
}
impl Default for HealthCheckConfig {
fn default() -> Self {
Self {
timeout_secs: default_health_check_timeout(),
check_interval_secs: default_health_check_interval(),
success_threshold: default_health_success_threshold(),
failure_threshold: default_health_failure_threshold(),
disable_health_check: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub enum WorkerModels {
#[default]
Wildcard,
Single(Box<ModelCard>),
Multi(Vec<ModelCard>),
}
impl WorkerModels {
pub fn is_wildcard(&self) -> bool {
matches!(self, Self::Wildcard)
}
pub fn primary(&self) -> Option<&ModelCard> {
match self {
Self::Wildcard => None,
Self::Single(card) => Some(card.as_ref()),
Self::Multi(cards) => cards.first(),
}
}
pub fn all(&self) -> &[ModelCard] {
match self {
Self::Wildcard => &[],
Self::Single(card) => std::slice::from_ref(card.as_ref()),
Self::Multi(cards) => cards,
}
}
pub fn find(&self, id: &str) -> Option<&ModelCard> {
match self {
Self::Wildcard => None,
Self::Single(card) => card.matches(id).then_some(card.as_ref()),
Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
}
}
pub fn supports(&self, id: &str) -> bool {
match self {
Self::Wildcard => true,
_ => self.find(id).is_some(),
}
}
pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
self.all().iter()
}
}
impl From<Vec<ModelCard>> for WorkerModels {
fn from(models: Vec<ModelCard>) -> Self {
match models.len() {
0 => Self::Wildcard,
1 => {
let Some(model) = models.into_iter().next() else {
return Self::Wildcard;
};
Self::Single(Box::new(model))
}
_ => Self::Multi(models),
}
}
}
impl Serialize for WorkerModels {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.all().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for WorkerModels {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let models = Vec::<ModelCard>::deserialize(deserializer)?;
Ok(Self::from(models))
}
}
impl JsonSchema for WorkerModels {
fn schema_name() -> String {
"WorkerModels".to_string()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
Vec::<ModelCard>::json_schema(gen)
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerSpec {
pub url: String,
#[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
pub models: WorkerModels,
#[serde(default)]
pub worker_type: WorkerType,
#[serde(default)]
pub connection_mode: ConnectionMode,
#[serde(default, alias = "runtime")]
pub runtime_type: RuntimeType,
pub provider: Option<ProviderType>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub labels: HashMap<String, String>,
#[serde(default = "default_priority")]
pub priority: u32,
#[serde(default = "default_cost")]
pub cost: f32,
#[serde(default, skip_serializing)]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<u16>,
#[serde(default, skip)]
pub bootstrap_host: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_base_url: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_size: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kv_connector: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kv_role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kv_block_size: Option<usize>,
#[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
pub health: HealthCheckUpdate,
#[serde(default, skip_serializing_if = "HttpPoolConfig::is_empty")]
pub http_pool: HttpPoolConfig,
#[serde(default, skip_serializing_if = "ResilienceUpdate::is_empty")]
pub resilience: ResilienceUpdate,
#[serde(default = "default_max_connection_attempts")]
pub max_connection_attempts: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub load_monitor_interval_secs: Option<u64>,
}
impl WorkerSpec {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
models: WorkerModels::Wildcard,
worker_type: WorkerType::default(),
connection_mode: ConnectionMode::default(),
runtime_type: RuntimeType::default(),
provider: None,
labels: HashMap::new(),
priority: DEFAULT_WORKER_PRIORITY,
cost: DEFAULT_WORKER_COST,
api_key: None,
bootstrap_port: None,
bootstrap_host: String::new(),
dp_base_url: None,
dp_rank: None,
dp_size: None,
kv_connector: None,
kv_role: None,
kv_block_size: None,
health: HealthCheckUpdate::default(),
http_pool: HttpPoolConfig::default(),
resilience: ResilienceUpdate::default(),
max_connection_attempts: default_max_connection_attempts(),
load_monitor_interval_secs: None,
}
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerInfo {
pub id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(flatten)]
pub spec: WorkerSpec,
pub is_healthy: bool,
pub load: usize,
pub job_status: Option<JobStatus>,
}
impl WorkerInfo {
pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
Self {
id: worker_id.to_string(),
model_id: None,
spec: WorkerSpec::new(url),
is_healthy: false,
load: 0,
job_status,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct JobStatus {
pub job_type: String,
pub worker_url: String,
pub status: String,
pub message: Option<String>,
pub timestamp: u64,
}
impl JobStatus {
pub fn pending(job_type: &str, worker_url: &str) -> Self {
Self {
job_type: job_type.to_string(),
worker_url: worker_url.to_string(),
status: "pending".to_string(),
message: None,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn processing(job_type: &str, worker_url: &str) -> Self {
Self {
job_type: job_type.to_string(),
worker_url: worker_url.to_string(),
status: "processing".to_string(),
message: None,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
Self {
job_type: job_type.to_string(),
worker_url: worker_url.to_string(),
status: "failed".to_string(),
message: Some(error),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerListResponse {
pub workers: Vec<WorkerInfo>,
pub total: usize,
pub stats: WorkerStats,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerStats {
pub total_workers: usize,
pub healthy_workers: usize,
pub total_models: usize,
pub total_load: usize,
pub by_type: WorkerTypeStats,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerTypeStats {
pub regular: usize,
pub prefill: usize,
pub decode: usize,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
pub struct HealthCheckUpdate {
pub timeout_secs: Option<u64>,
pub check_interval_secs: Option<u64>,
pub success_threshold: Option<u32>,
pub failure_threshold: Option<u32>,
pub disable_health_check: Option<bool>,
}
impl HealthCheckUpdate {
pub fn is_empty(&self) -> bool {
self.timeout_secs.is_none()
&& self.check_interval_secs.is_none()
&& self.success_threshold.is_none()
&& self.failure_threshold.is_none()
&& self.disable_health_check.is_none()
}
}
impl HealthCheckUpdate {
pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
HealthCheckConfig {
timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
check_interval_secs: self
.check_interval_secs
.unwrap_or(existing.check_interval_secs),
success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
disable_health_check: self
.disable_health_check
.unwrap_or(existing.disable_health_check),
}
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
pub struct HttpPoolConfig {
pub pool_max_idle_per_host: Option<usize>,
pub pool_idle_timeout_secs: Option<u64>,
pub timeout_secs: Option<u64>,
pub connect_timeout_secs: Option<u64>,
}
impl HttpPoolConfig {
pub fn is_empty(&self) -> bool {
self.pool_max_idle_per_host.is_none()
&& self.pool_idle_timeout_secs.is_none()
&& self.timeout_secs.is_none()
&& self.connect_timeout_secs.is_none()
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ResilienceUpdate {
pub max_retries: Option<u32>,
pub initial_backoff_ms: Option<u64>,
pub max_backoff_ms: Option<u64>,
pub backoff_multiplier: Option<f32>,
pub jitter_factor: Option<f32>,
pub disable_retry: Option<bool>,
pub cb_failure_threshold: Option<u32>,
pub cb_success_threshold: Option<u32>,
pub cb_timeout_secs: Option<u64>,
pub cb_window_secs: Option<u64>,
pub disable_circuit_breaker: Option<bool>,
pub retryable_status_codes: Option<Vec<u16>>,
}
impl ResilienceUpdate {
pub fn is_empty(&self) -> bool {
self.max_retries.is_none()
&& self.initial_backoff_ms.is_none()
&& self.max_backoff_ms.is_none()
&& self.backoff_multiplier.is_none()
&& self.jitter_factor.is_none()
&& self.disable_retry.is_none()
&& self.cb_failure_threshold.is_none()
&& self.cb_success_threshold.is_none()
&& self.cb_timeout_secs.is_none()
&& self.cb_window_secs.is_none()
&& self.disable_circuit_breaker.is_none()
&& self.retryable_status_codes.is_none()
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerUpdateRequest {
pub priority: Option<u32>,
pub cost: Option<f32>,
pub labels: Option<HashMap<String, String>>,
pub api_key: Option<String>,
pub health: Option<HealthCheckUpdate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerApiResponse {
pub success: bool,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub worker: Option<WorkerInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct WorkerErrorResponse {
pub error: String,
pub code: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlushCacheResult {
pub successful: Vec<String>,
pub failed: Vec<(String, String)>,
pub total_workers: usize,
pub http_workers: usize,
pub message: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadsResult {
pub loads: Vec<WorkerLoadInfo>,
pub total_workers: usize,
pub successful: usize,
pub failed: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct SchedulerLoadSnapshot {
pub dp_rank: i32,
pub num_running_reqs: i32,
pub num_waiting_reqs: i32,
pub num_total_reqs: i32,
pub num_used_tokens: i32,
pub max_total_num_tokens: i32,
pub token_usage: f64,
pub gen_throughput: f64,
pub cache_hit_rate: f64,
pub utilization: f64,
pub max_running_requests: i32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct WorkerLoadResponse {
pub timestamp: String,
pub dp_rank_count: i32,
pub loads: Vec<SchedulerLoadSnapshot>,
}
impl WorkerLoadResponse {
pub fn effective_token_usage(&self) -> f64 {
if self.loads.is_empty() {
return 0.0;
}
self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
}
pub fn total_used_tokens(&self) -> i64 {
self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
}
pub fn dp_rank_loads(&self) -> HashMap<isize, isize> {
let mut map = HashMap::new();
for snapshot in &self.loads {
map.insert(snapshot.dp_rank as isize, snapshot.num_used_tokens as isize);
}
map
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerLoadInfo {
pub worker: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_type: Option<String>,
pub load: isize,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<WorkerLoadResponse>,
}
#[cfg(feature = "axum")]
impl IntoResponse for FlushCacheResult {
fn into_response(self) -> Response {
let status = if self.failed.is_empty() {
StatusCode::OK
} else {
StatusCode::PARTIAL_CONTENT
};
let mut body = json!({
"status": if self.failed.is_empty() { "success" } else { "partial_success" },
"message": self.message,
"workers_flushed": self.successful.len(),
"total_http_workers": self.http_workers,
"total_workers": self.total_workers
});
if !self.failed.is_empty() {
body["successful"] = json!(self.successful);
body["failed"] = json!(self
.failed
.into_iter()
.map(|(url, err)| json!({"worker": url, "error": err}))
.collect::<Vec<_>>());
}
(status, Json(body)).into_response()
}
}
#[cfg(feature = "axum")]
impl IntoResponse for WorkerLoadsResult {
fn into_response(self) -> Response {
let loads: Vec<Value> = self
.loads
.iter()
.map(|info| {
let mut entry = json!({"worker": &info.worker, "load": info.load});
if let Some(ref details) = info.details {
entry["details"] = json!(details);
}
entry
})
.collect();
Json(json!({"workers": loads})).into_response()
}
}