Skip to main content

openai_protocol/
worker.rs

1//! Canonical worker types and identity.
2//!
3//! This module defines the single source of truth for worker identity, type
4//! enums, and core configuration. These types are shared across API
5//! request/response boundaries and internal runtime state.
6
7use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11    http::StatusCode,
12    response::{IntoResponse, Response},
13    Json,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17#[cfg(feature = "axum")]
18use serde_json::{json, Value};
19
20use super::model_card::ModelCard;
21
22// ── Default value constants ──────────────────────────────────────────
23
24pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
25pub const DEFAULT_WORKER_COST: f32 = 1.0;
26
27// ── Enums ────────────────────────────────────────────────────────────
28
29/// Worker type classification.
30#[derive(
31    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
32)]
33#[serde(rename_all = "lowercase")]
34pub enum WorkerType {
35    /// Regular worker for standard routing.
36    #[default]
37    Regular,
38    /// Prefill worker for PD disaggregated mode.
39    Prefill,
40    /// Decode worker for PD disaggregated mode.
41    Decode,
42}
43
44impl std::fmt::Display for WorkerType {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            WorkerType::Regular => write!(f, "regular"),
48            WorkerType::Prefill => write!(f, "prefill"),
49            WorkerType::Decode => write!(f, "decode"),
50        }
51    }
52}
53
54impl std::str::FromStr for WorkerType {
55    type Err = String;
56
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        if s.eq_ignore_ascii_case("regular") {
59            Ok(WorkerType::Regular)
60        } else if s.eq_ignore_ascii_case("prefill") {
61            Ok(WorkerType::Prefill)
62        } else if s.eq_ignore_ascii_case("decode") {
63            Ok(WorkerType::Decode)
64        } else {
65            Err(format!("Unknown worker type: {s}"))
66        }
67    }
68}
69
70/// Connection mode for worker communication.
71#[derive(
72    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
73)]
74#[serde(rename_all = "lowercase")]
75pub enum ConnectionMode {
76    /// HTTP/REST connection.
77    #[default]
78    Http,
79    /// gRPC connection.
80    Grpc,
81}
82
83impl std::fmt::Display for ConnectionMode {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            ConnectionMode::Http => write!(f, "http"),
87            ConnectionMode::Grpc => write!(f, "grpc"),
88        }
89    }
90}
91
92/// Composite key identifying a group of workers with the same characteristics.
93///
94/// Groups workers by `(model_id, worker_type, connection_mode)` — the natural
95/// partitioning used for metrics, load monitoring, and policy management.
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct WorkerGroupKey {
98    pub model_id: String,
99    pub worker_type: WorkerType,
100    pub connection_mode: ConnectionMode,
101}
102
103impl std::fmt::Display for WorkerGroupKey {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "{}:{}:{}",
108            self.model_id, self.worker_type, self.connection_mode
109        )
110    }
111}
112
113/// Runtime implementation type for workers.
114#[derive(
115    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
116)]
117#[serde(rename_all = "lowercase")]
118pub enum RuntimeType {
119    /// No runtime type specified — the backend will be auto-detected.
120    #[default]
121    Unspecified,
122    /// SGLang runtime.
123    Sglang,
124    /// vLLM runtime.
125    Vllm,
126    /// TensorRT-LLM runtime.
127    Trtllm,
128    /// External OpenAI-compatible API (not local inference).
129    External,
130}
131
132impl RuntimeType {
133    /// Returns `true` when the caller supplied an explicit runtime type.
134    pub fn is_specified(self) -> bool {
135        !matches!(self, RuntimeType::Unspecified)
136    }
137}
138
139impl std::fmt::Display for RuntimeType {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            RuntimeType::Unspecified => write!(f, "unspecified"),
143            RuntimeType::Sglang => write!(f, "sglang"),
144            RuntimeType::Vllm => write!(f, "vllm"),
145            RuntimeType::Trtllm => write!(f, "trtllm"),
146            RuntimeType::External => write!(f, "external"),
147        }
148    }
149}
150
151impl std::str::FromStr for RuntimeType {
152    type Err = String;
153
154    fn from_str(s: &str) -> Result<Self, Self::Err> {
155        if s.eq_ignore_ascii_case("unspecified") {
156            Ok(RuntimeType::Unspecified)
157        } else if s.eq_ignore_ascii_case("sglang") {
158            Ok(RuntimeType::Sglang)
159        } else if s.eq_ignore_ascii_case("vllm") {
160            Ok(RuntimeType::Vllm)
161        } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
162            Ok(RuntimeType::Trtllm)
163        } else if s.eq_ignore_ascii_case("external") {
164            Ok(RuntimeType::External)
165        } else {
166            Err(format!("Unknown runtime type: {s}"))
167        }
168    }
169}
170
171/// Provider type for external API transformations.
172///
173/// Different providers have different API formats and requirements.
174/// `None` (when used as `Option<ProviderType>`) means native/passthrough —
175/// no transformation needed (local SGLang backends).
176#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
177#[serde(rename_all = "lowercase")]
178pub enum ProviderType {
179    /// OpenAI API — strip SGLang-specific fields.
180    #[serde(alias = "openai")]
181    OpenAI,
182    /// xAI/Grok — special handling for input items.
183    #[serde(alias = "xai", alias = "grok")]
184    #[expect(
185        clippy::upper_case_acronyms,
186        reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
187    )]
188    XAI,
189    /// Anthropic Claude — different API format.
190    #[serde(alias = "anthropic", alias = "claude")]
191    Anthropic,
192    /// Google Gemini — special logprobs handling.
193    #[serde(alias = "gemini", alias = "google")]
194    Gemini,
195    /// Custom provider with string identifier.
196    #[serde(untagged)]
197    Custom(String),
198}
199
200impl ProviderType {
201    /// Get provider name as string.
202    pub fn as_str(&self) -> &str {
203        match self {
204            Self::OpenAI => "openai",
205            Self::XAI => "xai",
206            Self::Anthropic => "anthropic",
207            Self::Gemini => "gemini",
208            Self::Custom(s) => s.as_str(),
209        }
210    }
211
212    /// Detect provider from URL host.
213    /// Returns `None` for URLs that don't match known providers or can't be parsed.
214    pub fn from_url(url: &str) -> Option<Self> {
215        let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
216
217        if host.ends_with("openai.com") {
218            Some(Self::OpenAI)
219        } else if host.ends_with("x.ai") {
220            Some(Self::XAI)
221        } else if host.ends_with("anthropic.com") {
222            Some(Self::Anthropic)
223        } else if host.ends_with("googleapis.com") {
224            Some(Self::Gemini)
225        } else {
226            None
227        }
228    }
229
230    /// Environment variable name for per-provider admin API key (model discovery).
231    /// Returns `None` for `Custom` providers since there's no known env var.
232    pub fn admin_key_env_var(&self) -> Option<&'static str> {
233        match self {
234            Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
235            Self::XAI => Some("XAI_ADMIN_KEY"),
236            Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
237            Self::Gemini => Some("GEMINI_ADMIN_KEY"),
238            Self::Custom(_) => None,
239        }
240    }
241
242    /// Whether this provider uses `x-api-key` header instead of `Authorization: Bearer`.
243    pub fn uses_x_api_key(&self) -> bool {
244        matches!(self, Self::Anthropic)
245    }
246
247    /// Detect provider from model name (heuristic fallback).
248    /// Returns `None` for models that don't match known external providers.
249    pub fn from_model_name(model: &str) -> Option<Self> {
250        let model_lower = model.to_lowercase();
251        if model_lower.starts_with("grok") {
252            Some(Self::XAI)
253        } else if model_lower.starts_with("gemini") {
254            Some(Self::Gemini)
255        } else if model_lower.starts_with("claude") {
256            Some(Self::Anthropic)
257        } else if model_lower.starts_with("gpt")
258            || model_lower.starts_with("o1")
259            || model_lower.starts_with("o3")
260        {
261            Some(Self::OpenAI)
262        } else {
263            None
264        }
265    }
266}
267
268impl std::fmt::Display for ProviderType {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        write!(f, "{}", self.as_str())
271    }
272}
273
274// ── Serde default helpers ────────────────────────────────────────────
275
276fn default_priority() -> u32 {
277    DEFAULT_WORKER_PRIORITY
278}
279
280fn default_cost() -> f32 {
281    DEFAULT_WORKER_COST
282}
283
284fn default_health_check_timeout() -> u64 {
285    30
286}
287
288fn default_health_check_interval() -> u64 {
289    60
290}
291
292fn default_health_success_threshold() -> u32 {
293    2
294}
295
296fn default_health_failure_threshold() -> u32 {
297    3
298}
299
300fn default_max_connection_attempts() -> u32 {
301    20
302}
303
304// ── Health check config ─────────────────────────────────────────────
305
306/// Health check configuration shared across protocol and runtime layers.
307#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
308pub struct HealthCheckConfig {
309    /// Health check timeout in seconds (default: 30).
310    #[serde(default = "default_health_check_timeout")]
311    pub timeout_secs: u64,
312
313    /// Health check interval in seconds (default: 60).
314    #[serde(default = "default_health_check_interval")]
315    pub check_interval_secs: u64,
316
317    /// Number of successful health checks needed to mark worker as healthy (default: 2).
318    #[serde(default = "default_health_success_threshold")]
319    pub success_threshold: u32,
320
321    /// Number of failed health checks before marking worker as unhealthy (default: 3).
322    #[serde(default = "default_health_failure_threshold")]
323    pub failure_threshold: u32,
324
325    /// Disable periodic health checks for this worker (default: false).
326    #[serde(default)]
327    pub disable_health_check: bool,
328}
329
330impl Default for HealthCheckConfig {
331    fn default() -> Self {
332        Self {
333            timeout_secs: default_health_check_timeout(),
334            check_interval_secs: default_health_check_interval(),
335            success_threshold: default_health_success_threshold(),
336            failure_threshold: default_health_failure_threshold(),
337            disable_health_check: false,
338        }
339    }
340}
341
342// ── Worker models ───────────────────────────────────────────────────
343
344/// Models configuration for a worker.
345///
346/// Encodes the three real cases instead of relying on `Vec` semantics:
347/// - `Wildcard` — accepts any model (empty models list on the wire)
348/// - `Single` — serves exactly one model
349/// - `Multi` — serves multiple distinct models (len >= 2)
350#[derive(Debug, Clone, Default)]
351pub enum WorkerModels {
352    /// Worker accepts any model (e.g., external API without discovery).
353    #[default]
354    Wildcard,
355    /// Worker serves exactly one model (most common for local inference).
356    Single(Box<ModelCard>),
357    /// Worker serves multiple distinct models (len >= 2).
358    Multi(Vec<ModelCard>),
359}
360
361impl WorkerModels {
362    /// Returns `true` if this is a wildcard (accepts any model).
363    pub fn is_wildcard(&self) -> bool {
364        matches!(self, Self::Wildcard)
365    }
366
367    /// Returns the primary model: `Single` → `Some`, `Multi` → first, `Wildcard` → `None`.
368    pub fn primary(&self) -> Option<&ModelCard> {
369        match self {
370            Self::Wildcard => None,
371            Self::Single(card) => Some(card.as_ref()),
372            Self::Multi(cards) => cards.first(),
373        }
374    }
375
376    /// Returns all models as a slice (empty for `Wildcard`).
377    pub fn all(&self) -> &[ModelCard] {
378        match self {
379            Self::Wildcard => &[],
380            Self::Single(card) => std::slice::from_ref(card.as_ref()),
381            Self::Multi(cards) => cards,
382        }
383    }
384
385    /// Find a model by ID (checks aliases via `ModelCard::matches`).
386    pub fn find(&self, id: &str) -> Option<&ModelCard> {
387        match self {
388            Self::Wildcard => None,
389            Self::Single(card) => card.matches(id).then_some(card.as_ref()),
390            Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
391        }
392    }
393
394    /// Returns `true` if the worker supports the given model ID.
395    /// Wildcard workers always return `true`.
396    pub fn supports(&self, id: &str) -> bool {
397        match self {
398            Self::Wildcard => true,
399            _ => self.find(id).is_some(),
400        }
401    }
402
403    /// Iterate over all models. Empty iterator for `Wildcard`.
404    pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
405        self.all().iter()
406    }
407}
408
409impl From<Vec<ModelCard>> for WorkerModels {
410    fn from(models: Vec<ModelCard>) -> Self {
411        match models.len() {
412            0 => Self::Wildcard,
413            1 => {
414                let Some(model) = models.into_iter().next() else {
415                    return Self::Wildcard;
416                };
417                Self::Single(Box::new(model))
418            }
419            _ => Self::Multi(models),
420        }
421    }
422}
423
424/// Serialize as `Vec<ModelCard>` for wire compatibility.
425impl Serialize for WorkerModels {
426    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
427        self.all().serialize(serializer)
428    }
429}
430
431/// Deserialize from `Vec<ModelCard>` for wire compatibility.
432impl<'de> Deserialize<'de> for WorkerModels {
433    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
434        let models = Vec::<ModelCard>::deserialize(deserializer)?;
435        Ok(Self::from(models))
436    }
437}
438
439/// JsonSchema: wire format is `Vec<ModelCard>`.
440impl JsonSchema for WorkerModels {
441    fn schema_name() -> String {
442        "WorkerModels".to_string()
443    }
444
445    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
446        Vec::<ModelCard>::json_schema(gen)
447    }
448}
449
450// ── Core identity ────────────────────────────────────────────────────
451
452/// Core worker identity and configuration.
453///
454/// The single canonical representation of "what is a worker". Used as the
455/// shared sub-struct across API requests, API responses, and internal runtime
456/// state via `#[serde(flatten)]`.
457///
458/// Fields use `#[serde(default)]` so the same struct works for both input
459/// (partial config from user) and output (fully resolved state).
460#[serde_with::skip_serializing_none]
461#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
462pub struct WorkerSpec {
463    /// Worker URL.
464    pub url: String,
465
466    /// Models this worker can serve.
467    #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
468    pub models: WorkerModels,
469
470    /// Worker type: regular, prefill, or decode.
471    #[serde(default)]
472    pub worker_type: WorkerType,
473
474    /// Connection mode: http or grpc.
475    #[serde(default)]
476    pub connection_mode: ConnectionMode,
477
478    /// Runtime type: sglang, vllm, trtllm, or external.
479    #[serde(default, alias = "runtime")]
480    pub runtime_type: RuntimeType,
481
482    /// External provider for API transformations.
483    /// `None` means native/passthrough.
484    pub provider: Option<ProviderType>,
485
486    /// Additional labels/tags.
487    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
488    pub labels: HashMap<String, String>,
489
490    /// Worker priority (higher = preferred).
491    #[serde(default = "default_priority")]
492    pub priority: u32,
493
494    /// Worker cost factor (baseline = 1.0).
495    #[serde(default = "default_cost")]
496    pub cost: f32,
497
498    /// Worker API key. Accepted on input, never included in responses.
499    #[serde(default, skip_serializing)]
500    pub api_key: Option<String>,
501
502    /// Bootstrap port for prefill workers in PD disaggregated mode.
503    #[serde(default, skip_serializing_if = "Option::is_none")]
504    pub bootstrap_port: Option<u16>,
505
506    /// Bootstrap hostname (derived from URL at construction time).
507    #[serde(default, skip)]
508    pub bootstrap_host: String,
509
510    /// Base URL without DP rank suffix (for DP-aware workers).
511    /// When set, `url` contains the rank-suffixed form (`{base}@{rank}`).
512    #[serde(default, skip_serializing_if = "Option::is_none")]
513    pub dp_base_url: Option<String>,
514
515    /// Data-parallel rank (None = not DP-aware).
516    #[serde(default, skip_serializing_if = "Option::is_none")]
517    pub dp_rank: Option<usize>,
518
519    /// Total data-parallel group size (None = not DP-aware).
520    #[serde(default, skip_serializing_if = "Option::is_none")]
521    pub dp_size: Option<usize>,
522
523    /// KV connector type (e.g. "MooncakeConnector", "NixlConnector").
524    #[serde(default, skip_serializing_if = "Option::is_none")]
525    pub kv_connector: Option<String>,
526
527    /// KV role (e.g. "kv_producer", "kv_consumer", "kv_both").
528    #[serde(default, skip_serializing_if = "Option::is_none")]
529    pub kv_role: Option<String>,
530
531    /// KV cache block size (tokens per block) for event-driven routing.
532    /// When set, overrides the router-level default for this worker's model.
533    /// Typically matches the backend engine's page size (e.g. 16 for SGLang).
534    #[serde(default, skip_serializing_if = "Option::is_none")]
535    pub kv_block_size: Option<usize>,
536
537    /// Per-worker health check overrides (partial — only `Some` fields override router defaults).
538    #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
539    pub health: HealthCheckUpdate,
540
541    /// Per-worker HTTP connection pool overrides.
542    #[serde(default, skip_serializing_if = "HttpPoolConfig::is_empty")]
543    pub http_pool: HttpPoolConfig,
544
545    /// Per-worker resilience overrides (retry + circuit breaker).
546    #[serde(default, skip_serializing_if = "ResilienceUpdate::is_empty")]
547    pub resilience: ResilienceUpdate,
548
549    /// Maximum connection attempts during worker registration (default: 20).
550    #[serde(default = "default_max_connection_attempts")]
551    pub max_connection_attempts: u32,
552
553    /// Per-worker load monitor interval override (seconds).
554    /// When set, workers in the same group use this interval for load polling.
555    /// Falls back to the global `load_monitor_interval_secs` from router config.
556    #[serde(default, skip_serializing_if = "Option::is_none")]
557    pub load_monitor_interval_secs: Option<u64>,
558}
559
560impl WorkerSpec {
561    /// Create a new `WorkerSpec` with the given URL and sensible defaults.
562    pub fn new(url: impl Into<String>) -> Self {
563        Self {
564            url: url.into(),
565            models: WorkerModels::Wildcard,
566            worker_type: WorkerType::default(),
567            connection_mode: ConnectionMode::default(),
568            runtime_type: RuntimeType::default(),
569            provider: None,
570            labels: HashMap::new(),
571            priority: DEFAULT_WORKER_PRIORITY,
572            cost: DEFAULT_WORKER_COST,
573            api_key: None,
574            bootstrap_port: None,
575            bootstrap_host: String::new(),
576            dp_base_url: None,
577            dp_rank: None,
578            dp_size: None,
579            kv_connector: None,
580            kv_role: None,
581            kv_block_size: None,
582            health: HealthCheckUpdate::default(),
583            http_pool: HttpPoolConfig::default(),
584            resilience: ResilienceUpdate::default(),
585            max_connection_attempts: default_max_connection_attempts(),
586            load_monitor_interval_secs: None,
587        }
588    }
589}
590
591// ── API types ───────────────────────────────────────────────────────
592
593/// Worker information for API responses.
594#[serde_with::skip_serializing_none]
595#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
596pub struct WorkerInfo {
597    /// Worker unique identifier.
598    pub id: String,
599
600    /// Primary model ID for backwards compatibility.
601    /// Computed from `models[0].id` (single/multi) or `null` (wildcard).
602    #[serde(default, skip_serializing_if = "Option::is_none")]
603    pub model_id: Option<String>,
604
605    /// Worker identity and configuration.
606    #[serde(flatten)]
607    pub spec: WorkerSpec,
608
609    /// Whether the worker is healthy.
610    pub is_healthy: bool,
611
612    /// Current load on the worker.
613    pub load: usize,
614
615    /// Job status for async operations (if available).
616    pub job_status: Option<JobStatus>,
617}
618
619impl WorkerInfo {
620    /// Create a partial WorkerInfo for pending workers (not yet registered).
621    pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
622        Self {
623            id: worker_id.to_string(),
624            model_id: None,
625            spec: WorkerSpec::new(url),
626            is_healthy: false,
627            load: 0,
628            job_status,
629        }
630    }
631}
632
633/// Job status for async control plane operations
634#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
635pub struct JobStatus {
636    pub job_type: String,
637    pub worker_url: String,
638    pub status: String,
639    pub message: Option<String>,
640    pub timestamp: u64,
641}
642
643impl JobStatus {
644    /// Create a pending job status
645    pub fn pending(job_type: &str, worker_url: &str) -> Self {
646        Self {
647            job_type: job_type.to_string(),
648            worker_url: worker_url.to_string(),
649            status: "pending".to_string(),
650            message: None,
651            timestamp: std::time::SystemTime::now()
652                .duration_since(std::time::SystemTime::UNIX_EPOCH)
653                .unwrap_or_default()
654                .as_secs(),
655        }
656    }
657
658    /// Create a processing job status
659    pub fn processing(job_type: &str, worker_url: &str) -> Self {
660        Self {
661            job_type: job_type.to_string(),
662            worker_url: worker_url.to_string(),
663            status: "processing".to_string(),
664            message: None,
665            timestamp: std::time::SystemTime::now()
666                .duration_since(std::time::SystemTime::UNIX_EPOCH)
667                .unwrap_or_default()
668                .as_secs(),
669        }
670    }
671
672    /// Create a failed job status
673    pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
674        Self {
675            job_type: job_type.to_string(),
676            worker_url: worker_url.to_string(),
677            status: "failed".to_string(),
678            message: Some(error),
679            timestamp: std::time::SystemTime::now()
680                .duration_since(std::time::SystemTime::UNIX_EPOCH)
681                .unwrap_or_default()
682                .as_secs(),
683        }
684    }
685}
686
687/// Worker list response
688#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
689pub struct WorkerListResponse {
690    pub workers: Vec<WorkerInfo>,
691    pub total: usize,
692    pub stats: WorkerStats,
693}
694
695/// Worker statistics
696#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
697pub struct WorkerStats {
698    pub total_workers: usize,
699    pub healthy_workers: usize,
700    pub total_models: usize,
701    pub total_load: usize,
702    pub by_type: WorkerTypeStats,
703}
704
705/// Worker statistics by type
706#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
707pub struct WorkerTypeStats {
708    pub regular: usize,
709    pub prefill: usize,
710    pub decode: usize,
711}
712
713// ── Update types ────────────────────────────────────────────────────
714
715/// Partial health check config for PATCH-style updates.
716///
717/// Each `None` field means "keep the existing value". This avoids the problem
718/// where `#[serde(default)]` on [`HealthCheckConfig`] would silently reset
719/// unspecified fields to defaults.
720#[serde_with::skip_serializing_none]
721#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
722pub struct HealthCheckUpdate {
723    pub timeout_secs: Option<u64>,
724    pub check_interval_secs: Option<u64>,
725    pub success_threshold: Option<u32>,
726    pub failure_threshold: Option<u32>,
727    pub disable_health_check: Option<bool>,
728}
729
730impl HealthCheckUpdate {
731    /// Returns `true` if all fields are `None` (no overrides specified).
732    pub fn is_empty(&self) -> bool {
733        self.timeout_secs.is_none()
734            && self.check_interval_secs.is_none()
735            && self.success_threshold.is_none()
736            && self.failure_threshold.is_none()
737            && self.disable_health_check.is_none()
738    }
739}
740
741impl HealthCheckUpdate {
742    /// Merge this update into an existing [`HealthCheckConfig`], returning a new config.
743    /// Only `Some` fields are applied; `None` fields keep the existing value.
744    pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
745        HealthCheckConfig {
746            timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
747            check_interval_secs: self
748                .check_interval_secs
749                .unwrap_or(existing.check_interval_secs),
750            success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
751            failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
752            disable_health_check: self
753                .disable_health_check
754                .unwrap_or(existing.disable_health_check),
755        }
756    }
757}
758
759/// Per-worker HTTP connection pool configuration.
760/// All fields optional — `None` means "use router/global default".
761#[serde_with::skip_serializing_none]
762#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
763pub struct HttpPoolConfig {
764    /// Max idle connections per host (default: 8).
765    pub pool_max_idle_per_host: Option<usize>,
766    /// Idle connection timeout in seconds (default: 50).
767    pub pool_idle_timeout_secs: Option<u64>,
768    /// Request timeout in seconds (default: 30).
769    pub timeout_secs: Option<u64>,
770    /// Connect timeout in seconds (default: 10).
771    pub connect_timeout_secs: Option<u64>,
772}
773
774impl HttpPoolConfig {
775    /// Returns `true` if all fields are `None` (no overrides).
776    pub fn is_empty(&self) -> bool {
777        self.pool_max_idle_per_host.is_none()
778            && self.pool_idle_timeout_secs.is_none()
779            && self.timeout_secs.is_none()
780            && self.connect_timeout_secs.is_none()
781    }
782}
783
784/// Per-worker resilience overrides (retry + circuit breaker).
785/// All fields optional — `None` means "use router default".
786/// Mirrors `HealthCheckUpdate` pattern for PATCH-style config.
787#[serde_with::skip_serializing_none]
788#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
789pub struct ResilienceUpdate {
790    // ── Retry overrides ──
791    /// Max retry attempts (includes first attempt). 1 = no retries.
792    pub max_retries: Option<u32>,
793    /// Initial backoff delay in milliseconds.
794    pub initial_backoff_ms: Option<u64>,
795    /// Maximum backoff delay in milliseconds.
796    pub max_backoff_ms: Option<u64>,
797    /// Backoff multiplier for exponential backoff.
798    pub backoff_multiplier: Option<f32>,
799    /// Jitter factor (0.0–1.0) applied to backoff delay.
800    pub jitter_factor: Option<f32>,
801    /// Disable retries entirely for this worker.
802    pub disable_retry: Option<bool>,
803
804    // ── Circuit breaker overrides ──
805    /// Consecutive failures to open the circuit.
806    pub cb_failure_threshold: Option<u32>,
807    /// Consecutive successes to close the circuit from half-open.
808    pub cb_success_threshold: Option<u32>,
809    /// Seconds to wait before attempting half-open.
810    pub cb_timeout_secs: Option<u64>,
811    /// Time window in seconds for failure counting.
812    pub cb_window_secs: Option<u64>,
813    /// Disable circuit breaker entirely for this worker.
814    pub disable_circuit_breaker: Option<bool>,
815
816    // ── Retryable status codes ──
817    /// Custom retryable HTTP status codes.
818    /// When set, replaces the default set (408, 429, 500, 502, 503, 504).
819    pub retryable_status_codes: Option<Vec<u16>>,
820}
821
822impl ResilienceUpdate {
823    /// Returns `true` if all fields are `None` (no overrides).
824    pub fn is_empty(&self) -> bool {
825        self.max_retries.is_none()
826            && self.initial_backoff_ms.is_none()
827            && self.max_backoff_ms.is_none()
828            && self.backoff_multiplier.is_none()
829            && self.jitter_factor.is_none()
830            && self.disable_retry.is_none()
831            && self.cb_failure_threshold.is_none()
832            && self.cb_success_threshold.is_none()
833            && self.cb_timeout_secs.is_none()
834            && self.cb_window_secs.is_none()
835            && self.disable_circuit_breaker.is_none()
836            && self.retryable_status_codes.is_none()
837    }
838}
839
840/// Worker update request
841#[serde_with::skip_serializing_none]
842#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
843pub struct WorkerUpdateRequest {
844    /// Update priority
845    pub priority: Option<u32>,
846
847    /// Update cost
848    pub cost: Option<f32>,
849
850    /// Update labels
851    pub labels: Option<HashMap<String, String>>,
852
853    /// Update API key (for key rotation)
854    pub api_key: Option<String>,
855
856    /// Update health check configuration (partial — only specified fields change)
857    pub health: Option<HealthCheckUpdate>,
858}
859
860// ── Response types ──────────────────────────────────────────────────
861
862/// Generic API response
863#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
864pub struct WorkerApiResponse {
865    pub success: bool,
866    pub message: String,
867
868    #[serde(skip_serializing_if = "Option::is_none")]
869    pub worker: Option<WorkerInfo>,
870}
871
872/// Error response
873#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
874pub struct WorkerErrorResponse {
875    pub error: String,
876    pub code: String,
877}
878
879/// Result from flush cache operations across workers
880#[derive(Debug, Clone, Deserialize, Serialize)]
881pub struct FlushCacheResult {
882    pub successful: Vec<String>,
883    pub failed: Vec<(String, String)>,
884    pub total_workers: usize,
885    pub http_workers: usize,
886    pub message: String,
887}
888
889/// Result from getting worker loads
890#[derive(Debug, Clone, Deserialize, Serialize)]
891pub struct WorkerLoadsResult {
892    pub loads: Vec<WorkerLoadInfo>,
893    pub total_workers: usize,
894    pub successful: usize,
895    pub failed: usize,
896}
897
898/// Per-DP-rank load snapshot from a backend.
899///
900/// Contains core metrics from the sglang `/v1/loads` endpoint or `GetLoads` gRPC RPC.
901/// Each snapshot represents one data-parallel rank's scheduler state.
902#[derive(Debug, Clone, Default, Serialize, Deserialize)]
903#[serde(default)]
904pub struct SchedulerLoadSnapshot {
905    pub dp_rank: i32,
906    pub num_running_reqs: i32,
907    pub num_waiting_reqs: i32,
908    pub num_total_reqs: i32,
909    pub num_used_tokens: i32,
910    pub max_total_num_tokens: i32,
911    /// Token usage ratio (0.0–1.0).
912    pub token_usage: f64,
913    pub gen_throughput: f64,
914    pub cache_hit_rate: f64,
915    pub utilization: f64,
916    pub max_running_requests: i32,
917}
918
919/// Full load response for a single worker across all DP ranks.
920#[derive(Debug, Clone, Default, Serialize, Deserialize)]
921#[serde(default)]
922pub struct WorkerLoadResponse {
923    pub timestamp: String,
924    pub dp_rank_count: i32,
925    pub loads: Vec<SchedulerLoadSnapshot>,
926}
927
928impl WorkerLoadResponse {
929    /// Average token usage ratio across DP ranks. Returns 0.0 if empty.
930    pub fn effective_token_usage(&self) -> f64 {
931        if self.loads.is_empty() {
932            return 0.0;
933        }
934        self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
935    }
936
937    /// Total used tokens summed across all DP ranks.
938    pub fn total_used_tokens(&self) -> i64 {
939        self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
940    }
941}
942
943/// Individual worker load information
944#[derive(Debug, Clone, Deserialize, Serialize)]
945pub struct WorkerLoadInfo {
946    pub worker: String,
947    #[serde(skip_serializing_if = "Option::is_none")]
948    pub worker_type: Option<String>,
949    pub load: isize,
950    #[serde(skip_serializing_if = "Option::is_none")]
951    pub details: Option<WorkerLoadResponse>,
952}
953
954#[cfg(feature = "axum")]
955impl IntoResponse for FlushCacheResult {
956    fn into_response(self) -> Response {
957        let status = if self.failed.is_empty() {
958            StatusCode::OK
959        } else {
960            StatusCode::PARTIAL_CONTENT
961        };
962
963        let mut body = json!({
964            "status": if self.failed.is_empty() { "success" } else { "partial_success" },
965            "message": self.message,
966            "workers_flushed": self.successful.len(),
967            "total_http_workers": self.http_workers,
968            "total_workers": self.total_workers
969        });
970
971        if !self.failed.is_empty() {
972            body["successful"] = json!(self.successful);
973            body["failed"] = json!(self
974                .failed
975                .into_iter()
976                .map(|(url, err)| json!({"worker": url, "error": err}))
977                .collect::<Vec<_>>());
978        }
979
980        (status, Json(body)).into_response()
981    }
982}
983
984#[cfg(feature = "axum")]
985impl IntoResponse for WorkerLoadsResult {
986    fn into_response(self) -> Response {
987        let loads: Vec<Value> = self
988            .loads
989            .iter()
990            .map(|info| {
991                let mut entry = json!({"worker": &info.worker, "load": info.load});
992                if let Some(ref details) = info.details {
993                    entry["details"] = json!(details);
994                }
995                entry
996            })
997            .collect();
998        Json(json!({"workers": loads})).into_response()
999    }
1000}