oxibonsai_runtime/
admin.rs1use axum::{
27 extract::State,
28 http::StatusCode,
29 response::IntoResponse,
30 routing::{get, post},
31 Json, Router,
32};
33use serde::Serialize;
34use std::sync::Arc;
35use std::time::Instant;
36
37use crate::kv_cache_policy::KvCachePolicy;
38use crate::metrics::InferenceMetrics;
39use crate::request_metrics::RequestRateAggregator;
40
41#[derive(Debug, Serialize)]
45pub struct ServerStatus {
46 pub version: &'static str,
48 pub uptime_secs: u64,
50 pub model_loaded: bool,
52 pub requests_total: u64,
54 pub tokens_generated: u64,
56 pub active_connections: u64,
58 pub memory_rss_bytes: Option<u64>,
60}
61
62#[derive(Debug, Serialize)]
66pub struct ConfigSnapshot {
67 pub max_tokens_default: usize,
69 pub temperature_default: f32,
71 pub top_p_default: f32,
73 pub server_version: &'static str,
75 pub features: Vec<String>,
77}
78
79pub struct AdminState {
83 pub started_at: Instant,
85 pub metrics: Arc<InferenceMetrics>,
87 pub rate_aggregator: Option<Arc<RequestRateAggregator>>,
89 pub kv_cache_policy: Option<Arc<KvCachePolicy>>,
91}
92
93impl AdminState {
94 pub fn new(metrics: Arc<InferenceMetrics>) -> Self {
99 Self {
100 started_at: Instant::now(),
101 metrics,
102 rate_aggregator: None,
103 kv_cache_policy: None,
104 }
105 }
106
107 pub fn with_rate_aggregator(mut self, aggregator: Arc<RequestRateAggregator>) -> Self {
110 self.rate_aggregator = Some(aggregator);
111 self
112 }
113
114 pub fn with_kv_cache_policy(mut self, policy: Arc<KvCachePolicy>) -> Self {
117 self.kv_cache_policy = Some(policy);
118 self
119 }
120
121 pub fn uptime_secs(&self) -> u64 {
123 self.started_at.elapsed().as_secs()
124 }
125}
126
127pub async fn get_status(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
131 let rss = {
132 let rss_raw = crate::memory::get_rss_bytes();
133 if rss_raw == 0 {
134 None
135 } else {
136 Some(rss_raw)
137 }
138 };
139
140 let status = ServerStatus {
141 version: env!("CARGO_PKG_VERSION"),
142 uptime_secs: state.uptime_secs(),
143 model_loaded: state.metrics.requests_total.get() > 0
147 || state.metrics.tokens_generated_total.get() > 0,
148 requests_total: state.metrics.requests_total.get(),
149 tokens_generated: state.metrics.tokens_generated_total.get(),
150 active_connections: state.metrics.active_requests.get() as u64,
151 memory_rss_bytes: rss,
152 };
153
154 (StatusCode::OK, Json(status))
155}
156
157pub async fn get_config(_state: State<Arc<AdminState>>) -> impl IntoResponse {
159 let snapshot = ConfigSnapshot {
160 max_tokens_default: 256,
161 temperature_default: 0.7,
162 top_p_default: 0.9,
163 server_version: env!("CARGO_PKG_VERSION"),
164 features: features_enabled(),
165 };
166
167 (StatusCode::OK, Json(snapshot))
168}
169
170pub async fn reset_metrics(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
174 let requests = state.metrics.requests_total.get();
176 state.metrics.requests_total.inc_by(0); let tokens = state.metrics.tokens_generated_total.get();
185 let errors = state.metrics.errors_total.get();
186 let prompt = state.metrics.prompt_tokens_total.get();
187
188 state
190 .metrics
191 .requests_total
192 .inc_by(u64::MAX.wrapping_sub(requests).wrapping_add(1));
193 state
194 .metrics
195 .tokens_generated_total
196 .inc_by(u64::MAX.wrapping_sub(tokens).wrapping_add(1));
197 state
198 .metrics
199 .errors_total
200 .inc_by(u64::MAX.wrapping_sub(errors).wrapping_add(1));
201 state
202 .metrics
203 .prompt_tokens_total
204 .inc_by(u64::MAX.wrapping_sub(prompt).wrapping_add(1));
205
206 state.metrics.active_requests.set(0.0);
208 state.metrics.kv_cache_utilization.set(0.0);
209
210 let ts = std::time::SystemTime::now()
211 .duration_since(std::time::UNIX_EPOCH)
212 .unwrap_or_default()
213 .as_secs();
214
215 let body = serde_json::json!({
216 "reset": true,
217 "timestamp": ts,
218 });
219
220 (StatusCode::OK, Json(body))
221}
222
223pub async fn get_workload_stats(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
232 let request_rate = state.rate_aggregator.as_ref().map(|agg| {
233 let snap = agg.snapshot();
234 serde_json::json!({
235 "completed_requests": snap.completed_requests,
236 "mean_tokens_per_second": snap.mean_tokens_per_second,
237 "tbt_p50_seconds": snap.tbt_p50_seconds,
238 "tbt_p95_seconds": snap.tbt_p95_seconds,
239 "mean_queue_wait_seconds": snap.mean_queue_wait_seconds,
240 })
241 });
242
243 let kv_cache = state.kv_cache_policy.as_ref().map(|policy| {
244 let level = policy.current_level();
245 serde_json::json!({
246 "level": level.tag(),
247 "memory_factor": level.memory_factor(),
248 "pressure_ewma": policy.pressure(),
249 "samples": policy.samples(),
250 "upgrades": policy.upgrades(),
251 "downgrades": policy.downgrades(),
252 })
253 });
254
255 let body = serde_json::json!({
256 "request_rate": request_rate,
257 "kv_cache": kv_cache,
258 "status": "ok",
259 });
260 (StatusCode::OK, Json(body))
261}
262
263pub async fn get_cache_stats(_state: State<Arc<AdminState>>) -> impl IntoResponse {
265 let body = serde_json::json!({
266 "kv_cache": {
267 "capacity_blocks": 0,
268 "used_blocks": 0,
269 "utilization": 0.0,
270 "evictions_total": 0,
271 },
272 "prefix_cache": {
273 "entries": 0,
274 "hit_rate": 0.0,
275 },
276 "status": "ok",
277 });
278
279 (StatusCode::OK, Json(body))
280}
281
282pub fn create_admin_router(state: Arc<AdminState>) -> Router<Arc<AdminState>> {
289 Router::new()
290 .route("/admin/status", get(get_status))
291 .route("/admin/config", get(get_config))
292 .route("/admin/reset-metrics", post(reset_metrics))
293 .route("/admin/cache-stats", get(get_cache_stats))
294 .route("/admin/workload-stats", get(get_workload_stats))
295 .with_state(state)
296}
297
298#[allow(clippy::vec_init_then_push)]
302pub fn features_enabled() -> Vec<String> {
303 let mut features = Vec::new();
304
305 #[cfg(feature = "server")]
306 features.push("server".to_owned());
307
308 #[cfg(feature = "rag")]
309 features.push("rag".to_owned());
310
311 #[cfg(feature = "wasm")]
312 features.push("wasm".to_owned());
313
314 #[cfg(target_arch = "wasm32")]
315 features.push("wasm32".to_owned());
316
317 #[cfg(target_arch = "x86_64")]
318 features.push("x86_64".to_owned());
319
320 #[cfg(target_arch = "aarch64")]
321 features.push("aarch64".to_owned());
322
323 features.push("runtime".to_owned());
325
326 features
327}
328
329#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_admin_state_uptime() {
337 let metrics = Arc::new(InferenceMetrics::new());
338 let state = AdminState::new(metrics);
339 let uptime = state.uptime_secs();
341 assert!(
342 uptime < 5,
343 "uptime should be nearly 0 at creation; got {uptime}"
344 );
345 }
346
347 #[test]
348 fn test_admin_state_with_rate_aggregator() {
349 let metrics = Arc::new(InferenceMetrics::new());
350 let agg = Arc::new(RequestRateAggregator::new());
351 let state = AdminState::new(metrics).with_rate_aggregator(Arc::clone(&agg));
352 assert!(state.rate_aggregator.is_some());
353 assert!(state.kv_cache_policy.is_none());
354 }
355
356 #[test]
357 fn test_admin_state_with_kv_cache_policy() {
358 let metrics = Arc::new(InferenceMetrics::new());
359 let policy = Arc::new(KvCachePolicy::default());
360 let state = AdminState::new(metrics).with_kv_cache_policy(Arc::clone(&policy));
361 assert!(state.kv_cache_policy.is_some());
362 assert!(state.rate_aggregator.is_none());
363 }
364
365 #[tokio::test]
366 async fn test_get_workload_stats_empty() {
367 let metrics = Arc::new(InferenceMetrics::new());
368 let state = Arc::new(AdminState::new(metrics));
369 let response = get_workload_stats(State(Arc::clone(&state))).await;
371 let response = response.into_response();
372 assert_eq!(response.status(), StatusCode::OK);
373 }
374
375 #[tokio::test]
376 async fn test_get_workload_stats_with_sources() {
377 let metrics = Arc::new(InferenceMetrics::new());
378 let agg = Arc::new(RequestRateAggregator::new());
379 let policy = Arc::new(KvCachePolicy::default());
380 let state = Arc::new(
381 AdminState::new(metrics)
382 .with_rate_aggregator(Arc::clone(&agg))
383 .with_kv_cache_policy(Arc::clone(&policy)),
384 );
385 let response = get_workload_stats(State(Arc::clone(&state))).await;
386 let response = response.into_response();
387 assert_eq!(response.status(), StatusCode::OK);
388 }
389
390 #[test]
391 fn test_features_enabled_non_empty() {
392 let features = features_enabled();
393 assert!(!features.is_empty(), "features list should not be empty");
394 assert!(
395 features.contains(&"runtime".to_owned()),
396 "should always include 'runtime'"
397 );
398 }
399
400 #[test]
401 fn test_server_version_non_empty() {
402 let version: &'static str = env!("CARGO_PKG_VERSION");
403 assert!(!version.is_empty(), "CARGO_PKG_VERSION should not be empty");
404 }
405}