Skip to main content

pmcp_code_mode/
handler.rs

1//! Code Mode Handler trait for unified soft-disable and tool management.
2//!
3//! This module provides the `CodeModeHandler` trait that all Code Mode implementations
4//! should implement. It provides:
5//!
6//! - **Policy check**: Requires a policy evaluator to be configured
7//! - **Pre-handle hook**: Extensible hook for soft-disable and other checks
8//! - **Standard tool definitions**: Consistent `validate_code` and `execute_code` tools
9//! - **Response formatting**: Consistent JSON responses across server types
10
11use pmcp::types::ToolInfo;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14
15use crate::types::{
16    PolicyViolation, RiskLevel, UnifiedAction, ValidationMetadata, ValidationResult,
17};
18
19/// Response from `validate_code_impl` containing validation results plus
20/// handler-specific metadata.
21///
22/// Wraps [`ValidationResult`] from the validation pipeline and adds fields
23/// for handler-level concerns (auto-approval, unified action, code hash).
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ValidationResponse {
26    /// Core validation result from the pipeline.
27    #[serde(flatten)]
28    pub result: ValidationResult,
29
30    /// Whether this was auto-approved based on risk level.
31    pub auto_approved: bool,
32
33    /// Unified action (Read, Write, Delete, Admin).
34    pub action: Option<UnifiedAction>,
35
36    /// SHA-256 hash of the canonicalized code that was validated.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub validated_code_hash: Option<String>,
39}
40
41impl ValidationResponse {
42    /// Create a successful validation response.
43    pub fn success(
44        explanation: String,
45        risk_level: RiskLevel,
46        approval_token: String,
47        metadata: ValidationMetadata,
48    ) -> Self {
49        Self {
50            result: ValidationResult::success(explanation, risk_level, approval_token, metadata),
51            auto_approved: false,
52            action: None,
53            validated_code_hash: None,
54        }
55    }
56
57    /// Create a failed validation response.
58    pub fn failure(violations: Vec<PolicyViolation>, metadata: ValidationMetadata) -> Self {
59        Self {
60            result: ValidationResult::failure(violations, metadata),
61            auto_approved: false,
62            action: None,
63            validated_code_hash: None,
64        }
65    }
66
67    /// Create from an existing `ValidationResult`.
68    pub fn from_result(result: ValidationResult) -> Self {
69        Self {
70            result,
71            auto_approved: false,
72            action: None,
73            validated_code_hash: None,
74        }
75    }
76
77    /// Set the validated code hash (SHA-256 of canonicalized code).
78    pub fn with_code_hash(mut self, hash: String) -> Self {
79        self.validated_code_hash = Some(hash);
80        self
81    }
82
83    /// Set the action for this response.
84    pub fn with_action(mut self, action: UnifiedAction) -> Self {
85        self.action = Some(action);
86        self
87    }
88
89    /// Set auto_approved flag.
90    pub fn with_auto_approved(mut self, auto_approved: bool) -> Self {
91        self.auto_approved = auto_approved;
92        self
93    }
94
95    /// Add warnings to the response.
96    pub fn with_warnings(mut self, warnings: Vec<String>) -> Self {
97        self.result.warnings = warnings;
98        self
99    }
100
101    /// Convert to JSON response format.
102    ///
103    /// Returns a tuple of (json_value, is_error).
104    pub fn to_json_response(&self) -> (Value, bool) {
105        let response = json!({
106            "valid": self.result.is_valid,
107            "explanation": self.result.explanation,
108            "risk_level": format!("{}", self.result.risk_level),
109            "approval_token": self.result.approval_token,
110            "action": self.action.as_ref().map(|a| a.to_string()),
111            "auto_approved": self.auto_approved,
112            "warnings": self.result.warnings,
113            "violations": self.result.violations.iter().map(|v| json!({
114                "policy": v.policy_name,
115                "rule": v.rule,
116                "message": v.message,
117                "suggestion": v.suggestion
118            })).collect::<Vec<_>>(),
119            "validated_code_hash": self.validated_code_hash,
120            "metadata": {
121                "is_read_only": self.result.metadata.is_read_only,
122                "accessed_types": self.result.metadata.accessed_types,
123                "accessed_fields": self.result.metadata.accessed_fields,
124                "validation_time_ms": self.result.metadata.validation_time_ms
125            }
126        });
127
128        (response, !self.result.is_valid)
129    }
130}
131
132/// Code Mode handler trait with policy check and standard tool handling.
133#[async_trait::async_trait]
134pub trait CodeModeHandler: Send + Sync {
135    /// Get the server name/ID for identification.
136    fn server_name(&self) -> &str;
137
138    /// Check if Code Mode is enabled in the configuration.
139    fn is_enabled(&self) -> bool;
140
141    /// Get the code format for this server (e.g., "graphql", "javascript", "sql").
142    fn code_format(&self) -> &str;
143
144    /// Validate code and return a validation response.
145    async fn validate_code_impl(
146        &self,
147        code: &str,
148        variables: Option<&Value>,
149        dry_run: bool,
150        user_id: &str,
151        session_id: &str,
152    ) -> Result<ValidationResponse, String>;
153
154    /// Execute validated code and return the result.
155    async fn execute_code_impl(
156        &self,
157        code: &str,
158        approval_token: &str,
159        variables: Option<&Value>,
160    ) -> Result<Value, String>;
161
162    /// Check if a policy evaluator is configured.
163    ///
164    /// Defaults to `false` (safe default). When `false`, `handle_tool` rejects
165    /// all requests with a "policy evaluator required" error. Implementations
166    /// that have configured a policy evaluator MUST override this to return `true`.
167    fn is_policy_configured(&self) -> bool {
168        false
169    }
170
171    /// Deprecated alias for `is_policy_configured()`.
172    fn is_avp_configured(&self) -> bool {
173        self.is_policy_configured()
174    }
175
176    /// Pre-handle hook for checks before tool execution.
177    ///
178    /// Override this to implement soft-disable checks (e.g., DynamoDB toggle).
179    /// Return `Ok(Some((response, is_error)))` to short-circuit with a response.
180    /// Return `Ok(None)` to proceed normally.
181    async fn pre_handle_hook(&self) -> Result<Option<(Value, bool)>, String> {
182        Ok(None)
183    }
184
185    // =========================================================================
186    // Provided methods with default implementations
187    // =========================================================================
188
189    /// Check if this is a Code Mode tool.
190    fn is_code_mode_tool(&self, name: &str) -> bool {
191        name == "validate_code" || name == "execute_code"
192    }
193
194    /// Get the standard Code Mode tool definitions.
195    fn get_tools(&self) -> Vec<ToolInfo> {
196        if !self.is_enabled() {
197            return vec![];
198        }
199
200        CodeModeToolBuilder::new(self.code_format()).build_tools()
201    }
202
203    /// Handle a Code Mode tool call with policy and pre-handle checks.
204    async fn handle_tool(
205        &self,
206        name: &str,
207        arguments: Value,
208        user_id: &str,
209        session_id: &str,
210    ) -> Result<(Value, bool), String> {
211        // Policy enforcement: require a policy evaluator to be configured
212        if !self.is_policy_configured() {
213            return Ok((
214                json!({
215                    "error": "Code Mode requires a policy evaluator to be configured. \
216                              Configure AVP, local Cedar, or another policy backend.",
217                    "valid": false
218                }),
219                true,
220            ));
221        }
222
223        // Pre-handle hook (soft-disable, etc.)
224        if let Some(response) = self.pre_handle_hook().await? {
225            return Ok(response);
226        }
227
228        match name {
229            "validate_code" => {
230                self.handle_validate_code(arguments, user_id, session_id)
231                    .await
232            },
233            "execute_code" => self.handle_execute_code(arguments).await,
234            _ => Err(format!("Unknown Code Mode tool: {}", name)),
235        }
236    }
237
238    /// Handle validate_code tool call.
239    async fn handle_validate_code(
240        &self,
241        arguments: Value,
242        user_id: &str,
243        session_id: &str,
244    ) -> Result<(Value, bool), String> {
245        let mut input: ValidateCodeInput =
246            serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
247
248        input.code = input.code.trim().to_string();
249
250        let response = self
251            .validate_code_impl(
252                &input.code,
253                input.variables.as_ref(),
254                input.dry_run.unwrap_or(false),
255                user_id,
256                session_id,
257            )
258            .await?;
259
260        Ok(response.to_json_response())
261    }
262
263    /// Handle execute_code tool call.
264    async fn handle_execute_code(&self, arguments: Value) -> Result<(Value, bool), String> {
265        let mut input: ExecuteCodeInput =
266            serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
267
268        input.code = input.code.trim().to_string();
269
270        let result = self
271            .execute_code_impl(&input.code, &input.approval_token, input.variables.as_ref())
272            .await?;
273
274        Ok((result, false))
275    }
276}
277
278/// Input for validate_code tool.
279#[derive(Debug, Deserialize)]
280pub struct ValidateCodeInput {
281    pub code: String,
282    #[serde(default)]
283    pub variables: Option<Value>,
284    #[serde(default)]
285    pub format: Option<String>,
286    #[serde(default)]
287    pub dry_run: Option<bool>,
288}
289
290/// Input for execute_code tool.
291#[derive(Debug, Deserialize)]
292pub struct ExecuteCodeInput {
293    pub code: String,
294    pub approval_token: String,
295    #[serde(default)]
296    pub variables: Option<Value>,
297}
298
299/// Builder for standard Code Mode tool definitions.
300pub struct CodeModeToolBuilder {
301    code_format: String,
302}
303
304impl CodeModeToolBuilder {
305    /// Create a new tool builder for the given code format.
306    pub fn new(code_format: &str) -> Self {
307        Self {
308            code_format: code_format.to_string(),
309        }
310    }
311
312    /// Build the standard Code Mode tools.
313    pub fn build_tools(&self) -> Vec<ToolInfo> {
314        vec![self.build_validate_tool(), self.build_execute_tool()]
315    }
316
317    /// Build the validate_code tool definition.
318    pub fn build_validate_tool(&self) -> ToolInfo {
319        ToolInfo::new(
320            "validate_code",
321            Some(
322                "Validates code and returns a business-language explanation with an approval token. \
323                 The code is analyzed for security, complexity, and data access patterns. \
324                 You MUST call this before execute_code."
325                    .to_string(),
326            ),
327            json!({
328                "type": "object",
329                "properties": {
330                    "code": {
331                        "type": "string",
332                        "description": "The code to validate"
333                    },
334                    "variables": {
335                        "type": "object",
336                        "description": "Optional variables for the query"
337                    },
338                    "format": {
339                        "type": "string",
340                        "enum": [&self.code_format],
341                        "description": format!("Code format. Defaults to '{}' for this server.", self.code_format)
342                    },
343                    "dry_run": {
344                        "type": "boolean",
345                        "description": "If true, validate without generating approval token"
346                    }
347                },
348                "required": ["code"]
349            }),
350        )
351    }
352
353    /// Build the execute_code tool definition.
354    pub fn build_execute_tool(&self) -> ToolInfo {
355        ToolInfo::new(
356            "execute_code",
357            Some(
358                "Executes validated code using an approval token. \
359                 The token must be obtained from validate_code and the code must match exactly."
360                    .into(),
361            ),
362            json!({
363                "type": "object",
364                "properties": {
365                    "code": {
366                        "type": "string",
367                        "description": "The code to execute (must match validated code)"
368                    },
369                    "approval_token": {
370                        "type": "string",
371                        "description": "The approval token from validate_code"
372                    },
373                    "variables": {
374                        "type": "object",
375                        "description": "Optional variables for the query"
376                    }
377                },
378                "required": ["code", "approval_token"]
379            }),
380        )
381    }
382}
383
384/// Format an error as a JSON response.
385pub fn format_error_response(error: &str) -> (Value, bool) {
386    (
387        json!({
388            "error": error,
389            "valid": false
390        }),
391        true,
392    )
393}
394
395/// Format an execution error as a JSON response.
396pub fn format_execution_error(error: &str) -> (Value, bool) {
397    (
398        json!({
399            "error": error
400        }),
401        true,
402    )
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_validation_response_to_json() {
411        let response = ValidationResponse::success(
412            "Test explanation".into(),
413            RiskLevel::Low,
414            "token123".into(),
415            ValidationMetadata::default(),
416        )
417        .with_action(UnifiedAction::Read)
418        .with_auto_approved(true);
419
420        let (json, is_error) = response.to_json_response();
421
422        assert!(!is_error);
423        assert_eq!(json["valid"], true);
424        assert_eq!(json["explanation"], "Test explanation");
425        assert_eq!(json["risk_level"], "LOW");
426        assert_eq!(json["approval_token"], "token123");
427        assert_eq!(json["action"], "Read");
428        assert_eq!(json["auto_approved"], true);
429    }
430
431    #[test]
432    fn test_validation_response_failure() {
433        let violations = vec![PolicyViolation::new("policy", "rule", "message")];
434        let response = ValidationResponse::failure(violations, ValidationMetadata::default());
435
436        let (json, is_error) = response.to_json_response();
437
438        assert!(is_error);
439        assert_eq!(json["valid"], false);
440    }
441
442    #[test]
443    fn test_tool_builder() {
444        let builder = CodeModeToolBuilder::new("graphql");
445        let tools = builder.build_tools();
446
447        assert_eq!(tools.len(), 2);
448        assert_eq!(tools[0].name, "validate_code");
449        assert_eq!(tools[1].name, "execute_code");
450    }
451}