Skip to main content

openai_protocol/
worker.rs

1//! Canonical worker types and identity.
2//!
3//! This module defines the single source of truth for worker identity, type
4//! enums, and core configuration. These types are shared across API
5//! request/response boundaries and internal runtime state.
6
7use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11    http::StatusCode,
12    response::{IntoResponse, Response},
13    Json,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17#[cfg(feature = "axum")]
18use serde_json::{json, Value};
19
20use super::model_card::ModelCard;
21
22// ── Default value constants ──────────────────────────────────────────
23
24pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
25pub const DEFAULT_WORKER_COST: f32 = 1.0;
26
27// ── Enums ────────────────────────────────────────────────────────────
28
29/// Worker type classification.
30#[derive(
31    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
32)]
33#[serde(rename_all = "lowercase")]
34pub enum WorkerType {
35    /// Regular worker for standard routing.
36    #[default]
37    Regular,
38    /// Prefill worker for PD disaggregated mode.
39    Prefill,
40    /// Decode worker for PD disaggregated mode.
41    Decode,
42}
43
44impl std::fmt::Display for WorkerType {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            WorkerType::Regular => write!(f, "regular"),
48            WorkerType::Prefill => write!(f, "prefill"),
49            WorkerType::Decode => write!(f, "decode"),
50        }
51    }
52}
53
54impl std::str::FromStr for WorkerType {
55    type Err = String;
56
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        if s.eq_ignore_ascii_case("regular") {
59            Ok(WorkerType::Regular)
60        } else if s.eq_ignore_ascii_case("prefill") {
61            Ok(WorkerType::Prefill)
62        } else if s.eq_ignore_ascii_case("decode") {
63            Ok(WorkerType::Decode)
64        } else {
65            Err(format!("Unknown worker type: {s}"))
66        }
67    }
68}
69
70/// Connection mode for worker communication.
71#[derive(
72    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
73)]
74#[serde(rename_all = "lowercase")]
75pub enum ConnectionMode {
76    /// HTTP/REST connection.
77    #[default]
78    Http,
79    /// gRPC connection.
80    Grpc,
81}
82
83impl std::fmt::Display for ConnectionMode {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            ConnectionMode::Http => write!(f, "http"),
87            ConnectionMode::Grpc => write!(f, "grpc"),
88        }
89    }
90}
91
92/// Composite key identifying a group of workers with the same characteristics.
93///
94/// Groups workers by `(model_id, worker_type, connection_mode)` — the natural
95/// partitioning used for metrics, load monitoring, and policy management.
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct WorkerGroupKey {
98    pub model_id: String,
99    pub worker_type: WorkerType,
100    pub connection_mode: ConnectionMode,
101}
102
103impl std::fmt::Display for WorkerGroupKey {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "{}:{}:{}",
108            self.model_id, self.worker_type, self.connection_mode
109        )
110    }
111}
112
113/// Runtime implementation type for workers.
114#[derive(
115    Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
116)]
117#[serde(rename_all = "lowercase")]
118pub enum RuntimeType {
119    /// SGLang runtime (default).
120    #[default]
121    Sglang,
122    /// vLLM runtime.
123    Vllm,
124    /// TensorRT-LLM runtime.
125    Trtllm,
126    /// External OpenAI-compatible API (not local inference).
127    External,
128}
129
130impl std::fmt::Display for RuntimeType {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        match self {
133            RuntimeType::Sglang => write!(f, "sglang"),
134            RuntimeType::Vllm => write!(f, "vllm"),
135            RuntimeType::Trtllm => write!(f, "trtllm"),
136            RuntimeType::External => write!(f, "external"),
137        }
138    }
139}
140
141impl std::str::FromStr for RuntimeType {
142    type Err = String;
143
144    fn from_str(s: &str) -> Result<Self, Self::Err> {
145        if s.eq_ignore_ascii_case("sglang") {
146            Ok(RuntimeType::Sglang)
147        } else if s.eq_ignore_ascii_case("vllm") {
148            Ok(RuntimeType::Vllm)
149        } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
150            Ok(RuntimeType::Trtllm)
151        } else if s.eq_ignore_ascii_case("external") {
152            Ok(RuntimeType::External)
153        } else {
154            Err(format!("Unknown runtime type: {s}"))
155        }
156    }
157}
158
159/// Provider type for external API transformations.
160///
161/// Different providers have different API formats and requirements.
162/// `None` (when used as `Option<ProviderType>`) means native/passthrough —
163/// no transformation needed (local SGLang backends).
164#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
165#[serde(rename_all = "lowercase")]
166pub enum ProviderType {
167    /// OpenAI API — strip SGLang-specific fields.
168    #[serde(alias = "openai")]
169    OpenAI,
170    /// xAI/Grok — special handling for input items.
171    #[serde(alias = "xai", alias = "grok")]
172    #[expect(
173        clippy::upper_case_acronyms,
174        reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
175    )]
176    XAI,
177    /// Anthropic Claude — different API format.
178    #[serde(alias = "anthropic", alias = "claude")]
179    Anthropic,
180    /// Google Gemini — special logprobs handling.
181    #[serde(alias = "gemini", alias = "google")]
182    Gemini,
183    /// Custom provider with string identifier.
184    #[serde(untagged)]
185    Custom(String),
186}
187
188impl ProviderType {
189    /// Get provider name as string.
190    pub fn as_str(&self) -> &str {
191        match self {
192            Self::OpenAI => "openai",
193            Self::XAI => "xai",
194            Self::Anthropic => "anthropic",
195            Self::Gemini => "gemini",
196            Self::Custom(s) => s.as_str(),
197        }
198    }
199
200    /// Detect provider from URL host.
201    /// Returns `None` for URLs that don't match known providers or can't be parsed.
202    pub fn from_url(url: &str) -> Option<Self> {
203        let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
204
205        if host.ends_with("openai.com") {
206            Some(Self::OpenAI)
207        } else if host.ends_with("x.ai") {
208            Some(Self::XAI)
209        } else if host.ends_with("anthropic.com") {
210            Some(Self::Anthropic)
211        } else if host.ends_with("googleapis.com") {
212            Some(Self::Gemini)
213        } else {
214            None
215        }
216    }
217
218    /// Environment variable name for per-provider admin API key (model discovery).
219    /// Returns `None` for `Custom` providers since there's no known env var.
220    pub fn admin_key_env_var(&self) -> Option<&'static str> {
221        match self {
222            Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
223            Self::XAI => Some("XAI_ADMIN_KEY"),
224            Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
225            Self::Gemini => Some("GEMINI_ADMIN_KEY"),
226            Self::Custom(_) => None,
227        }
228    }
229
230    /// Whether this provider uses `x-api-key` header instead of `Authorization: Bearer`.
231    pub fn uses_x_api_key(&self) -> bool {
232        matches!(self, Self::Anthropic)
233    }
234
235    /// Detect provider from model name (heuristic fallback).
236    /// Returns `None` for models that don't match known external providers.
237    pub fn from_model_name(model: &str) -> Option<Self> {
238        let model_lower = model.to_lowercase();
239        if model_lower.starts_with("grok") {
240            Some(Self::XAI)
241        } else if model_lower.starts_with("gemini") {
242            Some(Self::Gemini)
243        } else if model_lower.starts_with("claude") {
244            Some(Self::Anthropic)
245        } else if model_lower.starts_with("gpt")
246            || model_lower.starts_with("o1")
247            || model_lower.starts_with("o3")
248        {
249            Some(Self::OpenAI)
250        } else {
251            None
252        }
253    }
254}
255
256impl std::fmt::Display for ProviderType {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        write!(f, "{}", self.as_str())
259    }
260}
261
262// ── Serde default helpers ────────────────────────────────────────────
263
264fn default_priority() -> u32 {
265    DEFAULT_WORKER_PRIORITY
266}
267
268fn default_cost() -> f32 {
269    DEFAULT_WORKER_COST
270}
271
272fn default_health_check_timeout() -> u64 {
273    30
274}
275
276fn default_health_check_interval() -> u64 {
277    60
278}
279
280fn default_health_success_threshold() -> u32 {
281    2
282}
283
284fn default_health_failure_threshold() -> u32 {
285    3
286}
287
288fn default_max_connection_attempts() -> u32 {
289    20
290}
291
292// ── Health check config ─────────────────────────────────────────────
293
294/// Health check configuration shared across protocol and runtime layers.
295#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
296pub struct HealthCheckConfig {
297    /// Health check timeout in seconds (default: 30).
298    #[serde(default = "default_health_check_timeout")]
299    pub timeout_secs: u64,
300
301    /// Health check interval in seconds (default: 60).
302    #[serde(default = "default_health_check_interval")]
303    pub check_interval_secs: u64,
304
305    /// Number of successful health checks needed to mark worker as healthy (default: 2).
306    #[serde(default = "default_health_success_threshold")]
307    pub success_threshold: u32,
308
309    /// Number of failed health checks before marking worker as unhealthy (default: 3).
310    #[serde(default = "default_health_failure_threshold")]
311    pub failure_threshold: u32,
312
313    /// Disable periodic health checks for this worker (default: false).
314    #[serde(default)]
315    pub disable_health_check: bool,
316}
317
318impl Default for HealthCheckConfig {
319    fn default() -> Self {
320        Self {
321            timeout_secs: default_health_check_timeout(),
322            check_interval_secs: default_health_check_interval(),
323            success_threshold: default_health_success_threshold(),
324            failure_threshold: default_health_failure_threshold(),
325            disable_health_check: false,
326        }
327    }
328}
329
330// ── Worker models ───────────────────────────────────────────────────
331
332/// Models configuration for a worker.
333///
334/// Encodes the three real cases instead of relying on `Vec` semantics:
335/// - `Wildcard` — accepts any model (empty models list on the wire)
336/// - `Single` — serves exactly one model
337/// - `Multi` — serves multiple distinct models (len >= 2)
338#[derive(Debug, Clone, Default)]
339pub enum WorkerModels {
340    /// Worker accepts any model (e.g., external API without discovery).
341    #[default]
342    Wildcard,
343    /// Worker serves exactly one model (most common for local inference).
344    Single(Box<ModelCard>),
345    /// Worker serves multiple distinct models (len >= 2).
346    Multi(Vec<ModelCard>),
347}
348
349impl WorkerModels {
350    /// Returns `true` if this is a wildcard (accepts any model).
351    pub fn is_wildcard(&self) -> bool {
352        matches!(self, Self::Wildcard)
353    }
354
355    /// Returns the primary model: `Single` → `Some`, `Multi` → first, `Wildcard` → `None`.
356    pub fn primary(&self) -> Option<&ModelCard> {
357        match self {
358            Self::Wildcard => None,
359            Self::Single(card) => Some(card.as_ref()),
360            Self::Multi(cards) => cards.first(),
361        }
362    }
363
364    /// Returns all models as a slice (empty for `Wildcard`).
365    pub fn all(&self) -> &[ModelCard] {
366        match self {
367            Self::Wildcard => &[],
368            Self::Single(card) => std::slice::from_ref(card.as_ref()),
369            Self::Multi(cards) => cards,
370        }
371    }
372
373    /// Find a model by ID (checks aliases via `ModelCard::matches`).
374    pub fn find(&self, id: &str) -> Option<&ModelCard> {
375        match self {
376            Self::Wildcard => None,
377            Self::Single(card) => card.matches(id).then_some(card.as_ref()),
378            Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
379        }
380    }
381
382    /// Returns `true` if the worker supports the given model ID.
383    /// Wildcard workers always return `true`.
384    pub fn supports(&self, id: &str) -> bool {
385        match self {
386            Self::Wildcard => true,
387            _ => self.find(id).is_some(),
388        }
389    }
390
391    /// Iterate over all models. Empty iterator for `Wildcard`.
392    pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
393        self.all().iter()
394    }
395}
396
397impl From<Vec<ModelCard>> for WorkerModels {
398    fn from(models: Vec<ModelCard>) -> Self {
399        match models.len() {
400            0 => Self::Wildcard,
401            1 => {
402                let Some(model) = models.into_iter().next() else {
403                    return Self::Wildcard;
404                };
405                Self::Single(Box::new(model))
406            }
407            _ => Self::Multi(models),
408        }
409    }
410}
411
412/// Serialize as `Vec<ModelCard>` for wire compatibility.
413impl Serialize for WorkerModels {
414    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
415        self.all().serialize(serializer)
416    }
417}
418
419/// Deserialize from `Vec<ModelCard>` for wire compatibility.
420impl<'de> Deserialize<'de> for WorkerModels {
421    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
422        let models = Vec::<ModelCard>::deserialize(deserializer)?;
423        Ok(Self::from(models))
424    }
425}
426
427/// JsonSchema: wire format is `Vec<ModelCard>`.
428impl JsonSchema for WorkerModels {
429    fn schema_name() -> String {
430        "WorkerModels".to_string()
431    }
432
433    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
434        Vec::<ModelCard>::json_schema(gen)
435    }
436}
437
438// ── Core identity ────────────────────────────────────────────────────
439
440/// Core worker identity and configuration.
441///
442/// The single canonical representation of "what is a worker". Used as the
443/// shared sub-struct across API requests, API responses, and internal runtime
444/// state via `#[serde(flatten)]`.
445///
446/// Fields use `#[serde(default)]` so the same struct works for both input
447/// (partial config from user) and output (fully resolved state).
448#[serde_with::skip_serializing_none]
449#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
450pub struct WorkerSpec {
451    /// Worker URL.
452    pub url: String,
453
454    /// Models this worker can serve.
455    #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
456    pub models: WorkerModels,
457
458    /// Worker type: regular, prefill, or decode.
459    #[serde(default)]
460    pub worker_type: WorkerType,
461
462    /// Connection mode: http or grpc.
463    #[serde(default)]
464    pub connection_mode: ConnectionMode,
465
466    /// Runtime type: sglang, vllm, trtllm, or external.
467    #[serde(default, alias = "runtime")]
468    pub runtime_type: RuntimeType,
469
470    /// External provider for API transformations.
471    /// `None` means native/passthrough.
472    pub provider: Option<ProviderType>,
473
474    /// Additional labels/tags.
475    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
476    pub labels: HashMap<String, String>,
477
478    /// Worker priority (higher = preferred).
479    #[serde(default = "default_priority")]
480    pub priority: u32,
481
482    /// Worker cost factor (baseline = 1.0).
483    #[serde(default = "default_cost")]
484    pub cost: f32,
485
486    /// Worker API key. Accepted on input, never included in responses.
487    #[serde(default, skip_serializing)]
488    pub api_key: Option<String>,
489
490    /// Bootstrap port for prefill workers in PD disaggregated mode.
491    #[serde(default, skip_serializing_if = "Option::is_none")]
492    pub bootstrap_port: Option<u16>,
493
494    /// Bootstrap hostname (derived from URL at construction time).
495    #[serde(default, skip)]
496    pub bootstrap_host: String,
497
498    /// Base URL without DP rank suffix (for DP-aware workers).
499    /// When set, `url` contains the rank-suffixed form (`{base}@{rank}`).
500    #[serde(default, skip_serializing_if = "Option::is_none")]
501    pub dp_base_url: Option<String>,
502
503    /// Data-parallel rank (None = not DP-aware).
504    #[serde(default, skip_serializing_if = "Option::is_none")]
505    pub dp_rank: Option<usize>,
506
507    /// Total data-parallel group size (None = not DP-aware).
508    #[serde(default, skip_serializing_if = "Option::is_none")]
509    pub dp_size: Option<usize>,
510
511    /// KV connector type (e.g. "MooncakeConnector", "NixlConnector").
512    #[serde(default, skip_serializing_if = "Option::is_none")]
513    pub kv_connector: Option<String>,
514
515    /// KV role (e.g. "kv_producer", "kv_consumer", "kv_both").
516    #[serde(default, skip_serializing_if = "Option::is_none")]
517    pub kv_role: Option<String>,
518
519    /// KV cache block size (tokens per block) for event-driven routing.
520    /// When set, overrides the router-level default for this worker's model.
521    /// Typically matches the backend engine's page size (e.g. 16 for SGLang).
522    #[serde(default, skip_serializing_if = "Option::is_none")]
523    pub kv_block_size: Option<usize>,
524
525    /// Per-worker health check overrides (partial — only `Some` fields override router defaults).
526    #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
527    pub health: HealthCheckUpdate,
528
529    /// Maximum connection attempts during worker registration (default: 20).
530    #[serde(default = "default_max_connection_attempts")]
531    pub max_connection_attempts: u32,
532
533    /// Per-worker load monitor interval override (seconds).
534    /// When set, workers in the same group use this interval for load polling.
535    /// Falls back to the global `load_monitor_interval_secs` from router config.
536    #[serde(default, skip_serializing_if = "Option::is_none")]
537    pub load_monitor_interval_secs: Option<u64>,
538}
539
540impl WorkerSpec {
541    /// Create a new `WorkerSpec` with the given URL and sensible defaults.
542    pub fn new(url: impl Into<String>) -> Self {
543        Self {
544            url: url.into(),
545            models: WorkerModels::Wildcard,
546            worker_type: WorkerType::default(),
547            connection_mode: ConnectionMode::default(),
548            runtime_type: RuntimeType::default(),
549            provider: None,
550            labels: HashMap::new(),
551            priority: DEFAULT_WORKER_PRIORITY,
552            cost: DEFAULT_WORKER_COST,
553            api_key: None,
554            bootstrap_port: None,
555            bootstrap_host: String::new(),
556            dp_base_url: None,
557            dp_rank: None,
558            dp_size: None,
559            kv_connector: None,
560            kv_role: None,
561            kv_block_size: None,
562            health: HealthCheckUpdate::default(),
563            max_connection_attempts: default_max_connection_attempts(),
564            load_monitor_interval_secs: None,
565        }
566    }
567}
568
569// ── API types ───────────────────────────────────────────────────────
570
571/// Worker information for API responses.
572#[serde_with::skip_serializing_none]
573#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
574pub struct WorkerInfo {
575    /// Worker unique identifier.
576    pub id: String,
577
578    /// Worker identity and configuration.
579    #[serde(flatten)]
580    pub spec: WorkerSpec,
581
582    /// Whether the worker is healthy.
583    pub is_healthy: bool,
584
585    /// Current load on the worker.
586    pub load: usize,
587
588    /// Job status for async operations (if available).
589    pub job_status: Option<JobStatus>,
590}
591
592impl WorkerInfo {
593    /// Create a partial WorkerInfo for pending workers (not yet registered).
594    pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
595        Self {
596            id: worker_id.to_string(),
597            spec: WorkerSpec::new(url),
598            is_healthy: false,
599            load: 0,
600            job_status,
601        }
602    }
603}
604
605/// Job status for async control plane operations
606#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
607pub struct JobStatus {
608    pub job_type: String,
609    pub worker_url: String,
610    pub status: String,
611    pub message: Option<String>,
612    pub timestamp: u64,
613}
614
615impl JobStatus {
616    /// Create a pending job status
617    pub fn pending(job_type: &str, worker_url: &str) -> Self {
618        Self {
619            job_type: job_type.to_string(),
620            worker_url: worker_url.to_string(),
621            status: "pending".to_string(),
622            message: None,
623            timestamp: std::time::SystemTime::now()
624                .duration_since(std::time::SystemTime::UNIX_EPOCH)
625                .unwrap_or_default()
626                .as_secs(),
627        }
628    }
629
630    /// Create a processing job status
631    pub fn processing(job_type: &str, worker_url: &str) -> Self {
632        Self {
633            job_type: job_type.to_string(),
634            worker_url: worker_url.to_string(),
635            status: "processing".to_string(),
636            message: None,
637            timestamp: std::time::SystemTime::now()
638                .duration_since(std::time::SystemTime::UNIX_EPOCH)
639                .unwrap_or_default()
640                .as_secs(),
641        }
642    }
643
644    /// Create a failed job status
645    pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
646        Self {
647            job_type: job_type.to_string(),
648            worker_url: worker_url.to_string(),
649            status: "failed".to_string(),
650            message: Some(error),
651            timestamp: std::time::SystemTime::now()
652                .duration_since(std::time::SystemTime::UNIX_EPOCH)
653                .unwrap_or_default()
654                .as_secs(),
655        }
656    }
657}
658
659/// Worker list response
660#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
661pub struct WorkerListResponse {
662    pub workers: Vec<WorkerInfo>,
663    pub total: usize,
664    pub stats: WorkerStats,
665}
666
667/// Worker statistics
668#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
669pub struct WorkerStats {
670    pub total_workers: usize,
671    pub healthy_workers: usize,
672    pub total_models: usize,
673    pub total_load: usize,
674    pub by_type: WorkerTypeStats,
675}
676
677/// Worker statistics by type
678#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
679pub struct WorkerTypeStats {
680    pub regular: usize,
681    pub prefill: usize,
682    pub decode: usize,
683}
684
685// ── Update types ────────────────────────────────────────────────────
686
687/// Partial health check config for PATCH-style updates.
688///
689/// Each `None` field means "keep the existing value". This avoids the problem
690/// where `#[serde(default)]` on [`HealthCheckConfig`] would silently reset
691/// unspecified fields to defaults.
692#[serde_with::skip_serializing_none]
693#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
694pub struct HealthCheckUpdate {
695    pub timeout_secs: Option<u64>,
696    pub check_interval_secs: Option<u64>,
697    pub success_threshold: Option<u32>,
698    pub failure_threshold: Option<u32>,
699    pub disable_health_check: Option<bool>,
700}
701
702impl HealthCheckUpdate {
703    /// Returns `true` if all fields are `None` (no overrides specified).
704    pub fn is_empty(&self) -> bool {
705        self.timeout_secs.is_none()
706            && self.check_interval_secs.is_none()
707            && self.success_threshold.is_none()
708            && self.failure_threshold.is_none()
709            && self.disable_health_check.is_none()
710    }
711}
712
713impl HealthCheckUpdate {
714    /// Merge this update into an existing [`HealthCheckConfig`], returning a new config.
715    /// Only `Some` fields are applied; `None` fields keep the existing value.
716    pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
717        HealthCheckConfig {
718            timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
719            check_interval_secs: self
720                .check_interval_secs
721                .unwrap_or(existing.check_interval_secs),
722            success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
723            failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
724            disable_health_check: self
725                .disable_health_check
726                .unwrap_or(existing.disable_health_check),
727        }
728    }
729}
730
731/// Worker update request
732#[serde_with::skip_serializing_none]
733#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
734pub struct WorkerUpdateRequest {
735    /// Update priority
736    pub priority: Option<u32>,
737
738    /// Update cost
739    pub cost: Option<f32>,
740
741    /// Update labels
742    pub labels: Option<HashMap<String, String>>,
743
744    /// Update API key (for key rotation)
745    pub api_key: Option<String>,
746
747    /// Update health check configuration (partial — only specified fields change)
748    pub health: Option<HealthCheckUpdate>,
749}
750
751// ── Response types ──────────────────────────────────────────────────
752
753/// Generic API response
754#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
755pub struct WorkerApiResponse {
756    pub success: bool,
757    pub message: String,
758
759    #[serde(skip_serializing_if = "Option::is_none")]
760    pub worker: Option<WorkerInfo>,
761}
762
763/// Error response
764#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
765pub struct WorkerErrorResponse {
766    pub error: String,
767    pub code: String,
768}
769
770/// Result from flush cache operations across workers
771#[derive(Debug, Clone, Deserialize, Serialize)]
772pub struct FlushCacheResult {
773    pub successful: Vec<String>,
774    pub failed: Vec<(String, String)>,
775    pub total_workers: usize,
776    pub http_workers: usize,
777    pub message: String,
778}
779
780/// Result from getting worker loads
781#[derive(Debug, Clone, Deserialize, Serialize)]
782pub struct WorkerLoadsResult {
783    pub loads: Vec<WorkerLoadInfo>,
784    pub total_workers: usize,
785    pub successful: usize,
786    pub failed: usize,
787}
788
789/// Per-DP-rank load snapshot from a backend.
790///
791/// Contains core metrics from the sglang `/v1/loads` endpoint or `GetLoads` gRPC RPC.
792/// Each snapshot represents one data-parallel rank's scheduler state.
793#[derive(Debug, Clone, Default, Serialize, Deserialize)]
794#[serde(default)]
795pub struct SchedulerLoadSnapshot {
796    pub dp_rank: i32,
797    pub num_running_reqs: i32,
798    pub num_waiting_reqs: i32,
799    pub num_total_reqs: i32,
800    pub num_used_tokens: i32,
801    pub max_total_num_tokens: i32,
802    /// Token usage ratio (0.0–1.0).
803    pub token_usage: f64,
804    pub gen_throughput: f64,
805    pub cache_hit_rate: f64,
806    pub utilization: f64,
807    pub max_running_requests: i32,
808}
809
810/// Full load response for a single worker across all DP ranks.
811#[derive(Debug, Clone, Default, Serialize, Deserialize)]
812#[serde(default)]
813pub struct WorkerLoadResponse {
814    pub timestamp: String,
815    pub dp_rank_count: i32,
816    pub loads: Vec<SchedulerLoadSnapshot>,
817}
818
819impl WorkerLoadResponse {
820    /// Average token usage ratio across DP ranks. Returns 0.0 if empty.
821    pub fn effective_token_usage(&self) -> f64 {
822        if self.loads.is_empty() {
823            return 0.0;
824        }
825        self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
826    }
827
828    /// Total used tokens summed across all DP ranks.
829    pub fn total_used_tokens(&self) -> i64 {
830        self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
831    }
832}
833
834/// Individual worker load information
835#[derive(Debug, Clone, Deserialize, Serialize)]
836pub struct WorkerLoadInfo {
837    pub worker: String,
838    #[serde(skip_serializing_if = "Option::is_none")]
839    pub worker_type: Option<String>,
840    pub load: isize,
841    #[serde(skip_serializing_if = "Option::is_none")]
842    pub details: Option<WorkerLoadResponse>,
843}
844
845#[cfg(feature = "axum")]
846impl IntoResponse for FlushCacheResult {
847    fn into_response(self) -> Response {
848        let status = if self.failed.is_empty() {
849            StatusCode::OK
850        } else {
851            StatusCode::PARTIAL_CONTENT
852        };
853
854        let mut body = json!({
855            "status": if self.failed.is_empty() { "success" } else { "partial_success" },
856            "message": self.message,
857            "workers_flushed": self.successful.len(),
858            "total_http_workers": self.http_workers,
859            "total_workers": self.total_workers
860        });
861
862        if !self.failed.is_empty() {
863            body["successful"] = json!(self.successful);
864            body["failed"] = json!(self
865                .failed
866                .into_iter()
867                .map(|(url, err)| json!({"worker": url, "error": err}))
868                .collect::<Vec<_>>());
869        }
870
871        (status, Json(body)).into_response()
872    }
873}
874
875#[cfg(feature = "axum")]
876impl IntoResponse for WorkerLoadsResult {
877    fn into_response(self) -> Response {
878        let loads: Vec<Value> = self
879            .loads
880            .iter()
881            .map(|info| {
882                let mut entry = json!({"worker": &info.worker, "load": info.load});
883                if let Some(ref details) = info.details {
884                    entry["details"] = json!(details);
885                }
886                entry
887            })
888            .collect();
889        Json(json!({"workers": loads})).into_response()
890    }
891}