use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use serde::Serialize;
use std::sync::Arc;
use std::time::Instant;
use crate::metrics::InferenceMetrics;
#[derive(Debug, Serialize)]
pub struct ServerStatus {
pub version: &'static str,
pub uptime_secs: u64,
pub model_loaded: bool,
pub requests_total: u64,
pub tokens_generated: u64,
pub active_connections: u64,
pub memory_rss_bytes: Option<u64>,
}
#[derive(Debug, Serialize)]
pub struct ConfigSnapshot {
pub max_tokens_default: usize,
pub temperature_default: f32,
pub top_p_default: f32,
pub server_version: &'static str,
pub features: Vec<String>,
}
pub struct AdminState {
pub started_at: Instant,
pub metrics: Arc<InferenceMetrics>,
}
impl AdminState {
pub fn new(metrics: Arc<InferenceMetrics>) -> Self {
Self {
started_at: Instant::now(),
metrics,
}
}
pub fn uptime_secs(&self) -> u64 {
self.started_at.elapsed().as_secs()
}
}
pub async fn get_status(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
let rss = {
let rss_raw = crate::memory::get_rss_bytes();
if rss_raw == 0 {
None
} else {
Some(rss_raw)
}
};
let status = ServerStatus {
version: env!("CARGO_PKG_VERSION"),
uptime_secs: state.uptime_secs(),
model_loaded: state.metrics.requests_total.get() > 0
|| state.metrics.tokens_generated_total.get() > 0,
requests_total: state.metrics.requests_total.get(),
tokens_generated: state.metrics.tokens_generated_total.get(),
active_connections: state.metrics.active_requests.get() as u64,
memory_rss_bytes: rss,
};
(StatusCode::OK, Json(status))
}
pub async fn get_config(_state: State<Arc<AdminState>>) -> impl IntoResponse {
let snapshot = ConfigSnapshot {
max_tokens_default: 256,
temperature_default: 0.7,
top_p_default: 0.9,
server_version: env!("CARGO_PKG_VERSION"),
features: features_enabled(),
};
(StatusCode::OK, Json(snapshot))
}
pub async fn reset_metrics(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
let requests = state.metrics.requests_total.get();
state.metrics.requests_total.inc_by(0);
let tokens = state.metrics.tokens_generated_total.get();
let errors = state.metrics.errors_total.get();
let prompt = state.metrics.prompt_tokens_total.get();
state
.metrics
.requests_total
.inc_by(u64::MAX.wrapping_sub(requests).wrapping_add(1));
state
.metrics
.tokens_generated_total
.inc_by(u64::MAX.wrapping_sub(tokens).wrapping_add(1));
state
.metrics
.errors_total
.inc_by(u64::MAX.wrapping_sub(errors).wrapping_add(1));
state
.metrics
.prompt_tokens_total
.inc_by(u64::MAX.wrapping_sub(prompt).wrapping_add(1));
state.metrics.active_requests.set(0.0);
state.metrics.kv_cache_utilization.set(0.0);
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let body = serde_json::json!({
"reset": true,
"timestamp": ts,
});
(StatusCode::OK, Json(body))
}
pub async fn get_cache_stats(_state: State<Arc<AdminState>>) -> impl IntoResponse {
let body = serde_json::json!({
"kv_cache": {
"capacity_blocks": 0,
"used_blocks": 0,
"utilization": 0.0,
"evictions_total": 0,
},
"prefix_cache": {
"entries": 0,
"hit_rate": 0.0,
},
"status": "ok",
});
(StatusCode::OK, Json(body))
}
pub fn create_admin_router(state: Arc<AdminState>) -> Router<Arc<AdminState>> {
Router::new()
.route("/admin/status", get(get_status))
.route("/admin/config", get(get_config))
.route("/admin/reset-metrics", post(reset_metrics))
.route("/admin/cache-stats", get(get_cache_stats))
.with_state(state)
}
#[allow(clippy::vec_init_then_push)]
pub fn features_enabled() -> Vec<String> {
let mut features = Vec::new();
#[cfg(feature = "server")]
features.push("server".to_owned());
#[cfg(feature = "rag")]
features.push("rag".to_owned());
#[cfg(feature = "wasm")]
features.push("wasm".to_owned());
#[cfg(target_arch = "wasm32")]
features.push("wasm32".to_owned());
#[cfg(target_arch = "x86_64")]
features.push("x86_64".to_owned());
#[cfg(target_arch = "aarch64")]
features.push("aarch64".to_owned());
features.push("runtime".to_owned());
features
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_admin_state_uptime() {
let metrics = Arc::new(InferenceMetrics::new());
let state = AdminState::new(metrics);
let uptime = state.uptime_secs();
assert!(
uptime < 5,
"uptime should be nearly 0 at creation; got {uptime}"
);
}
#[test]
fn test_features_enabled_non_empty() {
let features = features_enabled();
assert!(!features.is_empty(), "features list should not be empty");
assert!(
features.contains(&"runtime".to_owned()),
"should always include 'runtime'"
);
}
#[test]
fn test_server_version_non_empty() {
let version: &'static str = env!("CARGO_PKG_VERSION");
assert!(!version.is_empty(), "CARGO_PKG_VERSION should not be empty");
}
}