1use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11 http::StatusCode,
12 response::{IntoResponse, Response},
13 Json,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17#[cfg(feature = "axum")]
18use serde_json::{json, Value};
19
20use super::model_card::ModelCard;
21
22pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
25pub const DEFAULT_WORKER_COST: f32 = 1.0;
26
27#[derive(
31 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
32)]
33#[serde(rename_all = "lowercase")]
34pub enum WorkerType {
35 #[default]
37 Regular,
38 Prefill,
40 Decode,
42}
43
44impl std::fmt::Display for WorkerType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 WorkerType::Regular => write!(f, "regular"),
48 WorkerType::Prefill => write!(f, "prefill"),
49 WorkerType::Decode => write!(f, "decode"),
50 }
51 }
52}
53
54impl std::str::FromStr for WorkerType {
55 type Err = String;
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 if s.eq_ignore_ascii_case("regular") {
59 Ok(WorkerType::Regular)
60 } else if s.eq_ignore_ascii_case("prefill") {
61 Ok(WorkerType::Prefill)
62 } else if s.eq_ignore_ascii_case("decode") {
63 Ok(WorkerType::Decode)
64 } else {
65 Err(format!("Unknown worker type: {s}"))
66 }
67 }
68}
69
70#[derive(
72 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
73)]
74#[serde(rename_all = "lowercase")]
75pub enum ConnectionMode {
76 #[default]
78 Http,
79 Grpc,
81}
82
83impl std::fmt::Display for ConnectionMode {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 ConnectionMode::Http => write!(f, "http"),
87 ConnectionMode::Grpc => write!(f, "grpc"),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct WorkerGroupKey {
98 pub model_id: String,
99 pub worker_type: WorkerType,
100 pub connection_mode: ConnectionMode,
101}
102
103impl std::fmt::Display for WorkerGroupKey {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "{}:{}:{}",
108 self.model_id, self.worker_type, self.connection_mode
109 )
110 }
111}
112
113#[derive(
115 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
116)]
117#[serde(rename_all = "lowercase")]
118pub enum RuntimeType {
119 #[default]
121 Sglang,
122 Vllm,
124 Trtllm,
126 External,
128}
129
130impl std::fmt::Display for RuntimeType {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 RuntimeType::Sglang => write!(f, "sglang"),
134 RuntimeType::Vllm => write!(f, "vllm"),
135 RuntimeType::Trtllm => write!(f, "trtllm"),
136 RuntimeType::External => write!(f, "external"),
137 }
138 }
139}
140
141impl std::str::FromStr for RuntimeType {
142 type Err = String;
143
144 fn from_str(s: &str) -> Result<Self, Self::Err> {
145 if s.eq_ignore_ascii_case("sglang") {
146 Ok(RuntimeType::Sglang)
147 } else if s.eq_ignore_ascii_case("vllm") {
148 Ok(RuntimeType::Vllm)
149 } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
150 Ok(RuntimeType::Trtllm)
151 } else if s.eq_ignore_ascii_case("external") {
152 Ok(RuntimeType::External)
153 } else {
154 Err(format!("Unknown runtime type: {s}"))
155 }
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
165#[serde(rename_all = "lowercase")]
166pub enum ProviderType {
167 #[serde(alias = "openai")]
169 OpenAI,
170 #[serde(alias = "xai", alias = "grok")]
172 #[expect(
173 clippy::upper_case_acronyms,
174 reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
175 )]
176 XAI,
177 #[serde(alias = "anthropic", alias = "claude")]
179 Anthropic,
180 #[serde(alias = "gemini", alias = "google")]
182 Gemini,
183 #[serde(untagged)]
185 Custom(String),
186}
187
188impl ProviderType {
189 pub fn as_str(&self) -> &str {
191 match self {
192 Self::OpenAI => "openai",
193 Self::XAI => "xai",
194 Self::Anthropic => "anthropic",
195 Self::Gemini => "gemini",
196 Self::Custom(s) => s.as_str(),
197 }
198 }
199
200 pub fn from_url(url: &str) -> Option<Self> {
203 let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
204
205 if host.ends_with("openai.com") {
206 Some(Self::OpenAI)
207 } else if host.ends_with("x.ai") {
208 Some(Self::XAI)
209 } else if host.ends_with("anthropic.com") {
210 Some(Self::Anthropic)
211 } else if host.ends_with("googleapis.com") {
212 Some(Self::Gemini)
213 } else {
214 None
215 }
216 }
217
218 pub fn admin_key_env_var(&self) -> Option<&'static str> {
221 match self {
222 Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
223 Self::XAI => Some("XAI_ADMIN_KEY"),
224 Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
225 Self::Gemini => Some("GEMINI_ADMIN_KEY"),
226 Self::Custom(_) => None,
227 }
228 }
229
230 pub fn uses_x_api_key(&self) -> bool {
232 matches!(self, Self::Anthropic)
233 }
234
235 pub fn from_model_name(model: &str) -> Option<Self> {
238 let model_lower = model.to_lowercase();
239 if model_lower.starts_with("grok") {
240 Some(Self::XAI)
241 } else if model_lower.starts_with("gemini") {
242 Some(Self::Gemini)
243 } else if model_lower.starts_with("claude") {
244 Some(Self::Anthropic)
245 } else if model_lower.starts_with("gpt")
246 || model_lower.starts_with("o1")
247 || model_lower.starts_with("o3")
248 {
249 Some(Self::OpenAI)
250 } else {
251 None
252 }
253 }
254}
255
256impl std::fmt::Display for ProviderType {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 write!(f, "{}", self.as_str())
259 }
260}
261
262fn default_priority() -> u32 {
265 DEFAULT_WORKER_PRIORITY
266}
267
268fn default_cost() -> f32 {
269 DEFAULT_WORKER_COST
270}
271
272fn default_health_check_timeout() -> u64 {
273 30
274}
275
276fn default_health_check_interval() -> u64 {
277 60
278}
279
280fn default_health_success_threshold() -> u32 {
281 2
282}
283
284fn default_health_failure_threshold() -> u32 {
285 3
286}
287
288fn default_max_connection_attempts() -> u32 {
289 20
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
296pub struct HealthCheckConfig {
297 #[serde(default = "default_health_check_timeout")]
299 pub timeout_secs: u64,
300
301 #[serde(default = "default_health_check_interval")]
303 pub check_interval_secs: u64,
304
305 #[serde(default = "default_health_success_threshold")]
307 pub success_threshold: u32,
308
309 #[serde(default = "default_health_failure_threshold")]
311 pub failure_threshold: u32,
312
313 #[serde(default)]
315 pub disable_health_check: bool,
316}
317
318impl Default for HealthCheckConfig {
319 fn default() -> Self {
320 Self {
321 timeout_secs: default_health_check_timeout(),
322 check_interval_secs: default_health_check_interval(),
323 success_threshold: default_health_success_threshold(),
324 failure_threshold: default_health_failure_threshold(),
325 disable_health_check: false,
326 }
327 }
328}
329
330#[derive(Debug, Clone, Default)]
339pub enum WorkerModels {
340 #[default]
342 Wildcard,
343 Single(Box<ModelCard>),
345 Multi(Vec<ModelCard>),
347}
348
349impl WorkerModels {
350 pub fn is_wildcard(&self) -> bool {
352 matches!(self, Self::Wildcard)
353 }
354
355 pub fn primary(&self) -> Option<&ModelCard> {
357 match self {
358 Self::Wildcard => None,
359 Self::Single(card) => Some(card.as_ref()),
360 Self::Multi(cards) => cards.first(),
361 }
362 }
363
364 pub fn all(&self) -> &[ModelCard] {
366 match self {
367 Self::Wildcard => &[],
368 Self::Single(card) => std::slice::from_ref(card.as_ref()),
369 Self::Multi(cards) => cards,
370 }
371 }
372
373 pub fn find(&self, id: &str) -> Option<&ModelCard> {
375 match self {
376 Self::Wildcard => None,
377 Self::Single(card) => card.matches(id).then_some(card.as_ref()),
378 Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
379 }
380 }
381
382 pub fn supports(&self, id: &str) -> bool {
385 match self {
386 Self::Wildcard => true,
387 _ => self.find(id).is_some(),
388 }
389 }
390
391 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
393 self.all().iter()
394 }
395}
396
397impl From<Vec<ModelCard>> for WorkerModels {
398 fn from(models: Vec<ModelCard>) -> Self {
399 match models.len() {
400 0 => Self::Wildcard,
401 1 => {
402 let Some(model) = models.into_iter().next() else {
403 return Self::Wildcard;
404 };
405 Self::Single(Box::new(model))
406 }
407 _ => Self::Multi(models),
408 }
409 }
410}
411
412impl Serialize for WorkerModels {
414 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
415 self.all().serialize(serializer)
416 }
417}
418
419impl<'de> Deserialize<'de> for WorkerModels {
421 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
422 let models = Vec::<ModelCard>::deserialize(deserializer)?;
423 Ok(Self::from(models))
424 }
425}
426
427impl JsonSchema for WorkerModels {
429 fn schema_name() -> String {
430 "WorkerModels".to_string()
431 }
432
433 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
434 Vec::<ModelCard>::json_schema(gen)
435 }
436}
437
438#[serde_with::skip_serializing_none]
449#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
450pub struct WorkerSpec {
451 pub url: String,
453
454 #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
456 pub models: WorkerModels,
457
458 #[serde(default)]
460 pub worker_type: WorkerType,
461
462 #[serde(default)]
464 pub connection_mode: ConnectionMode,
465
466 #[serde(default, alias = "runtime")]
468 pub runtime_type: RuntimeType,
469
470 pub provider: Option<ProviderType>,
473
474 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
476 pub labels: HashMap<String, String>,
477
478 #[serde(default = "default_priority")]
480 pub priority: u32,
481
482 #[serde(default = "default_cost")]
484 pub cost: f32,
485
486 #[serde(default, skip_serializing)]
488 pub api_key: Option<String>,
489
490 #[serde(default, skip_serializing_if = "Option::is_none")]
492 pub bootstrap_port: Option<u16>,
493
494 #[serde(default, skip)]
496 pub bootstrap_host: String,
497
498 #[serde(default, skip_serializing_if = "Option::is_none")]
501 pub dp_base_url: Option<String>,
502
503 #[serde(default, skip_serializing_if = "Option::is_none")]
505 pub dp_rank: Option<usize>,
506
507 #[serde(default, skip_serializing_if = "Option::is_none")]
509 pub dp_size: Option<usize>,
510
511 #[serde(default, skip_serializing_if = "Option::is_none")]
513 pub kv_connector: Option<String>,
514
515 #[serde(default, skip_serializing_if = "Option::is_none")]
517 pub kv_role: Option<String>,
518
519 #[serde(default, skip_serializing_if = "Option::is_none")]
523 pub kv_block_size: Option<usize>,
524
525 #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
527 pub health: HealthCheckUpdate,
528
529 #[serde(default = "default_max_connection_attempts")]
531 pub max_connection_attempts: u32,
532
533 #[serde(default, skip_serializing_if = "Option::is_none")]
537 pub load_monitor_interval_secs: Option<u64>,
538}
539
540impl WorkerSpec {
541 pub fn new(url: impl Into<String>) -> Self {
543 Self {
544 url: url.into(),
545 models: WorkerModels::Wildcard,
546 worker_type: WorkerType::default(),
547 connection_mode: ConnectionMode::default(),
548 runtime_type: RuntimeType::default(),
549 provider: None,
550 labels: HashMap::new(),
551 priority: DEFAULT_WORKER_PRIORITY,
552 cost: DEFAULT_WORKER_COST,
553 api_key: None,
554 bootstrap_port: None,
555 bootstrap_host: String::new(),
556 dp_base_url: None,
557 dp_rank: None,
558 dp_size: None,
559 kv_connector: None,
560 kv_role: None,
561 kv_block_size: None,
562 health: HealthCheckUpdate::default(),
563 max_connection_attempts: default_max_connection_attempts(),
564 load_monitor_interval_secs: None,
565 }
566 }
567}
568
569#[serde_with::skip_serializing_none]
573#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
574pub struct WorkerInfo {
575 pub id: String,
577
578 #[serde(flatten)]
580 pub spec: WorkerSpec,
581
582 pub is_healthy: bool,
584
585 pub load: usize,
587
588 pub job_status: Option<JobStatus>,
590}
591
592impl WorkerInfo {
593 pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
595 Self {
596 id: worker_id.to_string(),
597 spec: WorkerSpec::new(url),
598 is_healthy: false,
599 load: 0,
600 job_status,
601 }
602 }
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
607pub struct JobStatus {
608 pub job_type: String,
609 pub worker_url: String,
610 pub status: String,
611 pub message: Option<String>,
612 pub timestamp: u64,
613}
614
615impl JobStatus {
616 pub fn pending(job_type: &str, worker_url: &str) -> Self {
618 Self {
619 job_type: job_type.to_string(),
620 worker_url: worker_url.to_string(),
621 status: "pending".to_string(),
622 message: None,
623 timestamp: std::time::SystemTime::now()
624 .duration_since(std::time::SystemTime::UNIX_EPOCH)
625 .unwrap_or_default()
626 .as_secs(),
627 }
628 }
629
630 pub fn processing(job_type: &str, worker_url: &str) -> Self {
632 Self {
633 job_type: job_type.to_string(),
634 worker_url: worker_url.to_string(),
635 status: "processing".to_string(),
636 message: None,
637 timestamp: std::time::SystemTime::now()
638 .duration_since(std::time::SystemTime::UNIX_EPOCH)
639 .unwrap_or_default()
640 .as_secs(),
641 }
642 }
643
644 pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
646 Self {
647 job_type: job_type.to_string(),
648 worker_url: worker_url.to_string(),
649 status: "failed".to_string(),
650 message: Some(error),
651 timestamp: std::time::SystemTime::now()
652 .duration_since(std::time::SystemTime::UNIX_EPOCH)
653 .unwrap_or_default()
654 .as_secs(),
655 }
656 }
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
661pub struct WorkerListResponse {
662 pub workers: Vec<WorkerInfo>,
663 pub total: usize,
664 pub stats: WorkerStats,
665}
666
667#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
669pub struct WorkerStats {
670 pub total_workers: usize,
671 pub healthy_workers: usize,
672 pub total_models: usize,
673 pub total_load: usize,
674 pub by_type: WorkerTypeStats,
675}
676
677#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
679pub struct WorkerTypeStats {
680 pub regular: usize,
681 pub prefill: usize,
682 pub decode: usize,
683}
684
685#[serde_with::skip_serializing_none]
693#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
694pub struct HealthCheckUpdate {
695 pub timeout_secs: Option<u64>,
696 pub check_interval_secs: Option<u64>,
697 pub success_threshold: Option<u32>,
698 pub failure_threshold: Option<u32>,
699 pub disable_health_check: Option<bool>,
700}
701
702impl HealthCheckUpdate {
703 pub fn is_empty(&self) -> bool {
705 self.timeout_secs.is_none()
706 && self.check_interval_secs.is_none()
707 && self.success_threshold.is_none()
708 && self.failure_threshold.is_none()
709 && self.disable_health_check.is_none()
710 }
711}
712
713impl HealthCheckUpdate {
714 pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
717 HealthCheckConfig {
718 timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
719 check_interval_secs: self
720 .check_interval_secs
721 .unwrap_or(existing.check_interval_secs),
722 success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
723 failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
724 disable_health_check: self
725 .disable_health_check
726 .unwrap_or(existing.disable_health_check),
727 }
728 }
729}
730
731#[serde_with::skip_serializing_none]
733#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
734pub struct WorkerUpdateRequest {
735 pub priority: Option<u32>,
737
738 pub cost: Option<f32>,
740
741 pub labels: Option<HashMap<String, String>>,
743
744 pub api_key: Option<String>,
746
747 pub health: Option<HealthCheckUpdate>,
749}
750
751#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
755pub struct WorkerApiResponse {
756 pub success: bool,
757 pub message: String,
758
759 #[serde(skip_serializing_if = "Option::is_none")]
760 pub worker: Option<WorkerInfo>,
761}
762
763#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
765pub struct WorkerErrorResponse {
766 pub error: String,
767 pub code: String,
768}
769
770#[derive(Debug, Clone, Deserialize, Serialize)]
772pub struct FlushCacheResult {
773 pub successful: Vec<String>,
774 pub failed: Vec<(String, String)>,
775 pub total_workers: usize,
776 pub http_workers: usize,
777 pub message: String,
778}
779
780#[derive(Debug, Clone, Deserialize, Serialize)]
782pub struct WorkerLoadsResult {
783 pub loads: Vec<WorkerLoadInfo>,
784 pub total_workers: usize,
785 pub successful: usize,
786 pub failed: usize,
787}
788
789#[derive(Debug, Clone, Default, Serialize, Deserialize)]
794#[serde(default)]
795pub struct SchedulerLoadSnapshot {
796 pub dp_rank: i32,
797 pub num_running_reqs: i32,
798 pub num_waiting_reqs: i32,
799 pub num_total_reqs: i32,
800 pub num_used_tokens: i32,
801 pub max_total_num_tokens: i32,
802 pub token_usage: f64,
804 pub gen_throughput: f64,
805 pub cache_hit_rate: f64,
806 pub utilization: f64,
807 pub max_running_requests: i32,
808}
809
810#[derive(Debug, Clone, Default, Serialize, Deserialize)]
812#[serde(default)]
813pub struct WorkerLoadResponse {
814 pub timestamp: String,
815 pub dp_rank_count: i32,
816 pub loads: Vec<SchedulerLoadSnapshot>,
817}
818
819impl WorkerLoadResponse {
820 pub fn effective_token_usage(&self) -> f64 {
822 if self.loads.is_empty() {
823 return 0.0;
824 }
825 self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
826 }
827
828 pub fn total_used_tokens(&self) -> i64 {
830 self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
831 }
832}
833
834#[derive(Debug, Clone, Deserialize, Serialize)]
836pub struct WorkerLoadInfo {
837 pub worker: String,
838 #[serde(skip_serializing_if = "Option::is_none")]
839 pub worker_type: Option<String>,
840 pub load: isize,
841 #[serde(skip_serializing_if = "Option::is_none")]
842 pub details: Option<WorkerLoadResponse>,
843}
844
845#[cfg(feature = "axum")]
846impl IntoResponse for FlushCacheResult {
847 fn into_response(self) -> Response {
848 let status = if self.failed.is_empty() {
849 StatusCode::OK
850 } else {
851 StatusCode::PARTIAL_CONTENT
852 };
853
854 let mut body = json!({
855 "status": if self.failed.is_empty() { "success" } else { "partial_success" },
856 "message": self.message,
857 "workers_flushed": self.successful.len(),
858 "total_http_workers": self.http_workers,
859 "total_workers": self.total_workers
860 });
861
862 if !self.failed.is_empty() {
863 body["successful"] = json!(self.successful);
864 body["failed"] = json!(self
865 .failed
866 .into_iter()
867 .map(|(url, err)| json!({"worker": url, "error": err}))
868 .collect::<Vec<_>>());
869 }
870
871 (status, Json(body)).into_response()
872 }
873}
874
875#[cfg(feature = "axum")]
876impl IntoResponse for WorkerLoadsResult {
877 fn into_response(self) -> Response {
878 let loads: Vec<Value> = self
879 .loads
880 .iter()
881 .map(|info| {
882 let mut entry = json!({"worker": &info.worker, "load": info.load});
883 if let Some(ref details) = info.details {
884 entry["details"] = json!(details);
885 }
886 entry
887 })
888 .collect();
889 Json(json!({"workers": loads})).into_response()
890 }
891}