#![allow(clippy::needless_continue)]
use darling::FromDeriveInput;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
const REQUIRED_FIELDS: &[&str] = &[
"code_mode_config",
"token_secret",
"policy_evaluator",
"code_executor",
];
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(code_mode))]
struct CodeModeOpts {
ident: syn::Ident,
data: darling::ast::Data<(), CodeModeField>,
#[darling(default)]
context_from: Option<String>,
#[darling(default)]
language: Option<String>,
}
#[derive(Debug, Clone, darling::FromField)]
#[darling(attributes(code_mode))]
struct CodeModeField {
ident: Option<syn::Ident>,
}
#[proc_macro_derive(CodeMode, attributes(code_mode))]
pub fn code_mode_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand_code_mode(&input)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
fn expand_code_mode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let opts = CodeModeOpts::from_derive_input(input)
.map_err(|e| syn::Error::new_spanned(input, e.to_string()))?;
let struct_name = &opts.ident;
let fields = match &opts.data {
darling::ast::Data::Struct(ref fields) => &fields.fields,
darling::ast::Data::Enum(_) => {
return Err(syn::Error::new_spanned(
input,
"#[derive(CodeMode)] can only be applied to structs with named fields",
));
},
};
let field_names: Vec<String> = fields
.iter()
.filter_map(|f| f.ident.as_ref().map(|i| i.to_string()))
.collect();
let missing: Vec<&str> = REQUIRED_FIELDS
.iter()
.filter(|&&name| !field_names.contains(&name.to_string()))
.copied()
.collect();
if !missing.is_empty() {
let all_required = REQUIRED_FIELDS.join(", ");
let missing_msgs: Vec<String> = missing
.iter()
.map(|&name| {
let type_hint = match name {
"code_mode_config" => "CodeModeConfig",
"token_secret" => "TokenSecret",
"policy_evaluator" => "Arc<dyn PolicyEvaluator>",
"code_executor" => "Arc<dyn CodeExecutor>",
_ => "unknown",
};
format!(
"#[derive(CodeMode)] requires field `{name}` (type: {type_hint}).\n\
Required fields: {all_required}"
)
})
.collect();
let msg = missing_msgs.join("\n\n");
return Err(syn::Error::new_spanned(&input.ident, msg));
}
let mod_name = syn::Ident::new(
&format!(
"__code_mode_impl_{}",
struct_name.to_string().to_lowercase()
),
struct_name.span(),
);
let language = opts.language.as_deref().unwrap_or("graphql");
let language_lit = syn::LitStr::new(language, struct_name.span());
let validation_call = gen_validation_call(language, &input.ident)?;
let expanded = if let Some(ref method_name) = opts.context_from {
if method_name.is_empty() || syn::parse_str::<syn::Ident>(method_name).is_err() {
return Err(syn::Error::new_spanned(
&input.ident,
format!("`context_from = \"{method_name}\"` is not a valid Rust identifier"),
));
}
let method_ident = syn::Ident::new(method_name, struct_name.span());
expand_with_context_from(
struct_name,
&mod_name,
&language_lit,
&method_ident,
&validation_call,
)
} else {
expand_without_context_from(struct_name, &mod_name, &language_lit, &validation_call)
};
Ok(expanded)
}
fn gen_validation_call(
language: &str,
error_span: &syn::Ident,
) -> Result<proc_macro2::TokenStream, syn::Error> {
let map_err = quote! {
.map_err(|e| pmcp::Error::Internal(format!("Validation error: {}", e)))?
};
match language {
"graphql" => Ok(quote! {
self.pipeline.validate_graphql_query_async(code, &context).await #map_err
}),
"javascript" | "js" => Ok(quote! {
self.pipeline.validate_javascript_code(code, &context) #map_err
}),
"sql" => Ok(quote! {
self.pipeline.validate_sql_query(code, &context) #map_err
}),
"mcp" => Ok(quote! {
self.pipeline.validate_mcp_composition(code, &context).await #map_err
}),
other => Err(syn::Error::new_spanned(
error_span,
format!(
"`language = \"{other}\"` is not a supported language. \
Supported values: \"graphql\" (default), \"javascript\" (requires `openapi-code-mode`), \
\"sql\" (requires `sql-code-mode`), \"mcp\" (requires `mcp-code-mode`)"
),
)),
}
}
fn expand_with_context_from(
struct_name: &syn::Ident,
mod_name: &syn::Ident,
language_lit: &syn::LitStr,
method_ident: &syn::Ident,
validation_call: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
quote! {
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<#struct_name>();
};
#[doc(hidden)]
#[allow(non_snake_case)]
mod #mod_name {
use super::*;
use std::sync::Arc;
use pmcp_code_mode::TokenGenerator as _;
pub(super) struct ValidateCodeHandler {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) config: pmcp_code_mode::CodeModeConfig,
pub(super) parent: Arc<#struct_name>,
}
#[pmcp_code_mode::async_trait]
impl pmcp::ToolHandler for ValidateCodeHandler {
async fn handle(
&self,
args: serde_json::Value,
extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ValidateCodeInput =
serde_json::from_value(args).map_err(|e| {
pmcp::Error::Internal(format!("Invalid arguments: {}", e))
})?;
let code = input.code.trim();
let dry_run = input.dry_run.unwrap_or(false);
let context = self.parent.#method_ident(&extra);
let result = #validation_call;
let response = pmcp_code_mode::ValidationResponse::success(
result.explanation.clone(),
result.risk_level,
if dry_run {
String::new()
} else {
result.approval_token.clone().unwrap_or_default()
},
result.metadata.clone(),
)
.with_warnings(result.warnings.clone())
.with_auto_approved(self.config.should_auto_approve(result.risk_level));
let (json, _is_error) = response.to_json_response();
Ok(json)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(pmcp_code_mode::CodeModeToolBuilder::new(#language_lit).build_validate_tool())
}
}
pub(super) struct ExecuteCodeHandler<E: pmcp_code_mode::CodeExecutor + 'static> {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) executor: Arc<E>,
}
#[pmcp_code_mode::async_trait]
impl<E: pmcp_code_mode::CodeExecutor + 'static> pmcp::ToolHandler for ExecuteCodeHandler<E> {
async fn handle(
&self,
args: serde_json::Value,
_extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ExecuteCodeInput =
serde_json::from_value(args).map_err(|e| {
pmcp::Error::Internal(format!("Invalid arguments: {}", e))
})?;
let code = input.code.trim();
let token_gen = self.pipeline.token_generator();
let token = pmcp_code_mode::ApprovalToken::decode(&input.approval_token)
.map_err(|e| pmcp::Error::Internal(
format!("Invalid approval token: {}", e),
))?;
token_gen.verify(&token)
.map_err(|e| pmcp::Error::Internal(
format!("Token verification failed: {}", e),
))?;
token_gen.verify_code(code, &token)
.map_err(|e| pmcp::Error::Internal(
format!("Code verification failed: {}", e),
))?;
let result = self.executor.execute(code, input.variables.as_ref()).await
.map_err(|e| pmcp::Error::Internal(
format!("Execution error: {}", e),
))?;
Ok(result)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(pmcp_code_mode::CodeModeToolBuilder::new(#language_lit).build_execute_tool())
}
}
}
impl #struct_name {
pub fn register_code_mode_tools(
self: &std::sync::Arc<Self>,
builder: pmcp::ServerBuilder,
) -> Result<pmcp::ServerBuilder, pmcp_code_mode::TokenError> {
let pipeline = std::sync::Arc::new(
pmcp_code_mode::ValidationPipeline::from_token_secret_with_policy(
self.code_mode_config.clone(),
&self.token_secret,
std::sync::Arc::clone(&self.policy_evaluator) as std::sync::Arc<dyn pmcp_code_mode::PolicyEvaluator>,
)?
);
let validate_handler = #mod_name::ValidateCodeHandler {
pipeline: std::sync::Arc::clone(&pipeline),
config: self.code_mode_config.clone(),
parent: std::sync::Arc::clone(self),
};
let execute_handler = #mod_name::ExecuteCodeHandler {
pipeline,
executor: std::sync::Arc::clone(&self.code_executor),
};
Ok(builder
.tool("validate_code", validate_handler)
.tool("execute_code", execute_handler))
}
}
}
}
fn expand_without_context_from(
struct_name: &syn::Ident,
mod_name: &syn::Ident,
language_lit: &syn::LitStr,
validation_call: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
quote! {
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<#struct_name>();
};
#[doc(hidden)]
#[allow(non_snake_case)]
mod #mod_name {
use super::*;
use std::sync::Arc;
use pmcp_code_mode::TokenGenerator as _;
pub(super) struct ValidateCodeHandler {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) config: pmcp_code_mode::CodeModeConfig,
}
#[pmcp_code_mode::async_trait]
impl pmcp::ToolHandler for ValidateCodeHandler {
async fn handle(
&self,
args: serde_json::Value,
_extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ValidateCodeInput =
serde_json::from_value(args).map_err(|e| {
pmcp::Error::Internal(format!("Invalid arguments: {}", e))
})?;
let code = input.code.trim();
let dry_run = input.dry_run.unwrap_or(false);
let context = pmcp_code_mode::ValidationContext::new(
"anonymous",
"session",
"schema",
"perms",
);
let result = #validation_call;
let response = pmcp_code_mode::ValidationResponse::success(
result.explanation.clone(),
result.risk_level,
if dry_run {
String::new()
} else {
result.approval_token.clone().unwrap_or_default()
},
result.metadata.clone(),
)
.with_warnings(result.warnings.clone())
.with_auto_approved(self.config.should_auto_approve(result.risk_level));
let (json, _is_error) = response.to_json_response();
Ok(json)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(pmcp_code_mode::CodeModeToolBuilder::new(#language_lit).build_validate_tool())
}
}
pub(super) struct ExecuteCodeHandler<E: pmcp_code_mode::CodeExecutor + 'static> {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) executor: Arc<E>,
}
#[pmcp_code_mode::async_trait]
impl<E: pmcp_code_mode::CodeExecutor + 'static> pmcp::ToolHandler for ExecuteCodeHandler<E> {
async fn handle(
&self,
args: serde_json::Value,
_extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ExecuteCodeInput =
serde_json::from_value(args).map_err(|e| {
pmcp::Error::Internal(format!("Invalid arguments: {}", e))
})?;
let code = input.code.trim();
let token_gen = self.pipeline.token_generator();
let token = pmcp_code_mode::ApprovalToken::decode(&input.approval_token)
.map_err(|e| pmcp::Error::Internal(
format!("Invalid approval token: {}", e),
))?;
token_gen.verify(&token)
.map_err(|e| pmcp::Error::Internal(
format!("Token verification failed: {}", e),
))?;
token_gen.verify_code(code, &token)
.map_err(|e| pmcp::Error::Internal(
format!("Code verification failed: {}", e),
))?;
let result = self.executor.execute(code, input.variables.as_ref()).await
.map_err(|e| pmcp::Error::Internal(
format!("Execution error: {}", e),
))?;
Ok(result)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(pmcp_code_mode::CodeModeToolBuilder::new(#language_lit).build_execute_tool())
}
}
}
impl #struct_name {
#[deprecated(note = "Use #[code_mode(context_from = \"method_name\")] for production. This uses placeholder ValidationContext.")]
pub fn register_code_mode_tools(
&self,
builder: pmcp::ServerBuilder,
) -> Result<pmcp::ServerBuilder, pmcp_code_mode::TokenError> {
let pipeline = std::sync::Arc::new(
pmcp_code_mode::ValidationPipeline::from_token_secret_with_policy(
self.code_mode_config.clone(),
&self.token_secret,
std::sync::Arc::clone(&self.policy_evaluator) as std::sync::Arc<dyn pmcp_code_mode::PolicyEvaluator>,
)?
);
let validate_handler = #mod_name::ValidateCodeHandler {
pipeline: std::sync::Arc::clone(&pipeline),
config: self.code_mode_config.clone(),
};
let execute_handler = #mod_name::ExecuteCodeHandler {
pipeline,
executor: std::sync::Arc::clone(&self.code_executor),
};
Ok(builder
.tool("validate_code", validate_handler)
.tool("execute_code", execute_handler))
}
}
}
}