#![cfg(feature = "code-mode")]
pub use pmcp_code_mode::{
canonicalize_code, compute_context_hash, hash_code, ApprovalToken, AuthorizationDecision,
CodeExecutor, CodeModeConfig, ExecutionError, HmacTokenGenerator, NoopPolicyEvaluator,
PolicyEvaluator, TokenGenerator, TokenSecret, ValidationContext, ValidationPipeline,
};
#[cfg(feature = "avp")]
pub use pmcp_code_mode::{AvpClient, AvpConfig, AvpPolicyEvaluator};
#[cfg(feature = "openapi-code-mode")]
pub use pmcp_code_mode::{ExecutionConfig, HttpExecutor, JsCodeExecutor};
use std::sync::Arc;
use crate::config::{CodeModeSection, ServerConfig};
use crate::error::{ConfigValidationError, Result, ToolkitError};
use crate::secrets::SecretValue;
use crate::sql::{Dialect, SqlConnector};
#[cfg(feature = "code-mode")]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ValidationFlavor {
Sql,
OpenApi,
}
#[cfg(feature = "code-mode")]
impl ValidationFlavor {
fn code_format(self) -> &'static str {
match self {
Self::Sql => "sql",
Self::OpenApi => "openapi",
}
}
}
#[cfg(feature = "openapi-code-mode")]
#[must_use]
pub fn request_executor_from_extra(
base: &HttpCodeExecutor,
extra: &pmcp::RequestHandlerExtra,
) -> HttpCodeExecutor {
let token = extra.auth_context().and_then(|ctx| ctx.token.clone());
base.clone().with_inbound_token(token)
}
pub fn validation_pipeline_from_config(config: &ServerConfig) -> Result<ValidationPipeline> {
let section = config.code_mode.as_ref().ok_or_else(|| {
ToolkitError::CodeMode("ServerConfig has no [code_mode] block".to_string())
})?;
let cm_config = build_cm_config(section);
let secret_value = resolve_token_secret(section)?;
let token_secret: TokenSecret = secret_value.into(); ValidationPipeline::from_token_secret(cm_config, &token_secret)
.map_err(|e| ToolkitError::CodeMode(format!("ValidationPipeline construction failed: {e}")))
}
pub fn code_mode_tools_from_executor(
builder: pmcp::ServerBuilder,
config: &ServerConfig,
executor: Arc<dyn CodeExecutor>,
flavor: ValidationFlavor,
) -> Result<pmcp::ServerBuilder> {
let Some(section) = config.code_mode.as_ref() else {
return Ok(builder); };
let cm_config = build_cm_config(section);
let secret_value = resolve_token_secret(section)?;
let token_secret: TokenSecret = secret_value.into();
let evaluator: Arc<dyn PolicyEvaluator> = Arc::new(NoopPolicyEvaluator::new());
let pipeline = ValidationPipeline::from_token_secret_with_policy(
cm_config.clone(),
&token_secret,
evaluator,
)
.map_err(|e| ToolkitError::CodeMode(format!("ValidationPipeline construction failed: {e}")))?;
let pipeline = Arc::new(pipeline);
let validate_handler = tool_handlers::ValidateCodeHandler {
pipeline: Arc::clone(&pipeline),
config: cm_config,
flavor,
};
let execute_handler = tool_handlers::ExecuteCodeHandler {
pipeline,
source: tool_handlers::ExecSource::Static(executor),
flavor,
};
Ok(builder
.tool_arc("validate_code", Arc::new(validate_handler))
.tool_arc("execute_code", Arc::new(execute_handler)))
}
#[cfg(feature = "openapi-code-mode")]
pub fn code_mode_http_tools_from_executor(
builder: pmcp::ServerBuilder,
config: &ServerConfig,
base: HttpCodeExecutor,
exec_config: ExecutionConfig,
flavor: ValidationFlavor,
) -> Result<pmcp::ServerBuilder> {
let Some(section) = config.code_mode.as_ref() else {
return Ok(builder); };
let cm_config = build_cm_config(section);
let secret_value = resolve_token_secret(section)?;
let token_secret: TokenSecret = secret_value.into();
let evaluator: Arc<dyn PolicyEvaluator> = Arc::new(NoopPolicyEvaluator::new());
let pipeline = ValidationPipeline::from_token_secret_with_policy(
cm_config.clone(),
&token_secret,
evaluator,
)
.map_err(|e| ToolkitError::CodeMode(format!("ValidationPipeline construction failed: {e}")))?;
let pipeline = Arc::new(pipeline);
let validate_handler = tool_handlers::ValidateCodeHandler {
pipeline: Arc::clone(&pipeline),
config: cm_config,
flavor,
};
let execute_handler = tool_handlers::ExecuteCodeHandler {
pipeline,
source: tool_handlers::ExecSource::PerRequestHttp { base, exec_config },
flavor,
};
Ok(builder
.tool_arc("validate_code", Arc::new(validate_handler))
.tool_arc("execute_code", Arc::new(execute_handler)))
}
pub fn register_code_mode_tools(
builder: pmcp::ServerBuilder,
config: &ServerConfig,
) -> Result<pmcp::ServerBuilder> {
if config.code_mode.is_none() {
return Ok(builder); }
let _pipeline = validation_pipeline_from_config(config)?;
Ok(builder)
}
mod tool_handlers {
use std::sync::Arc;
use super::ValidationFlavor;
use pmcp_code_mode::TokenGenerator as _;
fn run_flavored_validation(
pipeline: &pmcp_code_mode::ValidationPipeline,
flavor: ValidationFlavor,
code: &str,
context: &pmcp_code_mode::ValidationContext,
) -> std::result::Result<pmcp_code_mode::ValidationResult, String> {
match flavor {
ValidationFlavor::Sql => pipeline
.validate_sql_query(code, context)
.map_err(|e| format!("Validation error: {e}")),
#[cfg(feature = "openapi-code-mode")]
ValidationFlavor::OpenApi => pipeline
.validate_javascript_code(code, context)
.map_err(|e| format!("Validation error: {e}")),
#[cfg(not(feature = "openapi-code-mode"))]
ValidationFlavor::OpenApi => Err(
"OpenAPI Code Mode validation requires the `openapi-code-mode` feature".to_string(),
),
}
}
pub(super) struct ValidateCodeHandler {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) config: pmcp_code_mode::CodeModeConfig,
pub(super) flavor: ValidationFlavor,
}
#[pmcp_code_mode::async_trait]
impl pmcp::ToolHandler for ValidateCodeHandler {
async fn handle(
&self,
args: serde_json::Value,
_extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ValidateCodeInput = serde_json::from_value(args)
.map_err(|e| pmcp::Error::Internal(format!("Invalid arguments: {e}")))?;
let code = input.code.trim();
let dry_run = input.dry_run.unwrap_or(false);
let context = pmcp_code_mode::ValidationContext::new(
"code-mode-config",
"code-mode-session",
"schema-hash",
"perms-hash",
);
let result = run_flavored_validation(&self.pipeline, self.flavor, code, &context)
.map_err(pmcp::Error::Internal)?;
let mut response = pmcp_code_mode::ValidationResponse::from_result(result);
if response.result.is_valid {
if dry_run {
response.result.approval_token = None;
}
let risk = response.result.risk_level;
response = response.with_auto_approved(self.config.should_auto_approve(risk));
}
let (json, is_error) = response.to_json_response();
if is_error {
let message = response
.result
.violations
.first()
.map(ToString::to_string)
.unwrap_or_else(|| {
"Code Mode rejected the query (policy validation failed)".to_string()
});
return Err(pmcp::Error::tool_rejected(message, Some(json)));
}
Ok(json)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(
pmcp_code_mode::CodeModeToolBuilder::new(self.flavor.code_format())
.build_validate_tool(),
)
}
}
pub(super) enum ExecSource {
Static(Arc<dyn pmcp_code_mode::CodeExecutor>),
#[cfg(feature = "openapi-code-mode")]
PerRequestHttp {
base: super::HttpCodeExecutor,
exec_config: super::ExecutionConfig,
},
}
pub(super) struct ExecuteCodeHandler {
pub(super) pipeline: Arc<pmcp_code_mode::ValidationPipeline>,
pub(super) source: ExecSource,
pub(super) flavor: ValidationFlavor,
}
impl ExecuteCodeHandler {
async fn run_code(
&self,
code: &str,
variables: Option<&serde_json::Value>,
#[cfg_attr(not(feature = "openapi-code-mode"), allow(unused_variables))]
extra: &pmcp::RequestHandlerExtra,
) -> std::result::Result<serde_json::Value, pmcp_code_mode::ExecutionError> {
use pmcp_code_mode::CodeExecutor as _;
match &self.source {
ExecSource::Static(executor) => executor.execute(code, variables).await,
#[cfg(feature = "openapi-code-mode")]
ExecSource::PerRequestHttp { base, exec_config } => {
let http_exec = super::request_executor_from_extra(base, extra);
super::JsCodeExecutor::new(http_exec, exec_config.clone())
.execute(code, variables)
.await
},
}
}
}
#[pmcp_code_mode::async_trait]
impl pmcp::ToolHandler for ExecuteCodeHandler {
async fn handle(
&self,
args: serde_json::Value,
extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<serde_json::Value> {
let input: pmcp_code_mode::ExecuteCodeInput = serde_json::from_value(args)
.map_err(|e| pmcp::Error::Internal(format!("Invalid arguments: {e}")))?;
let code = input.code.trim();
let token_gen = self.pipeline.token_generator();
let token =
pmcp_code_mode::ApprovalToken::decode(&input.approval_token).map_err(|e| {
pmcp::Error::tool_rejected(
format!(
"Invalid approval_token: {e}. Call validate_code to obtain a valid token."
),
None,
)
})?;
token_gen.verify(&token).map_err(|e| {
pmcp::Error::tool_rejected(
format!(
"Approval token is invalid or expired: {e}. \
Call validate_code again to obtain a fresh token."
),
None,
)
})?;
token_gen.verify_code(code, &token).map_err(|e| {
pmcp::Error::tool_rejected(
format!(
"Code does not match the validated code: {e}. execute_code must use the \
exact code string that was passed to validate_code."
),
None,
)
})?;
let result = self
.run_code(code, input.variables.as_ref(), &extra)
.await
.map_err(|e| pmcp::Error::Internal(format!("Execution error: {e}")))?;
Ok(result)
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(
pmcp_code_mode::CodeModeToolBuilder::new(self.flavor.code_format())
.build_execute_tool(),
)
}
}
}
pub struct SqlCodeExecutor {
connector: Arc<dyn SqlConnector>,
pipeline: Arc<ValidationPipeline>,
}
impl SqlCodeExecutor {
pub fn new(connector: Arc<dyn SqlConnector>, config: ServerConfig) -> Result<Self> {
let pipeline = Arc::new(validation_pipeline_from_config(&config)?);
Ok(Self {
connector,
pipeline,
})
}
fn revalidate(&self, code: &str) -> std::result::Result<(), ExecutionError> {
let ctx = ValidationContext::new(
"code-mode-executor",
"code-mode-session",
"schema-hash",
"perms-hash",
);
let result = self
.pipeline
.validate_sql_query(code, &ctx)
.map_err(|e| ExecutionError::BackendError(format!("SQL validation failed: {e}")))?;
if !result.is_valid {
return Err(ExecutionError::BackendError(
"SQL rejected by [code_mode] policy on re-validation".to_string(),
));
}
Ok(())
}
}
fn variables_to_params(variables: Option<&serde_json::Value>) -> Vec<(String, serde_json::Value)> {
let Some(serde_json::Value::Object(map)) = variables else {
return Vec::new();
};
map.iter()
.map(|(k, v)| {
let key = k.strip_prefix(':').unwrap_or(k).to_string();
(key, v.clone())
})
.collect()
}
#[pmcp_code_mode::async_trait]
impl CodeExecutor for SqlCodeExecutor {
async fn execute(
&self,
code: &str,
variables: Option<&serde_json::Value>,
) -> std::result::Result<serde_json::Value, ExecutionError> {
self.revalidate(code)?;
let params = variables_to_params(variables);
let rows =
self.connector.execute(code, ¶ms).await.map_err(|e| {
ExecutionError::BackendError(format!("connector execute failed: {e}"))
})?;
Ok(serde_json::json!({ "rows": rows }))
}
}
#[cfg(feature = "openapi-code-mode")]
#[derive(Clone)]
pub struct HttpCodeExecutor {
client: reqwest::Client,
base_url: String,
auth: Arc<dyn crate::http::auth::HttpAuthProvider>,
inbound_token: Option<String>,
}
#[cfg(feature = "openapi-code-mode")]
impl HttpCodeExecutor {
#[must_use]
pub fn new(
client: reqwest::Client,
base_url: String,
auth: Arc<dyn crate::http::auth::HttpAuthProvider>,
) -> Self {
Self {
client,
base_url,
auth,
inbound_token: None,
}
}
#[must_use]
pub fn with_inbound_token(mut self, token: Option<String>) -> Self {
self.inbound_token = token;
self
}
#[cfg(test)]
pub(crate) fn inbound_token_for_test(&self) -> Option<&str> {
self.inbound_token.as_deref()
}
fn resolve_path(
path: &str,
body: &Option<serde_json::Value>,
) -> std::result::Result<(String, Option<serde_json::Value>), ExecutionError> {
let mut resolved_path = path.to_string();
let remaining = if let Some(serde_json::Value::Object(obj)) = body {
let mut remaining = serde_json::Map::new();
for (key, value) in obj {
let placeholder = format!("{{{key}}}");
if resolved_path.contains(&placeholder) {
resolved_path =
resolved_path.replace(&placeholder, &Self::scalar_str(key, value)?);
} else {
remaining.insert(key.clone(), value.clone());
}
}
if remaining.is_empty() {
None
} else {
Some(serde_json::Value::Object(remaining))
}
} else {
body.clone()
};
Ok((resolved_path, remaining))
}
fn scalar_str(
key: &str,
value: &serde_json::Value,
) -> std::result::Result<String, ExecutionError> {
match value {
serde_json::Value::String(s) => Ok(s.clone()),
serde_json::Value::Null => Ok("null".to_string()),
serde_json::Value::Number(n) => Ok(n.to_string()),
serde_json::Value::Bool(b) => Ok(b.to_string()),
serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
Err(ExecutionError::RuntimeError {
message: format!("path/query param '{key}' must be a scalar"),
})
},
}
}
}
#[cfg(feature = "openapi-code-mode")]
#[pmcp_code_mode::async_trait]
impl pmcp_code_mode::HttpExecutor for HttpCodeExecutor {
async fn execute_request(
&self,
method: &str,
path: &str,
body: Option<serde_json::Value>,
) -> std::result::Result<serde_json::Value, ExecutionError> {
let upper = method.to_uppercase();
let is_get_like = matches!(upper.as_str(), "GET" | "HEAD" | "OPTIONS");
let (resolved_path, remaining_body) = Self::resolve_path(path, &body)?;
let url = crate::http::join_url(&self.base_url, &resolved_path);
let mut headers = reqwest::header::HeaderMap::new();
let mut auth_query: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
self.auth
.apply(&mut headers, &mut auth_query, self.inbound_token.as_deref())
.await
.map_err(|_| ExecutionError::RuntimeError {
message: "authentication failed for outgoing request".to_string(),
})?;
let mut query_params: Vec<(String, String)> = auth_query.into_iter().collect();
let request_body = if is_get_like {
if let Some(serde_json::Value::Object(obj)) = &remaining_body {
for (key, value) in obj {
query_params.push((key.clone(), Self::scalar_str(key, value)?));
}
}
None
} else {
remaining_body
};
let final_url = if query_params.is_empty() {
url
} else {
let mut parsed = url::Url::parse(&url).map_err(|_| ExecutionError::RuntimeError {
message: "could not construct the request URL".to_string(),
})?;
{
let mut pairs = parsed.query_pairs_mut();
for (k, v) in &query_params {
pairs.append_pair(k, v);
}
}
parsed.to_string()
};
let mut request = match upper.as_str() {
"GET" => self.client.get(&final_url),
"POST" => self.client.post(&final_url),
"PUT" => self.client.put(&final_url),
"DELETE" => self.client.delete(&final_url),
"PATCH" => self.client.patch(&final_url),
"HEAD" => self.client.head(&final_url),
_ => {
return Err(ExecutionError::RuntimeError {
message: "unsupported HTTP method".to_string(),
})
},
};
request = request.headers(headers);
if let Some(b) = request_body {
request = request.header("Content-Type", "application/json").json(&b);
}
let response = request
.send()
.await
.map_err(|_| ExecutionError::RuntimeError {
message: "outgoing HTTP request failed".to_string(),
})?;
let status = response.status();
let text = response
.text()
.await
.map_err(|_| ExecutionError::RuntimeError {
message: "failed to read response body".to_string(),
})?;
if !status.is_success() {
return Err(ExecutionError::RuntimeError {
message: format!("backend returned HTTP status {}", status.as_u16()),
});
}
if text.is_empty() {
return Ok(serde_json::Value::Null);
}
serde_json::from_str(&text).map_err(|_| ExecutionError::RuntimeError {
message: "failed to parse response body as JSON".to_string(),
})
}
}
fn build_cm_config(section: &CodeModeSection) -> CodeModeConfig {
let mut cfg = CodeModeConfig {
enabled: section.enabled,
sql_allow_writes: section.allow_writes,
sql_allow_deletes: section.allow_deletes,
sql_allow_ddl: section.allow_ddl,
sql_blocked_tables: section.blocked_tables.iter().cloned().collect(),
sql_blocked_columns: section.sensitive_columns.iter().cloned().collect(),
..CodeModeConfig::default()
};
if let Some(ref sid) = section.server_id {
cfg.server_id = Some(sid.clone());
}
if let Some(ttl) = section.token_ttl_seconds {
cfg.token_ttl_seconds = i64::try_from(ttl).unwrap_or(i64::MAX);
}
map_auto_approve_levels(§ion.auto_approve_levels, &mut cfg);
if let Some(max) = section.max_limit {
cfg.sql_max_rows = max;
}
cfg.sql_require_limit = section.require_limit;
if let Some(ref limits) = section.limits {
let _gap_max_tables = limits.max_tables_per_query;
let _gap_max_join = limits.max_join_depth;
let _gap_max_subquery = limits.max_subquery_depth;
}
cfg
}
fn map_auto_approve_levels(levels: &[String], cfg: &mut CodeModeConfig) {
use pmcp_code_mode::RiskLevel;
let mut out = Vec::with_capacity(levels.len());
for level in levels {
match level.to_ascii_lowercase().as_str() {
"low" => out.push(RiskLevel::Low),
"medium" => out.push(RiskLevel::Medium),
"high" => out.push(RiskLevel::High),
"critical" => out.push(RiskLevel::Critical),
_ => {
tracing::debug!(
target: "pmcp_server_toolkit::code_mode",
"[code_mode] auto_approve_levels: unrecognised level '{}' — skipping",
level
);
},
}
}
if !out.is_empty() {
cfg.auto_approve_levels = out;
}
}
fn expand_braced_var(raw: &str) -> Option<&str> {
let inner = raw.strip_prefix("${")?.strip_suffix('}')?;
if inner.is_empty() {
return None;
}
Some(inner)
}
fn resolve_secret_env_var(var: &str) -> Result<SecretValue> {
let value = std::env::var(var)
.map_err(|_| ToolkitError::CodeMode(format!("env var '{var}' not set for token_secret")))?;
if value.trim().is_empty() {
return Err(ToolkitError::CodeMode(format!(
"env var '{var}' is set but empty for token_secret"
)));
}
Ok(SecretValue::new(value.into_bytes()))
}
fn resolve_token_secret(section: &CodeModeSection) -> Result<SecretValue> {
let raw = section.token_secret.as_ref().ok_or_else(|| {
ToolkitError::CodeMode(
"[code_mode] token_secret is required when code-mode is enabled".to_string(),
)
})?;
if let Some(var) = raw.strip_prefix("env:") {
return resolve_secret_env_var(var);
}
if let Some(var) = expand_braced_var(raw) {
return resolve_secret_env_var(var);
}
if section.allow_inline_token_secret_for_dev {
tracing::warn!(
target: "pmcp_server_toolkit::code_mode",
"[code_mode] token_secret is inline AND allow_inline_token_secret_for_dev=true; \
accepting under dev/test exception — NEVER set this flag in a committed \
production config"
);
return Ok(SecretValue::new(raw.as_bytes().to_vec()));
}
Err(ToolkitError::Validation(
ConfigValidationError::InlineSecretRejected,
))
}
pub async fn assemble_code_mode_prompt(
connector: &(dyn SqlConnector + '_),
config: &ServerConfig,
) -> Result<String> {
let dialect = connector.dialect();
let schema_text = connector
.schema_text()
.await
.map_err(|e| ToolkitError::CodeMode(format!("schema_text failed: {e}")))?;
let curated = format_curated_tables(config);
let mut out = String::with_capacity(schema_text.len() + curated.len() + 256);
out.push_str("# Code Mode — ");
out.push_str(dialect.name());
out.push_str("\n\n");
out.push_str(dialect.placeholder_guidance());
out.push_str("\n\n## Schema\n\n");
out.push_str(&schema_text);
if !curated.is_empty() {
out.push_str("\n\n## Curated Tables\n\n");
out.push_str(&curated);
}
out.push('\n');
Ok(out)
}
pub async fn build_code_mode_prompt(
connector: &(dyn SqlConnector + '_),
config: &ServerConfig,
) -> Result<String> {
assemble_code_mode_prompt(connector, config).await
}
#[must_use]
pub fn assemble_code_mode_prompt_with_schema(
schema_text: &str,
dialect: Dialect,
config: &ServerConfig,
) -> String {
const SCHEMA_HEADER: &str = "# Database Schema\n\n";
let curated = format_curated_tables(config);
let mut out = String::with_capacity(schema_text.len() + curated.len() + 256);
out.push_str("# Code Mode — ");
out.push_str(dialect.name());
out.push_str("\n\n");
out.push_str(dialect.placeholder_guidance());
out.push_str("\n\n## Schema\n\n");
out.push_str(SCHEMA_HEADER);
out.push_str(schema_text);
if !curated.is_empty() {
out.push_str("\n\n## Curated Tables\n\n");
out.push_str(&curated);
}
out.push('\n');
out
}
fn format_curated_tables(config: &ServerConfig) -> String {
config
.database
.tables
.iter()
.filter_map(|t| {
t.description
.as_deref()
.filter(|d| !d.is_empty())
.map(|d| format!("- `{}`: {}", t.name, d))
})
.collect::<Vec<_>>()
.join("\n")
}
#[cfg(test)]
mod test_env_guard {
use std::sync::{Mutex, MutexGuard};
static ENV_LOCK: Mutex<()> = Mutex::new(());
pub(super) fn lock() -> MutexGuard<'static, ()> {
ENV_LOCK
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{CodeModeLimits, CodeModeSection};
#[allow(dead_code)]
const _RE_EXPORTS_COMPILE: fn() = || {
let _: Option<Box<dyn CodeExecutor>> = None;
let _: Option<Box<dyn PolicyEvaluator>> = None;
let _: Option<ApprovalToken> = None;
let _: Option<HmacTokenGenerator> = None;
let _: Option<TokenSecret> = None;
let _: Option<NoopPolicyEvaluator> = None;
let _: Option<ValidationPipeline> = None;
let _: Option<ValidationContext> = None;
let _: Option<CodeModeConfig> = None;
let _: Option<AuthorizationDecision> = None;
let _hash = canonicalize_code;
let _ctx = compute_context_hash;
let _h = hash_code;
};
fn env_section(var: &str) -> CodeModeSection {
CodeModeSection {
enabled: true,
server_id: Some("test-server".to_string()),
allow_writes: false,
allow_deletes: false,
allow_ddl: false,
require_limit: false,
max_limit: Some(1000),
blocked_tables: vec![],
sensitive_columns: vec![],
auto_approve_levels: vec!["low".to_string()],
token_ttl_seconds: Some(300),
token_secret: Some(format!("env:{var}")),
allow_inline_token_secret_for_dev: false,
limits: Some(CodeModeLimits {
max_tables_per_query: Some(5),
max_join_depth: Some(3),
max_subquery_depth: Some(2),
}),
}
}
#[test]
fn build_cm_config_maps_allow_writes() {
let mut section = env_section("UNUSED");
section.allow_writes = true;
let cfg = build_cm_config(§ion);
assert!(
cfg.sql_allow_writes,
"unprefixed allow_writes=true must map to sql_allow_writes=true"
);
assert!(cfg.enabled);
assert_eq!(cfg.server_id.as_deref(), Some("test-server"));
assert_eq!(cfg.sql_max_rows, 1000);
assert_eq!(cfg.token_ttl_seconds, 300);
}
#[test]
fn build_cm_config_maps_require_limit_true() {
let mut section = env_section("UNUSED");
section.require_limit = true;
let cfg = build_cm_config(§ion);
assert!(
cfg.sql_require_limit,
"require_limit=true must map to sql_require_limit=true"
);
}
#[test]
fn build_cm_config_maps_require_limit_false() {
let mut section = env_section("UNUSED");
section.require_limit = false;
let cfg = build_cm_config(§ion);
assert!(
!cfg.sql_require_limit,
"require_limit=false must map to sql_require_limit=false"
);
}
#[test]
fn build_cm_config_propagates_blocked_tables() {
let mut section = env_section("UNUSED");
section.blocked_tables = vec!["users".into(), "secrets".into()];
section.sensitive_columns = vec!["users.password".into()];
let cfg = build_cm_config(§ion);
assert!(cfg.sql_blocked_tables.contains("users"));
assert!(cfg.sql_blocked_tables.contains("secrets"));
assert!(cfg.sql_blocked_columns.contains("users.password"));
}
#[test]
fn resolve_token_secret_env_reference_succeeds() {
let _env = super::test_env_guard::lock();
const VAR: &str = "PMCP_TOOLKIT_CODE_MODE_TEST_RESOLVE_ENV";
std::env::set_var(VAR, "a-test-secret-bytes-16-or-more");
let section = env_section(VAR);
let resolved = resolve_token_secret(§ion).expect("env resolution must succeed");
assert_eq!(resolved.expose_secret(), b"a-test-secret-bytes-16-or-more");
std::env::remove_var(VAR);
}
#[test]
fn resolve_token_secret_inline_without_dev_flag_rejected() {
let mut section = env_section("UNUSED");
section.token_secret = Some("raw-string-that-should-be-rejected".to_string());
section.allow_inline_token_secret_for_dev = false;
match resolve_token_secret(§ion) {
Ok(_) => panic!("must reject inline literal"),
Err(ToolkitError::Validation(ConfigValidationError::InlineSecretRejected)) => {},
Err(other) => panic!("expected InlineSecretRejected, got {other:?}"),
}
}
#[test]
fn resolve_token_secret_inline_with_dev_flag_accepted() {
let mut section = env_section("UNUSED");
section.token_secret = Some("a-test-secret-bytes-16-or-more".to_string());
section.allow_inline_token_secret_for_dev = true;
let resolved = resolve_token_secret(§ion).expect("dev flag must permit inline literal");
assert_eq!(resolved.expose_secret(), b"a-test-secret-bytes-16-or-more");
}
#[test]
fn resolve_token_secret_empty_env_var_is_set_but_empty_error() {
let _env = super::test_env_guard::lock();
const VAR: &str = "PMCP_TOOLKIT_CODE_MODE_TEST_EMPTY_ENV";
std::env::set_var(VAR, "");
let section = env_section(VAR);
let outcome = resolve_token_secret(§ion);
std::env::remove_var(VAR);
match outcome {
Ok(_) => panic!("empty env var must error, not yield an empty secret"),
Err(ToolkitError::CodeMode(msg)) => {
assert!(
msg.contains(VAR) && msg.contains("set but empty"),
"error must name the var as set-but-empty, got: {msg}"
);
},
Err(other) => panic!("expected CodeMode 'set but empty', got {other:?}"),
}
}
#[test]
fn resolve_token_secret_whitespace_env_var_is_set_but_empty_error() {
let _env = super::test_env_guard::lock();
const VAR: &str = "PMCP_TOOLKIT_CODE_MODE_TEST_WS_ENV";
std::env::set_var(VAR, " ");
let mut section = env_section("UNUSED");
section.token_secret = Some(format!("${{{VAR}}}"));
let outcome = resolve_token_secret(§ion);
std::env::remove_var(VAR);
match outcome {
Ok(_) => panic!("whitespace-only env var must error"),
Err(ToolkitError::CodeMode(msg)) => {
assert!(
msg.contains(VAR) && msg.contains("set but empty"),
"error must name the var as set-but-empty, got: {msg}"
);
},
Err(other) => panic!("expected CodeMode 'set but empty', got {other:?}"),
}
}
#[test]
fn variables_to_params_maps_object_stripping_colon_prefix() {
let vars = serde_json::json!({ ":name": "Rock", "limit": 5 });
let mut params = variables_to_params(Some(&vars));
params.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(
params,
vec![
("limit".to_string(), serde_json::json!(5)),
("name".to_string(), serde_json::json!("Rock")),
]
);
}
#[test]
fn variables_to_params_none_or_non_object_is_empty() {
assert!(variables_to_params(None).is_empty());
assert!(variables_to_params(Some(&serde_json::json!("not-an-object"))).is_empty());
assert!(variables_to_params(Some(&serde_json::json!([1, 2, 3]))).is_empty());
}
#[test]
fn resolve_token_secret_missing_env_var_surfaces_error() {
let section = env_section("PMCP_TOOLKIT_DEFINITELY_NOT_SET_FOR_TEST");
match resolve_token_secret(§ion) {
Ok(_) => panic!("missing env var must error"),
Err(ToolkitError::CodeMode(msg)) => {
assert!(
msg.contains("PMCP_TOOLKIT_DEFINITELY_NOT_SET_FOR_TEST"),
"error message must name the missing env var, got: {msg}"
);
},
Err(other) => panic!("expected CodeMode error, got {other:?}"),
}
}
}
#[cfg(all(test, feature = "sqlite"))]
mod sql_code_executor_tests {
use super::*;
use crate::config::{CodeModeSection, ServerConfig, ServerSection};
use crate::sql::SqliteConnector;
const TEST_SECRET_VAR: &str = "PMCP_TOOLKIT_SQL_EXECUTOR_TEST_SECRET";
fn ensure_secret() {
std::env::set_var(TEST_SECRET_VAR, "executor-test-secret-16-or-more");
}
async fn read_only_executor() -> SqlCodeExecutor {
let connector = SqliteConnector::open_in_memory().expect("open in-memory sqlite");
connector
.execute(
"CREATE TABLE Artist (ArtistId INTEGER PRIMARY KEY, Name TEXT)",
&[],
)
.await
.expect("create table");
connector
.execute(
"INSERT INTO Artist (ArtistId, Name) VALUES (1, 'AC/DC')",
&[],
)
.await
.expect("seed row");
let config = ServerConfig {
server: ServerSection {
name: "executor-test".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
code_mode: Some(CodeModeSection {
enabled: true,
server_id: Some("executor-test".to_string()),
allow_writes: false,
allow_deletes: false,
allow_ddl: false,
token_secret: Some(format!("env:{TEST_SECRET_VAR}")),
..Default::default()
}),
..Default::default()
};
let _env = super::test_env_guard::lock();
ensure_secret();
SqlCodeExecutor::new(Arc::new(connector), config).expect("build executor")
}
async fn read_only_executor_with_require_limit() -> SqlCodeExecutor {
let connector = SqliteConnector::open_in_memory().expect("open in-memory sqlite");
connector
.execute(
"CREATE TABLE Artist (ArtistId INTEGER PRIMARY KEY, Name TEXT)",
&[],
)
.await
.expect("create table");
connector
.execute(
"INSERT INTO Artist (ArtistId, Name) VALUES (1, 'AC/DC')",
&[],
)
.await
.expect("seed row");
let config = ServerConfig {
server: ServerSection {
name: "executor-test".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
code_mode: Some(CodeModeSection {
enabled: true,
server_id: Some("executor-test".to_string()),
allow_writes: false,
allow_deletes: false,
allow_ddl: false,
require_limit: true,
token_secret: Some(format!("env:{TEST_SECRET_VAR}")),
..Default::default()
}),
..Default::default()
};
let _env = super::test_env_guard::lock();
ensure_secret();
SqlCodeExecutor::new(Arc::new(connector), config).expect("build executor")
}
#[tokio::test]
async fn read_only_select_returns_rows() {
let executor = read_only_executor().await;
let result = executor
.execute("SELECT ArtistId, Name FROM Artist", None)
.await
.expect("read-only SELECT must succeed under a read-only policy");
let rows = result.get("rows").expect("payload has a `rows` key");
let arr = rows.as_array().expect("`rows` is an array");
assert_eq!(arr.len(), 1, "one seeded row expected, got {arr:?}");
assert_eq!(arr[0]["Name"], "AC/DC");
}
#[tokio::test]
async fn require_limit_rejects_bare_select_before_connector() {
let executor = read_only_executor_with_require_limit().await;
let err = executor
.execute("SELECT * FROM Artist", None)
.await
.expect_err("bare SELECT must be rejected when require_limit=true");
assert!(
matches!(err, ExecutionError::BackendError(_)),
"expected a policy-rejection BackendError, got {err:?}"
);
let count = executor
.connector
.execute("SELECT COUNT(*) AS n FROM Artist", &[])
.await
.expect("count query");
assert_eq!(count[0]["n"], 1, "row count must be unchanged");
}
#[tokio::test]
async fn require_limit_allows_limited_select() {
let executor = read_only_executor_with_require_limit().await;
let result = executor
.execute("SELECT ArtistId, Name FROM Artist LIMIT 5", None)
.await
.expect("a LIMITed SELECT must succeed under require_limit=true");
let rows = result.get("rows").expect("payload has a `rows` key");
let arr = rows.as_array().expect("`rows` is an array");
assert_eq!(arr.len(), 1, "one seeded row expected, got {arr:?}");
}
#[tokio::test]
async fn delete_rejected_before_connector_under_read_only_policy() {
let executor = read_only_executor().await;
let err = executor
.execute("DELETE FROM Artist WHERE ArtistId = 1", None)
.await
.expect_err("DELETE must be rejected when allow_deletes=false");
assert!(
matches!(err, ExecutionError::BackendError(_)),
"expected a policy-rejection BackendError, got {err:?}"
);
let still_there = executor
.connector
.execute("SELECT COUNT(*) AS n FROM Artist", &[])
.await
.expect("count query");
assert_eq!(still_there[0]["n"], 1, "DELETE must not have run");
}
#[tokio::test]
async fn ddl_rejected_under_read_only_policy() {
let executor = read_only_executor().await;
let err = executor
.execute("DROP TABLE Artist", None)
.await
.expect_err("DROP must be rejected when allow_ddl=false");
assert!(matches!(err, ExecutionError::BackendError(_)));
}
#[tokio::test]
async fn malformed_sql_returns_err_never_panics() {
let executor = read_only_executor().await;
let result = executor.execute("SELEC nonsense FRM", None).await;
assert!(
result.is_err(),
"malformed SQL must surface an Err, never panic"
);
}
#[tokio::test]
async fn execute_binds_variables_input() {
let executor = read_only_executor().await;
let vars = serde_json::json!({ ":name": "AC/DC" });
let result = executor
.execute(
"SELECT ArtistId FROM Artist WHERE Name = :name",
Some(&vars),
)
.await
.expect("bound variable must resolve the WHERE clause");
let rows = result.get("rows").expect("payload has a `rows` key");
let arr = rows.as_array().expect("`rows` is an array");
assert_eq!(arr.len(), 1, "the bound :name must match the seeded row");
assert_eq!(arr[0]["ArtistId"], 1);
}
#[tokio::test]
async fn execute_empty_variables_is_unaffected() {
let executor = read_only_executor().await;
let empty = serde_json::json!({});
let result = executor
.execute("SELECT ArtistId, Name FROM Artist", Some(&empty))
.await
.expect("empty variables must behave exactly like None");
let arr = result["rows"].as_array().expect("`rows` array");
assert_eq!(arr.len(), 1);
}
#[tokio::test]
async fn pipeline_cached_at_construction_not_reread_per_execute() {
let executor = read_only_executor().await;
executor
.execute("SELECT ArtistId FROM Artist LIMIT 1", None)
.await
.expect("first execute succeeds");
{
let _env = super::test_env_guard::lock();
std::env::remove_var(TEST_SECRET_VAR);
}
let result = executor
.execute("SELECT ArtistId FROM Artist LIMIT 1", None)
.await
.expect("second execute must succeed from the cached pipeline");
{
let _env = super::test_env_guard::lock();
ensure_secret();
}
assert!(result.get("rows").is_some());
}
}
#[cfg(test)]
mod tkit10_tests {
use super::*;
use crate::config::{DatabaseSection, DatabaseTableDecl, ServerConfig, ServerSection};
use crate::sql::{Dialect, MockSqlConnector};
fn make_cfg(tables: Vec<DatabaseTableDecl>) -> ServerConfig {
ServerConfig {
server: ServerSection {
name: "test".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
database: DatabaseSection {
tables,
..Default::default()
},
..Default::default()
}
}
#[tokio::test]
async fn assemble_includes_schema_text_and_dialect_name() {
let connector = MockSqlConnector {
dialect: Dialect::Postgres,
schema: "CREATE TABLE users (id SERIAL PRIMARY KEY);".to_string(),
};
let cfg = make_cfg(vec![]);
let prompt = assemble_code_mode_prompt(&connector, &cfg).await.unwrap();
assert!(
prompt.contains("# Code Mode — PostgreSQL"),
"prompt missing dialect header: {prompt}"
);
assert!(
prompt.contains("CREATE TABLE users"),
"prompt missing schema body: {prompt}"
);
assert!(
prompt.contains("$1"),
"Postgres guidance should mention $1: {prompt}"
);
}
#[tokio::test]
async fn assemble_includes_curated_descriptions() {
let connector = MockSqlConnector {
dialect: Dialect::Athena,
schema: "(see Glue catalog)".to_string(),
};
let cfg = make_cfg(vec![
DatabaseTableDecl {
name: "users".to_string(),
description: Some("App users".to_string()),
},
DatabaseTableDecl {
name: "orders".to_string(),
description: Some("Customer orders".to_string()),
},
]);
let prompt = assemble_code_mode_prompt(&connector, &cfg).await.unwrap();
assert!(
prompt.contains("## Curated Tables"),
"prompt missing curated header: {prompt}"
);
assert!(
prompt.contains("`users`: App users"),
"prompt missing users description: {prompt}"
);
assert!(
prompt.contains("`orders`: Customer orders"),
"prompt missing orders description: {prompt}"
);
assert!(
prompt.contains("Amazon Athena"),
"prompt missing Athena dialect name: {prompt}"
);
}
#[tokio::test]
async fn assemble_omits_curated_section_when_tables_empty() {
let connector = MockSqlConnector {
dialect: Dialect::Sqlite,
schema: "CREATE TABLE t (id INTEGER PRIMARY KEY);".to_string(),
};
let cfg = make_cfg(vec![]);
let prompt = assemble_code_mode_prompt(&connector, &cfg).await.unwrap();
assert!(
!prompt.contains("## Curated Tables"),
"empty [[database.tables]] must omit curated section: {prompt}"
);
assert!(
prompt.contains("SQLite"),
"prompt missing SQLite dialect name: {prompt}"
);
}
#[tokio::test]
async fn assemble_skips_tables_without_descriptions() {
let connector = MockSqlConnector {
dialect: Dialect::MySql,
schema: "CREATE TABLE t (id INT);".to_string(),
};
let cfg = make_cfg(vec![
DatabaseTableDecl {
name: "with_desc".to_string(),
description: Some("has description".to_string()),
},
DatabaseTableDecl {
name: "no_desc".to_string(),
description: None,
},
]);
let prompt = assemble_code_mode_prompt(&connector, &cfg).await.unwrap();
assert!(prompt.contains("`with_desc`: has description"));
assert!(
!prompt.contains("`no_desc`"),
"undescribed table must not appear in curated section: {prompt}"
);
}
#[test]
fn with_schema_includes_header_dialect_schema_and_curated() {
let cfg = make_cfg(vec![DatabaseTableDecl {
name: "Artist".to_string(),
description: Some("Musical artists".to_string()),
}]);
let schema = "CREATE TABLE Artist (ArtistId INTEGER PRIMARY KEY, Name TEXT);";
let prompt = assemble_code_mode_prompt_with_schema(schema, Dialect::Sqlite, &cfg);
assert!(
prompt.contains("# Code Mode"),
"missing code-mode header: {prompt}"
);
assert!(prompt.contains("SQLite"), "missing dialect name: {prompt}");
assert!(
prompt.contains("# Database Schema"),
"missing schema-resource header: {prompt}"
);
assert!(
prompt.contains(schema),
"schema text must appear verbatim: {prompt}"
);
assert!(
prompt.contains("`Artist`: Musical artists"),
"curated table description must appear: {prompt}"
);
}
#[test]
fn with_schema_is_sync_and_uses_passed_dialect() {
let cfg = make_cfg(vec![]);
let prompt = assemble_code_mode_prompt_with_schema(
"CREATE TABLE t (id INT);",
Dialect::Postgres,
&cfg,
);
assert!(
prompt.contains("# Code Mode — PostgreSQL"),
"passed dialect must drive the header: {prompt}"
);
assert!(prompt.contains("$1"), "Postgres guidance missing: {prompt}");
assert!(
!prompt.contains("## Curated Tables"),
"empty tables must omit curated section: {prompt}"
);
}
#[test]
fn with_schema_empty_text_still_has_header() {
let cfg = make_cfg(vec![]);
let prompt = assemble_code_mode_prompt_with_schema("", Dialect::MySql, &cfg);
assert!(
prompt.contains("# Code Mode — MySQL"),
"empty schema must still produce a valid prompt with the header: {prompt}"
);
assert!(
prompt.contains("# Database Schema"),
"schema-resource header present even for empty schema: {prompt}"
);
}
}
#[cfg(all(test, feature = "openapi-code-mode"))]
mod per_request_executor_tests {
use super::*;
use crate::config::{CodeModeSection, ServerConfig, ServerSection};
use crate::http::auth::{create_passthrough_auth_provider, AuthConfig};
use pmcp::server::auth::AuthContext;
fn passthrough_base() -> HttpCodeExecutor {
let auth = create_passthrough_auth_provider(
&AuthConfig::OAuthPassthrough {
target_header: "Authorization".to_string(),
required: true,
},
None,
)
.expect("passthrough auth provider");
HttpCodeExecutor::new(
reqwest::Client::new(),
"https://api.example".to_string(),
auth,
)
}
fn extra_with_token(token: Option<&str>) -> pmcp::RequestHandlerExtra {
let ctx = AuthContext {
subject: "s".to_string(),
scopes: vec![],
claims: std::collections::HashMap::new(),
token: token.map(str::to_string),
client_id: None,
expires_at: None,
authenticated: token.is_some(),
};
pmcp::RequestHandlerExtra::default().with_auth_context(Some(ctx))
}
#[test]
fn request_executor_from_extra_threads_present_token() {
let base = passthrough_base();
assert_eq!(
base.inbound_token_for_test(),
None,
"base executor starts with no inbound token"
);
let extra = extra_with_token(Some("Bearer client-tok"));
let scoped = request_executor_from_extra(&base, &extra);
assert_eq!(
scoped.inbound_token_for_test(),
Some("Bearer client-tok"),
"the captured inbound token must be threaded into the per-request executor"
);
}
#[test]
fn request_executor_from_extra_no_token_yields_none() {
let base = passthrough_base();
let extra = extra_with_token(None);
let scoped = request_executor_from_extra(&base, &extra);
assert_eq!(
scoped.inbound_token_for_test(),
None,
"an extra carrying no token must yield an executor with inbound_token None"
);
let bare = request_executor_from_extra(&base, &pmcp::RequestHandlerExtra::default());
assert_eq!(bare.inbound_token_for_test(), None);
}
fn cfg_with_code_mode() -> ServerConfig {
std::env::set_var(
"PMCP_TOOLKIT_90_10_HTTP_SECRET",
"per-request-test-secret-16-or-more",
);
ServerConfig {
server: ServerSection {
name: "http-cm".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
code_mode: Some(CodeModeSection {
enabled: true,
server_id: Some("http-cm".to_string()),
token_secret: Some("env:PMCP_TOOLKIT_90_10_HTTP_SECRET".to_string()),
..Default::default()
}),
..Default::default()
}
}
#[test]
fn http_tools_register_validate_and_execute_with_per_request_source() {
let _env = super::test_env_guard::lock();
let cfg = cfg_with_code_mode();
let builder = pmcp::Server::builder().name("http-cm").version("0.1.0");
let builder = code_mode_http_tools_from_executor(
builder,
&cfg,
passthrough_base(),
ExecutionConfig::default(),
ValidationFlavor::OpenApi,
)
.expect("OpenAPI per-request code-mode wiring must build");
let server = builder.build().expect("server builds");
assert!(
server.get_tool("validate_code").is_some(),
"validate_code registered"
);
assert!(
server.get_tool("execute_code").is_some(),
"execute_code registered"
);
std::env::remove_var("PMCP_TOOLKIT_90_10_HTTP_SECRET");
}
#[test]
fn http_tools_no_op_when_code_mode_absent() {
let cfg = ServerConfig {
server: ServerSection {
name: "no-cm".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
..Default::default()
};
let builder = pmcp::Server::builder().name("no-cm").version("0.1.0");
let builder = code_mode_http_tools_from_executor(
builder,
&cfg,
passthrough_base(),
ExecutionConfig::default(),
ValidationFlavor::OpenApi,
)
.expect("no-op when [code_mode] absent");
let server = builder.build().expect("server builds");
assert!(
server.get_tool("execute_code").is_none(),
"no tools without [code_mode]"
);
}
}
#[cfg(all(test, feature = "sqlite", feature = "openapi-code-mode"))]
mod sql_static_source_tests {
use super::*;
use crate::config::{CodeModeSection, ServerConfig, ServerSection};
use crate::sql::SqliteConnector;
#[test]
fn sql_path_registers_static_source_unchanged() {
let _env = super::test_env_guard::lock();
std::env::set_var(
"PMCP_TOOLKIT_90_10_SQL_SECRET",
"sql-static-test-secret-16-or-more",
);
let connector = SqliteConnector::open_in_memory().expect("sqlite");
let cfg = ServerConfig {
server: ServerSection {
name: "sql-cm".to_string(),
version: "0.1.0".to_string(),
..Default::default()
},
code_mode: Some(CodeModeSection {
enabled: true,
server_id: Some("sql-cm".to_string()),
token_secret: Some("env:PMCP_TOOLKIT_90_10_SQL_SECRET".to_string()),
..Default::default()
}),
..Default::default()
};
let executor: Arc<dyn CodeExecutor> =
Arc::new(SqlCodeExecutor::new(Arc::new(connector), cfg.clone()).expect("executor"));
let builder = pmcp::Server::builder().name("sql-cm").version("0.1.0");
let builder = code_mode_tools_from_executor(builder, &cfg, executor, ValidationFlavor::Sql)
.expect("SQL code-mode wiring must build");
let server = builder.build().expect("server builds");
assert!(server.get_tool("validate_code").is_some());
assert!(server.get_tool("execute_code").is_some());
std::env::remove_var("PMCP_TOOLKIT_90_10_SQL_SECRET");
}
}