1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use axum::Router;
10use axum::extract::Extension;
11use axum::response::{IntoResponse, Json};
12use axum::routing::get;
13use chrono::{DateTime, Utc};
14use serde::Serialize;
15use tokio::sync::RwLock;
16use tower_mcp::SessionHandle;
17use tower_mcp::proxy::McpProxy;
18
19#[derive(Clone)]
21pub struct AdminState {
22 health: Arc<RwLock<Vec<BackendStatus>>>,
23 proxy_name: String,
24 proxy_version: String,
25 backend_count: usize,
26}
27
28impl AdminState {
29 pub async fn health(&self) -> Vec<BackendStatus> {
31 self.health.read().await.clone()
32 }
33
34 pub fn proxy_name(&self) -> &str {
36 &self.proxy_name
37 }
38
39 pub fn proxy_version(&self) -> &str {
41 &self.proxy_version
42 }
43
44 pub fn backend_count(&self) -> usize {
46 self.backend_count
47 }
48}
49
50#[derive(Serialize, Clone)]
52pub struct BackendStatus {
53 pub namespace: String,
55 pub healthy: bool,
57 pub last_checked_at: Option<DateTime<Utc>>,
59 pub consecutive_failures: u32,
61 pub error: Option<String>,
63 pub transport: Option<String>,
65}
66
67#[derive(Serialize)]
68struct AdminBackendsResponse {
69 proxy: ProxyInfo,
70 backends: Vec<BackendStatus>,
71}
72
73#[derive(Serialize)]
74struct ProxyInfo {
75 name: String,
76 version: String,
77 backend_count: usize,
78 active_sessions: usize,
79}
80
81#[derive(Clone)]
83pub struct BackendMeta {
84 pub transport: String,
86}
87
88pub fn spawn_health_checker(
91 proxy: McpProxy,
92 proxy_name: String,
93 proxy_version: String,
94 backend_count: usize,
95 backend_meta: HashMap<String, BackendMeta>,
96) -> AdminState {
97 let health: Arc<RwLock<Vec<BackendStatus>>> = Arc::new(RwLock::new(Vec::new()));
98 let health_writer = Arc::clone(&health);
99
100 std::thread::spawn(move || {
105 let rt = tokio::runtime::Builder::new_current_thread()
106 .enable_all()
107 .build()
108 .expect("admin health check runtime");
109
110 let mut failure_counts: HashMap<String, u32> = HashMap::new();
112
113 rt.block_on(async move {
114 loop {
115 let results = proxy.health_check().await;
116 let now = Utc::now();
117 let statuses: Vec<BackendStatus> = results
118 .into_iter()
119 .map(|h| {
120 let count = failure_counts.entry(h.namespace.clone()).or_insert(0);
121 if h.healthy {
122 *count = 0;
123 } else {
124 *count += 1;
125 }
126 let meta = backend_meta.get(&h.namespace);
127 BackendStatus {
128 namespace: h.namespace,
129 healthy: h.healthy,
130 last_checked_at: Some(now),
131 consecutive_failures: *count,
132 error: if h.healthy {
133 None
134 } else {
135 Some("ping failed".to_string())
136 },
137 transport: meta.map(|m| m.transport.clone()),
138 }
139 })
140 .collect();
141 *health_writer.write().await = statuses;
142 tokio::time::sleep(Duration::from_secs(10)).await;
143 }
144 });
145 });
146
147 AdminState {
148 health,
149 proxy_name,
150 proxy_version,
151 backend_count,
152 }
153}
154
155async fn handle_backends(
156 Extension(state): Extension<AdminState>,
157 Extension(session_handle): Extension<SessionHandle>,
158) -> Json<AdminBackendsResponse> {
159 let backends = state.health.read().await.clone();
160 let active_sessions = session_handle.session_count().await;
161
162 Json(AdminBackendsResponse {
163 proxy: ProxyInfo {
164 name: state.proxy_name,
165 version: state.proxy_version,
166 backend_count: state.backend_count,
167 active_sessions,
168 },
169 backends,
170 })
171}
172
173async fn handle_health(Extension(state): Extension<AdminState>) -> Json<HealthResponse> {
174 let backends = state.health.read().await;
175 let all_healthy = backends.iter().all(|b| b.healthy);
176 let unhealthy: Vec<String> = backends
177 .iter()
178 .filter(|b| !b.healthy)
179 .map(|b| b.namespace.clone())
180 .collect();
181 Json(HealthResponse {
182 status: if all_healthy { "healthy" } else { "degraded" }.to_string(),
183 unhealthy_backends: unhealthy,
184 })
185}
186
187#[derive(Serialize)]
188struct HealthResponse {
189 status: String,
190 unhealthy_backends: Vec<String>,
191}
192
193#[cfg(feature = "metrics")]
194async fn handle_metrics(
195 Extension(handle): Extension<Option<metrics_exporter_prometheus::PrometheusHandle>>,
196) -> impl IntoResponse {
197 match handle {
198 Some(h) => h.render(),
199 None => String::new(),
200 }
201}
202
203#[cfg(not(feature = "metrics"))]
204async fn handle_metrics() -> impl IntoResponse {
205 String::new()
206}
207
208async fn handle_cache_stats(
209 Extension(cache_handle): Extension<Option<crate::cache::CacheHandle>>,
210) -> Json<Vec<crate::cache::CacheStatsSnapshot>> {
211 match cache_handle {
212 Some(h) => Json(h.stats()),
213 None => Json(vec![]),
214 }
215}
216
217async fn handle_cache_clear(
218 Extension(cache_handle): Extension<Option<crate::cache::CacheHandle>>,
219) -> &'static str {
220 if let Some(h) = cache_handle {
221 h.clear();
222 "caches cleared"
223 } else {
224 "no caches configured"
225 }
226}
227
228#[cfg(test)]
230fn test_admin_state(
231 proxy_name: &str,
232 proxy_version: &str,
233 backend_count: usize,
234 statuses: Vec<BackendStatus>,
235) -> AdminState {
236 AdminState {
237 health: Arc::new(RwLock::new(statuses)),
238 proxy_name: proxy_name.to_string(),
239 proxy_version: proxy_version.to_string(),
240 backend_count,
241 }
242}
243
244#[cfg(feature = "metrics")]
246pub type MetricsHandle = Option<metrics_exporter_prometheus::PrometheusHandle>;
247#[cfg(not(feature = "metrics"))]
249pub type MetricsHandle = Option<()>;
250
251pub fn admin_router(
253 state: AdminState,
254 metrics_handle: MetricsHandle,
255 session_handle: SessionHandle,
256 cache_handle: Option<crate::cache::CacheHandle>,
257) -> Router {
258 let router = Router::new()
259 .route("/backends", get(handle_backends))
260 .route("/health", get(handle_health))
261 .route("/cache/stats", get(handle_cache_stats))
262 .route("/cache/clear", axum::routing::post(handle_cache_clear))
263 .route("/metrics", get(handle_metrics))
264 .layer(Extension(state))
265 .layer(Extension(session_handle))
266 .layer(Extension(cache_handle));
267
268 #[cfg(feature = "metrics")]
269 let router = router.layer(Extension(metrics_handle));
270 #[cfg(not(feature = "metrics"))]
271 let _ = metrics_handle;
272
273 router
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use axum::body::Body;
280 use axum::http::Request;
281 use tower::ServiceExt;
282
283 fn make_state(statuses: Vec<BackendStatus>) -> AdminState {
284 test_admin_state("test-gw", "1.0.0", statuses.len(), statuses)
285 }
286
287 fn healthy_backend(name: &str) -> BackendStatus {
288 BackendStatus {
289 namespace: name.to_string(),
290 healthy: true,
291 last_checked_at: Some(Utc::now()),
292 consecutive_failures: 0,
293 error: None,
294 transport: Some("http".to_string()),
295 }
296 }
297
298 fn unhealthy_backend(name: &str) -> BackendStatus {
299 BackendStatus {
300 namespace: name.to_string(),
301 healthy: false,
302 last_checked_at: Some(Utc::now()),
303 consecutive_failures: 3,
304 error: Some("ping failed".to_string()),
305 transport: Some("stdio".to_string()),
306 }
307 }
308
309 fn make_session_handle() -> SessionHandle {
310 let svc = tower::util::BoxCloneService::new(tower::service_fn(
312 |_req: tower_mcp::RouterRequest| async {
313 Ok::<_, std::convert::Infallible>(tower_mcp::RouterResponse {
314 id: tower_mcp::protocol::RequestId::Number(1),
315 inner: Ok(tower_mcp::protocol::McpResponse::Pong(Default::default())),
316 })
317 },
318 ));
319 let (_, handle) =
320 tower_mcp::transport::http::HttpTransport::from_service(svc).into_router_with_handle();
321 handle
322 }
323
324 async fn get_json(router: &Router, path: &str) -> serde_json::Value {
325 let resp = router
326 .clone()
327 .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
328 .await
329 .unwrap();
330
331 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
332 .await
333 .unwrap();
334 serde_json::from_slice(&body).unwrap()
335 }
336
337 #[tokio::test]
338 async fn test_admin_state_accessors() {
339 let state = make_state(vec![healthy_backend("db/")]);
340 assert_eq!(state.proxy_name(), "test-gw");
341 assert_eq!(state.proxy_version(), "1.0.0");
342 assert_eq!(state.backend_count(), 1);
343
344 let health = state.health().await;
345 assert_eq!(health.len(), 1);
346 assert!(health[0].healthy);
347 }
348
349 #[tokio::test]
350 async fn test_health_endpoint_all_healthy() {
351 let state = make_state(vec![healthy_backend("db/"), healthy_backend("api/")]);
352 let session_handle = make_session_handle();
353 let router = admin_router(state, None, session_handle, None);
354
355 let json = get_json(&router, "/health").await;
356 assert_eq!(json["status"], "healthy");
357 assert!(json["unhealthy_backends"].as_array().unwrap().is_empty());
358 }
359
360 #[tokio::test]
361 async fn test_health_endpoint_degraded() {
362 let state = make_state(vec![healthy_backend("db/"), unhealthy_backend("flaky/")]);
363 let session_handle = make_session_handle();
364 let router = admin_router(state, None, session_handle, None);
365
366 let json = get_json(&router, "/health").await;
367 assert_eq!(json["status"], "degraded");
368 let unhealthy = json["unhealthy_backends"].as_array().unwrap();
369 assert_eq!(unhealthy.len(), 1);
370 assert_eq!(unhealthy[0], "flaky/");
371 }
372
373 #[tokio::test]
374 async fn test_backends_endpoint() {
375 let state = make_state(vec![healthy_backend("db/")]);
376 let session_handle = make_session_handle();
377 let router = admin_router(state, None, session_handle, None);
378
379 let json = get_json(&router, "/backends").await;
380 assert_eq!(json["proxy"]["name"], "test-gw");
381 assert_eq!(json["proxy"]["version"], "1.0.0");
382 assert_eq!(json["proxy"]["backend_count"], 1);
383 assert_eq!(json["backends"].as_array().unwrap().len(), 1);
384 assert_eq!(json["backends"][0]["namespace"], "db/");
385 assert!(json["backends"][0]["healthy"].as_bool().unwrap());
386 }
387
388 #[tokio::test]
389 async fn test_cache_stats_no_cache() {
390 let state = make_state(vec![]);
391 let session_handle = make_session_handle();
392 let router = admin_router(state, None, session_handle, None);
393
394 let json = get_json(&router, "/cache/stats").await;
395 assert!(json.as_array().unwrap().is_empty());
396 }
397
398 #[tokio::test]
399 async fn test_cache_clear_no_cache() {
400 let state = make_state(vec![]);
401 let session_handle = make_session_handle();
402 let router = admin_router(state, None, session_handle, None);
403
404 let resp = router
405 .clone()
406 .oneshot(
407 Request::builder()
408 .method("POST")
409 .uri("/cache/clear")
410 .body(Body::empty())
411 .unwrap(),
412 )
413 .await
414 .unwrap();
415
416 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
417 .await
418 .unwrap();
419 assert_eq!(body.as_ref(), b"no caches configured");
420 }
421
422 #[tokio::test]
423 async fn test_metrics_endpoint_no_recorder() {
424 let state = make_state(vec![]);
425 let session_handle = make_session_handle();
426 let router = admin_router(state, None, session_handle, None);
427
428 let resp = router
429 .clone()
430 .oneshot(
431 Request::builder()
432 .uri("/metrics")
433 .body(Body::empty())
434 .unwrap(),
435 )
436 .await
437 .unwrap();
438
439 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
440 .await
441 .unwrap();
442 assert!(body.is_empty());
443 }
444}