Skip to main content

mcp_proxy/
admin.rs

1//! Admin API for proxy introspection.
2//!
3//! Provides endpoints for checking backend health and proxy status.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8
9use axum::Router;
10use axum::extract::Extension;
11use axum::response::{IntoResponse, Json};
12use axum::routing::get;
13use chrono::{DateTime, Utc};
14use serde::Serialize;
15use tokio::sync::RwLock;
16use tower_mcp::SessionHandle;
17use tower_mcp::proxy::McpProxy;
18
19/// Cached health status, updated periodically by a background task.
20#[derive(Clone)]
21pub struct AdminState {
22    health: Arc<RwLock<Vec<BackendStatus>>>,
23    proxy_name: String,
24    proxy_version: String,
25    backend_count: usize,
26}
27
28impl AdminState {
29    /// Get a snapshot of backend health status.
30    pub async fn health(&self) -> Vec<BackendStatus> {
31        self.health.read().await.clone()
32    }
33
34    /// Proxy name from config.
35    pub fn proxy_name(&self) -> &str {
36        &self.proxy_name
37    }
38
39    /// Proxy version from config.
40    pub fn proxy_version(&self) -> &str {
41        &self.proxy_version
42    }
43
44    /// Number of configured backends.
45    pub fn backend_count(&self) -> usize {
46        self.backend_count
47    }
48}
49
50/// Health status of a single backend, updated by the background health checker.
51#[derive(Serialize, Clone)]
52pub struct BackendStatus {
53    /// Backend namespace (e.g. "db/").
54    pub namespace: String,
55    /// Whether the backend responded to the last health check.
56    pub healthy: bool,
57    /// Timestamp of the last health check.
58    pub last_checked_at: Option<DateTime<Utc>>,
59    /// Number of consecutive failed health checks.
60    pub consecutive_failures: u32,
61    /// Last error message from a failed health check.
62    pub error: Option<String>,
63    /// Transport type (e.g. "stdio", "http").
64    pub transport: Option<String>,
65}
66
67#[derive(Serialize)]
68struct AdminBackendsResponse {
69    proxy: ProxyInfo,
70    backends: Vec<BackendStatus>,
71}
72
73#[derive(Serialize)]
74struct ProxyInfo {
75    name: String,
76    version: String,
77    backend_count: usize,
78    active_sessions: usize,
79}
80
81/// Per-backend metadata passed in from config at startup.
82#[derive(Clone)]
83pub struct BackendMeta {
84    /// Transport type string (e.g. "stdio", "http").
85    pub transport: String,
86}
87
88/// Spawn a background task that periodically health-checks backends.
89/// Returns the AdminState that admin endpoints read from.
90pub fn spawn_health_checker(
91    proxy: McpProxy,
92    proxy_name: String,
93    proxy_version: String,
94    backend_count: usize,
95    backend_meta: HashMap<String, BackendMeta>,
96) -> AdminState {
97    let health: Arc<RwLock<Vec<BackendStatus>>> = Arc::new(RwLock::new(Vec::new()));
98    let health_writer = Arc::clone(&health);
99
100    // McpProxy is Send+Clone but not Sync, so &McpProxy is not Send.
101    // health_check(&self) borrows across .await, making futures !Send.
102    // Workaround: run health checks on a dedicated single-threaded runtime
103    // where Send is not required.
104    std::thread::spawn(move || {
105        let rt = tokio::runtime::Builder::new_current_thread()
106            .enable_all()
107            .build()
108            .expect("admin health check runtime");
109
110        // Track consecutive failure counts across check cycles.
111        let mut failure_counts: HashMap<String, u32> = HashMap::new();
112
113        rt.block_on(async move {
114            loop {
115                let results = proxy.health_check().await;
116                let now = Utc::now();
117                let statuses: Vec<BackendStatus> = results
118                    .into_iter()
119                    .map(|h| {
120                        let count = failure_counts.entry(h.namespace.clone()).or_insert(0);
121                        if h.healthy {
122                            *count = 0;
123                        } else {
124                            *count += 1;
125                        }
126                        let meta = backend_meta.get(&h.namespace);
127                        BackendStatus {
128                            namespace: h.namespace,
129                            healthy: h.healthy,
130                            last_checked_at: Some(now),
131                            consecutive_failures: *count,
132                            error: if h.healthy {
133                                None
134                            } else {
135                                Some("ping failed".to_string())
136                            },
137                            transport: meta.map(|m| m.transport.clone()),
138                        }
139                    })
140                    .collect();
141                *health_writer.write().await = statuses;
142                tokio::time::sleep(Duration::from_secs(10)).await;
143            }
144        });
145    });
146
147    AdminState {
148        health,
149        proxy_name,
150        proxy_version,
151        backend_count,
152    }
153}
154
155async fn handle_backends(
156    Extension(state): Extension<AdminState>,
157    Extension(session_handle): Extension<SessionHandle>,
158) -> Json<AdminBackendsResponse> {
159    let backends = state.health.read().await.clone();
160    let active_sessions = session_handle.session_count().await;
161
162    Json(AdminBackendsResponse {
163        proxy: ProxyInfo {
164            name: state.proxy_name,
165            version: state.proxy_version,
166            backend_count: state.backend_count,
167            active_sessions,
168        },
169        backends,
170    })
171}
172
173async fn handle_health(Extension(state): Extension<AdminState>) -> Json<HealthResponse> {
174    let backends = state.health.read().await;
175    let all_healthy = backends.iter().all(|b| b.healthy);
176    let unhealthy: Vec<String> = backends
177        .iter()
178        .filter(|b| !b.healthy)
179        .map(|b| b.namespace.clone())
180        .collect();
181    Json(HealthResponse {
182        status: if all_healthy { "healthy" } else { "degraded" }.to_string(),
183        unhealthy_backends: unhealthy,
184    })
185}
186
187#[derive(Serialize)]
188struct HealthResponse {
189    status: String,
190    unhealthy_backends: Vec<String>,
191}
192
193#[cfg(feature = "metrics")]
194async fn handle_metrics(
195    Extension(handle): Extension<Option<metrics_exporter_prometheus::PrometheusHandle>>,
196) -> impl IntoResponse {
197    match handle {
198        Some(h) => h.render(),
199        None => String::new(),
200    }
201}
202
203#[cfg(not(feature = "metrics"))]
204async fn handle_metrics() -> impl IntoResponse {
205    String::new()
206}
207
208async fn handle_cache_stats(
209    Extension(cache_handle): Extension<Option<crate::cache::CacheHandle>>,
210) -> Json<Vec<crate::cache::CacheStatsSnapshot>> {
211    match cache_handle {
212        Some(h) => Json(h.stats()),
213        None => Json(vec![]),
214    }
215}
216
217async fn handle_cache_clear(
218    Extension(cache_handle): Extension<Option<crate::cache::CacheHandle>>,
219) -> &'static str {
220    if let Some(h) = cache_handle {
221        h.clear();
222        "caches cleared"
223    } else {
224        "no caches configured"
225    }
226}
227
228/// Create an `AdminState` directly for testing.
229#[cfg(test)]
230fn test_admin_state(
231    proxy_name: &str,
232    proxy_version: &str,
233    backend_count: usize,
234    statuses: Vec<BackendStatus>,
235) -> AdminState {
236    AdminState {
237        health: Arc::new(RwLock::new(statuses)),
238        proxy_name: proxy_name.to_string(),
239        proxy_version: proxy_version.to_string(),
240        backend_count,
241    }
242}
243
244/// Metrics handle type -- wraps the Prometheus handle when the feature is enabled.
245#[cfg(feature = "metrics")]
246pub type MetricsHandle = Option<metrics_exporter_prometheus::PrometheusHandle>;
247/// Metrics handle type -- no-op when the metrics feature is disabled.
248#[cfg(not(feature = "metrics"))]
249pub type MetricsHandle = Option<()>;
250
251/// Build the admin API router.
252pub fn admin_router(
253    state: AdminState,
254    metrics_handle: MetricsHandle,
255    session_handle: SessionHandle,
256    cache_handle: Option<crate::cache::CacheHandle>,
257) -> Router {
258    let router = Router::new()
259        .route("/backends", get(handle_backends))
260        .route("/health", get(handle_health))
261        .route("/cache/stats", get(handle_cache_stats))
262        .route("/cache/clear", axum::routing::post(handle_cache_clear))
263        .route("/metrics", get(handle_metrics))
264        .layer(Extension(state))
265        .layer(Extension(session_handle))
266        .layer(Extension(cache_handle));
267
268    #[cfg(feature = "metrics")]
269    let router = router.layer(Extension(metrics_handle));
270    #[cfg(not(feature = "metrics"))]
271    let _ = metrics_handle;
272
273    router
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use axum::body::Body;
280    use axum::http::Request;
281    use tower::ServiceExt;
282
283    fn make_state(statuses: Vec<BackendStatus>) -> AdminState {
284        test_admin_state("test-gw", "1.0.0", statuses.len(), statuses)
285    }
286
287    fn healthy_backend(name: &str) -> BackendStatus {
288        BackendStatus {
289            namespace: name.to_string(),
290            healthy: true,
291            last_checked_at: Some(Utc::now()),
292            consecutive_failures: 0,
293            error: None,
294            transport: Some("http".to_string()),
295        }
296    }
297
298    fn unhealthy_backend(name: &str) -> BackendStatus {
299        BackendStatus {
300            namespace: name.to_string(),
301            healthy: false,
302            last_checked_at: Some(Utc::now()),
303            consecutive_failures: 3,
304            error: Some("ping failed".to_string()),
305            transport: Some("stdio".to_string()),
306        }
307    }
308
309    fn make_session_handle() -> SessionHandle {
310        // Create a session handle via HttpTransport (the only public way)
311        let svc = tower::util::BoxCloneService::new(tower::service_fn(
312            |_req: tower_mcp::RouterRequest| async {
313                Ok::<_, std::convert::Infallible>(tower_mcp::RouterResponse {
314                    id: tower_mcp::protocol::RequestId::Number(1),
315                    inner: Ok(tower_mcp::protocol::McpResponse::Pong(Default::default())),
316                })
317            },
318        ));
319        let (_, handle) =
320            tower_mcp::transport::http::HttpTransport::from_service(svc).into_router_with_handle();
321        handle
322    }
323
324    async fn get_json(router: &Router, path: &str) -> serde_json::Value {
325        let resp = router
326            .clone()
327            .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
328            .await
329            .unwrap();
330
331        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
332            .await
333            .unwrap();
334        serde_json::from_slice(&body).unwrap()
335    }
336
337    #[tokio::test]
338    async fn test_admin_state_accessors() {
339        let state = make_state(vec![healthy_backend("db/")]);
340        assert_eq!(state.proxy_name(), "test-gw");
341        assert_eq!(state.proxy_version(), "1.0.0");
342        assert_eq!(state.backend_count(), 1);
343
344        let health = state.health().await;
345        assert_eq!(health.len(), 1);
346        assert!(health[0].healthy);
347    }
348
349    #[tokio::test]
350    async fn test_health_endpoint_all_healthy() {
351        let state = make_state(vec![healthy_backend("db/"), healthy_backend("api/")]);
352        let session_handle = make_session_handle();
353        let router = admin_router(state, None, session_handle, None);
354
355        let json = get_json(&router, "/health").await;
356        assert_eq!(json["status"], "healthy");
357        assert!(json["unhealthy_backends"].as_array().unwrap().is_empty());
358    }
359
360    #[tokio::test]
361    async fn test_health_endpoint_degraded() {
362        let state = make_state(vec![healthy_backend("db/"), unhealthy_backend("flaky/")]);
363        let session_handle = make_session_handle();
364        let router = admin_router(state, None, session_handle, None);
365
366        let json = get_json(&router, "/health").await;
367        assert_eq!(json["status"], "degraded");
368        let unhealthy = json["unhealthy_backends"].as_array().unwrap();
369        assert_eq!(unhealthy.len(), 1);
370        assert_eq!(unhealthy[0], "flaky/");
371    }
372
373    #[tokio::test]
374    async fn test_backends_endpoint() {
375        let state = make_state(vec![healthy_backend("db/")]);
376        let session_handle = make_session_handle();
377        let router = admin_router(state, None, session_handle, None);
378
379        let json = get_json(&router, "/backends").await;
380        assert_eq!(json["proxy"]["name"], "test-gw");
381        assert_eq!(json["proxy"]["version"], "1.0.0");
382        assert_eq!(json["proxy"]["backend_count"], 1);
383        assert_eq!(json["backends"].as_array().unwrap().len(), 1);
384        assert_eq!(json["backends"][0]["namespace"], "db/");
385        assert!(json["backends"][0]["healthy"].as_bool().unwrap());
386    }
387
388    #[tokio::test]
389    async fn test_cache_stats_no_cache() {
390        let state = make_state(vec![]);
391        let session_handle = make_session_handle();
392        let router = admin_router(state, None, session_handle, None);
393
394        let json = get_json(&router, "/cache/stats").await;
395        assert!(json.as_array().unwrap().is_empty());
396    }
397
398    #[tokio::test]
399    async fn test_cache_clear_no_cache() {
400        let state = make_state(vec![]);
401        let session_handle = make_session_handle();
402        let router = admin_router(state, None, session_handle, None);
403
404        let resp = router
405            .clone()
406            .oneshot(
407                Request::builder()
408                    .method("POST")
409                    .uri("/cache/clear")
410                    .body(Body::empty())
411                    .unwrap(),
412            )
413            .await
414            .unwrap();
415
416        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
417            .await
418            .unwrap();
419        assert_eq!(body.as_ref(), b"no caches configured");
420    }
421
422    #[tokio::test]
423    async fn test_metrics_endpoint_no_recorder() {
424        let state = make_state(vec![]);
425        let session_handle = make_session_handle();
426        let router = admin_router(state, None, session_handle, None);
427
428        let resp = router
429            .clone()
430            .oneshot(
431                Request::builder()
432                    .uri("/metrics")
433                    .body(Body::empty())
434                    .unwrap(),
435            )
436            .await
437            .unwrap();
438
439        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
440            .await
441            .unwrap();
442        assert!(body.is_empty());
443    }
444}