use std::sync::Arc;
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::Value;
use tower_mcp::extract::{Json, State};
use tower_mcp::{CallToolResult, Error as McpError, McpRouter, ResultExt, Tool, ToolBuilder};
use crate::state::AppState;
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(rename_all = "UPPERCASE")]
pub enum HttpMethod {
Get,
Post,
Put,
Patch,
Delete,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct EnterpriseRawApiInput {
pub method: HttpMethod,
pub path: String,
#[serde(default)]
pub body: Option<Value>,
#[serde(default)]
pub profile: Option<String>,
#[serde(default)]
pub dry_run: bool,
}
pub fn enterprise_raw_api(state: Arc<AppState>) -> Tool {
ToolBuilder::new("enterprise_raw_api")
.description(
"DANGEROUS: Execute a raw REST API request against the Enterprise API. \
Escape hatch for endpoints not covered by dedicated tools.",
)
.destructive()
.extractor_handler(
state,
|State(state): State<Arc<AppState>>,
Json(input): Json<EnterpriseRawApiInput>| async move {
match input.method {
HttpMethod::Get => {
if !state.is_write_allowed() {
return Err(McpError::tool(
"enterprise_raw_api GET requires at least read-write tier",
));
}
}
HttpMethod::Post
| HttpMethod::Put
| HttpMethod::Patch
| HttpMethod::Delete => {
if !state.is_destructive_allowed() {
return Err(McpError::tool(
"enterprise_raw_api mutating methods require full tier",
));
}
}
}
if !input.path.starts_with('/') {
return Err(McpError::tool("path must start with '/'"));
}
if input.dry_run {
let preview = serde_json::json!({
"dry_run": true,
"method": format!("{:?}", input.method).to_uppercase(),
"path": input.path,
"body": input.body,
"profile": input.profile,
});
return CallToolResult::from_serialize(&preview);
}
let client = state
.enterprise_client_for_profile(input.profile.as_deref())
.await
.map_err(|e| crate::tools::credential_error("enterprise", e))?;
let result: Value = match input.method {
HttpMethod::Get => client
.get_raw(&input.path)
.await
.tool_context("enterprise_raw_api GET failed")?,
HttpMethod::Post => client
.post_raw(
&input.path,
input.body.unwrap_or(Value::Object(Default::default())),
)
.await
.tool_context("enterprise_raw_api POST failed")?,
HttpMethod::Put => client
.put_raw(
&input.path,
input.body.unwrap_or(Value::Object(Default::default())),
)
.await
.tool_context("enterprise_raw_api PUT failed")?,
HttpMethod::Patch => client
.patch_raw(
&input.path,
input.body.unwrap_or(Value::Object(Default::default())),
)
.await
.tool_context("enterprise_raw_api PATCH failed")?,
HttpMethod::Delete => client
.delete_raw(&input.path)
.await
.tool_context("enterprise_raw_api DELETE failed")?,
};
CallToolResult::from_serialize(&result)
},
)
.build()
}
pub(super) const TOOL_NAMES: &[&str] = &["enterprise_raw_api"];
pub fn router(state: Arc<AppState>) -> McpRouter {
McpRouter::new().tool(enterprise_raw_api(state))
}