1use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11 http::StatusCode,
12 response::{IntoResponse, Response},
13 Json,
14};
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16#[cfg(feature = "axum")]
17use serde_json::{json, Value};
18
19use super::model_card::ModelCard;
20
21pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
24pub const DEFAULT_WORKER_COST: f32 = 1.0;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
30#[serde(rename_all = "lowercase")]
31pub enum WorkerType {
32 #[default]
34 Regular,
35 Prefill,
37 Decode,
39}
40
41impl std::fmt::Display for WorkerType {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 WorkerType::Regular => write!(f, "regular"),
45 WorkerType::Prefill => write!(f, "prefill"),
46 WorkerType::Decode => write!(f, "decode"),
47 }
48 }
49}
50
51impl std::str::FromStr for WorkerType {
52 type Err = String;
53
54 fn from_str(s: &str) -> Result<Self, Self::Err> {
55 if s.eq_ignore_ascii_case("regular") {
56 Ok(WorkerType::Regular)
57 } else if s.eq_ignore_ascii_case("prefill") {
58 Ok(WorkerType::Prefill)
59 } else if s.eq_ignore_ascii_case("decode") {
60 Ok(WorkerType::Decode)
61 } else {
62 Err(format!("Unknown worker type: {}", s))
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
69#[serde(rename_all = "lowercase")]
70pub enum ConnectionMode {
71 #[default]
73 Http,
74 Grpc,
76}
77
78impl std::fmt::Display for ConnectionMode {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 ConnectionMode::Http => write!(f, "http"),
82 ConnectionMode::Grpc => write!(f, "grpc"),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
89#[serde(rename_all = "lowercase")]
90pub enum RuntimeType {
91 #[default]
93 Sglang,
94 Vllm,
96 Trtllm,
98 External,
100}
101
102impl std::fmt::Display for RuntimeType {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 RuntimeType::Sglang => write!(f, "sglang"),
106 RuntimeType::Vllm => write!(f, "vllm"),
107 RuntimeType::Trtllm => write!(f, "trtllm"),
108 RuntimeType::External => write!(f, "external"),
109 }
110 }
111}
112
113impl std::str::FromStr for RuntimeType {
114 type Err = String;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 if s.eq_ignore_ascii_case("sglang") {
118 Ok(RuntimeType::Sglang)
119 } else if s.eq_ignore_ascii_case("vllm") {
120 Ok(RuntimeType::Vllm)
121 } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
122 Ok(RuntimeType::Trtllm)
123 } else if s.eq_ignore_ascii_case("external") {
124 Ok(RuntimeType::External)
125 } else {
126 Err(format!("Unknown runtime type: {}", s))
127 }
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum ProviderType {
139 #[serde(alias = "openai")]
141 OpenAI,
142 #[serde(alias = "xai", alias = "grok")]
144 XAI,
145 #[serde(alias = "anthropic", alias = "claude")]
147 Anthropic,
148 #[serde(alias = "gemini", alias = "google")]
150 Gemini,
151 #[serde(untagged)]
153 Custom(String),
154}
155
156impl ProviderType {
157 pub fn as_str(&self) -> &str {
159 match self {
160 Self::OpenAI => "openai",
161 Self::XAI => "xai",
162 Self::Anthropic => "anthropic",
163 Self::Gemini => "gemini",
164 Self::Custom(s) => s.as_str(),
165 }
166 }
167
168 pub fn from_model_name(model: &str) -> Option<Self> {
171 let model_lower = model.to_lowercase();
172 if model_lower.starts_with("grok") {
173 Some(Self::XAI)
174 } else if model_lower.starts_with("gemini") {
175 Some(Self::Gemini)
176 } else if model_lower.starts_with("claude") {
177 Some(Self::Anthropic)
178 } else if model_lower.starts_with("gpt")
179 || model_lower.starts_with("o1")
180 || model_lower.starts_with("o3")
181 {
182 Some(Self::OpenAI)
183 } else {
184 None
185 }
186 }
187}
188
189impl std::fmt::Display for ProviderType {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(f, "{}", self.as_str())
192 }
193}
194
195fn default_priority() -> u32 {
198 DEFAULT_WORKER_PRIORITY
199}
200
201fn default_cost() -> f32 {
202 DEFAULT_WORKER_COST
203}
204
205fn default_health_check_timeout() -> u64 {
206 30
207}
208
209fn default_health_check_interval() -> u64 {
210 60
211}
212
213fn default_health_success_threshold() -> u32 {
214 2
215}
216
217fn default_health_failure_threshold() -> u32 {
218 3
219}
220
221fn default_max_connection_attempts() -> u32 {
222 20
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct HealthCheckConfig {
230 #[serde(default = "default_health_check_timeout")]
232 pub timeout_secs: u64,
233
234 #[serde(default = "default_health_check_interval")]
236 pub check_interval_secs: u64,
237
238 #[serde(default = "default_health_success_threshold")]
240 pub success_threshold: u32,
241
242 #[serde(default = "default_health_failure_threshold")]
244 pub failure_threshold: u32,
245
246 #[serde(default)]
248 pub disable_health_check: bool,
249}
250
251impl Default for HealthCheckConfig {
252 fn default() -> Self {
253 Self {
254 timeout_secs: default_health_check_timeout(),
255 check_interval_secs: default_health_check_interval(),
256 success_threshold: default_health_success_threshold(),
257 failure_threshold: default_health_failure_threshold(),
258 disable_health_check: false,
259 }
260 }
261}
262
263#[derive(Debug, Clone, Default)]
272pub enum WorkerModels {
273 #[default]
275 Wildcard,
276 Single(Box<ModelCard>),
278 Multi(Vec<ModelCard>),
280}
281
282impl WorkerModels {
283 pub fn is_wildcard(&self) -> bool {
285 matches!(self, Self::Wildcard)
286 }
287
288 pub fn primary(&self) -> Option<&ModelCard> {
290 match self {
291 Self::Wildcard => None,
292 Self::Single(card) => Some(card.as_ref()),
293 Self::Multi(cards) => cards.first(),
294 }
295 }
296
297 pub fn all(&self) -> &[ModelCard] {
299 match self {
300 Self::Wildcard => &[],
301 Self::Single(card) => std::slice::from_ref(card.as_ref()),
302 Self::Multi(cards) => cards,
303 }
304 }
305
306 pub fn find(&self, id: &str) -> Option<&ModelCard> {
308 match self {
309 Self::Wildcard => None,
310 Self::Single(card) => card.matches(id).then_some(card.as_ref()),
311 Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
312 }
313 }
314
315 pub fn supports(&self, id: &str) -> bool {
318 match self {
319 Self::Wildcard => true,
320 _ => self.find(id).is_some(),
321 }
322 }
323
324 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
326 self.all().iter()
327 }
328}
329
330impl From<Vec<ModelCard>> for WorkerModels {
331 fn from(models: Vec<ModelCard>) -> Self {
332 match models.len() {
333 0 => Self::Wildcard,
334 1 => Self::Single(Box::new(models.into_iter().next().unwrap())),
335 _ => Self::Multi(models),
336 }
337 }
338}
339
340impl Serialize for WorkerModels {
342 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
343 self.all().serialize(serializer)
344 }
345}
346
347impl<'de> Deserialize<'de> for WorkerModels {
349 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
350 let models = Vec::<ModelCard>::deserialize(deserializer)?;
351 Ok(Self::from(models))
352 }
353}
354
355#[serde_with::skip_serializing_none]
366#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct WorkerSpec {
368 pub url: String,
370
371 #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
373 pub models: WorkerModels,
374
375 #[serde(default)]
377 pub worker_type: WorkerType,
378
379 #[serde(default)]
381 pub connection_mode: ConnectionMode,
382
383 #[serde(default, alias = "runtime")]
385 pub runtime_type: RuntimeType,
386
387 pub provider: Option<ProviderType>,
390
391 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
393 pub labels: HashMap<String, String>,
394
395 #[serde(default = "default_priority")]
397 pub priority: u32,
398
399 #[serde(default = "default_cost")]
401 pub cost: f32,
402
403 #[serde(default, skip_serializing)]
405 pub api_key: Option<String>,
406
407 #[serde(default, skip_serializing_if = "Option::is_none")]
409 pub bootstrap_port: Option<u16>,
410
411 #[serde(default, skip)]
413 pub bootstrap_host: String,
414
415 #[serde(default, skip_serializing_if = "Option::is_none")]
417 pub kv_connector: Option<String>,
418
419 #[serde(default, skip_serializing_if = "Option::is_none")]
421 pub kv_role: Option<String>,
422
423 #[serde(default)]
425 pub health: HealthCheckConfig,
426
427 #[serde(default = "default_max_connection_attempts")]
429 pub max_connection_attempts: u32,
430}
431
432impl WorkerSpec {
433 pub fn new(url: impl Into<String>) -> Self {
435 Self {
436 url: url.into(),
437 models: WorkerModels::Wildcard,
438 worker_type: WorkerType::default(),
439 connection_mode: ConnectionMode::default(),
440 runtime_type: RuntimeType::default(),
441 provider: None,
442 labels: HashMap::new(),
443 priority: DEFAULT_WORKER_PRIORITY,
444 cost: DEFAULT_WORKER_COST,
445 api_key: None,
446 bootstrap_port: None,
447 bootstrap_host: String::new(),
448 kv_connector: None,
449 kv_role: None,
450 health: HealthCheckConfig::default(),
451 max_connection_attempts: default_max_connection_attempts(),
452 }
453 }
454}
455
456#[serde_with::skip_serializing_none]
460#[derive(Debug, Clone, Serialize)]
461pub struct WorkerInfo {
462 pub id: String,
464
465 #[serde(flatten)]
467 pub spec: WorkerSpec,
468
469 pub is_healthy: bool,
471
472 pub load: usize,
474
475 pub job_status: Option<JobStatus>,
477}
478
479impl WorkerInfo {
480 pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
482 Self {
483 id: worker_id.to_string(),
484 spec: WorkerSpec::new(url),
485 is_healthy: false,
486 load: 0,
487 job_status,
488 }
489 }
490}
491
492#[derive(Debug, Clone, Serialize, Deserialize)]
494pub struct JobStatus {
495 pub job_type: String,
496 pub worker_url: String,
497 pub status: String,
498 pub message: Option<String>,
499 pub timestamp: u64,
500}
501
502impl JobStatus {
503 pub fn pending(job_type: &str, worker_url: &str) -> Self {
505 Self {
506 job_type: job_type.to_string(),
507 worker_url: worker_url.to_string(),
508 status: "pending".to_string(),
509 message: None,
510 timestamp: std::time::SystemTime::now()
511 .duration_since(std::time::SystemTime::UNIX_EPOCH)
512 .unwrap_or_default()
513 .as_secs(),
514 }
515 }
516
517 pub fn processing(job_type: &str, worker_url: &str) -> Self {
519 Self {
520 job_type: job_type.to_string(),
521 worker_url: worker_url.to_string(),
522 status: "processing".to_string(),
523 message: None,
524 timestamp: std::time::SystemTime::now()
525 .duration_since(std::time::SystemTime::UNIX_EPOCH)
526 .unwrap_or_default()
527 .as_secs(),
528 }
529 }
530
531 pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
533 Self {
534 job_type: job_type.to_string(),
535 worker_url: worker_url.to_string(),
536 status: "failed".to_string(),
537 message: Some(error),
538 timestamp: std::time::SystemTime::now()
539 .duration_since(std::time::SystemTime::UNIX_EPOCH)
540 .unwrap_or_default()
541 .as_secs(),
542 }
543 }
544}
545
546#[derive(Debug, Clone, Serialize)]
548pub struct WorkerListResponse {
549 pub workers: Vec<WorkerInfo>,
550 pub total: usize,
551 pub stats: WorkerStats,
552}
553
554#[derive(Debug, Clone, Serialize)]
556pub struct WorkerStats {
557 pub total_workers: usize,
558 pub healthy_workers: usize,
559 pub total_models: usize,
560 pub total_load: usize,
561 pub by_type: WorkerTypeStats,
562}
563
564#[derive(Debug, Clone, Serialize)]
566pub struct WorkerTypeStats {
567 pub regular: usize,
568 pub prefill: usize,
569 pub decode: usize,
570}
571
572#[serde_with::skip_serializing_none]
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct HealthCheckUpdate {
582 pub timeout_secs: Option<u64>,
583 pub check_interval_secs: Option<u64>,
584 pub success_threshold: Option<u32>,
585 pub failure_threshold: Option<u32>,
586 pub disable_health_check: Option<bool>,
587}
588
589impl HealthCheckUpdate {
590 pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
593 HealthCheckConfig {
594 timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
595 check_interval_secs: self
596 .check_interval_secs
597 .unwrap_or(existing.check_interval_secs),
598 success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
599 failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
600 disable_health_check: self
601 .disable_health_check
602 .unwrap_or(existing.disable_health_check),
603 }
604 }
605}
606
607#[serde_with::skip_serializing_none]
609#[derive(Debug, Clone, Serialize, Deserialize)]
610pub struct WorkerUpdateRequest {
611 pub priority: Option<u32>,
613
614 pub cost: Option<f32>,
616
617 pub labels: Option<HashMap<String, String>>,
619
620 pub api_key: Option<String>,
622
623 pub health: Option<HealthCheckUpdate>,
625}
626
627#[derive(Debug, Clone, Serialize)]
631pub struct WorkerApiResponse {
632 pub success: bool,
633 pub message: String,
634
635 #[serde(skip_serializing_if = "Option::is_none")]
636 pub worker: Option<WorkerInfo>,
637}
638
639#[derive(Debug, Clone, Serialize)]
641pub struct WorkerErrorResponse {
642 pub error: String,
643 pub code: String,
644}
645
646#[derive(Debug, Clone, Deserialize, Serialize)]
648pub struct FlushCacheResult {
649 pub successful: Vec<String>,
650 pub failed: Vec<(String, String)>,
651 pub total_workers: usize,
652 pub http_workers: usize,
653 pub message: String,
654}
655
656#[derive(Debug, Clone, Deserialize, Serialize)]
658pub struct WorkerLoadsResult {
659 pub loads: Vec<WorkerLoadInfo>,
660 pub total_workers: usize,
661 pub successful: usize,
662 pub failed: usize,
663}
664
665#[derive(Debug, Clone, Deserialize, Serialize)]
667pub struct WorkerLoadInfo {
668 pub worker: String,
669 #[serde(skip_serializing_if = "Option::is_none")]
670 pub worker_type: Option<String>,
671 pub load: isize,
672}
673
674#[cfg(feature = "axum")]
675impl IntoResponse for FlushCacheResult {
676 fn into_response(self) -> Response {
677 let status = if self.failed.is_empty() {
678 StatusCode::OK
679 } else {
680 StatusCode::PARTIAL_CONTENT
681 };
682
683 let mut body = json!({
684 "status": if self.failed.is_empty() { "success" } else { "partial_success" },
685 "message": self.message,
686 "workers_flushed": self.successful.len(),
687 "total_http_workers": self.http_workers,
688 "total_workers": self.total_workers
689 });
690
691 if !self.failed.is_empty() {
692 body["successful"] = json!(self.successful);
693 body["failed"] = json!(self
694 .failed
695 .into_iter()
696 .map(|(url, err)| json!({"worker": url, "error": err}))
697 .collect::<Vec<_>>());
698 }
699
700 (status, Json(body)).into_response()
701 }
702}
703
704#[cfg(feature = "axum")]
705impl IntoResponse for WorkerLoadsResult {
706 fn into_response(self) -> Response {
707 let loads: Vec<Value> = self
708 .loads
709 .iter()
710 .map(|info| json!({"worker": &info.worker, "load": info.load}))
711 .collect();
712 Json(json!({"workers": loads})).into_response()
713 }
714}