use super::{RouteDoc, service_v2};
use crate::discovery::LoadThresholdConfig;
use axum::{
Json, Router,
extract::Request,
http::{Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct BusyThresholdRequest {
pub model: String,
pub active_decode_blocks_threshold: Option<f64>,
pub active_prefill_tokens_threshold: Option<u64>,
pub active_prefill_tokens_threshold_frac: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct BusyThresholdResponse {
pub model: String,
pub active_decode_blocks_threshold: Option<f64>,
pub active_prefill_tokens_threshold: Option<u64>,
pub active_prefill_tokens_threshold_frac: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct ListBusyThresholdsResponse {
pub thresholds: Vec<BusyThresholdResponse>,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
}
async fn json_error_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!(ErrorResponse {
error: error_message,
})),
)
.into_response()
} else {
response
}
}
pub fn busy_threshold_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let base_path = path.unwrap_or_else(|| "/busy_threshold".to_string());
let docs: Vec<RouteDoc> = vec![
RouteDoc::new(Method::POST, &base_path),
RouteDoc::new(Method::GET, &base_path),
];
let router = Router::new()
.route(&base_path, post(busy_threshold_handler))
.route(&base_path, get(list_busy_thresholds_handler))
.layer(axum::middleware::from_fn(json_error_middleware))
.with_state(state);
(docs, router)
}
async fn busy_threshold_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
Json(request): Json<BusyThresholdRequest>,
) -> impl IntoResponse {
if let Some(threshold) = request.active_decode_blocks_threshold
&& !(0.0..=1.0).contains(&threshold)
{
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!(ErrorResponse {
error: format!(
"active_decode_blocks_threshold must be between 0.0 and 1.0, got {}",
threshold
),
})),
);
}
let manager = state.manager();
let is_setting = request.active_decode_blocks_threshold.is_some()
|| request.active_prefill_tokens_threshold.is_some()
|| request.active_prefill_tokens_threshold_frac.is_some();
let update_config = if is_setting {
Some(LoadThresholdConfig {
active_decode_blocks_threshold: request.active_decode_blocks_threshold,
active_prefill_tokens_threshold: request.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac: request.active_prefill_tokens_threshold_frac,
})
} else {
None
};
let config = manager.load_threshold_config(&request.model, update_config.as_ref());
if is_setting && config.is_none() {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!(ErrorResponse {
error: format!(
"Model '{}' not found. Thresholds can only be set for discovered models.",
request.model
),
})),
);
}
if is_setting {
tracing::info!(
model = %request.model,
config = ?config,
"Updated busy thresholds"
);
}
let (
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac,
) = config.map_or((None, None, None), |c| {
(
c.active_decode_blocks_threshold,
c.active_prefill_tokens_threshold,
c.active_prefill_tokens_threshold_frac,
)
});
(
StatusCode::OK,
Json(serde_json::json!(BusyThresholdResponse {
model: request.model,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac,
})),
)
}
async fn list_busy_thresholds_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
let manager = state.manager();
let thresholds = manager.list_busy_thresholds();
let response = ListBusyThresholdsResponse {
thresholds: thresholds
.into_iter()
.map(|(model, config)| BusyThresholdResponse {
model,
active_decode_blocks_threshold: config.active_decode_blocks_threshold,
active_prefill_tokens_threshold: config.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac: config.active_prefill_tokens_threshold_frac,
})
.collect(),
};
Json(serde_json::json!(response))
}