1use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11 http::StatusCode,
12 response::{IntoResponse, Response},
13 Json,
14};
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16#[cfg(feature = "axum")]
17use serde_json::{json, Value};
18
19use super::model_card::ModelCard;
20
21pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
24pub const DEFAULT_WORKER_COST: f32 = 1.0;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
30#[serde(rename_all = "lowercase")]
31pub enum WorkerType {
32 #[default]
34 Regular,
35 Prefill,
37 Decode,
39}
40
41impl std::fmt::Display for WorkerType {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 WorkerType::Regular => write!(f, "regular"),
45 WorkerType::Prefill => write!(f, "prefill"),
46 WorkerType::Decode => write!(f, "decode"),
47 }
48 }
49}
50
51impl std::str::FromStr for WorkerType {
52 type Err = String;
53
54 fn from_str(s: &str) -> Result<Self, Self::Err> {
55 if s.eq_ignore_ascii_case("regular") {
56 Ok(WorkerType::Regular)
57 } else if s.eq_ignore_ascii_case("prefill") {
58 Ok(WorkerType::Prefill)
59 } else if s.eq_ignore_ascii_case("decode") {
60 Ok(WorkerType::Decode)
61 } else {
62 Err(format!("Unknown worker type: {s}"))
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum ConnectionMode {
71 #[default]
73 Http,
74 Grpc,
76}
77
78impl std::fmt::Display for ConnectionMode {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 ConnectionMode::Http => write!(f, "http"),
82 ConnectionMode::Grpc => write!(f, "grpc"),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
89#[serde(rename_all = "lowercase")]
90pub enum RuntimeType {
91 #[default]
93 Sglang,
94 Vllm,
96 Trtllm,
98 External,
100}
101
102impl std::fmt::Display for RuntimeType {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 RuntimeType::Sglang => write!(f, "sglang"),
106 RuntimeType::Vllm => write!(f, "vllm"),
107 RuntimeType::Trtllm => write!(f, "trtllm"),
108 RuntimeType::External => write!(f, "external"),
109 }
110 }
111}
112
113impl std::str::FromStr for RuntimeType {
114 type Err = String;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 if s.eq_ignore_ascii_case("sglang") {
118 Ok(RuntimeType::Sglang)
119 } else if s.eq_ignore_ascii_case("vllm") {
120 Ok(RuntimeType::Vllm)
121 } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
122 Ok(RuntimeType::Trtllm)
123 } else if s.eq_ignore_ascii_case("external") {
124 Ok(RuntimeType::External)
125 } else {
126 Err(format!("Unknown runtime type: {s}"))
127 }
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum ProviderType {
139 #[serde(alias = "openai")]
141 OpenAI,
142 #[serde(alias = "xai", alias = "grok")]
144 #[expect(
145 clippy::upper_case_acronyms,
146 reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
147 )]
148 XAI,
149 #[serde(alias = "anthropic", alias = "claude")]
151 Anthropic,
152 #[serde(alias = "gemini", alias = "google")]
154 Gemini,
155 #[serde(untagged)]
157 Custom(String),
158}
159
160impl ProviderType {
161 pub fn as_str(&self) -> &str {
163 match self {
164 Self::OpenAI => "openai",
165 Self::XAI => "xai",
166 Self::Anthropic => "anthropic",
167 Self::Gemini => "gemini",
168 Self::Custom(s) => s.as_str(),
169 }
170 }
171
172 pub fn from_model_name(model: &str) -> Option<Self> {
175 let model_lower = model.to_lowercase();
176 if model_lower.starts_with("grok") {
177 Some(Self::XAI)
178 } else if model_lower.starts_with("gemini") {
179 Some(Self::Gemini)
180 } else if model_lower.starts_with("claude") {
181 Some(Self::Anthropic)
182 } else if model_lower.starts_with("gpt")
183 || model_lower.starts_with("o1")
184 || model_lower.starts_with("o3")
185 {
186 Some(Self::OpenAI)
187 } else {
188 None
189 }
190 }
191}
192
193impl std::fmt::Display for ProviderType {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 write!(f, "{}", self.as_str())
196 }
197}
198
199fn default_priority() -> u32 {
202 DEFAULT_WORKER_PRIORITY
203}
204
205fn default_cost() -> f32 {
206 DEFAULT_WORKER_COST
207}
208
209fn default_health_check_timeout() -> u64 {
210 30
211}
212
213fn default_health_check_interval() -> u64 {
214 60
215}
216
217fn default_health_success_threshold() -> u32 {
218 2
219}
220
221fn default_health_failure_threshold() -> u32 {
222 3
223}
224
225fn default_max_connection_attempts() -> u32 {
226 20
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct HealthCheckConfig {
234 #[serde(default = "default_health_check_timeout")]
236 pub timeout_secs: u64,
237
238 #[serde(default = "default_health_check_interval")]
240 pub check_interval_secs: u64,
241
242 #[serde(default = "default_health_success_threshold")]
244 pub success_threshold: u32,
245
246 #[serde(default = "default_health_failure_threshold")]
248 pub failure_threshold: u32,
249
250 #[serde(default)]
252 pub disable_health_check: bool,
253}
254
255impl Default for HealthCheckConfig {
256 fn default() -> Self {
257 Self {
258 timeout_secs: default_health_check_timeout(),
259 check_interval_secs: default_health_check_interval(),
260 success_threshold: default_health_success_threshold(),
261 failure_threshold: default_health_failure_threshold(),
262 disable_health_check: false,
263 }
264 }
265}
266
267#[derive(Debug, Clone, Default)]
276pub enum WorkerModels {
277 #[default]
279 Wildcard,
280 Single(Box<ModelCard>),
282 Multi(Vec<ModelCard>),
284}
285
286impl WorkerModels {
287 pub fn is_wildcard(&self) -> bool {
289 matches!(self, Self::Wildcard)
290 }
291
292 pub fn primary(&self) -> Option<&ModelCard> {
294 match self {
295 Self::Wildcard => None,
296 Self::Single(card) => Some(card.as_ref()),
297 Self::Multi(cards) => cards.first(),
298 }
299 }
300
301 pub fn all(&self) -> &[ModelCard] {
303 match self {
304 Self::Wildcard => &[],
305 Self::Single(card) => std::slice::from_ref(card.as_ref()),
306 Self::Multi(cards) => cards,
307 }
308 }
309
310 pub fn find(&self, id: &str) -> Option<&ModelCard> {
312 match self {
313 Self::Wildcard => None,
314 Self::Single(card) => card.matches(id).then_some(card.as_ref()),
315 Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
316 }
317 }
318
319 pub fn supports(&self, id: &str) -> bool {
322 match self {
323 Self::Wildcard => true,
324 _ => self.find(id).is_some(),
325 }
326 }
327
328 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
330 self.all().iter()
331 }
332}
333
334impl From<Vec<ModelCard>> for WorkerModels {
335 fn from(models: Vec<ModelCard>) -> Self {
336 match models.len() {
337 0 => Self::Wildcard,
338 1 => {
339 let Some(model) = models.into_iter().next() else {
340 return Self::Wildcard;
341 };
342 Self::Single(Box::new(model))
343 }
344 _ => Self::Multi(models),
345 }
346 }
347}
348
349impl Serialize for WorkerModels {
351 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
352 self.all().serialize(serializer)
353 }
354}
355
356impl<'de> Deserialize<'de> for WorkerModels {
358 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
359 let models = Vec::<ModelCard>::deserialize(deserializer)?;
360 Ok(Self::from(models))
361 }
362}
363
364#[serde_with::skip_serializing_none]
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct WorkerSpec {
377 pub url: String,
379
380 #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
382 pub models: WorkerModels,
383
384 #[serde(default)]
386 pub worker_type: WorkerType,
387
388 #[serde(default)]
390 pub connection_mode: ConnectionMode,
391
392 #[serde(default, alias = "runtime")]
394 pub runtime_type: RuntimeType,
395
396 pub provider: Option<ProviderType>,
399
400 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
402 pub labels: HashMap<String, String>,
403
404 #[serde(default = "default_priority")]
406 pub priority: u32,
407
408 #[serde(default = "default_cost")]
410 pub cost: f32,
411
412 #[serde(default, skip_serializing)]
414 pub api_key: Option<String>,
415
416 #[serde(default, skip_serializing_if = "Option::is_none")]
418 pub bootstrap_port: Option<u16>,
419
420 #[serde(default, skip)]
422 pub bootstrap_host: String,
423
424 #[serde(default, skip_serializing_if = "Option::is_none")]
427 pub dp_base_url: Option<String>,
428
429 #[serde(default, skip_serializing_if = "Option::is_none")]
431 pub dp_rank: Option<usize>,
432
433 #[serde(default, skip_serializing_if = "Option::is_none")]
435 pub dp_size: Option<usize>,
436
437 #[serde(default, skip_serializing_if = "Option::is_none")]
439 pub kv_connector: Option<String>,
440
441 #[serde(default, skip_serializing_if = "Option::is_none")]
443 pub kv_role: Option<String>,
444
445 #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
447 pub health: HealthCheckUpdate,
448
449 #[serde(default = "default_max_connection_attempts")]
451 pub max_connection_attempts: u32,
452}
453
454impl WorkerSpec {
455 pub fn new(url: impl Into<String>) -> Self {
457 Self {
458 url: url.into(),
459 models: WorkerModels::Wildcard,
460 worker_type: WorkerType::default(),
461 connection_mode: ConnectionMode::default(),
462 runtime_type: RuntimeType::default(),
463 provider: None,
464 labels: HashMap::new(),
465 priority: DEFAULT_WORKER_PRIORITY,
466 cost: DEFAULT_WORKER_COST,
467 api_key: None,
468 bootstrap_port: None,
469 bootstrap_host: String::new(),
470 dp_base_url: None,
471 dp_rank: None,
472 dp_size: None,
473 kv_connector: None,
474 kv_role: None,
475 health: HealthCheckUpdate::default(),
476 max_connection_attempts: default_max_connection_attempts(),
477 }
478 }
479}
480
481#[serde_with::skip_serializing_none]
485#[derive(Debug, Clone, Serialize)]
486pub struct WorkerInfo {
487 pub id: String,
489
490 #[serde(flatten)]
492 pub spec: WorkerSpec,
493
494 pub is_healthy: bool,
496
497 pub load: usize,
499
500 pub job_status: Option<JobStatus>,
502}
503
504impl WorkerInfo {
505 pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
507 Self {
508 id: worker_id.to_string(),
509 spec: WorkerSpec::new(url),
510 is_healthy: false,
511 load: 0,
512 job_status,
513 }
514 }
515}
516
517#[derive(Debug, Clone, Serialize, Deserialize)]
519pub struct JobStatus {
520 pub job_type: String,
521 pub worker_url: String,
522 pub status: String,
523 pub message: Option<String>,
524 pub timestamp: u64,
525}
526
527impl JobStatus {
528 pub fn pending(job_type: &str, worker_url: &str) -> Self {
530 Self {
531 job_type: job_type.to_string(),
532 worker_url: worker_url.to_string(),
533 status: "pending".to_string(),
534 message: None,
535 timestamp: std::time::SystemTime::now()
536 .duration_since(std::time::SystemTime::UNIX_EPOCH)
537 .unwrap_or_default()
538 .as_secs(),
539 }
540 }
541
542 pub fn processing(job_type: &str, worker_url: &str) -> Self {
544 Self {
545 job_type: job_type.to_string(),
546 worker_url: worker_url.to_string(),
547 status: "processing".to_string(),
548 message: None,
549 timestamp: std::time::SystemTime::now()
550 .duration_since(std::time::SystemTime::UNIX_EPOCH)
551 .unwrap_or_default()
552 .as_secs(),
553 }
554 }
555
556 pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
558 Self {
559 job_type: job_type.to_string(),
560 worker_url: worker_url.to_string(),
561 status: "failed".to_string(),
562 message: Some(error),
563 timestamp: std::time::SystemTime::now()
564 .duration_since(std::time::SystemTime::UNIX_EPOCH)
565 .unwrap_or_default()
566 .as_secs(),
567 }
568 }
569}
570
571#[derive(Debug, Clone, Serialize)]
573pub struct WorkerListResponse {
574 pub workers: Vec<WorkerInfo>,
575 pub total: usize,
576 pub stats: WorkerStats,
577}
578
579#[derive(Debug, Clone, Serialize)]
581pub struct WorkerStats {
582 pub total_workers: usize,
583 pub healthy_workers: usize,
584 pub total_models: usize,
585 pub total_load: usize,
586 pub by_type: WorkerTypeStats,
587}
588
589#[derive(Debug, Clone, Serialize)]
591pub struct WorkerTypeStats {
592 pub regular: usize,
593 pub prefill: usize,
594 pub decode: usize,
595}
596
597#[serde_with::skip_serializing_none]
605#[derive(Debug, Clone, Default, Serialize, Deserialize)]
606pub struct HealthCheckUpdate {
607 pub timeout_secs: Option<u64>,
608 pub check_interval_secs: Option<u64>,
609 pub success_threshold: Option<u32>,
610 pub failure_threshold: Option<u32>,
611 pub disable_health_check: Option<bool>,
612}
613
614impl HealthCheckUpdate {
615 pub fn is_empty(&self) -> bool {
617 self.timeout_secs.is_none()
618 && self.check_interval_secs.is_none()
619 && self.success_threshold.is_none()
620 && self.failure_threshold.is_none()
621 && self.disable_health_check.is_none()
622 }
623}
624
625impl HealthCheckUpdate {
626 pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
629 HealthCheckConfig {
630 timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
631 check_interval_secs: self
632 .check_interval_secs
633 .unwrap_or(existing.check_interval_secs),
634 success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
635 failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
636 disable_health_check: self
637 .disable_health_check
638 .unwrap_or(existing.disable_health_check),
639 }
640 }
641}
642
643#[serde_with::skip_serializing_none]
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct WorkerUpdateRequest {
647 pub priority: Option<u32>,
649
650 pub cost: Option<f32>,
652
653 pub labels: Option<HashMap<String, String>>,
655
656 pub api_key: Option<String>,
658
659 pub health: Option<HealthCheckUpdate>,
661}
662
663#[derive(Debug, Clone, Serialize)]
667pub struct WorkerApiResponse {
668 pub success: bool,
669 pub message: String,
670
671 #[serde(skip_serializing_if = "Option::is_none")]
672 pub worker: Option<WorkerInfo>,
673}
674
675#[derive(Debug, Clone, Serialize)]
677pub struct WorkerErrorResponse {
678 pub error: String,
679 pub code: String,
680}
681
682#[derive(Debug, Clone, Deserialize, Serialize)]
684pub struct FlushCacheResult {
685 pub successful: Vec<String>,
686 pub failed: Vec<(String, String)>,
687 pub total_workers: usize,
688 pub http_workers: usize,
689 pub message: String,
690}
691
692#[derive(Debug, Clone, Deserialize, Serialize)]
694pub struct WorkerLoadsResult {
695 pub loads: Vec<WorkerLoadInfo>,
696 pub total_workers: usize,
697 pub successful: usize,
698 pub failed: usize,
699}
700
701#[derive(Debug, Clone, Deserialize, Serialize)]
703pub struct WorkerLoadInfo {
704 pub worker: String,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 pub worker_type: Option<String>,
707 pub load: isize,
708}
709
710#[cfg(feature = "axum")]
711impl IntoResponse for FlushCacheResult {
712 fn into_response(self) -> Response {
713 let status = if self.failed.is_empty() {
714 StatusCode::OK
715 } else {
716 StatusCode::PARTIAL_CONTENT
717 };
718
719 let mut body = json!({
720 "status": if self.failed.is_empty() { "success" } else { "partial_success" },
721 "message": self.message,
722 "workers_flushed": self.successful.len(),
723 "total_http_workers": self.http_workers,
724 "total_workers": self.total_workers
725 });
726
727 if !self.failed.is_empty() {
728 body["successful"] = json!(self.successful);
729 body["failed"] = json!(self
730 .failed
731 .into_iter()
732 .map(|(url, err)| json!({"worker": url, "error": err}))
733 .collect::<Vec<_>>());
734 }
735
736 (status, Json(body)).into_response()
737 }
738}
739
740#[cfg(feature = "axum")]
741impl IntoResponse for WorkerLoadsResult {
742 fn into_response(self) -> Response {
743 let loads: Vec<Value> = self
744 .loads
745 .iter()
746 .map(|info| json!({"worker": &info.worker, "load": info.load}))
747 .collect();
748 Json(json!({"workers": loads})).into_response()
749 }
750}