#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExplainResponse {
pub request_id: String,
pub model: String,
pub prediction: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub confidence: Option<f32>,
pub explanation: ShapExplanation,
pub summary: String,
pub latency_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditResponse {
pub record: AuditRecord,
}
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub openai_api: bool,
}
impl Default for RouterConfig {
fn default() -> Self {
Self { openai_api: true }
}
}
pub fn create_router(state: AppState) -> Router {
create_router_with_config(state, RouterConfig::default())
}
pub fn create_router_with_config(state: AppState, config: RouterConfig) -> Router {
let mut router = Router::new()
.route("/health", get(health_handler))
.route("/health/live", get(health_live_handler))
.route("/health/ready", get(health_ready_handler))
.route("/metrics", get(metrics_handler))
.route("/metrics/dispatch", get(dispatch_metrics_handler))
.route("/metrics/dispatch/reset", post(dispatch_reset_handler))
.route("/models", get(models_handler))
.route("/tokenize", post(tokenize_handler))
.route("/generate", post(generate_handler))
.route("/batch/tokenize", post(batch_tokenize_handler))
.route("/batch/generate", post(batch_generate_handler))
.route("/stream/generate", post(stream_generate_handler))
.route("/realize/generate", post(stream_generate_handler))
.route("/realize/batch", post(batch_generate_handler))
.route("/realize/embed", post(realize_embed_handler))
.route("/realize/model", get(realize_model_handler))
.route("/realize/reload", post(realize_reload_handler));
if config.openai_api {
router = router
.route("/v1/models", get(openai_models_handler))
.route("/v1/completions", post(openai_completions_handler))
.route(
"/v1/chat/completions",
post(openai_chat_completions_handler),
)
.route(
"/v1/chat/completions/stream",
post(openai_chat_completions_stream_handler),
)
.route("/v1/embeddings", post(openai_embeddings_handler))
.route("/v1/predict", post(apr_predict_handler))
.route("/v1/explain", post(apr_explain_handler))
.route("/v1/audit/:request_id", get(apr_audit_handler))
.route("/v1/gpu/warmup", post(gpu_warmup_handler))
.route("/v1/gpu/status", get(gpu_status_handler))
.route("/v1/batch/completions", post(gpu_batch_completions_handler))
.route("/v1/metrics", get(server_metrics_handler));
}
#[cfg(feature = "cuda")]
{
router = router
.route("/v1/logprobs", post(logprobs_handler))
.route("/v1/perplexity", post(perplexity_handler));
}
router = router.fallback(|| async {
(
axum::http::StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "not_found",
"message": "Route not found. See /health for available endpoints."
})),
)
});
router = router.layer(axum::middleware::from_fn(sanitize_json_rejection));
let cors = tower_http::cors::CorsLayer::permissive();
router.layer(cors).with_state(state)
}
async fn sanitize_json_rejection(
request: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> axum::response::Response {
use axum::response::IntoResponse;
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Invalid request body. Check that the JSON structure matches the expected schema.".to_string(),
}),
)
.into_response();
}
response
}
fn server_uptime_sec() -> f64 {
static SERVER_START: std::sync::OnceLock<std::time::Instant> = std::sync::OnceLock::new();
SERVER_START
.get_or_init(std::time::Instant::now)
.elapsed()
.as_secs_f64()
}
fn force_loading() -> bool {
std::env::var("APR_TEST_FORCE_LOADING").is_ok_and(|v| v == "1")
}
fn build_health_response(state: &AppState) -> HealthResponse {
let mut compute_mode = "cpu";
#[cfg(feature = "gpu")]
if state.has_gpu_model() || state.has_cached_model() {
compute_mode = "gpu";
}
#[cfg(feature = "cuda")]
if state.has_cuda_model() {
compute_mode = "gpu";
}
let model_loaded = state.model_loaded();
let status = if force_loading() || !model_loaded {
"loading"
} else {
"ok"
};
HealthResponse {
status: status.to_string(),
version: crate::VERSION.to_string(),
compute_mode: compute_mode.to_string(),
model_loaded,
uptime_sec: server_uptime_sec(),
}
}
fn health_status_code(body: &HealthResponse) -> StatusCode {
if body.status == "ok" {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
}
}
async fn health_handler(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
if state.is_verbose() {
eprintln!("[VERBOSE] GET /health");
}
let body = build_health_response(&state);
let code = health_status_code(&body);
if state.is_verbose() {
eprintln!("[VERBOSE] GET /health -> {} status={}", code, body.status);
}
(code, Json(body))
}
async fn health_live_handler(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
if state.is_verbose() {
eprintln!("[VERBOSE] GET /health/live");
}
(StatusCode::OK, Json(build_health_response(&state)))
}
async fn health_ready_handler(State(state): State<AppState>) -> (StatusCode, Json<HealthResponse>) {
if state.is_verbose() {
eprintln!("[VERBOSE] GET /health/ready");
}
let body = build_health_response(&state);
let code = if body.status == "ok" && body.model_loaded {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
};
(code, Json(body))
}
async fn metrics_handler(State(state): State<AppState>) -> String {
state.metrics.to_prometheus()
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct DispatchMetricsResponse {
pub cpu_dispatches: usize,
pub gpu_dispatches: usize,
pub total_dispatches: usize,
pub gpu_ratio: f64,
pub cpu_latency_p50_us: f64,
pub cpu_latency_p95_us: f64,
pub cpu_latency_p99_us: f64,
pub gpu_latency_p50_us: f64,
pub gpu_latency_p95_us: f64,
pub gpu_latency_p99_us: f64,
pub cpu_latency_mean_us: f64,
pub gpu_latency_mean_us: f64,
pub cpu_latency_min_us: u64,
pub cpu_latency_max_us: u64,
pub gpu_latency_min_us: u64,
pub gpu_latency_max_us: u64,
pub cpu_latency_variance_us: f64,
pub cpu_latency_stddev_us: f64,
pub gpu_latency_variance_us: f64,
pub gpu_latency_stddev_us: f64,
pub bucket_boundaries_us: Vec<String>,
pub cpu_latency_bucket_counts: Vec<usize>,
pub gpu_latency_bucket_counts: Vec<usize>,
pub throughput_rps: f64,
pub elapsed_seconds: f64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ServerMetricsResponse {
pub throughput_tok_per_sec: f64,
pub latency_p50_ms: f64,
pub latency_p95_ms: f64,
pub latency_p99_ms: f64,
pub gpu_memory_used_bytes: u64,
pub gpu_memory_total_bytes: u64,
pub gpu_utilization_percent: u32,
pub cuda_path_active: bool,
pub batch_size: usize,
pub queue_depth: usize,
pub total_tokens: u64,
pub total_requests: u64,
pub uptime_secs: u64,
pub model_name: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct DispatchMetricsQuery {
#[serde(default)]
pub format: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DispatchResetResponse {
pub success: bool,
pub message: String,
}
#[cfg(feature = "gpu")]
async fn dispatch_reset_handler(State(state): State<AppState>) -> axum::response::Response {
use axum::response::IntoResponse;
if let Some(metrics) = state.dispatch_metrics() {
metrics.reset();
Json(DispatchResetResponse {
success: true,
message: "Metrics reset successfully".to_string(),
})
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Dispatch metrics not available. No GPU model configured.".to_string(),
}),
)
.into_response()
}
}
#[cfg(not(feature = "gpu"))]
async fn dispatch_reset_handler(State(_state): State<AppState>) -> axum::response::Response {
use axum::response::IntoResponse;
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Dispatch metrics not available. GPU feature not enabled.".to_string(),
}),
)
.into_response()
}
#[cfg(feature = "gpu")]
async fn server_metrics_handler(State(state): State<AppState>) -> Json<ServerMetricsResponse> {
let snapshot = state.metrics.snapshot();
let (latency_p50_ms, latency_p95_ms, latency_p99_ms, gpu_dispatches, cuda_path_active) =
if let Some(dispatch) = state.dispatch_metrics() {
let gpu_p50 = dispatch.gpu_latency_p50_us();
let gpu_p95 = dispatch.gpu_latency_p95_us();
let gpu_p99 = dispatch.gpu_latency_p99_us();
let gpu_count = dispatch.gpu_dispatches();
if gpu_count > 0 {
(
gpu_p50 / 1000.0,
gpu_p95 / 1000.0,
gpu_p99 / 1000.0,
gpu_count,
true,
)
} else {
let cpu_p50 = dispatch.cpu_latency_p50_us();
let cpu_p95 = dispatch.cpu_latency_p95_us();
let cpu_p99 = dispatch.cpu_latency_p99_us();
(
cpu_p50 / 1000.0,
cpu_p95 / 1000.0,
cpu_p99 / 1000.0,
0,
false,
)
}
} else {
(0.0, 0.0, 0.0, 0, false)
};
let (gpu_memory_used_bytes, gpu_memory_total_bytes): (u64, u64) =
if let Some(model) = state.cached_model() {
let used = model.gpu_cache_memory() as u64;
let total = 24 * 1024 * 1024 * 1024u64;
(used, total)
} else {
(0, 0)
};
let gpu_utilization_percent = if let Some(dispatch) = state.dispatch_metrics() {
let total = dispatch.total_dispatches();
if total > 0 {
((gpu_dispatches as f64 / total as f64) * 100.0) as u32
} else {
0
}
} else {
0
};
let (batch_size, queue_depth) = if let Some(config) = state.batch_config() {
(config.optimal_batch, config.queue_size)
} else {
(1, 0)
};
let model_name = if state.cached_model().is_some() {
"phi-2-q4_k_m".to_string()
} else {
"N/A".to_string()
};
Json(ServerMetricsResponse {
throughput_tok_per_sec: snapshot.tokens_per_sec,
latency_p50_ms,
latency_p95_ms,
latency_p99_ms,
gpu_memory_used_bytes,
gpu_memory_total_bytes,
gpu_utilization_percent,
cuda_path_active,
batch_size,
queue_depth,
total_tokens: snapshot.total_tokens as u64,
total_requests: snapshot.total_requests as u64,
uptime_secs: snapshot.uptime_secs,
model_name,
})
}
#[cfg(not(feature = "gpu"))]
async fn server_metrics_handler(State(state): State<AppState>) -> Json<ServerMetricsResponse> {
let snapshot = state.metrics.snapshot();
Json(ServerMetricsResponse {
throughput_tok_per_sec: snapshot.tokens_per_sec,
latency_p50_ms: snapshot.avg_latency_ms,
latency_p95_ms: snapshot.avg_latency_ms * 1.5,
latency_p99_ms: snapshot.avg_latency_ms * 2.0,
gpu_memory_used_bytes: 0,
gpu_memory_total_bytes: 0,
gpu_utilization_percent: 0,
cuda_path_active: false,
batch_size: 1,
queue_depth: 0,
total_tokens: snapshot.total_tokens as u64,
total_requests: snapshot.total_requests as u64,
uptime_secs: snapshot.uptime_secs,
model_name: "N/A".to_string(),
})
}