1use crate::config::CodeModeConfig;
11#[cfg(feature = "openapi-code-mode")]
12use crate::config::OperationRegistry;
13use crate::explanation::{ExplanationGenerator, TemplateExplanationGenerator};
14use crate::graphql::{GraphQLQueryInfo, GraphQLValidator};
15use crate::policy::{OperationEntity, PolicyEvaluator};
16use crate::token::{compute_context_hash, HmacTokenGenerator, TokenGenerator, TokenSecret};
17use crate::types::{
18 PolicyViolation, TokenError, UnifiedAction, ValidationError, ValidationMetadata,
19 ValidationResult,
20};
21use std::sync::atomic::{AtomicBool, Ordering};
22use std::sync::Arc;
23use std::time::Instant;
24
25#[cfg(feature = "openapi-code-mode")]
26use crate::javascript::{JavaScriptCodeInfo, JavaScriptValidator};
27
28static NO_POLICY_WARNING_LOGGED: AtomicBool = AtomicBool::new(false);
30
31fn build_policy_violations(
46 decision: &crate::policy::AuthorizationDecision,
47 server_id: &str,
48 action: impl std::fmt::Display,
49 denied_subject: &str,
50) -> Vec<PolicyViolation> {
51 let capacity = decision.determining_policies.len() + decision.errors.len() + 1;
52 let mut violations: Vec<PolicyViolation> = Vec::with_capacity(capacity);
53
54 for policy_id in &decision.determining_policies {
55 violations.push(PolicyViolation::new(
56 "policy",
57 policy_id.clone(),
58 format!("Policy denied the {}", denied_subject),
59 ));
60 }
61
62 for err in &decision.errors {
63 violations.push(PolicyViolation::new(
64 "policy_error",
65 "evaluation_error",
66 err.clone(),
67 ));
68 }
69
70 if violations.is_empty() {
71 violations.push(PolicyViolation::new(
72 "policy",
73 "default_deny",
74 format!(
75 "Authorization default-deny: no Permit policy matched for \
76 server_id={server_id} action={action}. Check that Cedar \
77 policies exist for this server and that server_id is set correctly."
78 ),
79 ));
80 }
81
82 violations
83}
84
85fn warn_no_policy_configured() {
87 if !NO_POLICY_WARNING_LOGGED.swap(true, Ordering::SeqCst) {
88 tracing::warn!(
89 target: "code_mode",
90 "CODE MODE SECURITY WARNING: Code Mode is enabled but no policy evaluator \
91 is configured. Only basic config checks (allow_mutations, max_depth, etc.) will be \
92 performed. This provides NO real authorization policy evaluation. \
93 For production deployments, configure a policy evaluator (AVP or local Cedar)."
94 );
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct ValidationContext {
101 pub user_id: String,
103
104 pub session_id: String,
106
107 pub schema_hash: String,
109
110 pub permissions_hash: String,
112}
113
114impl ValidationContext {
115 pub fn new(
117 user_id: impl Into<String>,
118 session_id: impl Into<String>,
119 schema_hash: impl Into<String>,
120 permissions_hash: impl Into<String>,
121 ) -> Self {
122 Self {
123 user_id: user_id.into(),
124 session_id: session_id.into(),
125 schema_hash: schema_hash.into(),
126 permissions_hash: permissions_hash.into(),
127 }
128 }
129
130 pub fn context_hash(&self) -> String {
132 compute_context_hash(&self.schema_hash, &self.permissions_hash)
133 }
134}
135
136pub struct ValidationPipeline<
138 T: TokenGenerator = HmacTokenGenerator,
139 E: ExplanationGenerator = TemplateExplanationGenerator,
140> {
141 config: CodeModeConfig,
142 graphql_validator: GraphQLValidator,
143 #[cfg(feature = "openapi-code-mode")]
144 javascript_validator: JavaScriptValidator,
145 #[cfg(feature = "openapi-code-mode")]
146 operation_registry: OperationRegistry,
147 token_generator: T,
148 explanation_generator: E,
149 policy_evaluator: Option<Arc<dyn PolicyEvaluator>>,
150}
151
152impl ValidationPipeline<HmacTokenGenerator, TemplateExplanationGenerator> {
153 pub fn new(
163 mut config: CodeModeConfig,
164 token_secret: impl Into<Vec<u8>>,
165 ) -> Result<Self, TokenError> {
166 if config.enabled {
167 warn_no_policy_configured();
168 }
169
170 config.resolve_server_id();
171
172 #[cfg(feature = "openapi-code-mode")]
173 let operation_registry = OperationRegistry::from_entries(&config.operations);
174
175 Ok(Self {
176 graphql_validator: GraphQLValidator::default(),
177 #[cfg(feature = "openapi-code-mode")]
178 javascript_validator: JavaScriptValidator::default()
179 .with_sdk_operations(config.sdk_operations.clone()),
180 #[cfg(feature = "openapi-code-mode")]
181 operation_registry,
182 token_generator: HmacTokenGenerator::new_from_bytes(token_secret)?,
183 explanation_generator: TemplateExplanationGenerator::new(),
184 policy_evaluator: None,
185 config,
186 })
187 }
188
189 pub fn from_token_secret(
208 config: CodeModeConfig,
209 secret: &TokenSecret,
210 ) -> Result<Self, TokenError> {
211 Self::new(config, secret.expose_secret().to_vec())
212 }
213
214 pub fn with_policy_evaluator(
221 mut config: CodeModeConfig,
222 token_secret: impl Into<Vec<u8>>,
223 evaluator: Arc<dyn PolicyEvaluator>,
224 ) -> Result<Self, TokenError> {
225 config.resolve_server_id();
226 if config.server_id.is_none() {
227 tracing::warn!(
228 target: "code_mode",
229 "CodeModeConfig.server_id is not set — AVP/Cedar authorization will use 'unknown' \
230 as the resource entity ID and will likely default-deny silently. \
231 Set server_id in config.toml, or the PMCP_SERVER_ID or AWS_LAMBDA_FUNCTION_NAME env var."
232 );
233 }
234
235 #[cfg(feature = "openapi-code-mode")]
236 let operation_registry = OperationRegistry::from_entries(&config.operations);
237
238 Ok(Self {
239 graphql_validator: GraphQLValidator::default(),
240 #[cfg(feature = "openapi-code-mode")]
241 javascript_validator: JavaScriptValidator::default()
242 .with_sdk_operations(config.sdk_operations.clone()),
243 #[cfg(feature = "openapi-code-mode")]
244 operation_registry,
245 token_generator: HmacTokenGenerator::new_from_bytes(token_secret)?,
246 explanation_generator: TemplateExplanationGenerator::new(),
247 policy_evaluator: Some(evaluator),
248 config,
249 })
250 }
251
252 pub fn from_token_secret_with_policy(
262 config: CodeModeConfig,
263 secret: &TokenSecret,
264 evaluator: Arc<dyn PolicyEvaluator>,
265 ) -> Result<Self, TokenError> {
266 Self::with_policy_evaluator(config, secret.expose_secret().to_vec(), evaluator)
267 }
268}
269
270impl<T: TokenGenerator, E: ExplanationGenerator> ValidationPipeline<T, E> {
271 pub fn with_generators(
273 mut config: CodeModeConfig,
274 token_generator: T,
275 explanation_generator: E,
276 ) -> Self {
277 config.resolve_server_id();
278
279 #[cfg(feature = "openapi-code-mode")]
280 let operation_registry = OperationRegistry::from_entries(&config.operations);
281
282 Self {
283 graphql_validator: GraphQLValidator::default(),
284 #[cfg(feature = "openapi-code-mode")]
285 javascript_validator: JavaScriptValidator::default()
286 .with_sdk_operations(config.sdk_operations.clone()),
287 #[cfg(feature = "openapi-code-mode")]
288 operation_registry,
289 token_generator,
290 explanation_generator,
291 policy_evaluator: None,
292 config,
293 }
294 }
295
296 pub fn set_policy_evaluator(&mut self, evaluator: Arc<dyn PolicyEvaluator>) {
298 self.policy_evaluator = Some(evaluator);
299 }
300
301 pub fn has_policy_evaluator(&self) -> bool {
303 self.policy_evaluator.is_some()
304 }
305
306 fn check_config_authorization(
311 &self,
312 query_info: &GraphQLQueryInfo,
313 start: Instant,
314 ) -> Option<ValidationResult> {
315 if !query_info.operation_type.is_read_only() {
317 let mutation_name = query_info.root_fields.first().cloned().unwrap_or_default();
318
319 if !self.config.blocked_mutations.is_empty()
320 && self.config.blocked_mutations.contains(&mutation_name)
321 {
322 return Some(ValidationResult::failure(
323 vec![PolicyViolation::new(
324 "code_mode",
325 "blocked_mutation",
326 &format!("Mutation '{}' is blocked for this server", mutation_name),
327 )
328 .with_suggestion("This mutation is in the blocklist and cannot be executed")],
329 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
330 ));
331 }
332
333 if !self.config.allowed_mutations.is_empty() {
334 if !self.config.allowed_mutations.contains(&mutation_name) {
335 return Some(ValidationResult::failure(
336 vec![PolicyViolation::new(
337 "code_mode",
338 "mutation_not_allowed",
339 &format!("Mutation '{}' is not in the allowlist", mutation_name),
340 )
341 .with_suggestion(&format!(
342 "Only these mutations are allowed: {}",
343 self.config
344 .allowed_mutations
345 .iter()
346 .cloned()
347 .collect::<Vec<_>>()
348 .join(", ")
349 ))],
350 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
351 ));
352 }
353 } else if !self.config.allow_mutations {
354 return Some(ValidationResult::failure(
355 vec![PolicyViolation::new(
356 "code_mode",
357 "allow_mutations",
358 "Mutations are not enabled for this server",
359 )
360 .with_suggestion("Only read-only queries are allowed")],
361 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
362 ));
363 }
364 }
365
366 if query_info.operation_type.is_read_only() {
368 let query_name = query_info.root_fields.first().cloned().unwrap_or_default();
369
370 if !self.config.blocked_queries.is_empty()
371 && self.config.blocked_queries.contains(&query_name)
372 {
373 return Some(ValidationResult::failure(
374 vec![PolicyViolation::new(
375 "code_mode",
376 "blocked_query",
377 &format!("Query '{}' is blocked for this server", query_name),
378 )
379 .with_suggestion("This query is in the blocklist and cannot be executed")],
380 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
381 ));
382 }
383
384 if !self.config.allowed_queries.is_empty()
385 && !self.config.allowed_queries.contains(&query_name)
386 {
387 return Some(ValidationResult::failure(
388 vec![PolicyViolation::new(
389 "code_mode",
390 "query_not_allowed",
391 &format!("Query '{}' is not in the allowlist", query_name),
392 )
393 .with_suggestion(&format!(
394 "Only these queries are allowed: {}",
395 self.config
396 .allowed_queries
397 .iter()
398 .cloned()
399 .collect::<Vec<_>>()
400 .join(", ")
401 ))],
402 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
403 ));
404 }
405 }
406
407 None
408 }
409
410 pub fn validate_graphql_query(
412 &self,
413 query: &str,
414 context: &ValidationContext,
415 ) -> Result<ValidationResult, ValidationError> {
416 let start = Instant::now();
417
418 if !self.config.enabled {
419 return Err(ValidationError::ConfigError(
420 "Code Mode is not enabled for this server".into(),
421 ));
422 }
423
424 if query.len() > self.config.max_query_length {
425 return Err(ValidationError::SecurityError {
426 message: format!(
427 "Query length {} exceeds maximum {}",
428 query.len(),
429 self.config.max_query_length
430 ),
431 issue: crate::types::SecurityIssueType::HighComplexity,
432 });
433 }
434
435 let query_info = self.graphql_validator.validate(query)?;
436
437 if let Some(failure) = self.check_config_authorization(&query_info, start) {
439 return Ok(failure);
440 }
441
442 self.complete_validation(query, &query_info, context, start)
443 }
444
445 pub async fn validate_graphql_query_async(
447 &self,
448 query: &str,
449 context: &ValidationContext,
450 ) -> Result<ValidationResult, ValidationError> {
451 let start = Instant::now();
452
453 if !self.config.enabled {
454 return Err(ValidationError::ConfigError(
455 "Code Mode is not enabled for this server".into(),
456 ));
457 }
458
459 if query.len() > self.config.max_query_length {
460 return Err(ValidationError::SecurityError {
461 message: format!(
462 "Query length {} exceeds maximum {}",
463 query.len(),
464 self.config.max_query_length
465 ),
466 issue: crate::types::SecurityIssueType::HighComplexity,
467 });
468 }
469
470 let query_info = self.graphql_validator.validate(query)?;
471
472 if let Some(ref evaluator) = self.policy_evaluator {
474 let operation_entity = OperationEntity::from_query_info(&query_info);
475 let server_config = self.config.to_server_config_entity();
476
477 let decision = evaluator
478 .evaluate_operation(&operation_entity, &server_config)
479 .await
480 .map_err(|e| {
481 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
482 })?;
483
484 if !decision.allowed {
485 let op_type_str = format!("{:?}", query_info.operation_type);
486 let action =
487 UnifiedAction::from_graphql(&op_type_str, query_info.operation_name.as_deref());
488 let violations = build_policy_violations(
489 &decision,
490 self.config.server_id(),
491 action,
492 "operation",
493 );
494
495 return Ok(ValidationResult::failure(
496 violations,
497 self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
498 ));
499 }
500 } else {
501 warn_no_policy_configured();
502 tracing::debug!(
503 target: "code_mode",
504 "Falling back to basic config checks (no policy evaluator configured)"
505 );
506 if let Some(failure) = self.check_config_authorization(&query_info, start) {
508 return Ok(failure);
509 }
510 }
511
512 self.complete_validation(query, &query_info, context, start)
513 }
514
515 fn complete_validation(
517 &self,
518 query: &str,
519 query_info: &GraphQLQueryInfo,
520 context: &ValidationContext,
521 start: Instant,
522 ) -> Result<ValidationResult, ValidationError> {
523 let security_analysis = self.graphql_validator.analyze_security(query_info);
524 let risk_level = security_analysis.assess_risk();
525
526 if security_analysis
527 .potential_issues
528 .iter()
529 .any(|i| i.is_critical())
530 {
531 let violations: Vec<PolicyViolation> = security_analysis
532 .potential_issues
533 .iter()
534 .filter(|i| i.is_critical())
535 .map(|i| {
536 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
537 })
538 .collect();
539
540 return Ok(ValidationResult::failure(
541 violations,
542 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
543 ));
544 }
545
546 let explanation = self
547 .explanation_generator
548 .explain_graphql(query_info, &security_analysis);
549
550 let context_hash = context.context_hash();
551 let token = self.token_generator.generate(
552 query,
553 &context.user_id,
554 &context.session_id,
555 self.config.server_id(),
556 &context_hash,
557 risk_level,
558 self.config.token_ttl_seconds,
559 );
560
561 let token_string = token.encode().map_err(|e| {
562 ValidationError::InternalError(format!("Failed to encode token: {}", e))
563 })?;
564
565 let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
566 let mutation_name = query_info.operation_name.as_deref();
567 let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
568 let action = UnifiedAction::resolve(
569 inferred_action,
570 &self.config.action_tags,
571 query_info.operation_name.as_deref().unwrap_or(""),
572 );
573
574 let metadata = ValidationMetadata {
575 is_read_only: query_info.operation_type.is_read_only(),
576 estimated_rows: security_analysis.estimated_rows,
577 accessed_types: security_analysis.tables_accessed.iter().cloned().collect(),
578 accessed_fields: security_analysis.fields_accessed.iter().cloned().collect(),
579 has_aggregation: security_analysis.has_aggregation,
580 code_type: Some(self.graphql_validator.to_code_type(query_info)),
581 action: Some(action),
582 validation_time_ms: start.elapsed().as_millis() as u64,
583 };
584
585 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
586
587 for issue in &security_analysis.potential_issues {
588 if !issue.is_critical() {
589 result.warnings.push(issue.message.clone());
590 }
591 }
592
593 Ok(result)
594 }
595
596 fn build_metadata(
598 &self,
599 query_info: &GraphQLQueryInfo,
600 validation_time_ms: u64,
601 ) -> ValidationMetadata {
602 let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
603 let mutation_name = query_info.operation_name.as_deref();
604 let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
605 let action = UnifiedAction::resolve(
606 inferred_action,
607 &self.config.action_tags,
608 query_info.operation_name.as_deref().unwrap_or(""),
609 );
610
611 ValidationMetadata {
612 is_read_only: query_info.operation_type.is_read_only(),
613 estimated_rows: None,
614 accessed_types: query_info.types_accessed.iter().cloned().collect(),
615 accessed_fields: query_info.fields_accessed.iter().cloned().collect(),
616 has_aggregation: false,
617 code_type: Some(self.graphql_validator.to_code_type(query_info)),
618 action: Some(action),
619 validation_time_ms,
620 }
621 }
622
623 #[cfg(feature = "openapi-code-mode")]
629 pub fn validate_javascript_code(
630 &self,
631 code: &str,
632 context: &ValidationContext,
633 ) -> Result<ValidationResult, ValidationError> {
634 let start = Instant::now();
635 let code_info = self.validate_js_preamble(code)?;
636 if let Some(failure) = self.check_js_config_authorization(&code_info, start) {
637 return Ok(failure);
638 }
639 self.complete_js_validation(code, &code_info, context, start)
640 }
641
642 #[cfg(feature = "openapi-code-mode")]
651 pub async fn validate_javascript_code_async(
652 &self,
653 code: &str,
654 context: &ValidationContext,
655 ) -> Result<ValidationResult, ValidationError> {
656 use crate::policy::types::ScriptEntity;
657
658 let start = Instant::now();
659 let code_info = self.validate_js_preamble(code)?;
660 if let Some(failure) = self.check_js_config_authorization(&code_info, start) {
661 return Ok(failure);
662 }
663
664 if let Some(ref evaluator) = self.policy_evaluator {
666 let sensitive_patterns: Vec<String> =
667 self.config.openapi_blocked_paths.iter().cloned().collect();
668 let registry_ref = if self.operation_registry.is_empty() {
669 None
670 } else {
671 Some(&self.operation_registry)
672 };
673 let script_entity =
674 ScriptEntity::from_javascript_info(&code_info, &sensitive_patterns, registry_ref);
675 let server_entity = self.config.to_openapi_server_entity();
676
677 let decision = evaluator
678 .evaluate_script(&script_entity, &server_entity)
679 .await
680 .map_err(|e| {
681 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
682 })?;
683
684 if !decision.allowed {
685 let violations = build_policy_violations(
686 &decision,
687 self.config.server_id(),
688 script_entity.action(),
689 "script",
690 );
691
692 return Ok(ValidationResult::failure(
693 violations,
694 self.build_js_metadata(&code_info, start.elapsed().as_millis() as u64),
695 ));
696 }
697 }
698
699 self.complete_js_validation(code, &code_info, context, start)
700 }
701
702 #[cfg(feature = "openapi-code-mode")]
704 fn validate_js_preamble(&self, code: &str) -> Result<JavaScriptCodeInfo, ValidationError> {
705 if !self.config.enabled {
706 return Err(ValidationError::ConfigError(
707 "Code Mode is not enabled for this server".into(),
708 ));
709 }
710
711 if code.len() > self.config.max_query_length {
712 return Err(ValidationError::SecurityError {
713 message: format!(
714 "Code length {} exceeds maximum {}",
715 code.len(),
716 self.config.max_query_length
717 ),
718 issue: crate::types::SecurityIssueType::HighComplexity,
719 });
720 }
721
722 Ok(self.javascript_validator.validate(code)?)
723 }
724
725 #[cfg(feature = "openapi-code-mode")]
731 fn check_js_config_authorization(
732 &self,
733 code_info: &JavaScriptCodeInfo,
734 start: Instant,
735 ) -> Option<ValidationResult> {
736 if code_info.is_read_only {
737 return None;
738 }
739
740 for method in &code_info.methods_used {
741 if !self.config.openapi_blocked_writes.is_empty()
742 && self.config.openapi_blocked_writes.contains(method)
743 {
744 return Some(ValidationResult::failure(
745 vec![PolicyViolation::new(
746 "code_mode",
747 "blocked_method",
748 &format!("HTTP method '{}' is blocked for this server", method),
749 )
750 .with_suggestion("This method is in the blocklist and cannot be used")],
751 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
752 ));
753 }
754 }
755
756 if !self.config.openapi_allowed_writes.is_empty() {
757 tracing::debug!(
758 target: "code_mode",
759 "Skipping method-level check - policy evaluator will check operation allowlist ({} entries)",
760 self.config.openapi_allowed_writes.len()
761 );
762 } else if !self.config.openapi_allow_writes {
763 return Some(ValidationResult::failure(
764 vec![PolicyViolation::new(
765 "code_mode",
766 "allow_mutations",
767 "Write HTTP methods (POST, PUT, DELETE, PATCH) are not enabled for this server",
768 )
769 .with_suggestion("Only read-only methods (GET, HEAD, OPTIONS) are allowed. Contact your administrator to enable write operations.")],
770 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
771 ));
772 }
773
774 None
775 }
776
777 #[cfg(feature = "openapi-code-mode")]
779 fn complete_js_validation(
780 &self,
781 code: &str,
782 code_info: &JavaScriptCodeInfo,
783 context: &ValidationContext,
784 start: Instant,
785 ) -> Result<ValidationResult, ValidationError> {
786 let security_analysis = self.javascript_validator.analyze_security(code_info);
787 let risk_level = security_analysis.assess_risk();
788
789 if security_analysis
790 .potential_issues
791 .iter()
792 .any(|i| i.is_critical())
793 {
794 let violations: Vec<PolicyViolation> = security_analysis
795 .potential_issues
796 .iter()
797 .filter(|i| i.is_critical())
798 .map(|i| {
799 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
800 })
801 .collect();
802
803 return Ok(ValidationResult::failure(
804 violations,
805 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
806 ));
807 }
808
809 let explanation = self.generate_js_explanation(code_info, &security_analysis);
810
811 let context_hash = context.context_hash();
812 let token = self.token_generator.generate(
813 code,
814 &context.user_id,
815 &context.session_id,
816 self.config.server_id(),
817 &context_hash,
818 risk_level,
819 self.config.token_ttl_seconds,
820 );
821
822 let token_string = token.encode().map_err(|e| {
823 ValidationError::InternalError(format!("Failed to encode token: {}", e))
824 })?;
825
826 let metadata = self.build_js_metadata(code_info, start.elapsed().as_millis() as u64);
827
828 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
829
830 for issue in &security_analysis.potential_issues {
831 if !issue.is_critical() {
832 result.warnings.push(issue.message.clone());
833 }
834 }
835
836 Ok(result)
837 }
838
839 #[cfg(feature = "openapi-code-mode")]
841 fn build_js_metadata(
842 &self,
843 code_info: &JavaScriptCodeInfo,
844 validation_time_ms: u64,
845 ) -> ValidationMetadata {
846 let action = if !code_info.api_calls.is_empty() {
847 let mut max_action = UnifiedAction::Read;
848 for call in &code_info.api_calls {
849 let method_str = format!("{:?}", call.method);
850 let inferred = UnifiedAction::from_http_method(&method_str);
851 match (&max_action, &inferred) {
852 (UnifiedAction::Read, _) => max_action = inferred,
853 (UnifiedAction::Write, UnifiedAction::Delete | UnifiedAction::Admin) => {
854 max_action = inferred
855 },
856 (UnifiedAction::Delete, UnifiedAction::Admin) => max_action = inferred,
857 _ => {},
858 }
859 }
860 Some(max_action)
861 } else if code_info.is_read_only {
862 Some(UnifiedAction::Read)
863 } else {
864 Some(UnifiedAction::Write)
865 };
866
867 ValidationMetadata {
868 is_read_only: code_info.is_read_only,
869 estimated_rows: None,
870 accessed_types: code_info.endpoints_accessed.iter().cloned().collect(),
871 accessed_fields: code_info.methods_used.iter().cloned().collect(),
872 has_aggregation: false,
873 code_type: Some(self.javascript_validator.to_code_type(code_info)),
874 action,
875 validation_time_ms,
876 }
877 }
878
879 #[cfg(feature = "openapi-code-mode")]
881 fn generate_js_explanation(
882 &self,
883 code_info: &JavaScriptCodeInfo,
884 security_analysis: &crate::types::SecurityAnalysis,
885 ) -> String {
886 let mut parts = Vec::new();
887
888 if code_info.is_read_only {
889 parts.push("This code will perform read-only API requests.".to_string());
890 } else {
891 parts.push("This code will perform API requests that may modify data.".to_string());
892 }
893
894 if !code_info.api_calls.is_empty() {
895 let call_descriptions: Vec<String> = code_info
896 .api_calls
897 .iter()
898 .map(|call| format!("{:?} {}", call.method, call.path))
899 .collect();
900
901 if call_descriptions.len() <= 3 {
902 parts.push(format!("API calls: {}", call_descriptions.join(", ")));
903 } else {
904 parts.push(format!(
905 "API calls: {} and {} more",
906 call_descriptions[..2].join(", "),
907 call_descriptions.len() - 2
908 ));
909 }
910 }
911
912 if code_info.loop_count > 0 {
913 if code_info.all_loops_bounded {
914 parts.push(format!(
915 "Contains {} bounded loop(s).",
916 code_info.loop_count
917 ));
918 } else {
919 parts.push(format!(
920 "Contains {} loop(s) - ensure they are properly bounded.",
921 code_info.loop_count
922 ));
923 }
924 }
925
926 let risk = security_analysis.assess_risk();
927 parts.push(format!("Risk: {}", risk));
928
929 parts.join(" ")
930 }
931
932 #[cfg(feature = "sql-code-mode")]
936 pub fn validate_sql_query(
937 &self,
938 sql: &str,
939 context: &ValidationContext,
940 ) -> Result<ValidationResult, ValidationError> {
941 let start = Instant::now();
942 let info = self.validate_sql_preamble(sql)?;
943 if let Some(failure) = self.check_sql_config_authorization(&info, start) {
944 return Ok(failure);
945 }
946 self.complete_sql_validation(sql, &info, context, start)
947 }
948
949 #[cfg(feature = "sql-code-mode")]
958 pub async fn validate_sql_query_async(
959 &self,
960 sql: &str,
961 context: &ValidationContext,
962 ) -> Result<ValidationResult, ValidationError> {
963 use crate::policy::StatementEntity;
964
965 let start = Instant::now();
966 let info = self.validate_sql_preamble(sql)?;
967 if let Some(failure) = self.check_sql_config_authorization(&info, start) {
968 return Ok(failure);
969 }
970
971 if let Some(ref evaluator) = self.policy_evaluator {
972 let statement_entity = StatementEntity::from_sql_info(&info);
973 let server_entity = self.config.to_sql_server_entity();
974
975 let decision = evaluator
976 .evaluate_statement(&statement_entity, &server_entity)
977 .await
978 .map_err(|e| {
979 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
980 })?;
981
982 if !decision.allowed {
983 let violations = build_policy_violations(
984 &decision,
985 self.config.server_id(),
986 statement_entity.action(),
987 "SQL statement",
988 );
989
990 return Ok(ValidationResult::failure(
991 violations,
992 self.build_sql_metadata(&info, start.elapsed().as_millis() as u64),
993 ));
994 }
995 } else {
996 warn_no_policy_configured();
997 }
998
999 self.complete_sql_validation(sql, &info, context, start)
1000 }
1001
1002 #[cfg(feature = "sql-code-mode")]
1004 fn validate_sql_preamble(
1005 &self,
1006 sql: &str,
1007 ) -> Result<crate::sql::SqlStatementInfo, ValidationError> {
1008 if !self.config.enabled {
1009 return Err(ValidationError::ConfigError(
1010 "Code Mode is not enabled for this server".into(),
1011 ));
1012 }
1013
1014 if sql.len() > self.config.max_query_length {
1015 return Err(ValidationError::SecurityError {
1016 message: format!(
1017 "SQL length {} exceeds maximum {}",
1018 sql.len(),
1019 self.config.max_query_length
1020 ),
1021 issue: crate::types::SecurityIssueType::HighComplexity,
1022 });
1023 }
1024
1025 let validator = crate::sql::SqlValidator::new();
1026 validator.validate(sql)
1027 }
1028
1029 #[cfg(feature = "sql-code-mode")]
1034 fn check_sql_config_authorization(
1035 &self,
1036 info: &crate::sql::SqlStatementInfo,
1037 start: Instant,
1038 ) -> Option<ValidationResult> {
1039 use crate::sql::SqlStatementType;
1040
1041 let stype = info.statement_type.as_str();
1042
1043 if self.config.sql_blocked_statements.contains(stype) {
1045 return Some(ValidationResult::failure(
1046 vec![PolicyViolation::new(
1047 "code_mode",
1048 "blocked_statement",
1049 format!("Statement type '{}' is blocked for this server", stype),
1050 )],
1051 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1052 ));
1053 }
1054
1055 if !self.config.sql_allowed_statements.is_empty()
1057 && !self.config.sql_allowed_statements.contains(stype)
1058 {
1059 return Some(ValidationResult::failure(
1060 vec![PolicyViolation::new(
1061 "code_mode",
1062 "statement_not_allowed",
1063 format!("Statement type '{}' is not in the allowlist", stype),
1064 )],
1065 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1066 ));
1067 }
1068
1069 match info.statement_type {
1071 SqlStatementType::Select => {
1072 if !self.config.sql_reads_enabled {
1073 return Some(ValidationResult::failure(
1074 vec![PolicyViolation::new(
1075 "code_mode",
1076 "reads_disabled",
1077 "SELECT statements are not enabled for this server",
1078 )],
1079 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1080 ));
1081 }
1082 },
1083 SqlStatementType::Insert | SqlStatementType::Update => {
1084 if !self.config.sql_allow_writes {
1085 return Some(ValidationResult::failure(
1086 vec![PolicyViolation::new(
1087 "code_mode",
1088 "writes_disabled",
1089 "INSERT/UPDATE statements are not enabled for this server",
1090 )
1091 .with_suggestion("Contact your administrator to enable sql_allow_writes.")],
1092 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1093 ));
1094 }
1095 if matches!(info.statement_type, SqlStatementType::Update)
1097 && self.config.sql_require_where_on_writes
1098 && !info.has_where
1099 {
1100 return Some(ValidationResult::failure(
1101 vec![PolicyViolation::new(
1102 "code_mode",
1103 "missing_where",
1104 format!("{} without WHERE clause is not allowed", info.verb),
1105 )],
1106 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1107 ));
1108 }
1109 },
1110 SqlStatementType::Delete => {
1111 if !self.config.sql_allow_deletes {
1112 return Some(ValidationResult::failure(
1113 vec![PolicyViolation::new(
1114 "code_mode",
1115 "deletes_disabled",
1116 "DELETE statements are not enabled for this server",
1117 )],
1118 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1119 ));
1120 }
1121 if self.config.sql_require_where_on_writes && !info.has_where {
1122 return Some(ValidationResult::failure(
1123 vec![PolicyViolation::new(
1124 "code_mode",
1125 "missing_where",
1126 "DELETE without WHERE clause is not allowed",
1127 )],
1128 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1129 ));
1130 }
1131 },
1132 SqlStatementType::Ddl => {
1133 if !self.config.sql_allow_ddl {
1134 return Some(ValidationResult::failure(
1135 vec![PolicyViolation::new(
1136 "code_mode",
1137 "ddl_disabled",
1138 "DDL (CREATE/ALTER/DROP/GRANT/REVOKE) is not enabled for this server",
1139 )],
1140 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1141 ));
1142 }
1143 },
1144 SqlStatementType::Other => {
1145 return Some(ValidationResult::failure(
1146 vec![PolicyViolation::new(
1147 "code_mode",
1148 "unsupported_statement",
1149 format!("Statement type '{}' is not supported", info.verb),
1150 )],
1151 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1152 ));
1153 },
1154 }
1155
1156 if !self.config.sql_blocked_tables.is_empty() {
1158 for table in &info.tables {
1159 if self.config.sql_blocked_tables.contains(table) {
1160 return Some(ValidationResult::failure(
1161 vec![PolicyViolation::new(
1162 "code_mode",
1163 "blocked_table",
1164 format!("Table '{}' is blocked for this server", table),
1165 )],
1166 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1167 ));
1168 }
1169 }
1170 }
1171
1172 if !self.config.sql_allowed_tables.is_empty() {
1174 for table in &info.tables {
1175 if !self.config.sql_allowed_tables.contains(table) {
1176 return Some(ValidationResult::failure(
1177 vec![PolicyViolation::new(
1178 "code_mode",
1179 "table_not_allowed",
1180 format!("Table '{}' is not in the allowlist", table),
1181 )],
1182 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1183 ));
1184 }
1185 }
1186 }
1187
1188 if !self.config.sql_blocked_columns.is_empty() {
1190 for col in &info.columns {
1191 if self.config.sql_blocked_columns.contains(col) {
1192 return Some(ValidationResult::failure(
1193 vec![PolicyViolation::new(
1194 "code_mode",
1195 "blocked_column",
1196 format!("Column '{}' is blocked for this server", col),
1197 )],
1198 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1199 ));
1200 }
1201 }
1202 }
1203
1204 if info.join_count > self.config.sql_max_joins {
1206 return Some(ValidationResult::failure(
1207 vec![PolicyViolation::new(
1208 "code_mode",
1209 "excessive_joins",
1210 format!(
1211 "Query has {} JOINs, exceeds limit of {}",
1212 info.join_count, self.config.sql_max_joins
1213 ),
1214 )],
1215 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1216 ));
1217 }
1218
1219 if info.estimated_rows > self.config.sql_max_rows {
1220 return Some(ValidationResult::failure(
1221 vec![PolicyViolation::new(
1222 "code_mode",
1223 "excessive_rows",
1224 format!(
1225 "Estimated rows {} exceeds limit of {}",
1226 info.estimated_rows, self.config.sql_max_rows
1227 ),
1228 )],
1229 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1230 ));
1231 }
1232
1233 None
1234 }
1235
1236 #[cfg(feature = "sql-code-mode")]
1238 fn complete_sql_validation(
1239 &self,
1240 sql: &str,
1241 info: &crate::sql::SqlStatementInfo,
1242 context: &ValidationContext,
1243 start: Instant,
1244 ) -> Result<ValidationResult, ValidationError> {
1245 let validator = crate::sql::SqlValidator::new();
1246 let security_analysis = validator.analyze_security(info);
1247 let risk_level = security_analysis.assess_risk();
1248
1249 if security_analysis
1250 .potential_issues
1251 .iter()
1252 .any(|i| i.is_critical())
1253 {
1254 let violations: Vec<PolicyViolation> = security_analysis
1255 .potential_issues
1256 .iter()
1257 .filter(|i| i.is_critical())
1258 .map(|i| {
1259 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
1260 })
1261 .collect();
1262
1263 return Ok(ValidationResult::failure(
1264 violations,
1265 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1266 ));
1267 }
1268
1269 let context_hash = context.context_hash();
1270 let token = self.token_generator.generate(
1271 sql,
1272 &context.user_id,
1273 &context.session_id,
1274 self.config.server_id(),
1275 &context_hash,
1276 risk_level,
1277 self.config.token_ttl_seconds,
1278 );
1279
1280 let token_string = token.encode().map_err(|e| {
1281 ValidationError::InternalError(format!("Failed to encode token: {}", e))
1282 })?;
1283
1284 let explanation = self.generate_sql_explanation(info, &security_analysis);
1285 let metadata = self.build_sql_metadata(info, start.elapsed().as_millis() as u64);
1286
1287 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
1288
1289 for issue in &security_analysis.potential_issues {
1290 if !issue.is_critical() {
1291 result.warnings.push(issue.message.clone());
1292 }
1293 }
1294
1295 Ok(result)
1296 }
1297
1298 #[cfg(feature = "sql-code-mode")]
1300 fn build_sql_metadata(
1301 &self,
1302 info: &crate::sql::SqlStatementInfo,
1303 validation_time_ms: u64,
1304 ) -> ValidationMetadata {
1305 let inferred = UnifiedAction::from_sql(info.statement_type.as_str());
1306 let action = UnifiedAction::resolve(inferred, &self.config.action_tags, &info.verb);
1307
1308 ValidationMetadata {
1309 is_read_only: info.statement_type.is_read_only(),
1310 estimated_rows: Some(info.estimated_rows),
1311 accessed_types: info.tables.iter().cloned().collect(),
1312 accessed_fields: info.columns.iter().cloned().collect(),
1313 has_aggregation: info.has_aggregation,
1314 code_type: Some(if info.statement_type.is_read_only() {
1315 crate::types::CodeType::SqlQuery
1316 } else {
1317 crate::types::CodeType::SqlMutation
1318 }),
1319 action: Some(action),
1320 validation_time_ms,
1321 }
1322 }
1323
1324 #[cfg(feature = "sql-code-mode")]
1326 fn generate_sql_explanation(
1327 &self,
1328 info: &crate::sql::SqlStatementInfo,
1329 security_analysis: &crate::types::SecurityAnalysis,
1330 ) -> String {
1331 let mut parts = Vec::new();
1332
1333 let verb_phrase = match info.statement_type.as_str() {
1334 "SELECT" => "This query reads data",
1335 "INSERT" => "This statement inserts rows",
1336 "UPDATE" => "This statement updates rows",
1337 "DELETE" => "This statement deletes rows",
1338 "DDL" => "This statement changes schema or permissions",
1339 _ => "This statement",
1340 };
1341
1342 let tables_phrase = if info.tables.is_empty() {
1343 String::new()
1344 } else {
1345 let mut ts: Vec<&String> = info.tables.iter().collect();
1346 ts.sort();
1347 format!(
1348 " in table(s): {}",
1349 ts.into_iter().cloned().collect::<Vec<_>>().join(", ")
1350 )
1351 };
1352
1353 parts.push(format!("{}{}.", verb_phrase, tables_phrase));
1354
1355 if info.has_where {
1356 parts.push("Filtered with WHERE clause.".to_string());
1357 }
1358 if info.has_limit {
1359 parts.push(format!("Limited to {} rows.", info.estimated_rows));
1360 }
1361 if info.join_count > 0 {
1362 parts.push(format!("Uses {} JOIN(s).", info.join_count));
1363 }
1364 if info.subquery_count > 0 {
1365 parts.push(format!("Contains {} subquer(ies).", info.subquery_count));
1366 }
1367
1368 let risk = security_analysis.assess_risk();
1369 parts.push(format!("Risk: {}", risk));
1370
1371 parts.join(" ")
1372 }
1373
1374 pub fn should_auto_approve(&self, result: &ValidationResult) -> bool {
1376 result.is_valid && self.config.should_auto_approve(result.risk_level)
1377 }
1378
1379 pub fn config(&self) -> &CodeModeConfig {
1381 &self.config
1382 }
1383
1384 pub fn token_generator(&self) -> &T {
1386 &self.token_generator
1387 }
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392 use super::*;
1393 use crate::types::RiskLevel;
1394
1395 fn test_pipeline() -> ValidationPipeline {
1396 ValidationPipeline::new(CodeModeConfig::enabled(), b"test-secret-key!".to_vec()).unwrap()
1397 }
1398
1399 fn test_context() -> ValidationContext {
1400 ValidationContext::new("user-123", "session-456", "schema-hash", "perms-hash")
1401 }
1402
1403 #[test]
1404 fn test_simple_query_validation() {
1405 let pipeline = test_pipeline();
1406 let ctx = test_context();
1407
1408 let result = pipeline
1409 .validate_graphql_query("query { users { id name } }", &ctx)
1410 .unwrap();
1411
1412 assert!(result.is_valid);
1413 assert!(result.approval_token.is_some());
1414 assert_eq!(result.risk_level, RiskLevel::Low);
1415 assert!(result.explanation.contains("read"));
1416 }
1417
1418 #[test]
1419 fn test_mutation_blocked() {
1420 let mut config = CodeModeConfig::enabled();
1421 config.allow_mutations = false;
1422
1423 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1424 let ctx = test_context();
1425
1426 let result = pipeline
1427 .validate_graphql_query("mutation { createUser(name: \"test\") { id } }", &ctx)
1428 .unwrap();
1429
1430 assert!(!result.is_valid);
1431 assert!(result
1432 .violations
1433 .iter()
1434 .any(|v| v.rule == "allow_mutations"));
1435 }
1436
1437 #[test]
1438 fn test_disabled_code_mode() {
1439 let config = CodeModeConfig::default();
1440 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1441 let ctx = test_context();
1442
1443 let result = pipeline.validate_graphql_query("query { users { id } }", &ctx);
1444
1445 assert!(matches!(result, Err(ValidationError::ConfigError(_))));
1446 }
1447
1448 #[test]
1449 fn test_auto_approve_low_risk() {
1450 let pipeline = test_pipeline();
1451 let ctx = test_context();
1452
1453 let result = pipeline
1454 .validate_graphql_query("query { users { id } }", &ctx)
1455 .unwrap();
1456
1457 assert!(pipeline.should_auto_approve(&result));
1458 }
1459
1460 #[test]
1461 fn test_context_hash() {
1462 let ctx = test_context();
1463 let hash1 = ctx.context_hash();
1464
1465 let ctx2 =
1466 ValidationContext::new("user-123", "session-456", "different-schema", "perms-hash");
1467 let hash2 = ctx2.context_hash();
1468
1469 assert_ne!(hash1, hash2);
1470 }
1471
1472 #[test]
1473 fn test_blocked_query_rejected() {
1474 let mut config = CodeModeConfig::enabled();
1475 config.blocked_queries.insert("users".to_string());
1476
1477 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1478 let ctx = test_context();
1479
1480 let result = pipeline
1481 .validate_graphql_query("query { users { id } }", &ctx)
1482 .unwrap();
1483
1484 assert!(!result.is_valid);
1485 assert!(result.violations.iter().any(|v| v.rule == "blocked_query"));
1486 }
1487
1488 #[test]
1489 fn test_allowed_queries_enforced() {
1490 let mut config = CodeModeConfig::enabled();
1491 config.allowed_queries.insert("orders".to_string());
1492
1493 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1494 let ctx = test_context();
1495
1496 let result = pipeline
1498 .validate_graphql_query("query { users { id } }", &ctx)
1499 .unwrap();
1500
1501 assert!(!result.is_valid);
1502 assert!(result
1503 .violations
1504 .iter()
1505 .any(|v| v.rule == "query_not_allowed"));
1506 }
1507
1508 #[cfg(feature = "sql-code-mode")]
1513 mod sql_tests {
1514 use super::*;
1515
1516 fn sql_pipeline() -> ValidationPipeline {
1517 ValidationPipeline::new(CodeModeConfig::enabled(), b"test-secret-key!".to_vec())
1518 .unwrap()
1519 }
1520
1521 #[test]
1522 fn validates_select() {
1523 let pipeline = sql_pipeline();
1524 let ctx = test_context();
1525
1526 let result = pipeline
1527 .validate_sql_query("SELECT id, name FROM users LIMIT 10", &ctx)
1528 .unwrap();
1529
1530 assert!(result.is_valid);
1531 assert!(result.approval_token.is_some());
1532 }
1533
1534 #[test]
1535 fn rejects_insert_when_writes_disabled() {
1536 let pipeline = sql_pipeline();
1537 let ctx = test_context();
1538
1539 let result = pipeline
1540 .validate_sql_query("INSERT INTO users (id, name) VALUES (1, 'Alice')", &ctx)
1541 .unwrap();
1542
1543 assert!(!result.is_valid);
1544 assert!(result
1545 .violations
1546 .iter()
1547 .any(|v| v.rule == "writes_disabled"));
1548 }
1549
1550 #[test]
1551 fn permits_insert_when_writes_enabled() {
1552 let mut config = CodeModeConfig::enabled();
1553 config.sql_allow_writes = true;
1554 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1555 let ctx = test_context();
1556
1557 let result = pipeline
1558 .validate_sql_query("INSERT INTO users (id, name) VALUES (1, 'Alice')", &ctx)
1559 .unwrap();
1560
1561 assert!(result.is_valid);
1562 }
1563
1564 #[test]
1565 fn rejects_update_without_where_by_default() {
1566 let mut config = CodeModeConfig::enabled();
1567 config.sql_allow_writes = true;
1568 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1569 let ctx = test_context();
1570
1571 let result = pipeline
1572 .validate_sql_query("UPDATE users SET active = 0", &ctx)
1573 .unwrap();
1574
1575 assert!(!result.is_valid);
1576 assert!(result.violations.iter().any(|v| v.rule == "missing_where"));
1577 }
1578
1579 #[test]
1580 fn rejects_blocked_table() {
1581 let mut config = CodeModeConfig::enabled();
1582 config.sql_blocked_tables.insert("secrets".to_string());
1583 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1584 let ctx = test_context();
1585
1586 let result = pipeline
1587 .validate_sql_query("SELECT * FROM secrets LIMIT 10", &ctx)
1588 .unwrap();
1589
1590 assert!(!result.is_valid);
1591 assert!(result.violations.iter().any(|v| v.rule == "blocked_table"));
1592 }
1593
1594 #[test]
1595 fn rejects_non_allowlisted_table() {
1596 let mut config = CodeModeConfig::enabled();
1597 config.sql_allowed_tables.insert("users".to_string());
1598 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1599 let ctx = test_context();
1600
1601 let result = pipeline
1603 .validate_sql_query("SELECT id FROM orders LIMIT 10", &ctx)
1604 .unwrap();
1605
1606 assert!(!result.is_valid);
1607 assert!(result
1608 .violations
1609 .iter()
1610 .any(|v| v.rule == "table_not_allowed"));
1611 }
1612
1613 #[test]
1614 fn rejects_blocked_column() {
1615 let mut config = CodeModeConfig::enabled();
1616 config.sql_blocked_columns.insert("password".to_string());
1617 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1618 let ctx = test_context();
1619
1620 let result = pipeline
1621 .validate_sql_query("SELECT id, password FROM users LIMIT 10", &ctx)
1622 .unwrap();
1623
1624 assert!(!result.is_valid);
1625 assert!(result.violations.iter().any(|v| v.rule == "blocked_column"));
1626 }
1627
1628 #[test]
1629 fn rejects_ddl_by_default() {
1630 let pipeline = sql_pipeline();
1631 let ctx = test_context();
1632
1633 let result = pipeline
1634 .validate_sql_query("CREATE TABLE foo (id INT)", &ctx)
1635 .unwrap();
1636
1637 assert!(!result.is_valid);
1638 assert!(result.violations.iter().any(|v| v.rule == "ddl_disabled"));
1639 }
1640
1641 #[test]
1642 fn rejects_syntax_error() {
1643 let pipeline = sql_pipeline();
1644 let ctx = test_context();
1645
1646 let result = pipeline.validate_sql_query("SELEC id FRM users", &ctx);
1647
1648 assert!(matches!(result, Err(ValidationError::ParseError { .. })));
1649 }
1650
1651 struct FixedDenyEvaluator {
1652 errors: Vec<String>,
1653 }
1654
1655 #[async_trait::async_trait]
1656 impl PolicyEvaluator for FixedDenyEvaluator {
1657 async fn evaluate_operation(
1658 &self,
1659 _op: &crate::policy::OperationEntity,
1660 _cfg: &crate::policy::ServerConfigEntity,
1661 ) -> Result<crate::policy::AuthorizationDecision, crate::policy::PolicyEvaluationError>
1662 {
1663 Ok(crate::policy::AuthorizationDecision {
1664 allowed: false,
1665 determining_policies: vec![],
1666 errors: self.errors.clone(),
1667 })
1668 }
1669
1670 #[cfg(feature = "sql-code-mode")]
1671 async fn evaluate_statement(
1672 &self,
1673 _stmt: &crate::policy::StatementEntity,
1674 _server: &crate::policy::SqlServerEntity,
1675 ) -> Result<crate::policy::AuthorizationDecision, crate::policy::PolicyEvaluationError>
1676 {
1677 Ok(crate::policy::AuthorizationDecision {
1678 allowed: false,
1679 determining_policies: vec![],
1680 errors: self.errors.clone(),
1681 })
1682 }
1683
1684 fn name(&self) -> &str {
1685 "fixed-deny-test"
1686 }
1687 }
1688
1689 fn sql_pipeline_with_evaluator(evaluator: Arc<dyn PolicyEvaluator>) -> ValidationPipeline {
1690 let mut config = CodeModeConfig::enabled();
1691 config.server_id = Some("test-server".to_string());
1692 ValidationPipeline::with_policy_evaluator(
1693 config,
1694 b"test-secret-key!".to_vec(),
1695 evaluator,
1696 )
1697 .unwrap()
1698 }
1699
1700 #[tokio::test]
1701 async fn default_deny_produces_synthetic_violation() {
1702 let evaluator =
1703 Arc::new(FixedDenyEvaluator { errors: vec![] }) as Arc<dyn PolicyEvaluator>;
1704 let pipeline = sql_pipeline_with_evaluator(evaluator);
1705 let ctx = test_context();
1706
1707 let result = pipeline
1708 .validate_sql_query_async("SELECT id FROM users LIMIT 10", &ctx)
1709 .await
1710 .unwrap();
1711
1712 assert!(!result.is_valid);
1713 let default_deny = result
1714 .violations
1715 .iter()
1716 .find(|v| v.rule == "default_deny")
1717 .expect("expected a synthetic default_deny violation");
1718 assert!(default_deny.message.contains("test-server"));
1719 assert!(default_deny.message.contains("Read"));
1720 }
1721
1722 #[tokio::test]
1723 async fn policy_errors_flow_to_violations() {
1724 let evaluator = Arc::new(FixedDenyEvaluator {
1725 errors: vec!["schema validation: missing required attribute X".to_string()],
1726 }) as Arc<dyn PolicyEvaluator>;
1727 let pipeline = sql_pipeline_with_evaluator(evaluator);
1728 let ctx = test_context();
1729
1730 let result = pipeline
1731 .validate_sql_query_async("SELECT id FROM users LIMIT 10", &ctx)
1732 .await
1733 .unwrap();
1734
1735 assert!(!result.is_valid);
1736 let policy_error = result
1737 .violations
1738 .iter()
1739 .find(|v| v.rule == "evaluation_error")
1740 .expect("expected a policy_error violation");
1741 assert!(policy_error.message.contains("schema validation"));
1742 }
1743
1744 #[test]
1745 fn rejects_excessive_joins() {
1746 let mut config = CodeModeConfig::enabled();
1747 config.sql_max_joins = 1;
1748 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1749 let ctx = test_context();
1750
1751 let result = pipeline
1752 .validate_sql_query(
1753 "SELECT u.id FROM users u \
1754 JOIN orders o ON u.id = o.user_id \
1755 JOIN items i ON o.id = i.order_id LIMIT 10",
1756 &ctx,
1757 )
1758 .unwrap();
1759
1760 assert!(!result.is_valid);
1761 assert!(result
1762 .violations
1763 .iter()
1764 .any(|v| v.rule == "excessive_joins"));
1765 }
1766 }
1767}