1use std::collections::HashMap;
8
9#[cfg(feature = "axum")]
10use axum::{
11 http::StatusCode,
12 response::{IntoResponse, Response},
13 Json,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17#[cfg(feature = "axum")]
18use serde_json::{json, Value};
19
20use super::model_card::ModelCard;
21
22pub const DEFAULT_WORKER_PRIORITY: u32 = 50;
25pub const DEFAULT_WORKER_COST: f32 = 1.0;
26
27#[derive(
31 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
32)]
33#[serde(rename_all = "lowercase")]
34pub enum WorkerType {
35 #[default]
37 Regular,
38 Prefill,
40 Decode,
42}
43
44impl std::fmt::Display for WorkerType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 WorkerType::Regular => write!(f, "regular"),
48 WorkerType::Prefill => write!(f, "prefill"),
49 WorkerType::Decode => write!(f, "decode"),
50 }
51 }
52}
53
54impl std::str::FromStr for WorkerType {
55 type Err = String;
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 if s.eq_ignore_ascii_case("regular") {
59 Ok(WorkerType::Regular)
60 } else if s.eq_ignore_ascii_case("prefill") {
61 Ok(WorkerType::Prefill)
62 } else if s.eq_ignore_ascii_case("decode") {
63 Ok(WorkerType::Decode)
64 } else {
65 Err(format!("Unknown worker type: {s}"))
66 }
67 }
68}
69
70#[derive(
72 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
73)]
74#[serde(rename_all = "lowercase")]
75pub enum ConnectionMode {
76 #[default]
78 Http,
79 Grpc,
81}
82
83impl std::fmt::Display for ConnectionMode {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 ConnectionMode::Http => write!(f, "http"),
87 ConnectionMode::Grpc => write!(f, "grpc"),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct WorkerGroupKey {
98 pub model_id: String,
99 pub worker_type: WorkerType,
100 pub connection_mode: ConnectionMode,
101}
102
103impl std::fmt::Display for WorkerGroupKey {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "{}:{}:{}",
108 self.model_id, self.worker_type, self.connection_mode
109 )
110 }
111}
112
113#[derive(
115 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, schemars::JsonSchema,
116)]
117#[serde(rename_all = "lowercase")]
118pub enum RuntimeType {
119 #[default]
121 Unspecified,
122 Sglang,
124 Vllm,
126 Trtllm,
128 External,
130}
131
132impl RuntimeType {
133 pub fn is_specified(self) -> bool {
135 !matches!(self, RuntimeType::Unspecified)
136 }
137}
138
139impl std::fmt::Display for RuntimeType {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 RuntimeType::Unspecified => write!(f, "unspecified"),
143 RuntimeType::Sglang => write!(f, "sglang"),
144 RuntimeType::Vllm => write!(f, "vllm"),
145 RuntimeType::Trtllm => write!(f, "trtllm"),
146 RuntimeType::External => write!(f, "external"),
147 }
148 }
149}
150
151impl std::str::FromStr for RuntimeType {
152 type Err = String;
153
154 fn from_str(s: &str) -> Result<Self, Self::Err> {
155 if s.eq_ignore_ascii_case("unspecified") {
156 Ok(RuntimeType::Unspecified)
157 } else if s.eq_ignore_ascii_case("sglang") {
158 Ok(RuntimeType::Sglang)
159 } else if s.eq_ignore_ascii_case("vllm") {
160 Ok(RuntimeType::Vllm)
161 } else if s.eq_ignore_ascii_case("trtllm") || s.eq_ignore_ascii_case("tensorrt-llm") {
162 Ok(RuntimeType::Trtllm)
163 } else if s.eq_ignore_ascii_case("external") {
164 Ok(RuntimeType::External)
165 } else {
166 Err(format!("Unknown runtime type: {s}"))
167 }
168 }
169}
170
171#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
177#[serde(rename_all = "lowercase")]
178pub enum ProviderType {
179 #[serde(alias = "openai")]
181 OpenAI,
182 #[serde(alias = "xai", alias = "grok")]
184 #[expect(
185 clippy::upper_case_acronyms,
186 reason = "xAI is a proper company name; XAI matches industry convention and existing serde aliases"
187 )]
188 XAI,
189 #[serde(alias = "anthropic", alias = "claude")]
191 Anthropic,
192 #[serde(alias = "gemini", alias = "google")]
194 Gemini,
195 #[serde(untagged)]
197 Custom(String),
198}
199
200impl ProviderType {
201 pub fn as_str(&self) -> &str {
203 match self {
204 Self::OpenAI => "openai",
205 Self::XAI => "xai",
206 Self::Anthropic => "anthropic",
207 Self::Gemini => "gemini",
208 Self::Custom(s) => s.as_str(),
209 }
210 }
211
212 pub fn from_url(url: &str) -> Option<Self> {
215 let host = url::Url::parse(url).ok()?.host_str()?.to_lowercase();
216
217 if host.ends_with("openai.com") {
218 Some(Self::OpenAI)
219 } else if host.ends_with("x.ai") {
220 Some(Self::XAI)
221 } else if host.ends_with("anthropic.com") {
222 Some(Self::Anthropic)
223 } else if host.ends_with("googleapis.com") {
224 Some(Self::Gemini)
225 } else {
226 None
227 }
228 }
229
230 pub fn admin_key_env_var(&self) -> Option<&'static str> {
233 match self {
234 Self::OpenAI => Some("OPENAI_ADMIN_KEY"),
235 Self::XAI => Some("XAI_ADMIN_KEY"),
236 Self::Anthropic => Some("ANTHROPIC_ADMIN_KEY"),
237 Self::Gemini => Some("GEMINI_ADMIN_KEY"),
238 Self::Custom(_) => None,
239 }
240 }
241
242 pub fn uses_x_api_key(&self) -> bool {
244 matches!(self, Self::Anthropic)
245 }
246
247 pub fn from_model_name(model: &str) -> Option<Self> {
250 let model_lower = model.to_lowercase();
251 if model_lower.starts_with("grok") {
252 Some(Self::XAI)
253 } else if model_lower.starts_with("gemini") {
254 Some(Self::Gemini)
255 } else if model_lower.starts_with("claude") {
256 Some(Self::Anthropic)
257 } else if model_lower.starts_with("gpt")
258 || model_lower.starts_with("o1")
259 || model_lower.starts_with("o3")
260 {
261 Some(Self::OpenAI)
262 } else {
263 None
264 }
265 }
266}
267
268impl std::fmt::Display for ProviderType {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 write!(f, "{}", self.as_str())
271 }
272}
273
274fn default_priority() -> u32 {
277 DEFAULT_WORKER_PRIORITY
278}
279
280fn default_cost() -> f32 {
281 DEFAULT_WORKER_COST
282}
283
284fn default_health_check_timeout() -> u64 {
285 30
286}
287
288fn default_health_check_interval() -> u64 {
289 60
290}
291
292fn default_health_success_threshold() -> u32 {
293 2
294}
295
296fn default_health_failure_threshold() -> u32 {
297 3
298}
299
300fn default_max_connection_attempts() -> u32 {
301 20
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
308pub struct HealthCheckConfig {
309 #[serde(default = "default_health_check_timeout")]
311 pub timeout_secs: u64,
312
313 #[serde(default = "default_health_check_interval")]
315 pub check_interval_secs: u64,
316
317 #[serde(default = "default_health_success_threshold")]
319 pub success_threshold: u32,
320
321 #[serde(default = "default_health_failure_threshold")]
323 pub failure_threshold: u32,
324
325 #[serde(default)]
327 pub disable_health_check: bool,
328}
329
330impl Default for HealthCheckConfig {
331 fn default() -> Self {
332 Self {
333 timeout_secs: default_health_check_timeout(),
334 check_interval_secs: default_health_check_interval(),
335 success_threshold: default_health_success_threshold(),
336 failure_threshold: default_health_failure_threshold(),
337 disable_health_check: false,
338 }
339 }
340}
341
342#[derive(Debug, Clone, Default)]
351pub enum WorkerModels {
352 #[default]
354 Wildcard,
355 Single(Box<ModelCard>),
357 Multi(Vec<ModelCard>),
359}
360
361impl WorkerModels {
362 pub fn is_wildcard(&self) -> bool {
364 matches!(self, Self::Wildcard)
365 }
366
367 pub fn primary(&self) -> Option<&ModelCard> {
369 match self {
370 Self::Wildcard => None,
371 Self::Single(card) => Some(card.as_ref()),
372 Self::Multi(cards) => cards.first(),
373 }
374 }
375
376 pub fn all(&self) -> &[ModelCard] {
378 match self {
379 Self::Wildcard => &[],
380 Self::Single(card) => std::slice::from_ref(card.as_ref()),
381 Self::Multi(cards) => cards,
382 }
383 }
384
385 pub fn find(&self, id: &str) -> Option<&ModelCard> {
387 match self {
388 Self::Wildcard => None,
389 Self::Single(card) => card.matches(id).then_some(card.as_ref()),
390 Self::Multi(cards) => cards.iter().find(|m| m.matches(id)),
391 }
392 }
393
394 pub fn supports(&self, id: &str) -> bool {
397 match self {
398 Self::Wildcard => true,
399 _ => self.find(id).is_some(),
400 }
401 }
402
403 pub fn iter(&self) -> impl Iterator<Item = &ModelCard> {
405 self.all().iter()
406 }
407}
408
409impl From<Vec<ModelCard>> for WorkerModels {
410 fn from(models: Vec<ModelCard>) -> Self {
411 match models.len() {
412 0 => Self::Wildcard,
413 1 => {
414 let Some(model) = models.into_iter().next() else {
415 return Self::Wildcard;
416 };
417 Self::Single(Box::new(model))
418 }
419 _ => Self::Multi(models),
420 }
421 }
422}
423
424impl Serialize for WorkerModels {
426 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
427 self.all().serialize(serializer)
428 }
429}
430
431impl<'de> Deserialize<'de> for WorkerModels {
433 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
434 let models = Vec::<ModelCard>::deserialize(deserializer)?;
435 Ok(Self::from(models))
436 }
437}
438
439impl JsonSchema for WorkerModels {
441 fn schema_name() -> String {
442 "WorkerModels".to_string()
443 }
444
445 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
446 Vec::<ModelCard>::json_schema(gen)
447 }
448}
449
450#[serde_with::skip_serializing_none]
461#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
462pub struct WorkerSpec {
463 pub url: String,
465
466 #[serde(default, skip_serializing_if = "WorkerModels::is_wildcard")]
468 pub models: WorkerModels,
469
470 #[serde(default)]
472 pub worker_type: WorkerType,
473
474 #[serde(default)]
476 pub connection_mode: ConnectionMode,
477
478 #[serde(default, alias = "runtime")]
480 pub runtime_type: RuntimeType,
481
482 pub provider: Option<ProviderType>,
485
486 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
488 pub labels: HashMap<String, String>,
489
490 #[serde(default = "default_priority")]
492 pub priority: u32,
493
494 #[serde(default = "default_cost")]
496 pub cost: f32,
497
498 #[serde(default, skip_serializing)]
500 pub api_key: Option<String>,
501
502 #[serde(default, skip_serializing_if = "Option::is_none")]
504 pub bootstrap_port: Option<u16>,
505
506 #[serde(default, skip)]
508 pub bootstrap_host: String,
509
510 #[serde(default, skip_serializing_if = "Option::is_none")]
513 pub dp_base_url: Option<String>,
514
515 #[serde(default, skip_serializing_if = "Option::is_none")]
517 pub dp_rank: Option<usize>,
518
519 #[serde(default, skip_serializing_if = "Option::is_none")]
521 pub dp_size: Option<usize>,
522
523 #[serde(default, skip_serializing_if = "Option::is_none")]
525 pub kv_connector: Option<String>,
526
527 #[serde(default, skip_serializing_if = "Option::is_none")]
529 pub kv_role: Option<String>,
530
531 #[serde(default, skip_serializing_if = "Option::is_none")]
535 pub kv_block_size: Option<usize>,
536
537 #[serde(default, skip_serializing_if = "HealthCheckUpdate::is_empty")]
539 pub health: HealthCheckUpdate,
540
541 #[serde(default, skip_serializing_if = "HttpPoolConfig::is_empty")]
543 pub http_pool: HttpPoolConfig,
544
545 #[serde(default, skip_serializing_if = "ResilienceUpdate::is_empty")]
547 pub resilience: ResilienceUpdate,
548
549 #[serde(default = "default_max_connection_attempts")]
551 pub max_connection_attempts: u32,
552
553 #[serde(default, skip_serializing_if = "Option::is_none")]
557 pub load_monitor_interval_secs: Option<u64>,
558}
559
560impl WorkerSpec {
561 pub fn new(url: impl Into<String>) -> Self {
563 Self {
564 url: url.into(),
565 models: WorkerModels::Wildcard,
566 worker_type: WorkerType::default(),
567 connection_mode: ConnectionMode::default(),
568 runtime_type: RuntimeType::default(),
569 provider: None,
570 labels: HashMap::new(),
571 priority: DEFAULT_WORKER_PRIORITY,
572 cost: DEFAULT_WORKER_COST,
573 api_key: None,
574 bootstrap_port: None,
575 bootstrap_host: String::new(),
576 dp_base_url: None,
577 dp_rank: None,
578 dp_size: None,
579 kv_connector: None,
580 kv_role: None,
581 kv_block_size: None,
582 health: HealthCheckUpdate::default(),
583 http_pool: HttpPoolConfig::default(),
584 resilience: ResilienceUpdate::default(),
585 max_connection_attempts: default_max_connection_attempts(),
586 load_monitor_interval_secs: None,
587 }
588 }
589}
590
591#[serde_with::skip_serializing_none]
595#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
596pub struct WorkerInfo {
597 pub id: String,
599
600 #[serde(default, skip_serializing_if = "Option::is_none")]
603 pub model_id: Option<String>,
604
605 #[serde(flatten)]
607 pub spec: WorkerSpec,
608
609 pub is_healthy: bool,
611
612 pub load: usize,
614
615 pub job_status: Option<JobStatus>,
617}
618
619impl WorkerInfo {
620 pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
622 Self {
623 id: worker_id.to_string(),
624 model_id: None,
625 spec: WorkerSpec::new(url),
626 is_healthy: false,
627 load: 0,
628 job_status,
629 }
630 }
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
635pub struct JobStatus {
636 pub job_type: String,
637 pub worker_url: String,
638 pub status: String,
639 pub message: Option<String>,
640 pub timestamp: u64,
641}
642
643impl JobStatus {
644 pub fn pending(job_type: &str, worker_url: &str) -> Self {
646 Self {
647 job_type: job_type.to_string(),
648 worker_url: worker_url.to_string(),
649 status: "pending".to_string(),
650 message: None,
651 timestamp: std::time::SystemTime::now()
652 .duration_since(std::time::SystemTime::UNIX_EPOCH)
653 .unwrap_or_default()
654 .as_secs(),
655 }
656 }
657
658 pub fn processing(job_type: &str, worker_url: &str) -> Self {
660 Self {
661 job_type: job_type.to_string(),
662 worker_url: worker_url.to_string(),
663 status: "processing".to_string(),
664 message: None,
665 timestamp: std::time::SystemTime::now()
666 .duration_since(std::time::SystemTime::UNIX_EPOCH)
667 .unwrap_or_default()
668 .as_secs(),
669 }
670 }
671
672 pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
674 Self {
675 job_type: job_type.to_string(),
676 worker_url: worker_url.to_string(),
677 status: "failed".to_string(),
678 message: Some(error),
679 timestamp: std::time::SystemTime::now()
680 .duration_since(std::time::SystemTime::UNIX_EPOCH)
681 .unwrap_or_default()
682 .as_secs(),
683 }
684 }
685}
686
687#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
689pub struct WorkerListResponse {
690 pub workers: Vec<WorkerInfo>,
691 pub total: usize,
692 pub stats: WorkerStats,
693}
694
695#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
697pub struct WorkerStats {
698 pub total_workers: usize,
699 pub healthy_workers: usize,
700 pub total_models: usize,
701 pub total_load: usize,
702 pub by_type: WorkerTypeStats,
703}
704
705#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
707pub struct WorkerTypeStats {
708 pub regular: usize,
709 pub prefill: usize,
710 pub decode: usize,
711}
712
713#[serde_with::skip_serializing_none]
721#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
722pub struct HealthCheckUpdate {
723 pub timeout_secs: Option<u64>,
724 pub check_interval_secs: Option<u64>,
725 pub success_threshold: Option<u32>,
726 pub failure_threshold: Option<u32>,
727 pub disable_health_check: Option<bool>,
728}
729
730impl HealthCheckUpdate {
731 pub fn is_empty(&self) -> bool {
733 self.timeout_secs.is_none()
734 && self.check_interval_secs.is_none()
735 && self.success_threshold.is_none()
736 && self.failure_threshold.is_none()
737 && self.disable_health_check.is_none()
738 }
739}
740
741impl HealthCheckUpdate {
742 pub fn apply_to(&self, existing: &HealthCheckConfig) -> HealthCheckConfig {
745 HealthCheckConfig {
746 timeout_secs: self.timeout_secs.unwrap_or(existing.timeout_secs),
747 check_interval_secs: self
748 .check_interval_secs
749 .unwrap_or(existing.check_interval_secs),
750 success_threshold: self.success_threshold.unwrap_or(existing.success_threshold),
751 failure_threshold: self.failure_threshold.unwrap_or(existing.failure_threshold),
752 disable_health_check: self
753 .disable_health_check
754 .unwrap_or(existing.disable_health_check),
755 }
756 }
757}
758
759#[serde_with::skip_serializing_none]
762#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
763pub struct HttpPoolConfig {
764 pub pool_max_idle_per_host: Option<usize>,
766 pub pool_idle_timeout_secs: Option<u64>,
768 pub timeout_secs: Option<u64>,
770 pub connect_timeout_secs: Option<u64>,
772}
773
774impl HttpPoolConfig {
775 pub fn is_empty(&self) -> bool {
777 self.pool_max_idle_per_host.is_none()
778 && self.pool_idle_timeout_secs.is_none()
779 && self.timeout_secs.is_none()
780 && self.connect_timeout_secs.is_none()
781 }
782}
783
784#[serde_with::skip_serializing_none]
788#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
789pub struct ResilienceUpdate {
790 pub max_retries: Option<u32>,
793 pub initial_backoff_ms: Option<u64>,
795 pub max_backoff_ms: Option<u64>,
797 pub backoff_multiplier: Option<f32>,
799 pub jitter_factor: Option<f32>,
801 pub disable_retry: Option<bool>,
803
804 pub cb_failure_threshold: Option<u32>,
807 pub cb_success_threshold: Option<u32>,
809 pub cb_timeout_secs: Option<u64>,
811 pub cb_window_secs: Option<u64>,
813 pub disable_circuit_breaker: Option<bool>,
815
816 pub retryable_status_codes: Option<Vec<u16>>,
820}
821
822impl ResilienceUpdate {
823 pub fn is_empty(&self) -> bool {
825 self.max_retries.is_none()
826 && self.initial_backoff_ms.is_none()
827 && self.max_backoff_ms.is_none()
828 && self.backoff_multiplier.is_none()
829 && self.jitter_factor.is_none()
830 && self.disable_retry.is_none()
831 && self.cb_failure_threshold.is_none()
832 && self.cb_success_threshold.is_none()
833 && self.cb_timeout_secs.is_none()
834 && self.cb_window_secs.is_none()
835 && self.disable_circuit_breaker.is_none()
836 && self.retryable_status_codes.is_none()
837 }
838}
839
840#[serde_with::skip_serializing_none]
842#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
843pub struct WorkerUpdateRequest {
844 pub priority: Option<u32>,
846
847 pub cost: Option<f32>,
849
850 pub labels: Option<HashMap<String, String>>,
852
853 pub api_key: Option<String>,
855
856 pub health: Option<HealthCheckUpdate>,
858}
859
860#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
864pub struct WorkerApiResponse {
865 pub success: bool,
866 pub message: String,
867
868 #[serde(skip_serializing_if = "Option::is_none")]
869 pub worker: Option<WorkerInfo>,
870}
871
872#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
874pub struct WorkerErrorResponse {
875 pub error: String,
876 pub code: String,
877}
878
879#[derive(Debug, Clone, Deserialize, Serialize)]
881pub struct FlushCacheResult {
882 pub successful: Vec<String>,
883 pub failed: Vec<(String, String)>,
884 pub total_workers: usize,
885 pub http_workers: usize,
886 pub message: String,
887}
888
889#[derive(Debug, Clone, Deserialize, Serialize)]
891pub struct WorkerLoadsResult {
892 pub loads: Vec<WorkerLoadInfo>,
893 pub total_workers: usize,
894 pub successful: usize,
895 pub failed: usize,
896}
897
898#[derive(Debug, Clone, Default, Serialize, Deserialize)]
903#[serde(default)]
904pub struct SchedulerLoadSnapshot {
905 pub dp_rank: i32,
906 pub num_running_reqs: i32,
907 pub num_waiting_reqs: i32,
908 pub num_total_reqs: i32,
909 pub num_used_tokens: i32,
910 pub max_total_num_tokens: i32,
911 pub token_usage: f64,
913 pub gen_throughput: f64,
914 pub cache_hit_rate: f64,
915 pub utilization: f64,
916 pub max_running_requests: i32,
917}
918
919#[derive(Debug, Clone, Default, Serialize, Deserialize)]
921#[serde(default)]
922pub struct WorkerLoadResponse {
923 pub timestamp: String,
924 pub dp_rank_count: i32,
925 pub loads: Vec<SchedulerLoadSnapshot>,
926}
927
928impl WorkerLoadResponse {
929 pub fn effective_token_usage(&self) -> f64 {
931 if self.loads.is_empty() {
932 return 0.0;
933 }
934 self.loads.iter().map(|l| l.token_usage).sum::<f64>() / self.loads.len() as f64
935 }
936
937 pub fn total_used_tokens(&self) -> i64 {
939 self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
940 }
941}
942
943#[derive(Debug, Clone, Deserialize, Serialize)]
945pub struct WorkerLoadInfo {
946 pub worker: String,
947 #[serde(skip_serializing_if = "Option::is_none")]
948 pub worker_type: Option<String>,
949 pub load: isize,
950 #[serde(skip_serializing_if = "Option::is_none")]
951 pub details: Option<WorkerLoadResponse>,
952}
953
954#[cfg(feature = "axum")]
955impl IntoResponse for FlushCacheResult {
956 fn into_response(self) -> Response {
957 let status = if self.failed.is_empty() {
958 StatusCode::OK
959 } else {
960 StatusCode::PARTIAL_CONTENT
961 };
962
963 let mut body = json!({
964 "status": if self.failed.is_empty() { "success" } else { "partial_success" },
965 "message": self.message,
966 "workers_flushed": self.successful.len(),
967 "total_http_workers": self.http_workers,
968 "total_workers": self.total_workers
969 });
970
971 if !self.failed.is_empty() {
972 body["successful"] = json!(self.successful);
973 body["failed"] = json!(self
974 .failed
975 .into_iter()
976 .map(|(url, err)| json!({"worker": url, "error": err}))
977 .collect::<Vec<_>>());
978 }
979
980 (status, Json(body)).into_response()
981 }
982}
983
984#[cfg(feature = "axum")]
985impl IntoResponse for WorkerLoadsResult {
986 fn into_response(self) -> Response {
987 let loads: Vec<Value> = self
988 .loads
989 .iter()
990 .map(|info| {
991 let mut entry = json!({"worker": &info.worker, "load": info.load});
992 if let Some(ref details) = info.details {
993 entry["details"] = json!(details);
994 }
995 entry
996 })
997 .collect();
998 Json(json!({"workers": loads})).into_response()
999 }
1000}