Skip to main content

openai_protocol/
worker.rs

1//! Canonical worker types and identity.
2//!
3//! This module defines the single source of truth for worker identity, type
4//! enums, and core configuration. These types are shared across API
5//! request/response boundaries and internal runtime state.
6
7use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11    http::StatusCode,
12    response::{IntoResponse, Response},
13    Json,
14};
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16#[cfg(feature = "axum")]
17use serde_json::{json, Value};
18
19use super::model_card::ModelCard;
20
21// ── Default value constants ──────────────────────────────────────────
22
23pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
24pub const DEFAULT_WORKER_COST: f32 = 1.0;
25
26// ── Enums ────────────────────────────────────────────────────────────
27
28/// Worker type classification.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
30#[serde(rename_all = "lowercase")]
31pub enum WorkerType {
32    /// Regular worker for standard routing.
33    #[default]
34    Regular,
35    /// Prefill worker for PD disaggregated mode.
36    Prefill,
37    /// Decode worker for PD disaggregated mode.
38    Decode,
39}
40
41impl std::fmt::Display for WorkerType {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            WorkerType::Regular => write!(f, "regular"),
45            WorkerType::Prefill => write!(f, "prefill"),
46            WorkerType::Decode => write!(f, "decode"),
47        }
48    }
49}
50
51impl std::str::FromStr for WorkerType {
52    type Err = String;
53
54    fn from_str(s: &str) -> Result<Self, Self::Err> {
55        if s.eq_ignore_ascii_case("regular") {
56            Ok(WorkerType::Regular)
57        } else if s.eq_ignore_ascii_case("prefill") {
58            Ok(WorkerType::Prefill)
59        } else if s.eq_ignore_ascii_case("decode") {
60            Ok(WorkerType::Decode)
61        } else {
62            Err(format!("Unknown worker type: {}", s))
63        }
64    }
65}
66
67/// Connection mode for worker communication.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum ConnectionMode {
71    /// HTTP/REST connection.
72    #[default]
73    Http,
74    /// gRPC connection.
75    Grpc,
76}
77
78impl std::fmt::Display for ConnectionMode {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            ConnectionMode::Http => write!(f, "http"),
82            ConnectionMode::Grpc => write!(f, "grpc"),
83        }
84    }
85}
86
87/// Runtime implementation type for workers.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
89#[serde(rename_all = "lowercase")]
90pub enum RuntimeType {
91    /// SGLang runtime (default).
92    #[default]
93    Sglang,
94    /// vLLM runtime.
95    Vllm,
96    /// TensorRT-LLM runtime.
97    Trtllm,
98    /// External OpenAI-compatible API (not local inference).
99    External,
100}
101
102impl std::fmt::Display for RuntimeType {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            RuntimeType::Sglang => write!(f, "sglang"),
106            RuntimeType::Vllm => write!(f, "vllm"),
107            RuntimeType::Trtllm => write!(f, "trtllm"),
108            RuntimeType::External => write!(f, "external"),
109        }
110    }
111}
112
113impl std::str::FromStr for RuntimeType {
114    type Err = String;
115
116    fn from_str(s: &str) -> Result<Self, Self::Err> {
117        if s.eq_ignore_ascii_case("sglang") {
118            Ok(RuntimeType::Sglang)
119        } else if s.eq_ignore_ascii_case("vllm") {
120            Ok(RuntimeType::Vllm)
121        } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
122            Ok(RuntimeType::Trtllm)
123        } else if s.eq_ignore_ascii_case("external") {
124            Ok(RuntimeType::External)
125        } else {
126            Err(format!("Unknown runtime type: {}", s))
127        }
128    }
129}
130
131/// Provider type for external API transformations.
132///
133/// Different providers have different API formats and requirements.
134/// `None` (when used as `Option<ProviderType>`) means native/passthrough —
135/// no transformation needed (local SGLang backends).
136#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum ProviderType {
139    /// OpenAI API — strip SGLang-specific fields.
140    #[serde(alias = "openai")]
141    OpenAI,
142    /// xAI/Grok — special handling for input items.
143    #[serde(alias = "xai", alias = "grok")]
144    XAI,
145    /// Anthropic Claude — different API format.
146    #[serde(alias = "anthropic", alias = "claude")]
147    Anthropic,
148    /// Google Gemini — special logprobs handling.
149    #[serde(alias = "gemini", alias = "google")]
150    Gemini,
151    /// Custom provider with string identifier.
152    #[serde(untagged)]
153    Custom(String),
154}
155
156impl ProviderType {
157    /// Get provider name as string.
158    pub fn as_str(&self) -> &str {
159        match self {
160            Self::OpenAI => "openai",
161            Self::XAI => "xai",
162            Self::Anthropic => "anthropic",
163            Self::Gemini => "gemini",
164            Self::Custom(s) => s.as_str(),
165        }
166    }
167
168    /// Detect provider from model name (heuristic fallback).
169    /// Returns `None` for models that don't match known external providers.
170    pub fn from_model_name(model: &str) -> Option<Self> {
171        let model_lower = model.to_lowercase();
172        if model_lower.starts_with("grok") {
173            Some(Self::XAI)
174        } else if model_lower.starts_with("gemini") {
175            Some(Self::Gemini)
176        } else if model_lower.starts_with("claude") {
177            Some(Self::Anthropic)
178        } else if model_lower.starts_with("gpt")
179            || model_lower.starts_with("o1")
180            || model_lower.starts_with("o3")
181        {
182            Some(Self::OpenAI)
183        } else {
184            None
185        }
186    }
187}
188
189impl std::fmt::Display for ProviderType {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        write!(f, "{}", self.as_str())
192    }
193}
194
195// ── Serde default helpers ────────────────────────────────────────────
196
197fn default_priority() -> u32 {
198    DEFAULT_WORKER_PRIORITY
199}
200
201fn default_cost() -> f32 {
202    DEFAULT_WORKER_COST
203}
204
205fn default_health_check_timeout() -> u64 {
206    30
207}
208
209fn default_health_check_interval() -> u64 {
210    60
211}
212
213fn default_health_success_threshold() -> u32 {
214    2
215}
216
217fn default_health_failure_threshold() -> u32 {
218    3
219}
220
221fn default_max_connection_attempts() -> u32 {
222    20
223}
224
225// ── Health check config ─────────────────────────────────────────────
226
227/// Health check configuration shared across protocol and runtime layers.
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct HealthCheckConfig {
230    /// Health check timeout in seconds (default: 30).
231    #[serde(default = "default_health_check_timeout")]
232    pub timeout_secs: u64,
233
234    /// Health check interval in seconds (default: 60).
235    #[serde(default = "default_health_check_interval")]
236    pub check_interval_secs: u64,
237
238    /// Number of successful health checks needed to mark worker as healthy (default: 2).
239    #[serde(default = "default_health_success_threshold")]
240    pub success_threshold: u32,
241
242    /// Number of failed health checks before marking worker as unhealthy (default: 3).
243    #[serde(default = "default_health_failure_threshold")]
244    pub failure_threshold: u32,
245
246    /// Disable periodic health checks for this worker (default: false).
247    #[serde(default)]
248    pub disable_health_check: bool,
249}
250
251impl Default for HealthCheckConfig {
252    fn default() -> Self {
253        Self {
254            timeout_secs: default_health_check_timeout(),
255            check_interval_secs: default_health_check_interval(),
256            success_threshold: default_health_success_threshold(),
257            failure_threshold: default_health_failure_threshold(),
258            disable_health_check: false,
259        }
260    }
261}
262
263// ── Worker models ───────────────────────────────────────────────────
264
265/// Models configuration for a worker.
266///
267/// Encodes the three real cases instead of relying on `Vec` semantics:
268/// - `Wildcard` — accepts any model (empty models list on the wire)
269/// - `Single` — serves exactly one model
270/// - `Multi` — serves multiple distinct models (len >= 2)
271#[derive(Debug, Clone, Default)]
272pub enum WorkerModels {
273    /// Worker accepts any model (e.g., external API without discovery).
274    #[default]
275    Wildcard,
276    /// Worker serves exactly one model (most common for local inference).
277    Single(Box<ModelCard>),
278    /// Worker serves multiple distinct models (len >= 2).
279    Multi(Vec<ModelCard>),
280}
281
282impl WorkerModels {
283    /// Returns `true` if this is a wildcard (accepts any model).
284    pub fn is_wildcard(&self) -> bool {
285        matches!(self, Self::Wildcard)
286    }
287
288    /// Returns the primary model: `Single` → `Some`, `Multi` → first, `Wildcard` → `None`.
289    pub fn primary(&self) -> Option<&ModelCard> {
290        match self {
291            Self::Wildcard => None,
292            Self::Single(card) => Some(card.as_ref()),
293            Self::Multi(cards) => cards.first(),
294        }
295    }
296
297    /// Returns all models as a slice (empty for `Wildcard`).
298    pub fn all(&self) -> &[ModelCard] {
299        match self {
300            Self::Wildcard => &[],
301            Self::Single(card) => std::slice::from_ref(card.as_ref()),
302            Self::Multi(cards) => cards,
303        }
304    }
305
306    /// Find a model by ID (checks aliases via `ModelCard::matches`).
307    pub fn find(&self, id: &str) -> Option<&ModelCard> {
308        match self {
309            Self::Wildcard => None,
310            Self::Single(card) => card.matches(id).then_some(card.as_ref()),
311            Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
312        }
313    }
314
315    /// Returns `true` if the worker supports the given model ID.
316    /// Wildcard workers always return `true`.
317    pub fn supports(&self, id: &str) -> bool {
318        match self {
319            Self::Wildcard => true,
320            _ => self.find(id).is_some(),
321        }
322    }
323
324    /// Iterate over all models. Empty iterator for `Wildcard`.
325    pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
326        self.all().iter()
327    }
328}
329
330impl From<Vec<ModelCard>> for WorkerModels {
331    fn from(models: Vec<ModelCard>) -> Self {
332        match models.len() {
333            0 => Self::Wildcard,
334            1 => Self::Single(Box::new(models.into_iter().next().unwrap())),
335            _ => Self::Multi(models),
336        }
337    }
338}
339
340/// Serialize as `Vec<ModelCard>` for wire compatibility.
341impl Serialize for WorkerModels {
342    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
343        self.all().serialize(serializer)
344    }
345}
346
347/// Deserialize from `Vec<ModelCard>` for wire compatibility.
348impl<'de> Deserialize<'de> for WorkerModels {
349    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
350        let models = Vec::<ModelCard>::deserialize(deserializer)?;
351        Ok(Self::from(models))
352    }
353}
354
355// ── Core identity ────────────────────────────────────────────────────
356
357/// Core worker identity and configuration.
358///
359/// The single canonical representation of "what is a worker". Used as the
360/// shared sub-struct across API requests, API responses, and internal runtime
361/// state via `#[serde(flatten)]`.
362///
363/// Fields use `#[serde(default)]` so the same struct works for both input
364/// (partial config from user) and output (fully resolved state).
365#[serde_with::skip_serializing_none]
366#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct WorkerSpec {
368    /// Worker URL.
369    pub url: String,
370
371    /// Models this worker can serve.
372    #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
373    pub models: WorkerModels,
374
375    /// Worker type: regular, prefill, or decode.
376    #[serde(default)]
377    pub worker_type: WorkerType,
378
379    /// Connection mode: http or grpc.
380    #[serde(default)]
381    pub connection_mode: ConnectionMode,
382
383    /// Runtime type: sglang, vllm, trtllm, or external.
384    #[serde(default, alias = "runtime")]
385    pub runtime_type: RuntimeType,
386
387    /// External provider for API transformations.
388    /// `None` means native/passthrough.
389    pub provider: Option<ProviderType>,
390
391    /// Additional labels/tags.
392    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
393    pub labels: HashMap<String, String>,
394
395    /// Worker priority (higher = preferred).
396    #[serde(default = "default_priority")]
397    pub priority: u32,
398
399    /// Worker cost factor (baseline = 1.0).
400    #[serde(default = "default_cost")]
401    pub cost: f32,
402
403    /// Worker API key. Accepted on input, never included in responses.
404    #[serde(default, skip_serializing)]
405    pub api_key: Option<String>,
406
407    /// Bootstrap port for prefill workers in PD disaggregated mode.
408    #[serde(default, skip_serializing_if = "Option::is_none")]
409    pub bootstrap_port: Option<u16>,
410
411    /// Bootstrap hostname (derived from URL at construction time).
412    #[serde(default, skip)]
413    pub bootstrap_host: String,
414
415    /// KV connector type (e.g. "MooncakeConnector", "NixlConnector").
416    #[serde(default, skip_serializing_if = "Option::is_none")]
417    pub kv_connector: Option<String>,
418
419    /// KV role (e.g. "kv_producer", "kv_consumer", "kv_both").
420    #[serde(default, skip_serializing_if = "Option::is_none")]
421    pub kv_role: Option<String>,
422
423    /// Health check configuration.
424    #[serde(default)]
425    pub health: HealthCheckConfig,
426
427    /// Maximum connection attempts during worker registration (default: 20).
428    #[serde(default = "default_max_connection_attempts")]
429    pub max_connection_attempts: u32,
430}
431
432impl WorkerSpec {
433    /// Create a new `WorkerSpec` with the given URL and sensible defaults.
434    pub fn new(url: impl Into<String>) -> Self {
435        Self {
436            url: url.into(),
437            models: WorkerModels::Wildcard,
438            worker_type: WorkerType::default(),
439            connection_mode: ConnectionMode::default(),
440            runtime_type: RuntimeType::default(),
441            provider: None,
442            labels: HashMap::new(),
443            priority: DEFAULT_WORKER_PRIORITY,
444            cost: DEFAULT_WORKER_COST,
445            api_key: None,
446            bootstrap_port: None,
447            bootstrap_host: String::new(),
448            kv_connector: None,
449            kv_role: None,
450            health: HealthCheckConfig::default(),
451            max_connection_attempts: default_max_connection_attempts(),
452        }
453    }
454}
455
456// ── API types ───────────────────────────────────────────────────────
457
458/// Worker information for API responses.
459#[serde_with::skip_serializing_none]
460#[derive(Debug, Clone, Serialize)]
461pub struct WorkerInfo {
462    /// Worker unique identifier.
463    pub id: String,
464
465    /// Worker identity and configuration.
466    #[serde(flatten)]
467    pub spec: WorkerSpec,
468
469    /// Whether the worker is healthy.
470    pub is_healthy: bool,
471
472    /// Current load on the worker.
473    pub load: usize,
474
475    /// Job status for async operations (if available).
476    pub job_status: Option<JobStatus>,
477}
478
479impl WorkerInfo {
480    /// Create a partial WorkerInfo for pending workers (not yet registered).
481    pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
482        Self {
483            id: worker_id.to_string(),
484            spec: WorkerSpec::new(url),
485            is_healthy: false,
486            load: 0,
487            job_status,
488        }
489    }
490}
491
492/// Job status for async control plane operations
493#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct JobStatus {
495    pub job_type: String,
496    pub worker_url: String,
497    pub status: String,
498    pub message: Option<String>,
499    pub timestamp: u64,
500}
501
502impl JobStatus {
503    /// Create a pending job status
504    pub fn pending(job_type: &str, worker_url: &str) -> Self {
505        Self {
506            job_type: job_type.to_string(),
507            worker_url: worker_url.to_string(),
508            status: "pending".to_string(),
509            message: None,
510            timestamp: std::time::SystemTime::now()
511                .duration_since(std::time::SystemTime::UNIX_EPOCH)
512                .unwrap_or_default()
513                .as_secs(),
514        }
515    }
516
517    /// Create a processing job status
518    pub fn processing(job_type: &str, worker_url: &str) -> Self {
519        Self {
520            job_type: job_type.to_string(),
521            worker_url: worker_url.to_string(),
522            status: "processing".to_string(),
523            message: None,
524            timestamp: std::time::SystemTime::now()
525                .duration_since(std::time::SystemTime::UNIX_EPOCH)
526                .unwrap_or_default()
527                .as_secs(),
528        }
529    }
530
531    /// Create a failed job status
532    pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
533        Self {
534            job_type: job_type.to_string(),
535            worker_url: worker_url.to_string(),
536            status: "failed".to_string(),
537            message: Some(error),
538            timestamp: std::time::SystemTime::now()
539                .duration_since(std::time::SystemTime::UNIX_EPOCH)
540                .unwrap_or_default()
541                .as_secs(),
542        }
543    }
544}
545
546/// Worker list response
547#[derive(Debug, Clone, Serialize)]
548pub struct WorkerListResponse {
549    pub workers: Vec<WorkerInfo>,
550    pub total: usize,
551    pub stats: WorkerStats,
552}
553
554/// Worker statistics
555#[derive(Debug, Clone, Serialize)]
556pub struct WorkerStats {
557    pub total_workers: usize,
558    pub healthy_workers: usize,
559    pub total_models: usize,
560    pub total_load: usize,
561    pub by_type: WorkerTypeStats,
562}
563
564/// Worker statistics by type
565#[derive(Debug, Clone, Serialize)]
566pub struct WorkerTypeStats {
567    pub regular: usize,
568    pub prefill: usize,
569    pub decode: usize,
570}
571
572// ── Update types ────────────────────────────────────────────────────
573
574/// Partial health check config for PATCH-style updates.
575///
576/// Each `None` field means "keep the existing value". This avoids the problem
577/// where `#[serde(default)]` on [`HealthCheckConfig`] would silently reset
578/// unspecified fields to defaults.
579#[serde_with::skip_serializing_none]
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct HealthCheckUpdate {
582    pub timeout_secs: Option<u64>,
583    pub check_interval_secs: Option<u64>,
584    pub success_threshold: Option<u32>,
585    pub failure_threshold: Option<u32>,
586    pub disable_health_check: Option<bool>,
587}
588
589impl HealthCheckUpdate {
590    /// Merge this update into an existing [`HealthCheckConfig`], returning a new config.
591    /// Only `Some` fields are applied; `None` fields keep the existing value.
592    pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
593        HealthCheckConfig {
594            timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
595            check_interval_secs: self
596                .check_interval_secs
597                .unwrap_or(existing.check_interval_secs),
598            success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
599            failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
600            disable_health_check: self
601                .disable_health_check
602                .unwrap_or(existing.disable_health_check),
603        }
604    }
605}
606
607/// Worker update request
608#[serde_with::skip_serializing_none]
609#[derive(Debug, Clone, Serialize, Deserialize)]
610pub struct WorkerUpdateRequest {
611    /// Update priority
612    pub priority: Option<u32>,
613
614    /// Update cost
615    pub cost: Option<f32>,
616
617    /// Update labels
618    pub labels: Option<HashMap<String, String>>,
619
620    /// Update API key (for key rotation)
621    pub api_key: Option<String>,
622
623    /// Update health check configuration (partial — only specified fields change)
624    pub health: Option<HealthCheckUpdate>,
625}
626
627// ── Response types ──────────────────────────────────────────────────
628
629/// Generic API response
630#[derive(Debug, Clone, Serialize)]
631pub struct WorkerApiResponse {
632    pub success: bool,
633    pub message: String,
634
635    #[serde(skip_serializing_if = "Option::is_none")]
636    pub worker: Option<WorkerInfo>,
637}
638
639/// Error response
640#[derive(Debug, Clone, Serialize)]
641pub struct WorkerErrorResponse {
642    pub error: String,
643    pub code: String,
644}
645
646/// Result from flush cache operations across workers
647#[derive(Debug, Clone, Deserialize, Serialize)]
648pub struct FlushCacheResult {
649    pub successful: Vec<String>,
650    pub failed: Vec<(String, String)>,
651    pub total_workers: usize,
652    pub http_workers: usize,
653    pub message: String,
654}
655
656/// Result from getting worker loads
657#[derive(Debug, Clone, Deserialize, Serialize)]
658pub struct WorkerLoadsResult {
659    pub loads: Vec<WorkerLoadInfo>,
660    pub total_workers: usize,
661    pub successful: usize,
662    pub failed: usize,
663}
664
665/// Individual worker load information
666#[derive(Debug, Clone, Deserialize, Serialize)]
667pub struct WorkerLoadInfo {
668    pub worker: String,
669    #[serde(skip_serializing_if = "Option::is_none")]
670    pub worker_type: Option<String>,
671    pub load: isize,
672}
673
674#[cfg(feature = "axum")]
675impl IntoResponse for FlushCacheResult {
676    fn into_response(self) -> Response {
677        let status = if self.failed.is_empty() {
678            StatusCode::OK
679        } else {
680            StatusCode::PARTIAL_CONTENT
681        };
682
683        let mut body = json!({
684            "status": if self.failed.is_empty() { "success" } else { "partial_success" },
685            "message": self.message,
686            "workers_flushed": self.successful.len(),
687            "total_http_workers": self.http_workers,
688            "total_workers": self.total_workers
689        });
690
691        if !self.failed.is_empty() {
692            body["successful"] = json!(self.successful);
693            body["failed"] = json!(self
694                .failed
695                .into_iter()
696                .map(|(url, err)| json!({"worker": url, "error": err}))
697                .collect::<Vec<_>>());
698        }
699
700        (status, Json(body)).into_response()
701    }
702}
703
704#[cfg(feature = "axum")]
705impl IntoResponse for WorkerLoadsResult {
706    fn into_response(self) -> Response {
707        let loads: Vec<Value> = self
708            .loads
709            .iter()
710            .map(|info| json!({"worker": &info.worker, "load": info.load}))
711            .collect();
712        Json(json!({"workers": loads})).into_response()
713    }
714}