openai_protocol/
worker_spec.rs

1//! Worker management API specifications
2//!
3//! Defines the request/response structures for worker management endpoints
4
5use std::collections::HashMap;
6
7#[cfg(feature = "axum")]
8use axum::{
9    http::StatusCode,
10    response::{IntoResponse, Response},
11    Json,
12};
13use serde::{Deserialize, Serialize};
14#[cfg(feature = "axum")]
15use serde_json::{json, Value};
16
17use super::UNKNOWN_MODEL_ID;
18
19/// Worker configuration for API requests
20#[derive(Debug, Clone, Deserialize, Serialize)]
21pub struct WorkerConfigRequest {
22    /// Worker URL (required)
23    pub url: String,
24
25    /// Worker API key (optional)
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub api_key: Option<String>,
28
29    /// Model ID (optional, will query from server if not provided)
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub model_id: Option<String>,
32
33    /// Worker priority (optional, default: 50, higher = preferred)
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub priority: Option<u32>,
36
37    /// Worker cost factor (optional, default: 1.0)
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub cost: Option<f32>,
40
41    /// Worker type (optional: "regular", "prefill", "decode")
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub worker_type: Option<String>,
44
45    /// Bootstrap port for prefill workers (optional)
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub bootstrap_port: Option<u16>,
48
49    /// Runtime type (optional: "sglang", "vllm", default: "sglang")
50    /// Only relevant for gRPC workers
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub runtime: Option<String>,
53
54    // gRPC-specific configuration (optional, ignored in HTTP mode)
55    /// Tokenizer path for gRPC mode
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub tokenizer_path: Option<String>,
58
59    /// Reasoning parser type for gRPC mode
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub reasoning_parser: Option<String>,
62
63    /// Tool parser type for gRPC mode
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub tool_parser: Option<String>,
66
67    /// Chat template for gRPC mode
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub chat_template: Option<String>,
70
71    /// Additional labels (optional)
72    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73    pub labels: HashMap<String, String>,
74
75    /// Health check timeout in seconds (default: 30)
76    #[serde(default = "default_health_check_timeout")]
77    pub health_check_timeout_secs: u64,
78
79    /// Health check interval in seconds (default: 60)
80    #[serde(default = "default_health_check_interval")]
81    pub health_check_interval_secs: u64,
82
83    /// Number of successful health checks needed to mark worker as healthy (default: 2)
84    #[serde(default = "default_health_success_threshold")]
85    pub health_success_threshold: u32,
86
87    /// Number of failed health checks before marking worker as unhealthy (default: 3)
88    #[serde(default = "default_health_failure_threshold")]
89    pub health_failure_threshold: u32,
90
91    /// Disable periodic health checks for this worker (default: false)
92    #[serde(default)]
93    pub disable_health_check: bool,
94
95    /// Maximum connection attempts during worker registration (default: 20)
96    #[serde(default = "default_max_connection_attempts")]
97    pub max_connection_attempts: u32,
98
99    /// Enable data parallelism aware scheduling (default: false)
100    #[serde(default)]
101    pub dp_aware: bool,
102}
103
104// Default value functions for serde
105fn default_health_check_timeout() -> u64 {
106    30
107}
108
109fn default_health_check_interval() -> u64 {
110    60
111}
112
113fn default_health_success_threshold() -> u32 {
114    2
115}
116
117fn default_health_failure_threshold() -> u32 {
118    3
119}
120
121fn default_max_connection_attempts() -> u32 {
122    20
123}
124
125/// Worker information for API responses
126#[derive(Debug, Clone, Serialize)]
127pub struct WorkerInfo {
128    /// Worker unique identifier
129    pub id: String,
130
131    /// Worker URL
132    pub url: String,
133
134    /// Model ID this worker serves
135    pub model_id: String,
136
137    /// Worker priority
138    pub priority: u32,
139
140    /// Worker cost factor
141    pub cost: f32,
142
143    /// Worker type
144    pub worker_type: String,
145
146    /// Whether the worker is healthy
147    pub is_healthy: bool,
148
149    /// Current load on the worker
150    pub load: usize,
151
152    /// Connection mode (http or grpc)
153    pub connection_mode: String,
154
155    /// Runtime type (sglang or vllm, for gRPC workers)
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub runtime_type: Option<String>,
158
159    // gRPC-specific fields (None for HTTP workers)
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub tokenizer_path: Option<String>,
162
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub reasoning_parser: Option<String>,
165
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub tool_parser: Option<String>,
168
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub chat_template: Option<String>,
171
172    /// Bootstrap port for prefill workers
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub bootstrap_port: Option<u16>,
175
176    /// Additional metadata
177    #[serde(skip_serializing_if = "HashMap::is_empty")]
178    pub metadata: HashMap<String, String>,
179
180    /// Whether health checks are disabled for this worker
181    pub disable_health_check: bool,
182
183    /// Job status for async operations (if available)
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub job_status: Option<JobStatus>,
186}
187
188impl WorkerInfo {
189    /// Create a partial WorkerInfo for pending workers (not yet registered).
190    /// Used when a worker ID maps to a URL but the worker is still being registered.
191    pub fn pending(worker_id: &str, url: String, job_status: Option<JobStatus>) -> Self {
192        Self {
193            id: worker_id.to_string(),
194            url,
195            model_id: UNKNOWN_MODEL_ID.to_string(),
196            priority: 0,
197            cost: 1.0,
198            worker_type: UNKNOWN_MODEL_ID.to_string(),
199            is_healthy: false,
200            load: 0,
201            connection_mode: UNKNOWN_MODEL_ID.to_string(),
202            runtime_type: None,
203            tokenizer_path: None,
204            reasoning_parser: None,
205            tool_parser: None,
206            chat_template: None,
207            bootstrap_port: None,
208            metadata: HashMap::new(),
209            disable_health_check: false,
210            job_status,
211        }
212    }
213}
214
215/// Job status for async control plane operations
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct JobStatus {
218    pub job_type: String,
219    pub worker_url: String,
220    pub status: String,
221    pub message: Option<String>,
222    pub timestamp: u64,
223}
224
225impl JobStatus {
226    /// Create a pending job status
227    pub fn pending(job_type: &str, worker_url: &str) -> Self {
228        Self {
229            job_type: job_type.to_string(),
230            worker_url: worker_url.to_string(),
231            status: "pending".to_string(),
232            message: None,
233            timestamp: std::time::SystemTime::now()
234                .duration_since(std::time::SystemTime::UNIX_EPOCH)
235                .unwrap()
236                .as_secs(),
237        }
238    }
239
240    /// Create a processing job status
241    pub fn processing(job_type: &str, worker_url: &str) -> Self {
242        Self {
243            job_type: job_type.to_string(),
244            worker_url: worker_url.to_string(),
245            status: "processing".to_string(),
246            message: None,
247            timestamp: std::time::SystemTime::now()
248                .duration_since(std::time::SystemTime::UNIX_EPOCH)
249                .unwrap()
250                .as_secs(),
251        }
252    }
253
254    /// Create a failed job status
255    pub fn failed(job_type: &str, worker_url: &str, error: String) -> Self {
256        Self {
257            job_type: job_type.to_string(),
258            worker_url: worker_url.to_string(),
259            status: "failed".to_string(),
260            message: Some(error),
261            timestamp: std::time::SystemTime::now()
262                .duration_since(std::time::SystemTime::UNIX_EPOCH)
263                .unwrap()
264                .as_secs(),
265        }
266    }
267}
268
269/// Worker list response
270#[derive(Debug, Clone, Serialize)]
271pub struct WorkerListResponse {
272    /// List of workers
273    pub workers: Vec<WorkerInfo>,
274
275    /// Total count
276    pub total: usize,
277
278    /// Statistics
279    pub stats: WorkerStats,
280}
281
282/// Worker statistics
283#[derive(Debug, Clone, Serialize)]
284pub struct WorkerStats {
285    pub total_workers: usize,
286    pub healthy_workers: usize,
287    pub total_models: usize,
288    pub total_load: usize,
289    pub by_type: WorkerTypeStats,
290}
291
292/// Worker statistics by type
293#[derive(Debug, Clone, Serialize)]
294pub struct WorkerTypeStats {
295    pub regular: usize,
296    pub prefill: usize,
297    pub decode: usize,
298}
299
300/// Worker update request
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct WorkerUpdateRequest {
303    /// Update priority
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub priority: Option<u32>,
306
307    /// Update cost
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub cost: Option<f32>,
310
311    /// Update labels
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub labels: Option<HashMap<String, String>>,
314
315    /// Update API key (for key rotation)
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub api_key: Option<String>,
318
319    /// Update health check timeout in seconds
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub health_check_timeout_secs: Option<u64>,
322
323    /// Update health check interval in seconds
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub health_check_interval_secs: Option<u64>,
326
327    /// Update health success threshold
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub health_success_threshold: Option<u32>,
330
331    /// Update health failure threshold
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub health_failure_threshold: Option<u32>,
334
335    /// Disable periodic health checks for this worker
336    #[serde(skip_serializing_if = "Option::is_none")]
337    pub disable_health_check: Option<bool>,
338}
339
340/// Generic API response
341#[derive(Debug, Clone, Serialize)]
342pub struct WorkerApiResponse {
343    pub success: bool,
344    pub message: String,
345
346    #[serde(skip_serializing_if = "Option::is_none")]
347    pub worker: Option<WorkerInfo>,
348}
349
350/// Error response
351#[derive(Debug, Clone, Serialize)]
352pub struct WorkerErrorResponse {
353    pub error: String,
354    pub code: String,
355}
356
357/// Server info response from /get_server_info endpoint
358#[derive(Debug, Clone, Deserialize)]
359pub struct ServerInfo {
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub model_id: Option<String>,
362
363    #[serde(skip_serializing_if = "Option::is_none")]
364    pub model_path: Option<String>,
365
366    #[serde(skip_serializing_if = "Option::is_none")]
367    pub priority: Option<u32>,
368
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub cost: Option<f32>,
371
372    #[serde(skip_serializing_if = "Option::is_none")]
373    pub worker_type: Option<String>,
374
375    // gRPC-specific
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub tokenizer_path: Option<String>,
378
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub reasoning_parser: Option<String>,
381
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub tool_parser: Option<String>,
384
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub chat_template: Option<String>,
387}
388
389/// Result from flush cache operations across workers
390#[derive(Debug, Clone, Deserialize, Serialize)]
391pub struct FlushCacheResult {
392    /// URLs of workers where cache flush succeeded
393    pub successful: Vec<String>,
394    /// URLs and error messages for workers where cache flush failed
395    pub failed: Vec<(String, String)>,
396    /// Total number of workers attempted
397    pub total_workers: usize,
398    /// Number of HTTP workers (gRPC workers don't support flush cache)
399    pub http_workers: usize,
400    /// Human-readable summary message
401    pub message: String,
402}
403
404/// Result from getting worker loads
405#[derive(Debug, Clone, Deserialize, Serialize)]
406pub struct WorkerLoadsResult {
407    /// Worker URL and load pairs
408    pub loads: Vec<WorkerLoadInfo>,
409    /// Total number of workers
410    pub total_workers: usize,
411    /// Number of workers with successful load fetches
412    pub successful: usize,
413    /// Number of workers with failed load fetches
414    pub failed: usize,
415}
416
417/// Individual worker load information
418#[derive(Debug, Clone, Deserialize, Serialize)]
419pub struct WorkerLoadInfo {
420    /// Worker URL
421    pub worker: String,
422    /// Worker type (regular, prefill, decode)
423    #[serde(skip_serializing_if = "Option::is_none")]
424    pub worker_type: Option<String>,
425    /// Current load (-1 indicates failure to fetch)
426    pub load: isize,
427}
428
429#[cfg(feature = "axum")]
430impl IntoResponse for FlushCacheResult {
431    fn into_response(self) -> Response {
432        let status = if self.failed.is_empty() {
433            StatusCode::OK
434        } else {
435            StatusCode::PARTIAL_CONTENT
436        };
437
438        let mut body = json!({
439            "status": if self.failed.is_empty() { "success" } else { "partial_success" },
440            "message": self.message,
441            "workers_flushed": self.successful.len(),
442            "total_http_workers": self.http_workers,
443            "total_workers": self.total_workers
444        });
445
446        if !self.failed.is_empty() {
447            body["successful"] = json!(self.successful);
448            body["failed"] = json!(self
449                .failed
450                .into_iter()
451                .map(|(url, err)| json!({"worker": url, "error": err}))
452                .collect::<Vec<_>>());
453        }
454
455        (status, Json(body)).into_response()
456    }
457}
458
459#[cfg(feature = "axum")]
460impl IntoResponse for WorkerLoadsResult {
461    fn into_response(self) -> Response {
462        let loads: Vec<Value> = self
463            .loads
464            .iter()
465            .map(|info| json!({"worker": &info.worker, "load": info.load}))
466            .collect();
467        Json(json!({"workers": loads})).into_response()
468    }
469}