Skip to main content

pmcp_code_mode/
validation.rs

1//! Validation pipeline for Code Mode.
2//!
3//! The pipeline validates code through multiple stages:
4//! 1. Parse (syntax check)
5//! 2. Policy evaluation (PolicyEvaluator trait or basic config checks)
6//! 3. Security analysis
7//! 4. Explanation generation
8//! 5. Token generation
9
10use crate::config::CodeModeConfig;
11use crate::explanation::{ExplanationGenerator, TemplateExplanationGenerator};
12use crate::graphql::{GraphQLQueryInfo, GraphQLValidator};
13use crate::policy::{OperationEntity, PolicyEvaluator};
14use crate::token::{compute_context_hash, HmacTokenGenerator, TokenGenerator};
15use crate::types::{
16    PolicyViolation, UnifiedAction, ValidationError, ValidationMetadata, ValidationResult,
17};
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::time::Instant;
20
21#[cfg(feature = "openapi-code-mode")]
22use crate::javascript::{JavaScriptCodeInfo, JavaScriptValidator};
23
24/// Static flag to ensure the "no policy evaluator" warning is only logged once per process.
25static NO_POLICY_WARNING_LOGGED: AtomicBool = AtomicBool::new(false);
26
27/// Log a warning when Code Mode is enabled without a policy evaluator.
28fn warn_no_policy_configured() {
29    if !NO_POLICY_WARNING_LOGGED.swap(true, Ordering::SeqCst) {
30        tracing::warn!(
31            target: "code_mode",
32            "CODE MODE SECURITY WARNING: Code Mode is enabled but no policy evaluator \
33            is configured. Only basic config checks (allow_mutations, max_depth, etc.) will be \
34            performed. This provides NO real authorization policy evaluation. \
35            For production deployments, configure a policy evaluator (AVP or local Cedar)."
36        );
37    }
38}
39
40/// Context for validation (user, session, schema).
41#[derive(Debug, Clone)]
42pub struct ValidationContext {
43    /// User ID from access token
44    pub user_id: String,
45
46    /// MCP session ID
47    pub session_id: String,
48
49    /// Schema hash for context binding
50    pub schema_hash: String,
51
52    /// Permissions hash for context binding
53    pub permissions_hash: String,
54}
55
56impl ValidationContext {
57    /// Create a new validation context.
58    pub fn new(
59        user_id: impl Into<String>,
60        session_id: impl Into<String>,
61        schema_hash: impl Into<String>,
62        permissions_hash: impl Into<String>,
63    ) -> Self {
64        Self {
65            user_id: user_id.into(),
66            session_id: session_id.into(),
67            schema_hash: schema_hash.into(),
68            permissions_hash: permissions_hash.into(),
69        }
70    }
71
72    /// Compute the combined context hash.
73    pub fn context_hash(&self) -> String {
74        compute_context_hash(&self.schema_hash, &self.permissions_hash)
75    }
76}
77
78/// The validation pipeline that orchestrates all validation stages.
79pub struct ValidationPipeline<
80    T: TokenGenerator = HmacTokenGenerator,
81    E: ExplanationGenerator = TemplateExplanationGenerator,
82> {
83    config: CodeModeConfig,
84    graphql_validator: GraphQLValidator,
85    #[cfg(feature = "openapi-code-mode")]
86    javascript_validator: JavaScriptValidator,
87    token_generator: T,
88    explanation_generator: E,
89    policy_evaluator: Option<Box<dyn PolicyEvaluator>>,
90}
91
92impl ValidationPipeline<HmacTokenGenerator, TemplateExplanationGenerator> {
93    /// Create a new validation pipeline with default generators.
94    ///
95    /// **Warning**: This constructor does not configure a policy evaluator.
96    /// Only basic config checks will be performed.
97    pub fn new(config: CodeModeConfig, token_secret: impl Into<Vec<u8>>) -> Self {
98        if config.enabled {
99            warn_no_policy_configured();
100        }
101
102        Self {
103            graphql_validator: GraphQLValidator::default(),
104            #[cfg(feature = "openapi-code-mode")]
105            javascript_validator: JavaScriptValidator::default(),
106            token_generator: HmacTokenGenerator::new(token_secret),
107            explanation_generator: TemplateExplanationGenerator::new(),
108            policy_evaluator: None,
109            config,
110        }
111    }
112
113    /// Create a new validation pipeline with a policy evaluator.
114    pub fn with_policy_evaluator(
115        config: CodeModeConfig,
116        token_secret: impl Into<Vec<u8>>,
117        evaluator: Box<dyn PolicyEvaluator>,
118    ) -> Self {
119        Self {
120            graphql_validator: GraphQLValidator::default(),
121            #[cfg(feature = "openapi-code-mode")]
122            javascript_validator: JavaScriptValidator::default(),
123            token_generator: HmacTokenGenerator::new(token_secret),
124            explanation_generator: TemplateExplanationGenerator::new(),
125            policy_evaluator: Some(evaluator),
126            config,
127        }
128    }
129}
130
131impl<T: TokenGenerator, E: ExplanationGenerator> ValidationPipeline<T, E> {
132    /// Create a pipeline with custom generators.
133    pub fn with_generators(
134        config: CodeModeConfig,
135        token_generator: T,
136        explanation_generator: E,
137    ) -> Self {
138        Self {
139            graphql_validator: GraphQLValidator::default(),
140            #[cfg(feature = "openapi-code-mode")]
141            javascript_validator: JavaScriptValidator::default(),
142            token_generator,
143            explanation_generator,
144            policy_evaluator: None,
145            config,
146        }
147    }
148
149    /// Set the policy evaluator for this pipeline.
150    pub fn set_policy_evaluator(&mut self, evaluator: Box<dyn PolicyEvaluator>) {
151        self.policy_evaluator = Some(evaluator);
152    }
153
154    /// Check if a policy evaluator is configured.
155    pub fn has_policy_evaluator(&self) -> bool {
156        self.policy_evaluator.is_some()
157    }
158
159    /// Validate a GraphQL query using basic config checks only.
160    pub fn validate_graphql_query(
161        &self,
162        query: &str,
163        context: &ValidationContext,
164    ) -> Result<ValidationResult, ValidationError> {
165        let start = Instant::now();
166
167        if !self.config.enabled {
168            return Err(ValidationError::ConfigError(
169                "Code Mode is not enabled for this server".into(),
170            ));
171        }
172
173        if query.len() > self.config.max_query_length {
174            return Err(ValidationError::SecurityError {
175                message: format!(
176                    "Query length {} exceeds maximum {}",
177                    query.len(),
178                    self.config.max_query_length
179                ),
180                issue: crate::types::SecurityIssueType::HighComplexity,
181            });
182        }
183
184        let query_info = self.graphql_validator.validate(query)?;
185
186        // Mutation authorization checks
187        if !query_info.operation_type.is_read_only() {
188            let mutation_name = query_info.root_fields.first().cloned().unwrap_or_default();
189
190            if !self.config.blocked_mutations.is_empty()
191                && self.config.blocked_mutations.contains(&mutation_name)
192            {
193                return Ok(ValidationResult::failure(
194                    vec![PolicyViolation::new(
195                        "code_mode",
196                        "blocked_mutation",
197                        &format!("Mutation '{}' is blocked for this server", mutation_name),
198                    )
199                    .with_suggestion("This mutation is in the blocklist and cannot be executed")],
200                    self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
201                ));
202            }
203
204            if !self.config.allowed_mutations.is_empty() {
205                if !self.config.allowed_mutations.contains(&mutation_name) {
206                    return Ok(ValidationResult::failure(
207                        vec![PolicyViolation::new(
208                            "code_mode",
209                            "mutation_not_allowed",
210                            &format!("Mutation '{}' is not in the allowlist", mutation_name),
211                        )
212                        .with_suggestion(&format!(
213                            "Only these mutations are allowed: {}",
214                            self.config
215                                .allowed_mutations
216                                .iter()
217                                .cloned()
218                                .collect::<Vec<_>>()
219                                .join(", ")
220                        ))],
221                        self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
222                    ));
223                }
224            } else if !self.config.allow_mutations {
225                return Ok(ValidationResult::failure(
226                    vec![PolicyViolation::new(
227                        "code_mode",
228                        "allow_mutations",
229                        "Mutations are not enabled for this server",
230                    )
231                    .with_suggestion("Only read-only queries are allowed")],
232                    self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
233                ));
234            }
235        }
236
237        self.complete_validation(query, &query_info, context, start)
238    }
239
240    /// Validate a GraphQL query using a policy evaluator (async).
241    pub async fn validate_graphql_query_async(
242        &self,
243        query: &str,
244        context: &ValidationContext,
245    ) -> Result<ValidationResult, ValidationError> {
246        let start = Instant::now();
247
248        if !self.config.enabled {
249            return Err(ValidationError::ConfigError(
250                "Code Mode is not enabled for this server".into(),
251            ));
252        }
253
254        if query.len() > self.config.max_query_length {
255            return Err(ValidationError::SecurityError {
256                message: format!(
257                    "Query length {} exceeds maximum {}",
258                    query.len(),
259                    self.config.max_query_length
260                ),
261                issue: crate::types::SecurityIssueType::HighComplexity,
262            });
263        }
264
265        let query_info = self.graphql_validator.validate(query)?;
266
267        // Policy evaluation via trait
268        if let Some(ref evaluator) = self.policy_evaluator {
269            let operation_entity = OperationEntity::from_query_info(&query_info);
270            let server_config = self.config.to_server_config_entity();
271
272            let decision = evaluator
273                .evaluate_operation(&operation_entity, &server_config)
274                .await
275                .map_err(|e| ValidationError::InternalError(format!("Policy evaluation error: {}", e)))?;
276
277            if !decision.allowed {
278                let violations: Vec<PolicyViolation> = decision
279                    .determining_policies
280                    .iter()
281                    .map(|policy_id| {
282                        PolicyViolation::new(
283                            "policy",
284                            policy_id.clone(),
285                            "Policy denied the operation",
286                        )
287                    })
288                    .collect();
289
290                return Ok(ValidationResult::failure(
291                    violations,
292                    self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
293                ));
294            }
295        } else {
296            warn_no_policy_configured();
297            tracing::debug!(
298                target: "code_mode",
299                "Falling back to basic config checks (no policy evaluator configured)"
300            );
301            return self.validate_graphql_query(query, context);
302        }
303
304        self.complete_validation(query, &query_info, context, start)
305    }
306
307    /// Complete validation after policy check passes.
308    fn complete_validation(
309        &self,
310        query: &str,
311        query_info: &GraphQLQueryInfo,
312        context: &ValidationContext,
313        start: Instant,
314    ) -> Result<ValidationResult, ValidationError> {
315        let security_analysis = self.graphql_validator.analyze_security(query_info);
316        let risk_level = security_analysis.assess_risk();
317
318        if security_analysis
319            .potential_issues
320            .iter()
321            .any(|i| i.is_critical())
322        {
323            let violations: Vec<PolicyViolation> = security_analysis
324                .potential_issues
325                .iter()
326                .filter(|i| i.is_critical())
327                .map(|i| {
328                    PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
329                })
330                .collect();
331
332            return Ok(ValidationResult::failure(
333                violations,
334                self.build_metadata(query_info, start.elapsed().as_millis() as u64),
335            ));
336        }
337
338        let explanation = self
339            .explanation_generator
340            .explain_graphql(query_info, &security_analysis);
341
342        let context_hash = context.context_hash();
343        let token = self.token_generator.generate(
344            query,
345            &context.user_id,
346            &context.session_id,
347            self.config.server_id(),
348            &context_hash,
349            risk_level,
350            self.config.token_ttl_seconds,
351        );
352
353        let token_string = token.encode().map_err(|e| {
354            ValidationError::InternalError(format!("Failed to encode token: {}", e))
355        })?;
356
357        let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
358        let mutation_name = query_info.operation_name.as_deref();
359        let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
360        let action = UnifiedAction::resolve(
361            inferred_action,
362            &self.config.action_tags,
363            query_info.operation_name.as_deref().unwrap_or(""),
364        );
365
366        let metadata = ValidationMetadata {
367            is_read_only: query_info.operation_type.is_read_only(),
368            estimated_rows: security_analysis.estimated_rows,
369            accessed_types: security_analysis.tables_accessed.iter().cloned().collect(),
370            accessed_fields: security_analysis.fields_accessed.iter().cloned().collect(),
371            has_aggregation: security_analysis.has_aggregation,
372            code_type: Some(self.graphql_validator.to_code_type(query_info)),
373            action: Some(action),
374            validation_time_ms: start.elapsed().as_millis() as u64,
375        };
376
377        let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
378
379        for issue in &security_analysis.potential_issues {
380            if !issue.is_critical() {
381                result.warnings.push(issue.message.clone());
382            }
383        }
384
385        Ok(result)
386    }
387
388    /// Build metadata from query info.
389    fn build_metadata(
390        &self,
391        query_info: &GraphQLQueryInfo,
392        validation_time_ms: u64,
393    ) -> ValidationMetadata {
394        let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
395        let mutation_name = query_info.operation_name.as_deref();
396        let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
397        let action = UnifiedAction::resolve(
398            inferred_action,
399            &self.config.action_tags,
400            query_info.operation_name.as_deref().unwrap_or(""),
401        );
402
403        ValidationMetadata {
404            is_read_only: query_info.operation_type.is_read_only(),
405            estimated_rows: None,
406            accessed_types: query_info.types_accessed.iter().cloned().collect(),
407            accessed_fields: query_info.fields_accessed.iter().cloned().collect(),
408            has_aggregation: false,
409            code_type: Some(self.graphql_validator.to_code_type(query_info)),
410            action: Some(action),
411            validation_time_ms,
412        }
413    }
414
415    /// Validate JavaScript code for OpenAPI Code Mode.
416    #[cfg(feature = "openapi-code-mode")]
417    pub fn validate_javascript_code(
418        &self,
419        code: &str,
420        context: &ValidationContext,
421    ) -> Result<ValidationResult, ValidationError> {
422        let start = Instant::now();
423
424        if !self.config.enabled {
425            return Err(ValidationError::ConfigError(
426                "Code Mode is not enabled for this server".into(),
427            ));
428        }
429
430        if code.len() > self.config.max_query_length {
431            return Err(ValidationError::SecurityError {
432                message: format!(
433                    "Code length {} exceeds maximum {}",
434                    code.len(),
435                    self.config.max_query_length
436                ),
437                issue: crate::types::SecurityIssueType::HighComplexity,
438            });
439        }
440
441        let code_info = self.javascript_validator.validate(code)?;
442
443        if !code_info.is_read_only {
444            for method in &code_info.methods_used {
445                if !self.config.openapi_blocked_writes.is_empty()
446                    && self.config.openapi_blocked_writes.contains(method)
447                {
448                    return Ok(ValidationResult::failure(
449                        vec![PolicyViolation::new(
450                            "code_mode",
451                            "blocked_method",
452                            &format!("HTTP method '{}' is blocked for this server", method),
453                        )
454                        .with_suggestion("This method is in the blocklist and cannot be used")],
455                        self.build_js_metadata(&code_info, start.elapsed().as_millis() as u64),
456                    ));
457                }
458            }
459
460            if !self.config.openapi_allowed_writes.is_empty() {
461                tracing::debug!(
462                    target: "code_mode",
463                    "Skipping method-level check - policy evaluator will check operation allowlist ({} entries)",
464                    self.config.openapi_allowed_writes.len()
465                );
466            } else if !self.config.openapi_allow_writes {
467                return Ok(ValidationResult::failure(
468                    vec![PolicyViolation::new(
469                        "code_mode",
470                        "allow_mutations",
471                        "Write HTTP methods (POST, PUT, DELETE, PATCH) are not enabled for this server",
472                    )
473                    .with_suggestion("Only read-only methods (GET, HEAD, OPTIONS) are allowed. Contact your administrator to enable write operations.")],
474                    self.build_js_metadata(&code_info, start.elapsed().as_millis() as u64),
475                ));
476            }
477        }
478
479        self.complete_js_validation(code, &code_info, context, start)
480    }
481
482    /// Complete JavaScript validation after policy checks pass.
483    #[cfg(feature = "openapi-code-mode")]
484    fn complete_js_validation(
485        &self,
486        code: &str,
487        code_info: &JavaScriptCodeInfo,
488        context: &ValidationContext,
489        start: Instant,
490    ) -> Result<ValidationResult, ValidationError> {
491        let security_analysis = self.javascript_validator.analyze_security(code_info);
492        let risk_level = security_analysis.assess_risk();
493
494        if security_analysis
495            .potential_issues
496            .iter()
497            .any(|i| i.is_critical())
498        {
499            let violations: Vec<PolicyViolation> = security_analysis
500                .potential_issues
501                .iter()
502                .filter(|i| i.is_critical())
503                .map(|i| {
504                    PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
505                })
506                .collect();
507
508            return Ok(ValidationResult::failure(
509                violations,
510                self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
511            ));
512        }
513
514        let explanation = self.generate_js_explanation(code_info, &security_analysis);
515
516        let context_hash = context.context_hash();
517        let token = self.token_generator.generate(
518            code,
519            &context.user_id,
520            &context.session_id,
521            self.config.server_id(),
522            &context_hash,
523            risk_level,
524            self.config.token_ttl_seconds,
525        );
526
527        let token_string = token.encode().map_err(|e| {
528            ValidationError::InternalError(format!("Failed to encode token: {}", e))
529        })?;
530
531        let metadata = self.build_js_metadata(code_info, start.elapsed().as_millis() as u64);
532
533        let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
534
535        for issue in &security_analysis.potential_issues {
536            if !issue.is_critical() {
537                result.warnings.push(issue.message.clone());
538            }
539        }
540
541        Ok(result)
542    }
543
544    /// Build metadata from JavaScript code info.
545    #[cfg(feature = "openapi-code-mode")]
546    fn build_js_metadata(
547        &self,
548        code_info: &JavaScriptCodeInfo,
549        validation_time_ms: u64,
550    ) -> ValidationMetadata {
551        let action = if !code_info.api_calls.is_empty() {
552            let mut max_action = UnifiedAction::Read;
553            for call in &code_info.api_calls {
554                let method_str = format!("{:?}", call.method);
555                let inferred = UnifiedAction::from_http_method(&method_str);
556                match (&max_action, &inferred) {
557                    (UnifiedAction::Read, _) => max_action = inferred,
558                    (UnifiedAction::Write, UnifiedAction::Delete | UnifiedAction::Admin) => {
559                        max_action = inferred
560                    }
561                    (UnifiedAction::Delete, UnifiedAction::Admin) => max_action = inferred,
562                    _ => {}
563                }
564            }
565            Some(max_action)
566        } else if code_info.is_read_only {
567            Some(UnifiedAction::Read)
568        } else {
569            Some(UnifiedAction::Write)
570        };
571
572        ValidationMetadata {
573            is_read_only: code_info.is_read_only,
574            estimated_rows: None,
575            accessed_types: code_info.endpoints_accessed.iter().cloned().collect(),
576            accessed_fields: code_info.methods_used.iter().cloned().collect(),
577            has_aggregation: false,
578            code_type: Some(self.javascript_validator.to_code_type(code_info)),
579            action,
580            validation_time_ms,
581        }
582    }
583
584    /// Generate a human-readable explanation for JavaScript code.
585    #[cfg(feature = "openapi-code-mode")]
586    fn generate_js_explanation(
587        &self,
588        code_info: &JavaScriptCodeInfo,
589        security_analysis: &crate::types::SecurityAnalysis,
590    ) -> String {
591        let mut parts = Vec::new();
592
593        if code_info.is_read_only {
594            parts.push("This code will perform read-only API requests.".to_string());
595        } else {
596            parts.push("This code will perform API requests that may modify data.".to_string());
597        }
598
599        if !code_info.api_calls.is_empty() {
600            let call_descriptions: Vec<String> = code_info
601                .api_calls
602                .iter()
603                .map(|call| format!("{:?} {}", call.method, call.path))
604                .collect();
605
606            if call_descriptions.len() <= 3 {
607                parts.push(format!("API calls: {}", call_descriptions.join(", ")));
608            } else {
609                parts.push(format!(
610                    "API calls: {} and {} more",
611                    call_descriptions[..2].join(", "),
612                    call_descriptions.len() - 2
613                ));
614            }
615        }
616
617        if code_info.loop_count > 0 {
618            if code_info.all_loops_bounded {
619                parts.push(format!(
620                    "Contains {} bounded loop(s).",
621                    code_info.loop_count
622                ));
623            } else {
624                parts.push(format!(
625                    "Contains {} loop(s) - ensure they are properly bounded.",
626                    code_info.loop_count
627                ));
628            }
629        }
630
631        let risk = security_analysis.assess_risk();
632        parts.push(format!("Risk: {}", risk));
633
634        parts.join(" ")
635    }
636
637    /// Check if a validation result should be auto-approved.
638    pub fn should_auto_approve(&self, result: &ValidationResult) -> bool {
639        result.is_valid && self.config.should_auto_approve(result.risk_level)
640    }
641
642    /// Get the config.
643    pub fn config(&self) -> &CodeModeConfig {
644        &self.config
645    }
646
647    /// Get the token generator.
648    pub fn token_generator(&self) -> &T {
649        &self.token_generator
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use crate::types::RiskLevel;
657
658    fn test_pipeline() -> ValidationPipeline {
659        ValidationPipeline::new(CodeModeConfig::enabled(), b"test-secret".to_vec())
660    }
661
662    fn test_context() -> ValidationContext {
663        ValidationContext::new("user-123", "session-456", "schema-hash", "perms-hash")
664    }
665
666    #[test]
667    fn test_simple_query_validation() {
668        let pipeline = test_pipeline();
669        let ctx = test_context();
670
671        let result = pipeline
672            .validate_graphql_query("query { users { id name } }", &ctx)
673            .unwrap();
674
675        assert!(result.is_valid);
676        assert!(result.approval_token.is_some());
677        assert_eq!(result.risk_level, RiskLevel::Low);
678        assert!(result.explanation.contains("read"));
679    }
680
681    #[test]
682    fn test_mutation_blocked() {
683        let mut config = CodeModeConfig::enabled();
684        config.allow_mutations = false;
685
686        let pipeline = ValidationPipeline::new(config, b"test-secret".to_vec());
687        let ctx = test_context();
688
689        let result = pipeline
690            .validate_graphql_query("mutation { createUser(name: \"test\") { id } }", &ctx)
691            .unwrap();
692
693        assert!(!result.is_valid);
694        assert!(result
695            .violations
696            .iter()
697            .any(|v| v.rule == "allow_mutations"));
698    }
699
700    #[test]
701    fn test_disabled_code_mode() {
702        let config = CodeModeConfig::default();
703        let pipeline = ValidationPipeline::new(config, b"test-secret".to_vec());
704        let ctx = test_context();
705
706        let result = pipeline.validate_graphql_query("query { users { id } }", &ctx);
707
708        assert!(matches!(result, Err(ValidationError::ConfigError(_))));
709    }
710
711    #[test]
712    fn test_auto_approve_low_risk() {
713        let pipeline = test_pipeline();
714        let ctx = test_context();
715
716        let result = pipeline
717            .validate_graphql_query("query { users { id } }", &ctx)
718            .unwrap();
719
720        assert!(pipeline.should_auto_approve(&result));
721    }
722
723    #[test]
724    fn test_context_hash() {
725        let ctx = test_context();
726        let hash1 = ctx.context_hash();
727
728        let ctx2 =
729            ValidationContext::new("user-123", "session-456", "different-schema", "perms-hash");
730        let hash2 = ctx2.context_hash();
731
732        assert_ne!(hash1, hash2);
733    }
734}