Skip to main content

pmcp_code_mode/
handler.rs

1//! Code Mode Handler trait for unified soft-disable and tool management.
2//!
3//! This module provides the `CodeModeHandler` trait that all Code Mode implementations
4//! should implement. It provides:
5//!
6//! - **Policy check**: Requires a policy evaluator to be configured
7//! - **Pre-handle hook**: Extensible hook for soft-disable and other checks
8//! - **Standard tool definitions**: Consistent `validate_code` and `execute_code` tools
9//! - **Response formatting**: Consistent JSON responses across server types
10
11use pmcp::types::ToolInfo;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14
15use crate::types::{PolicyViolation, RiskLevel, UnifiedAction, ValidationMetadata};
16
17/// Response from `validate_code_impl` containing all validation results.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ValidationResponse {
20    /// Whether the code is valid
21    pub is_valid: bool,
22
23    /// Human-readable explanation of what the code does
24    pub explanation: String,
25
26    /// Risk level (LOW, MEDIUM, HIGH, CRITICAL)
27    pub risk_level: RiskLevel,
28
29    /// Approval token for execution (None if invalid or dry_run)
30    pub approval_token: Option<String>,
31
32    /// Whether this was auto-approved based on risk level
33    pub auto_approved: bool,
34
35    /// Warnings (non-blocking issues)
36    pub warnings: Vec<String>,
37
38    /// Policy violations (blocking issues)
39    pub violations: Vec<PolicyViolation>,
40
41    /// Validation metadata
42    pub metadata: ValidationMetadata,
43
44    /// Unified action (Read, Write, Delete, Admin)
45    pub action: Option<UnifiedAction>,
46
47    /// SHA-256 hash of the canonicalized code that was validated.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub validated_code_hash: Option<String>,
50}
51
52impl ValidationResponse {
53    /// Create a successful validation response.
54    pub fn success(
55        explanation: String,
56        risk_level: RiskLevel,
57        approval_token: String,
58        metadata: ValidationMetadata,
59    ) -> Self {
60        Self {
61            is_valid: true,
62            explanation,
63            risk_level,
64            approval_token: Some(approval_token),
65            auto_approved: false,
66            warnings: vec![],
67            violations: vec![],
68            metadata,
69            action: None,
70            validated_code_hash: None,
71        }
72    }
73
74    /// Create a failed validation response.
75    pub fn failure(violations: Vec<PolicyViolation>, metadata: ValidationMetadata) -> Self {
76        Self {
77            is_valid: false,
78            explanation: String::new(),
79            risk_level: RiskLevel::Critical,
80            approval_token: None,
81            auto_approved: false,
82            warnings: vec![],
83            violations,
84            metadata,
85            action: None,
86            validated_code_hash: None,
87        }
88    }
89
90    /// Set the validated code hash (SHA-256 of canonicalized code).
91    pub fn with_code_hash(mut self, hash: String) -> Self {
92        self.validated_code_hash = Some(hash);
93        self
94    }
95
96    /// Set the action for this response.
97    pub fn with_action(mut self, action: UnifiedAction) -> Self {
98        self.action = Some(action);
99        self
100    }
101
102    /// Set auto_approved flag.
103    pub fn with_auto_approved(mut self, auto_approved: bool) -> Self {
104        self.auto_approved = auto_approved;
105        self
106    }
107
108    /// Add warnings to the response.
109    pub fn with_warnings(mut self, warnings: Vec<String>) -> Self {
110        self.warnings = warnings;
111        self
112    }
113
114    /// Convert to JSON response format.
115    ///
116    /// Returns a tuple of (json_value, is_error).
117    pub fn to_json_response(&self) -> (Value, bool) {
118        let response = json!({
119            "valid": self.is_valid,
120            "explanation": self.explanation,
121            "risk_level": format!("{}", self.risk_level),
122            "approval_token": self.approval_token,
123            "action": self.action.as_ref().map(|a| a.to_string()),
124            "auto_approved": self.auto_approved,
125            "warnings": self.warnings,
126            "violations": self.violations.iter().map(|v| json!({
127                "policy": v.policy_name,
128                "rule": v.rule,
129                "message": v.message,
130                "suggestion": v.suggestion
131            })).collect::<Vec<_>>(),
132            "validated_code_hash": self.validated_code_hash,
133            "metadata": {
134                "is_read_only": self.metadata.is_read_only,
135                "accessed_types": self.metadata.accessed_types,
136                "accessed_fields": self.metadata.accessed_fields,
137                "validation_time_ms": self.metadata.validation_time_ms
138            }
139        });
140
141        (response, !self.is_valid)
142    }
143}
144
145/// Code Mode handler trait with policy check and standard tool handling.
146#[async_trait::async_trait]
147pub trait CodeModeHandler: Send + Sync {
148    /// Get the server name/ID for identification.
149    fn server_name(&self) -> &str;
150
151    /// Check if Code Mode is enabled in the configuration.
152    fn is_enabled(&self) -> bool;
153
154    /// Get the code format for this server (e.g., "graphql", "javascript", "sql").
155    fn code_format(&self) -> &str;
156
157    /// Validate code and return a validation response.
158    async fn validate_code_impl(
159        &self,
160        code: &str,
161        variables: Option<&Value>,
162        dry_run: bool,
163        user_id: &str,
164        session_id: &str,
165    ) -> Result<ValidationResponse, String>;
166
167    /// Execute validated code and return the result.
168    async fn execute_code_impl(
169        &self,
170        code: &str,
171        approval_token: &str,
172        variables: Option<&Value>,
173    ) -> Result<Value, String>;
174
175    /// Check if a policy evaluator is configured.
176    ///
177    /// The default returns `true` for backward compatibility with tests.
178    /// Production implementations MUST override this.
179    fn is_policy_configured(&self) -> bool {
180        true
181    }
182
183    /// Deprecated alias for `is_policy_configured()`.
184    fn is_avp_configured(&self) -> bool {
185        self.is_policy_configured()
186    }
187
188    /// Pre-handle hook for checks before tool execution.
189    ///
190    /// Override this to implement soft-disable checks (e.g., DynamoDB toggle).
191    /// Return `Ok(Some((response, is_error)))` to short-circuit with a response.
192    /// Return `Ok(None)` to proceed normally.
193    async fn pre_handle_hook(&self) -> Result<Option<(Value, bool)>, String> {
194        Ok(None)
195    }
196
197    // =========================================================================
198    // Provided methods with default implementations
199    // =========================================================================
200
201    /// Check if this is a Code Mode tool.
202    fn is_code_mode_tool(&self, name: &str) -> bool {
203        name == "validate_code" || name == "execute_code"
204    }
205
206    /// Get the standard Code Mode tool definitions.
207    fn get_tools(&self) -> Vec<ToolInfo> {
208        if !self.is_enabled() {
209            return vec![];
210        }
211
212        CodeModeToolBuilder::new(self.code_format()).build_tools()
213    }
214
215    /// Handle a Code Mode tool call with policy and pre-handle checks.
216    async fn handle_tool(
217        &self,
218        name: &str,
219        arguments: Value,
220        user_id: &str,
221        session_id: &str,
222    ) -> Result<(Value, bool), String> {
223        // Policy enforcement: require a policy evaluator to be configured
224        if !self.is_policy_configured() {
225            return Ok((
226                json!({
227                    "error": "Code Mode requires a policy evaluator to be configured. \
228                              Configure AVP, local Cedar, or another policy backend.",
229                    "valid": false
230                }),
231                true,
232            ));
233        }
234
235        // Pre-handle hook (soft-disable, etc.)
236        if let Some(response) = self.pre_handle_hook().await? {
237            return Ok(response);
238        }
239
240        match name {
241            "validate_code" => {
242                self.handle_validate_code(arguments, user_id, session_id)
243                    .await
244            }
245            "execute_code" => self.handle_execute_code(arguments).await,
246            _ => Err(format!("Unknown Code Mode tool: {}", name)),
247        }
248    }
249
250    /// Handle validate_code tool call.
251    async fn handle_validate_code(
252        &self,
253        arguments: Value,
254        user_id: &str,
255        session_id: &str,
256    ) -> Result<(Value, bool), String> {
257        let mut input: ValidateCodeInput =
258            serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
259
260        input.code = input.code.trim().to_string();
261
262        let response = self
263            .validate_code_impl(
264                &input.code,
265                input.variables.as_ref(),
266                input.dry_run.unwrap_or(false),
267                user_id,
268                session_id,
269            )
270            .await?;
271
272        Ok(response.to_json_response())
273    }
274
275    /// Handle execute_code tool call.
276    async fn handle_execute_code(&self, arguments: Value) -> Result<(Value, bool), String> {
277        let mut input: ExecuteCodeInput =
278            serde_json::from_value(arguments).map_err(|e| format!("Invalid arguments: {}", e))?;
279
280        input.code = input.code.trim().to_string();
281
282        let result = self
283            .execute_code_impl(&input.code, &input.approval_token, input.variables.as_ref())
284            .await?;
285
286        Ok((result, false))
287    }
288}
289
290/// Input for validate_code tool.
291#[derive(Debug, Deserialize)]
292pub struct ValidateCodeInput {
293    pub code: String,
294    #[serde(default)]
295    pub variables: Option<Value>,
296    #[serde(default)]
297    pub format: Option<String>,
298    #[serde(default)]
299    pub dry_run: Option<bool>,
300}
301
302/// Input for execute_code tool.
303#[derive(Debug, Deserialize)]
304pub struct ExecuteCodeInput {
305    pub code: String,
306    pub approval_token: String,
307    #[serde(default)]
308    pub variables: Option<Value>,
309}
310
311/// Builder for standard Code Mode tool definitions.
312pub struct CodeModeToolBuilder {
313    code_format: String,
314}
315
316impl CodeModeToolBuilder {
317    /// Create a new tool builder for the given code format.
318    pub fn new(code_format: &str) -> Self {
319        Self {
320            code_format: code_format.to_string(),
321        }
322    }
323
324    /// Build the standard Code Mode tools.
325    pub fn build_tools(&self) -> Vec<ToolInfo> {
326        vec![self.build_validate_tool(), self.build_execute_tool()]
327    }
328
329    /// Build the validate_code tool definition.
330    pub fn build_validate_tool(&self) -> ToolInfo {
331        ToolInfo::new(
332            "validate_code",
333            Some(format!(
334                "Validates code and returns a business-language explanation with an approval token. \
335                 The code is analyzed for security, complexity, and data access patterns. \
336                 You MUST call this before execute_code."
337            )),
338            json!({
339                "type": "object",
340                "properties": {
341                    "code": {
342                        "type": "string",
343                        "description": "The code to validate"
344                    },
345                    "variables": {
346                        "type": "object",
347                        "description": "Optional variables for the query"
348                    },
349                    "format": {
350                        "type": "string",
351                        "enum": [&self.code_format],
352                        "description": format!("Code format. Defaults to '{}' for this server.", self.code_format)
353                    },
354                    "dry_run": {
355                        "type": "boolean",
356                        "description": "If true, validate without generating approval token"
357                    }
358                },
359                "required": ["code"]
360            }),
361        )
362    }
363
364    /// Build the execute_code tool definition.
365    pub fn build_execute_tool(&self) -> ToolInfo {
366        ToolInfo::new(
367            "execute_code",
368            Some(
369                "Executes validated code using an approval token. \
370                 The token must be obtained from validate_code and the code must match exactly."
371                    .into(),
372            ),
373            json!({
374                "type": "object",
375                "properties": {
376                    "code": {
377                        "type": "string",
378                        "description": "The code to execute (must match validated code)"
379                    },
380                    "approval_token": {
381                        "type": "string",
382                        "description": "The approval token from validate_code"
383                    },
384                    "variables": {
385                        "type": "object",
386                        "description": "Optional variables for the query"
387                    }
388                },
389                "required": ["code", "approval_token"]
390            }),
391        )
392    }
393}
394
395/// Format an error as a JSON response.
396pub fn format_error_response(error: &str) -> (Value, bool) {
397    (
398        json!({
399            "error": error,
400            "valid": false
401        }),
402        true,
403    )
404}
405
406/// Format an execution error as a JSON response.
407pub fn format_execution_error(error: &str) -> (Value, bool) {
408    (
409        json!({
410            "error": error
411        }),
412        true,
413    )
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_validation_response_to_json() {
422        let response = ValidationResponse::success(
423            "Test explanation".into(),
424            RiskLevel::Low,
425            "token123".into(),
426            ValidationMetadata::default(),
427        )
428        .with_action(UnifiedAction::Read)
429        .with_auto_approved(true);
430
431        let (json, is_error) = response.to_json_response();
432
433        assert!(!is_error);
434        assert_eq!(json["valid"], true);
435        assert_eq!(json["explanation"], "Test explanation");
436        assert_eq!(json["risk_level"], "LOW");
437        assert_eq!(json["approval_token"], "token123");
438        assert_eq!(json["action"], "Read");
439        assert_eq!(json["auto_approved"], true);
440    }
441
442    #[test]
443    fn test_validation_response_failure() {
444        let violations = vec![PolicyViolation::new("policy", "rule", "message")];
445        let response = ValidationResponse::failure(violations, ValidationMetadata::default());
446
447        let (json, is_error) = response.to_json_response();
448
449        assert!(is_error);
450        assert_eq!(json["valid"], false);
451    }
452
453    #[test]
454    fn test_tool_builder() {
455        let builder = CodeModeToolBuilder::new("graphql");
456        let tools = builder.build_tools();
457
458        assert_eq!(tools.len(), 2);
459        assert_eq!(tools[0].name, "validate_code");
460        assert_eq!(tools[1].name, "execute_code");
461    }
462}