1use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11 http::StatusCode,
12 response::{IntoResponse, Response},
13 Json,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17#[cfg(feature = "axum")]
18use serde_json::{json, Value};
19
20use super::model_card::ModelCard;
21
22pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
25pub const DEFAULT_WORKER_COST: f32 = 1.0;
26
27#[derive(
31 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
32)]
33#[serde(rename_all = "lowercase")]
34pub enum WorkerType {
35 #[default]
37 Regular,
38 Prefill,
40 Decode,
42}
43
44impl std::fmt::Display for WorkerType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 WorkerType::Regular => write!(f, "regular"),
48 WorkerType::Prefill => write!(f, "prefill"),
49 WorkerType::Decode => write!(f, "decode"),
50 }
51 }
52}
53
54impl std::str::FromStr for WorkerType {
55 type Err = String;
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 if s.eq_ignore_ascii_case("regular") {
59 Ok(WorkerType::Regular)
60 } else if s.eq_ignore_ascii_case("prefill") {
61 Ok(WorkerType::Prefill)
62 } else if s.eq_ignore_ascii_case("decode") {
63 Ok(WorkerType::Decode)
64 } else {
65 Err(format!("Unknown worker type: {s}"))
66 }
67 }
68}
69
70#[derive(
72 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
73)]
74#[serde(rename_all = "lowercase")]
75pub enum ConnectionMode {
76 #[default]
78 Http,
79 Grpc,
81}
82
83impl std::fmt::Display for ConnectionMode {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 ConnectionMode::Http => write!(f, "http"),
87 ConnectionMode::Grpc => write!(f, "grpc"),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct WorkerGroupKey {
98 pub model_id: String,
99 pub worker_type: WorkerType,
100 pub connection_mode: ConnectionMode,
101}
102
103impl std::fmt::Display for WorkerGroupKey {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "{}:{}:{}",
108 self.model_id, self.worker_type, self.connection_mode
109 )
110 }
111}
112
113#[derive(
115 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
116)]
117#[serde(rename_all = "lowercase")]
118pub enum RuntimeType {
119 #[default]
121 Sglang,
122 Vllm,
124 Trtllm,
126 External,
128}
129
130impl std::fmt::Display for RuntimeType {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 RuntimeType::Sglang => write!(f, "sglang"),
134 RuntimeType::Vllm => write!(f, "vllm"),
135 RuntimeType::Trtllm => write!(f, "trtllm"),
136 RuntimeType::External => write!(f, "external"),
137 }
138 }
139}
140
141impl std::str::FromStr for RuntimeType {
142 type Err = String;
143
144 fn from_str(s: &str) -> Result<Self, Self::Err> {
145 if s.eq_ignore_ascii_case("sglang") {
146 Ok(RuntimeType::Sglang)
147 } else if s.eq_ignore_ascii_case("vllm") {
148 Ok(RuntimeType::Vllm)
149 } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
150 Ok(RuntimeType::Trtllm)
151 } else if s.eq_ignore_ascii_case("external") {
152 Ok(RuntimeType::External)
153 } else {
154 Err(format!("Unknown runtime type: {s}"))
155 }
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
165#[serde(rename_all = "lowercase")]
166pub enum ProviderType {
167 #[serde(alias = "openai")]
169 OpenAI,
170 #[serde(alias = "xai", alias = "grok")]
172 #[expect(
173 clippy::upper_case_acronyms,
174 reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
175 )]
176 XAI,
177 #[serde(alias = "anthropic", alias = "claude")]
179 Anthropic,
180 #[serde(alias = "gemini", alias = "google")]
182 Gemini,
183 #[serde(untagged)]
185 Custom(String),
186}
187
188impl ProviderType {
189 pub fn as_str(&self) -> &str {
191 match self {
192 Self::OpenAI => "openai",
193 Self::XAI => "xai",
194 Self::Anthropic => "anthropic",
195 Self::Gemini => "gemini",
196 Self::Custom(s) => s.as_str(),
197 }
198 }
199
200 pub fn from_url(url: &str) -> Option<Self> {
203 let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
204
205 if host.ends_with("openai.com") {
206 Some(Self::OpenAI)
207 } else if host.ends_with("x.ai") {
208 Some(Self::XAI)
209 } else if host.ends_with("anthropic.com") {
210 Some(Self::Anthropic)
211 } else if host.ends_with("googleapis.com") {
212 Some(Self::Gemini)
213 } else {
214 None
215 }
216 }
217
218 pub fn admin_key_env_var(&self) -> Option<&'static str> {
221 match self {
222 Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
223 Self::XAI => Some("XAI_ADMIN_KEY"),
224 Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
225 Self::Gemini => Some("GEMINI_ADMIN_KEY"),
226 Self::Custom(_) => None,
227 }
228 }
229
230 pub fn uses_x_api_key(&self) -> bool {
232 matches!(self, Self::Anthropic)
233 }
234
235 pub fn from_model_name(model: &str) -> Option<Self> {
238 let model_lower = model.to_lowercase();
239 if model_lower.starts_with("grok") {
240 Some(Self::XAI)
241 } else if model_lower.starts_with("gemini") {
242 Some(Self::Gemini)
243 } else if model_lower.starts_with("claude") {
244 Some(Self::Anthropic)
245 } else if model_lower.starts_with("gpt")
246 || model_lower.starts_with("o1")
247 || model_lower.starts_with("o3")
248 {
249 Some(Self::OpenAI)
250 } else {
251 None
252 }
253 }
254}
255
256impl std::fmt::Display for ProviderType {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 write!(f, "{}", self.as_str())
259 }
260}
261
262fn default_priority() -> u32 {
265 DEFAULT_WORKER_PRIORITY
266}
267
268fn default_cost() -> f32 {
269 DEFAULT_WORKER_COST
270}
271
272fn default_health_check_timeout() -> u64 {
273 30
274}
275
276fn default_health_check_interval() -> u64 {
277 60
278}
279
280fn default_health_success_threshold() -> u32 {
281 2
282}
283
284fn default_health_failure_threshold() -> u32 {
285 3
286}
287
288fn default_max_connection_attempts() -> u32 {
289 20
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
296pub struct HealthCheckConfig {
297 #[serde(default = "default_health_check_timeout")]
299 pub timeout_secs: u64,
300
301 #[serde(default = "default_health_check_interval")]
303 pub check_interval_secs: u64,
304
305 #[serde(default = "default_health_success_threshold")]
307 pub success_threshold: u32,
308
309 #[serde(default = "default_health_failure_threshold")]
311 pub failure_threshold: u32,
312
313 #[serde(default)]
315 pub disable_health_check: bool,
316}
317
318impl Default for HealthCheckConfig {
319 fn default() -> Self {
320 Self {
321 timeout_secs: default_health_check_timeout(),
322 check_interval_secs: default_health_check_interval(),
323 success_threshold: default_health_success_threshold(),
324 failure_threshold: default_health_failure_threshold(),
325 disable_health_check: false,
326 }
327 }
328}
329
330#[derive(Debug, Clone, Default)]
339pub enum WorkerModels {
340 #[default]
342 Wildcard,
343 Single(Box<ModelCard>),
345 Multi(Vec<ModelCard>),
347}
348
349impl WorkerModels {
350 pub fn is_wildcard(&self) -> bool {
352 matches!(self, Self::Wildcard)
353 }
354
355 pub fn primary(&self) -> Option<&ModelCard> {
357 match self {
358 Self::Wildcard => None,
359 Self::Single(card) => Some(card.as_ref()),
360 Self::Multi(cards) => cards.first(),
361 }
362 }
363
364 pub fn all(&self) -> &[ModelCard] {
366 match self {
367 Self::Wildcard => &[],
368 Self::Single(card) => std::slice::from_ref(card.as_ref()),
369 Self::Multi(cards) => cards,
370 }
371 }
372
373 pub fn find(&self, id: &str) -> Option<&ModelCard> {
375 match self {
376 Self::Wildcard => None,
377 Self::Single(card) => card.matches(id).then_some(card.as_ref()),
378 Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
379 }
380 }
381
382 pub fn supports(&self, id: &str) -> bool {
385 match self {
386 Self::Wildcard => true,
387 _ => self.find(id).is_some(),
388 }
389 }
390
391 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
393 self.all().iter()
394 }
395}
396
397impl From<Vec<ModelCard>> for WorkerModels {
398 fn from(models: Vec<ModelCard>) -> Self {
399 match models.len() {
400 0 => Self::Wildcard,
401 1 => {
402 let Some(model) = models.into_iter().next() else {
403 return Self::Wildcard;
404 };
405 Self::Single(Box::new(model))
406 }
407 _ => Self::Multi(models),
408 }
409 }
410}
411
412impl Serialize for WorkerModels {
414 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
415 self.all().serialize(serializer)
416 }
417}
418
419impl<'de> Deserialize<'de> for WorkerModels {
421 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
422 let models = Vec::<ModelCard>::deserialize(deserializer)?;
423 Ok(Self::from(models))
424 }
425}
426
427impl JsonSchema for WorkerModels {
429 fn schema_name() -> String {
430 "WorkerModels".to_string()
431 }
432
433 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
434 Vec::<ModelCard>::json_schema(gen)
435 }
436}
437
438#[serde_with::skip_serializing_none]
449#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
450pub struct WorkerSpec {
451 pub url: String,
453
454 #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
456 pub models: WorkerModels,
457
458 #[serde(default)]
460 pub worker_type: WorkerType,
461
462 #[serde(default)]
464 pub connection_mode: ConnectionMode,
465
466 #[serde(default, alias = "runtime")]
468 pub runtime_type: RuntimeType,
469
470 pub provider: Option<ProviderType>,
473
474 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
476 pub labels: HashMap<String, String>,
477
478 #[serde(default = "default_priority")]
480 pub priority: u32,
481
482 #[serde(default = "default_cost")]
484 pub cost: f32,
485
486 #[serde(default, skip_serializing)]
488 pub api_key: Option<String>,
489
490 #[serde(default, skip_serializing_if = "Option::is_none")]
492 pub bootstrap_port: Option<u16>,
493
494 #[serde(default, skip)]
496 pub bootstrap_host: String,
497
498 #[serde(default, skip_serializing_if = "Option::is_none")]
501 pub dp_base_url: Option<String>,
502
503 #[serde(default, skip_serializing_if = "Option::is_none")]
505 pub dp_rank: Option<usize>,
506
507 #[serde(default, skip_serializing_if = "Option::is_none")]
509 pub dp_size: Option<usize>,
510
511 #[serde(default, skip_serializing_if = "Option::is_none")]
513 pub kv_connector: Option<String>,
514
515 #[serde(default, skip_serializing_if = "Option::is_none")]
517 pub kv_role: Option<String>,
518
519 #[serde(default, skip_serializing_if = "Option::is_none")]
523 pub kv_block_size: Option<usize>,
524
525 #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
527 pub health: HealthCheckUpdate,
528
529 #[serde(default = "default_max_connection_attempts")]
531 pub max_connection_attempts: u32,
532
533 #[serde(default, skip_serializing_if = "Option::is_none")]
537 pub load_monitor_interval_secs: Option<u64>,
538}
539
540impl WorkerSpec {
541 pub fn new(url: impl Into<String>) -> Self {
543 Self {
544 url: url.into(),
545 models: WorkerModels::Wildcard,
546 worker_type: WorkerType::default(),
547 connection_mode: ConnectionMode::default(),
548 runtime_type: RuntimeType::default(),
549 provider: None,
550 labels: HashMap::new(),
551 priority: DEFAULT_WORKER_PRIORITY,
552 cost: DEFAULT_WORKER_COST,
553 api_key: None,
554 bootstrap_port: None,
555 bootstrap_host: String::new(),
556 dp_base_url: None,
557 dp_rank: None,
558 dp_size: None,
559 kv_connector: None,
560 kv_role: None,
561 kv_block_size: None,
562 health: HealthCheckUpdate::default(),
563 max_connection_attempts: default_max_connection_attempts(),
564 load_monitor_interval_secs: None,
565 }
566 }
567}
568
569#[serde_with::skip_serializing_none]
573#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
574pub struct WorkerInfo {
575 pub id: String,
577
578 #[serde(default, skip_serializing_if = "Option::is_none")]
581 pub model_id: Option<String>,
582
583 #[serde(flatten)]
585 pub spec: WorkerSpec,
586
587 pub is_healthy: bool,
589
590 pub load: usize,
592
593 pub job_status: Option<JobStatus>,
595}
596
597impl WorkerInfo {
598 pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
600 Self {
601 id: worker_id.to_string(),
602 model_id: None,
603 spec: WorkerSpec::new(url),
604 is_healthy: false,
605 load: 0,
606 job_status,
607 }
608 }
609}
610
611#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
613pub struct JobStatus {
614 pub job_type: String,
615 pub worker_url: String,
616 pub status: String,
617 pub message: Option<String>,
618 pub timestamp: u64,
619}
620
621impl JobStatus {
622 pub fn pending(job_type: &str, worker_url: &str) -> Self {
624 Self {
625 job_type: job_type.to_string(),
626 worker_url: worker_url.to_string(),
627 status: "pending".to_string(),
628 message: None,
629 timestamp: std::time::SystemTime::now()
630 .duration_since(std::time::SystemTime::UNIX_EPOCH)
631 .unwrap_or_default()
632 .as_secs(),
633 }
634 }
635
636 pub fn processing(job_type: &str, worker_url: &str) -> Self {
638 Self {
639 job_type: job_type.to_string(),
640 worker_url: worker_url.to_string(),
641 status: "processing".to_string(),
642 message: None,
643 timestamp: std::time::SystemTime::now()
644 .duration_since(std::time::SystemTime::UNIX_EPOCH)
645 .unwrap_or_default()
646 .as_secs(),
647 }
648 }
649
650 pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
652 Self {
653 job_type: job_type.to_string(),
654 worker_url: worker_url.to_string(),
655 status: "failed".to_string(),
656 message: Some(error),
657 timestamp: std::time::SystemTime::now()
658 .duration_since(std::time::SystemTime::UNIX_EPOCH)
659 .unwrap_or_default()
660 .as_secs(),
661 }
662 }
663}
664
665#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
667pub struct WorkerListResponse {
668 pub workers: Vec<WorkerInfo>,
669 pub total: usize,
670 pub stats: WorkerStats,
671}
672
673#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
675pub struct WorkerStats {
676 pub total_workers: usize,
677 pub healthy_workers: usize,
678 pub total_models: usize,
679 pub total_load: usize,
680 pub by_type: WorkerTypeStats,
681}
682
683#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
685pub struct WorkerTypeStats {
686 pub regular: usize,
687 pub prefill: usize,
688 pub decode: usize,
689}
690
691#[serde_with::skip_serializing_none]
699#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
700pub struct HealthCheckUpdate {
701 pub timeout_secs: Option<u64>,
702 pub check_interval_secs: Option<u64>,
703 pub success_threshold: Option<u32>,
704 pub failure_threshold: Option<u32>,
705 pub disable_health_check: Option<bool>,
706}
707
708impl HealthCheckUpdate {
709 pub fn is_empty(&self) -> bool {
711 self.timeout_secs.is_none()
712 && self.check_interval_secs.is_none()
713 && self.success_threshold.is_none()
714 && self.failure_threshold.is_none()
715 && self.disable_health_check.is_none()
716 }
717}
718
719impl HealthCheckUpdate {
720 pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
723 HealthCheckConfig {
724 timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
725 check_interval_secs: self
726 .check_interval_secs
727 .unwrap_or(existing.check_interval_secs),
728 success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
729 failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
730 disable_health_check: self
731 .disable_health_check
732 .unwrap_or(existing.disable_health_check),
733 }
734 }
735}
736
737#[serde_with::skip_serializing_none]
739#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
740pub struct WorkerUpdateRequest {
741 pub priority: Option<u32>,
743
744 pub cost: Option<f32>,
746
747 pub labels: Option<HashMap<String, String>>,
749
750 pub api_key: Option<String>,
752
753 pub health: Option<HealthCheckUpdate>,
755}
756
757#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
761pub struct WorkerApiResponse {
762 pub success: bool,
763 pub message: String,
764
765 #[serde(skip_serializing_if = "Option::is_none")]
766 pub worker: Option<WorkerInfo>,
767}
768
769#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
771pub struct WorkerErrorResponse {
772 pub error: String,
773 pub code: String,
774}
775
776#[derive(Debug, Clone, Deserialize, Serialize)]
778pub struct FlushCacheResult {
779 pub successful: Vec<String>,
780 pub failed: Vec<(String, String)>,
781 pub total_workers: usize,
782 pub http_workers: usize,
783 pub message: String,
784}
785
786#[derive(Debug, Clone, Deserialize, Serialize)]
788pub struct WorkerLoadsResult {
789 pub loads: Vec<WorkerLoadInfo>,
790 pub total_workers: usize,
791 pub successful: usize,
792 pub failed: usize,
793}
794
795#[derive(Debug, Clone, Default, Serialize, Deserialize)]
800#[serde(default)]
801pub struct SchedulerLoadSnapshot {
802 pub dp_rank: i32,
803 pub num_running_reqs: i32,
804 pub num_waiting_reqs: i32,
805 pub num_total_reqs: i32,
806 pub num_used_tokens: i32,
807 pub max_total_num_tokens: i32,
808 pub token_usage: f64,
810 pub gen_throughput: f64,
811 pub cache_hit_rate: f64,
812 pub utilization: f64,
813 pub max_running_requests: i32,
814}
815
816#[derive(Debug, Clone, Default, Serialize, Deserialize)]
818#[serde(default)]
819pub struct WorkerLoadResponse {
820 pub timestamp: String,
821 pub dp_rank_count: i32,
822 pub loads: Vec<SchedulerLoadSnapshot>,
823}
824
825impl WorkerLoadResponse {
826 pub fn effective_token_usage(&self) -> f64 {
828 if self.loads.is_empty() {
829 return 0.0;
830 }
831 self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
832 }
833
834 pub fn total_used_tokens(&self) -> i64 {
836 self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
837 }
838}
839
840#[derive(Debug, Clone, Deserialize, Serialize)]
842pub struct WorkerLoadInfo {
843 pub worker: String,
844 #[serde(skip_serializing_if = "Option::is_none")]
845 pub worker_type: Option<String>,
846 pub load: isize,
847 #[serde(skip_serializing_if = "Option::is_none")]
848 pub details: Option<WorkerLoadResponse>,
849}
850
851#[cfg(feature = "axum")]
852impl IntoResponse for FlushCacheResult {
853 fn into_response(self) -> Response {
854 let status = if self.failed.is_empty() {
855 StatusCode::OK
856 } else {
857 StatusCode::PARTIAL_CONTENT
858 };
859
860 let mut body = json!({
861 "status": if self.failed.is_empty() { "success" } else { "partial_success" },
862 "message": self.message,
863 "workers_flushed": self.successful.len(),
864 "total_http_workers": self.http_workers,
865 "total_workers": self.total_workers
866 });
867
868 if !self.failed.is_empty() {
869 body["successful"] = json!(self.successful);
870 body["failed"] = json!(self
871 .failed
872 .into_iter()
873 .map(|(url, err)| json!({"worker": url, "error": err}))
874 .collect::<Vec<_>>());
875 }
876
877 (status, Json(body)).into_response()
878 }
879}
880
881#[cfg(feature = "axum")]
882impl IntoResponse for WorkerLoadsResult {
883 fn into_response(self) -> Response {
884 let loads: Vec<Value> = self
885 .loads
886 .iter()
887 .map(|info| {
888 let mut entry = json!({"worker": &info.worker, "load": info.load});
889 if let Some(ref details) = info.details {
890 entry["details"] = json!(details);
891 }
892 entry
893 })
894 .collect();
895 Json(json!({"workers": loads})).into_response()
896 }
897}