use pmcp::types::ToolInfo;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::types::{
PolicyViolation, RiskLevel, UnifiedAction, ValidationMetadata, ValidationResult,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResponse {
#[serde(flatten)]
pub result: ValidationResult,
pub auto_approved: bool,
pub action: Option<UnifiedAction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validated_code_hash: Option<String>,
}
impl ValidationResponse {
pub fn success(
explanation: String,
risk_level: RiskLevel,
approval_token: String,
metadata: ValidationMetadata,
) -> Self {
Self {
result: ValidationResult::success(explanation, risk_level, approval_token, metadata),
auto_approved: false,
action: None,
validated_code_hash: None,
}
}
pub fn failure(violations: Vec<PolicyViolation>, metadata: ValidationMetadata) -> Self {
Self {
result: ValidationResult::failure(violations, metadata),
auto_approved: false,
action: None,
validated_code_hash: None,
}
}
pub fn from_result(result: ValidationResult) -> Self {
Self {
result,
auto_approved: false,
action: None,
validated_code_hash: None,
}
}
pub fn with_code_hash(mut self, hash: String) -> Self {
self.validated_code_hash = Some(hash);
self
}
pub fn with_action(mut self, action: UnifiedAction) -> Self {
self.action = Some(action);
self
}
pub fn with_auto_approved(mut self, auto_approved: bool) -> Self {
self.auto_approved = auto_approved;
self
}
pub fn with_warnings(mut self, warnings: Vec<String>) -> Self {
self.result.warnings = warnings;
self
}
pub fn to_json_response(&self) -> (Value, bool) {
let response = json!({
"valid": self.result.is_valid,
"explanation": self.result.explanation,
"risk_level": format!("{}", self.result.risk_level),
"approval_token": self.result.approval_token,
"action": self.action.as_ref().map(|a| a.to_string()),
"auto_approved": self.auto_approved,
"warnings": self.result.warnings,
"violations": self.result.violations.iter().map(|v| json!({
"policy": v.policy_name,
"rule": v.rule,
"message": v.message,
"suggestion": v.suggestion
})).collect::<Vec<_>>(),
"validated_code_hash": self.validated_code_hash,
"metadata": {
"is_read_only": self.result.metadata.is_read_only,
"accessed_types": self.result.metadata.accessed_types,
"accessed_fields": self.result.metadata.accessed_fields,
"validation_time_ms": self.result.metadata.validation_time_ms
}
});
(response, !self.result.is_valid)
}
}
#[async_trait::async_trait]
pub trait CodeModeHandler: Send + Sync {
fn server_name(&self) -> &str;
fn is_enabled(&self) -> bool;
fn code_format(&self) -> &str;
async fn validate_code_impl(
&self,
code: &str,
variables: Option<&Value>,
dry_run: bool,
user_id: &str,
session_id: &str,
) -> Result<ValidationResponse, String>;
async fn execute_code_impl(
&self,
code: &str,
approval_token: &str,
variables: Option<&Value>,
) -> Result<Value, String>;
fn is_policy_configured(&self) -> bool {
false
}
fn is_avp_configured(&self) -> bool {
self.is_policy_configured()
}
async fn pre_handle_hook(&self) -> Result<Option<(Value, bool)>, String> {
Ok(None)
}
fn is_code_mode_tool(&self, name: &str) -> bool {
name == "validate_code" || name == "execute_code"
}
fn get_tools(&self) -> Vec<ToolInfo> {
if !self.is_enabled() {
return vec![];
}
CodeModeToolBuilder::new(self.code_format()).build_tools()
}
async fn handle_tool(
&self,
name: &str,
arguments: Value,
user_id: &str,
session_id: &str,
) -> Result<(Value, bool), String> {
if !self.is_policy_configured() {
return Ok((
json!({
"error": "Code Mode requires a policy evaluator to be configured. \
Configure AVP, local Cedar, or another policy backend.",
"valid": false
}),
true,
));
}
if let Some(response) = self.pre_handle_hook().await? {
return Ok(response);
}
match name {
"validate_code" => {
self.handle_validate_code(arguments, user_id, session_id)
.await
},
"execute_code" => self.handle_execute_code(arguments).await,
_ => Err(format!("Unknown Code Mode tool: {}", name)),
}
}
async fn handle_validate_code(
&self,
arguments: Value,
user_id: &str,
session_id: &str,
) -> Result<(Value, bool), String> {
let mut input: ValidateCodeInput =
serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
input.code = input.code.trim().to_string();
let response = self
.validate_code_impl(
&input.code,
input.variables.as_ref(),
input.dry_run.unwrap_or(false),
user_id,
session_id,
)
.await?;
Ok(response.to_json_response())
}
async fn handle_execute_code(&self, arguments: Value) -> Result<(Value, bool), String> {
let mut input: ExecuteCodeInput =
serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
input.code = input.code.trim().to_string();
let result = self
.execute_code_impl(&input.code, &input.approval_token, input.variables.as_ref())
.await?;
Ok((result, false))
}
}
#[derive(Debug, Deserialize)]
pub struct ValidateCodeInput {
pub code: String,
#[serde(default)]
pub variables: Option<Value>,
#[serde(default)]
pub format: Option<String>,
#[serde(default)]
pub dry_run: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ExecuteCodeInput {
pub code: String,
pub approval_token: String,
#[serde(default)]
pub variables: Option<Value>,
}
pub struct CodeModeToolBuilder {
code_format: String,
}
impl CodeModeToolBuilder {
pub fn new(code_format: &str) -> Self {
Self {
code_format: code_format.to_string(),
}
}
pub fn build_tools(&self) -> Vec<ToolInfo> {
vec![self.build_validate_tool(), self.build_execute_tool()]
}
pub fn build_validate_tool(&self) -> ToolInfo {
ToolInfo::new(
"validate_code",
Some(
"Validates code and returns a business-language explanation with an approval token. \
The code is analyzed for security, complexity, and data access patterns. \
You MUST call this before execute_code."
.to_string(),
),
json!({
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to validate"
},
"variables": {
"type": "object",
"description": "Optional variables for the query"
},
"format": {
"type": "string",
"enum": [&self.code_format],
"description": format!("Code format. Defaults to '{}' for this server.", self.code_format)
},
"dry_run": {
"type": "boolean",
"description": "If true, validate without generating approval token"
}
},
"required": ["code"]
}),
)
}
pub fn build_execute_tool(&self) -> ToolInfo {
ToolInfo::new(
"execute_code",
Some(
"Executes validated code using an approval token. \
The token must be obtained from validate_code and the code must match exactly."
.into(),
),
json!({
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to execute (must match validated code)"
},
"approval_token": {
"type": "string",
"description": "The approval token from validate_code"
},
"variables": {
"type": "object",
"description": "Optional variables for the query"
}
},
"required": ["code", "approval_token"]
}),
)
}
}
pub fn format_error_response(error: &str) -> (Value, bool) {
(
json!({
"error": error,
"valid": false
}),
true,
)
}
pub fn format_execution_error(error: &str) -> (Value, bool) {
(
json!({
"error": error
}),
true,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_response_to_json() {
let response = ValidationResponse::success(
"Test explanation".into(),
RiskLevel::Low,
"token123".into(),
ValidationMetadata::default(),
)
.with_action(UnifiedAction::Read)
.with_auto_approved(true);
let (json, is_error) = response.to_json_response();
assert!(!is_error);
assert_eq!(json["valid"], true);
assert_eq!(json["explanation"], "Test explanation");
assert_eq!(json["risk_level"], "LOW");
assert_eq!(json["approval_token"], "token123");
assert_eq!(json["action"], "Read");
assert_eq!(json["auto_approved"], true);
}
#[test]
fn test_validation_response_failure() {
let violations = vec![PolicyViolation::new("policy", "rule", "message")];
let response = ValidationResponse::failure(violations, ValidationMetadata::default());
let (json, is_error) = response.to_json_response();
assert!(is_error);
assert_eq!(json["valid"], false);
}
#[test]
fn test_tool_builder() {
let builder = CodeModeToolBuilder::new("graphql");
let tools = builder.build_tools();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "validate_code");
assert_eq!(tools[1].name, "execute_code");
}
}