1use std::sync::Arc;
8
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use tower_mcp::client::ChannelTransport;
12use tower_mcp::proxy::{AddBackendError, McpProxy};
13use tower_mcp::{CallToolResult, McpRouter, NoParams, SessionHandle, ToolBuilder};
14
15use crate::admin::AdminState;
16use crate::config::ProxyConfig;
17
18#[derive(Clone)]
20struct AdminToolState {
21 admin_state: AdminState,
22 session_handle: SessionHandle,
23 config_snapshot: Arc<String>,
24 proxy: McpProxy,
25}
26
27#[derive(Serialize)]
28struct BackendInfo {
29 namespace: String,
30 healthy: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 last_checked_at: Option<String>,
33 consecutive_failures: u32,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 error: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 transport: Option<String>,
38}
39
40#[derive(Serialize)]
41struct BackendsResult {
42 proxy_name: String,
43 proxy_version: String,
44 backend_count: usize,
45 backends: Vec<BackendInfo>,
46}
47
48#[derive(Serialize)]
49struct SessionResult {
50 active_sessions: usize,
51}
52
53pub async fn register_admin_tools(
63 proxy: &McpProxy,
64 admin_state: AdminState,
65 session_handle: SessionHandle,
66 config: &ProxyConfig,
67 discovery_tools: Option<Vec<tower_mcp::Tool>>,
68) -> Result<(), AddBackendError> {
69 let config_toml =
70 toml::to_string_pretty(config).unwrap_or_else(|e| format!("error serializing: {e}"));
71
72 let search_mode = config.proxy.tool_exposure == crate::config::ToolExposure::Search;
73
74 let state = AdminToolState {
75 admin_state,
76 session_handle,
77 config_snapshot: Arc::new(config_toml),
78 proxy: proxy.clone(),
79 };
80
81 #[cfg(feature = "skills")]
83 let skills = crate::skills::build_skills(state.config_snapshot.clone());
84 #[cfg(not(feature = "skills"))]
85 let skills: Vec<tower_mcp::Prompt> = vec![];
86
87 let router = build_admin_router(state, discovery_tools, search_mode, skills);
88 let transport = ChannelTransport::new(router);
89
90 proxy.add_backend("proxy", transport).await
91}
92
93fn build_admin_router(
94 state: AdminToolState,
95 discovery_tools: Option<Vec<tower_mcp::Tool>>,
96 search_mode: bool,
97 skills: Vec<tower_mcp::Prompt>,
98) -> McpRouter {
99 let state_for_backends = state.clone();
100 let list_backends = ToolBuilder::new("list_backends")
101 .description("List all proxy backends with health status")
102 .handler(move |_: NoParams| {
103 let s = state_for_backends.clone();
104 async move {
105 let health = s.admin_state.health().await;
106 let backends: Vec<BackendInfo> = health
107 .iter()
108 .map(|b| BackendInfo {
109 namespace: b.namespace.clone(),
110 healthy: b.healthy,
111 last_checked_at: b.last_checked_at.map(|t| t.to_rfc3339()),
112 consecutive_failures: b.consecutive_failures,
113 error: b.error.clone(),
114 transport: b.transport.clone(),
115 })
116 .collect();
117
118 let result = BackendsResult {
119 proxy_name: s.admin_state.proxy_name().to_string(),
120 proxy_version: s.admin_state.proxy_version().to_string(),
121 backend_count: s.admin_state.backend_count(),
122 backends,
123 };
124
125 Ok(CallToolResult::text(
126 serde_json::to_string_pretty(&result).unwrap(),
127 ))
128 }
129 })
130 .build();
131
132 let state_for_sessions = state.clone();
133 let session_count = ToolBuilder::new("session_count")
134 .description("Get the number of active MCP sessions")
135 .handler(move |_: NoParams| {
136 let s = state_for_sessions.clone();
137 async move {
138 let count = s.session_handle.session_count().await;
139 let result = SessionResult {
140 active_sessions: count,
141 };
142 Ok(CallToolResult::text(
143 serde_json::to_string_pretty(&result).unwrap(),
144 ))
145 }
146 })
147 .build();
148
149 let config_snapshot = Arc::clone(&state.config_snapshot);
150 let config_tool = ToolBuilder::new("config")
151 .description("Dump the current proxy configuration")
152 .handler(move |_: NoParams| {
153 let config = Arc::clone(&config_snapshot);
154 async move { Ok(CallToolResult::text((*config).clone())) }
155 })
156 .build();
157
158 let state_for_health = state.clone();
159 let health_check = ToolBuilder::new("health_check")
160 .description("Get cached health check results for all backends")
161 .handler(move |_: NoParams| {
162 let s = state_for_health.clone();
163 async move {
164 let health = s.admin_state.health().await;
165 let backends: Vec<BackendInfo> = health
166 .iter()
167 .map(|b| BackendInfo {
168 namespace: b.namespace.clone(),
169 healthy: b.healthy,
170 last_checked_at: b.last_checked_at.map(|t| t.to_rfc3339()),
171 consecutive_failures: b.consecutive_failures,
172 error: b.error.clone(),
173 transport: b.transport.clone(),
174 })
175 .collect();
176 let healthy_count = backends.iter().filter(|b| b.healthy).count();
177 let total = backends.len();
178 let result = HealthCheckResult {
179 status: if healthy_count == total {
180 "healthy"
181 } else {
182 "degraded"
183 }
184 .to_string(),
185 healthy_count,
186 total_count: total,
187 backends,
188 };
189 Ok(CallToolResult::text(
190 serde_json::to_string_pretty(&result).unwrap(),
191 ))
192 }
193 })
194 .build();
195
196 let state_for_add = state.clone();
197 let add_backend = ToolBuilder::new("add_backend")
198 .description("Dynamically add an HTTP backend to the proxy")
199 .handler(move |input: AddBackendInput| {
200 let s = state_for_add.clone();
201 async move {
202 let transport = tower_mcp::client::HttpClientTransport::new(&input.url);
203 match s.proxy.add_backend(&input.name, transport).await {
204 Ok(()) => Ok(CallToolResult::text(format!(
205 "Backend '{}' added successfully at {}",
206 input.name, input.url
207 ))),
208 Err(e) => Ok(CallToolResult::text(format!(
209 "Failed to add backend '{}': {e}",
210 input.name
211 ))),
212 }
213 }
214 })
215 .build();
216
217 let mut router = McpRouter::new()
218 .server_info("mcp-proxy-admin", "0.1.0")
219 .tool(list_backends)
220 .tool(health_check)
221 .tool(session_count)
222 .tool(add_backend)
223 .tool(config_tool);
224
225 if search_mode {
226 let state_for_call = state.clone();
227 let call_tool = ToolBuilder::new("call_tool")
228 .description(
229 "Invoke any backend tool by its fully-qualified name. Use proxy/search_tools \
230 to discover available tools, then call them through this tool.",
231 )
232 .handler(move |input: CallToolInput| {
233 let s = state_for_call.clone();
234 async move {
235 use tower::Service;
236 use tower_mcp::protocol::{CallToolParams, McpRequest, McpResponse, RequestId};
237 use tower_mcp::router::{Extensions, RouterRequest};
238
239 let req = RouterRequest {
240 id: RequestId::Number(0),
241 inner: McpRequest::CallTool(CallToolParams {
242 name: input.name.clone(),
243 arguments: input.arguments.unwrap_or_default().into(),
244 meta: None,
245 task: None,
246 }),
247 extensions: Extensions::new(),
248 };
249
250 let mut proxy = s.proxy.clone();
251 match proxy.call(req).await {
252 Ok(resp) => match resp.inner {
253 Ok(McpResponse::CallTool(result)) => Ok(result),
254 Ok(_) => Ok(CallToolResult::text(format!(
255 "Unexpected response type for tool '{}'",
256 input.name
257 ))),
258 Err(e) => Ok(CallToolResult::text(format!(
259 "Error calling '{}': {}",
260 input.name, e.message
261 ))),
262 },
263 Err(_) => Ok(CallToolResult::text(format!(
264 "Internal error calling '{}'",
265 input.name
266 ))),
267 }
268 }
269 })
270 .build();
271 router = router.tool(call_tool);
272 }
273
274 if let Some(tools) = discovery_tools {
275 for tool in tools {
276 router = router.tool(tool);
277 }
278 }
279
280 for skill in skills {
282 router = router.prompt(skill);
283 }
284
285 router
286}
287
288#[derive(Serialize)]
289struct HealthCheckResult {
290 status: String,
291 healthy_count: usize,
292 total_count: usize,
293 backends: Vec<BackendInfo>,
294}
295
296#[derive(Debug, Deserialize, JsonSchema)]
297struct AddBackendInput {
298 name: String,
300 url: String,
302}
303
304#[derive(Debug, Deserialize, JsonSchema)]
306struct CallToolInput {
307 name: String,
309 arguments: Option<serde_json::Map<String, serde_json::Value>>,
311}
312
313#[cfg(test)]
314mod tests {
315 use tower::Service;
316 use tower_mcp::client::ChannelTransport;
317 use tower_mcp::protocol::{
318 CallToolParams, ListToolsParams, McpRequest, McpResponse, RequestId,
319 };
320 use tower_mcp::proxy::McpProxy;
321 use tower_mcp::router::{Extensions, RouterRequest};
322 use tower_mcp::{CallToolResult, McpRouter, SessionHandle, ToolBuilder};
323
324 use super::*;
325
326 fn make_session_handle() -> SessionHandle {
327 let svc = tower::util::BoxCloneService::new(tower::service_fn(
328 |_req: tower_mcp::RouterRequest| async {
329 Ok::<_, std::convert::Infallible>(tower_mcp::RouterResponse {
330 id: RequestId::Number(1),
331 inner: Ok(McpResponse::Pong(Default::default())),
332 })
333 },
334 ));
335 let (_, handle) =
336 tower_mcp::transport::http::HttpTransport::from_service(svc).into_router_with_handle();
337 handle
338 }
339
340 fn make_admin_state() -> AdminState {
341 crate::admin::test_admin_state("test-proxy", "0.1.0", 0, vec![])
342 }
343
344 async fn make_test_proxy() -> McpProxy {
345 let router = McpRouter::new().server_info("test", "1.0.0").tool(
346 ToolBuilder::new("ping")
347 .description("Ping")
348 .handler(|_: tower_mcp::NoParams| async move { Ok(CallToolResult::text("pong")) })
349 .build(),
350 );
351
352 McpProxy::builder("test-proxy", "1.0.0")
353 .backend("test", ChannelTransport::new(router))
354 .await
355 .build_strict()
356 .await
357 .unwrap()
358 }
359
360 async fn list_tools(proxy: &mut McpProxy) -> Vec<String> {
361 let req = RouterRequest {
362 id: RequestId::Number(1),
363 inner: McpRequest::ListTools(ListToolsParams {
364 cursor: None,
365 meta: None,
366 }),
367 extensions: Extensions::new(),
368 };
369 let resp = proxy.call(req).await.expect("infallible");
370 match resp.inner.unwrap() {
371 McpResponse::ListTools(result) => result.tools.into_iter().map(|t| t.name).collect(),
372 other => panic!("expected ListTools, got: {other:?}"),
373 }
374 }
375
376 #[tokio::test]
377 async fn test_build_admin_router_has_expected_tools() {
378 let proxy = make_test_proxy().await;
379 let state = AdminToolState {
380 admin_state: make_admin_state(),
381 session_handle: make_session_handle(),
382 config_snapshot: Arc::new("# empty config".to_string()),
383 proxy: proxy.clone(),
384 };
385
386 let router = build_admin_router(state, None, false, vec![]);
387 let transport = ChannelTransport::new(router);
388
389 let mut test_proxy = McpProxy::builder("verify", "1.0.0")
390 .backend("admin", transport)
391 .await
392 .build_strict()
393 .await
394 .unwrap();
395
396 let tools = list_tools(&mut test_proxy).await;
397 assert!(tools.contains(&"admin_list_backends".to_string()));
398 assert!(tools.contains(&"admin_health_check".to_string()));
399 assert!(tools.contains(&"admin_session_count".to_string()));
400 assert!(tools.contains(&"admin_add_backend".to_string()));
401 assert!(tools.contains(&"admin_config".to_string()));
402 assert!(!tools.contains(&"admin_call_tool".to_string()));
404 }
405
406 #[tokio::test]
407 async fn test_search_mode_adds_call_tool() {
408 let proxy = make_test_proxy().await;
409 let state = AdminToolState {
410 admin_state: make_admin_state(),
411 session_handle: make_session_handle(),
412 config_snapshot: Arc::new(String::new()),
413 proxy: proxy.clone(),
414 };
415
416 let router = build_admin_router(state, None, true, vec![]);
417 let transport = ChannelTransport::new(router);
418
419 let mut test_proxy = McpProxy::builder("verify", "1.0.0")
420 .backend("admin", transport)
421 .await
422 .build_strict()
423 .await
424 .unwrap();
425
426 let tools = list_tools(&mut test_proxy).await;
427 assert!(
428 tools.contains(&"admin_call_tool".to_string()),
429 "search mode should add call_tool, got: {tools:?}"
430 );
431 }
432
433 #[tokio::test]
434 async fn test_discovery_tools_included() {
435 let proxy = make_test_proxy().await;
436 let state = AdminToolState {
437 admin_state: make_admin_state(),
438 session_handle: make_session_handle(),
439 config_snapshot: Arc::new(String::new()),
440 proxy: proxy.clone(),
441 };
442
443 let extra_tool = ToolBuilder::new("search_tools")
444 .description("Search for tools")
445 .handler(
446 |_: tower_mcp::NoParams| async move { Ok(CallToolResult::text("search results")) },
447 )
448 .build();
449
450 let router = build_admin_router(state, Some(vec![extra_tool]), false, vec![]);
451 let transport = ChannelTransport::new(router);
452
453 let mut test_proxy = McpProxy::builder("verify", "1.0.0")
454 .backend("admin", transport)
455 .await
456 .build_strict()
457 .await
458 .unwrap();
459
460 let tools = list_tools(&mut test_proxy).await;
461 assert!(
462 tools.contains(&"admin_search_tools".to_string()),
463 "discovery tool should be included, got: {tools:?}"
464 );
465 }
466
467 #[tokio::test]
468 async fn test_config_tool_returns_snapshot() {
469 let config_text = "[proxy]\nname = \"test\"\n".to_string();
470 let proxy = make_test_proxy().await;
471 let state = AdminToolState {
472 admin_state: make_admin_state(),
473 session_handle: make_session_handle(),
474 config_snapshot: Arc::new(config_text.clone()),
475 proxy: proxy.clone(),
476 };
477
478 let router = build_admin_router(state, None, false, vec![]);
479 let transport = ChannelTransport::new(router);
480
481 let mut test_proxy = McpProxy::builder("verify", "1.0.0")
482 .backend("admin", transport)
483 .await
484 .build_strict()
485 .await
486 .unwrap();
487
488 let req = RouterRequest {
489 id: RequestId::Number(1),
490 inner: McpRequest::CallTool(CallToolParams {
491 name: "admin_config".to_string(),
492 arguments: serde_json::json!({}),
493 meta: None,
494 task: None,
495 }),
496 extensions: Extensions::new(),
497 };
498 let resp = test_proxy.call(req).await.expect("infallible");
499 match resp.inner.unwrap() {
500 McpResponse::CallTool(result) => {
501 let text = result.all_text();
502 assert!(
503 text.contains("[proxy]"),
504 "config tool should return the config snapshot, got: {text}"
505 );
506 }
507 other => panic!("expected CallTool, got: {other:?}"),
508 }
509 }
510}