1use crate::config::CodeModeConfig;
11#[cfg(feature = "openapi-code-mode")]
12use crate::config::OperationRegistry;
13use crate::explanation::{ExplanationGenerator, TemplateExplanationGenerator};
14use crate::graphql::{GraphQLQueryInfo, GraphQLValidator};
15use crate::policy::{OperationEntity, PolicyEvaluator};
16use crate::token::{compute_context_hash, HmacTokenGenerator, TokenGenerator, TokenSecret};
17use crate::types::{
18 PolicyViolation, TokenError, UnifiedAction, ValidationError, ValidationMetadata,
19 ValidationResult,
20};
21use std::sync::atomic::{AtomicBool, Ordering};
22use std::sync::Arc;
23use std::time::Instant;
24
25#[cfg(feature = "openapi-code-mode")]
26use crate::javascript::{JavaScriptCodeInfo, JavaScriptValidator};
27
28static NO_POLICY_WARNING_LOGGED: AtomicBool = AtomicBool::new(false);
30
31fn build_policy_violations(
46 decision: &crate::policy::AuthorizationDecision,
47 server_id: &str,
48 action: impl std::fmt::Display,
49 denied_subject: &str,
50) -> Vec<PolicyViolation> {
51 let capacity = decision.determining_policies.len() + decision.errors.len() + 1;
52 let mut violations: Vec<PolicyViolation> = Vec::with_capacity(capacity);
53
54 for policy_id in &decision.determining_policies {
55 violations.push(PolicyViolation::new(
56 "policy",
57 policy_id.clone(),
58 format!("Policy denied the {}", denied_subject),
59 ));
60 }
61
62 for err in &decision.errors {
63 violations.push(PolicyViolation::new(
64 "policy_error",
65 "evaluation_error",
66 err.clone(),
67 ));
68 }
69
70 if violations.is_empty() {
71 violations.push(PolicyViolation::new(
72 "policy",
73 "default_deny",
74 format!(
75 "Authorization default-deny: no Permit policy matched for \
76 server_id={server_id} action={action}. Check that Cedar \
77 policies exist for this server and that server_id is set correctly."
78 ),
79 ));
80 }
81
82 violations
83}
84
85fn warn_no_policy_configured() {
87 if !NO_POLICY_WARNING_LOGGED.swap(true, Ordering::SeqCst) {
88 tracing::warn!(
89 target: "code_mode",
90 "CODE MODE SECURITY WARNING: Code Mode is enabled but no policy evaluator \
91 is configured. Only basic config checks (allow_mutations, max_depth, etc.) will be \
92 performed. This provides NO real authorization policy evaluation. \
93 For production deployments, configure a policy evaluator (AVP or local Cedar)."
94 );
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct ValidationContext {
101 pub user_id: String,
103
104 pub session_id: String,
106
107 pub schema_hash: String,
109
110 pub permissions_hash: String,
112}
113
114impl ValidationContext {
115 pub fn new(
117 user_id: impl Into<String>,
118 session_id: impl Into<String>,
119 schema_hash: impl Into<String>,
120 permissions_hash: impl Into<String>,
121 ) -> Self {
122 Self {
123 user_id: user_id.into(),
124 session_id: session_id.into(),
125 schema_hash: schema_hash.into(),
126 permissions_hash: permissions_hash.into(),
127 }
128 }
129
130 pub fn context_hash(&self) -> String {
132 compute_context_hash(&self.schema_hash, &self.permissions_hash)
133 }
134}
135
136pub struct ValidationPipeline<
138 T: TokenGenerator = HmacTokenGenerator,
139 E: ExplanationGenerator = TemplateExplanationGenerator,
140> {
141 config: CodeModeConfig,
142 graphql_validator: GraphQLValidator,
143 #[cfg(feature = "openapi-code-mode")]
144 javascript_validator: JavaScriptValidator,
145 #[cfg(feature = "openapi-code-mode")]
146 operation_registry: OperationRegistry,
147 token_generator: T,
148 explanation_generator: E,
149 policy_evaluator: Option<Arc<dyn PolicyEvaluator>>,
150}
151
152impl ValidationPipeline<HmacTokenGenerator, TemplateExplanationGenerator> {
153 pub fn new(
163 mut config: CodeModeConfig,
164 token_secret: impl Into<Vec<u8>>,
165 ) -> Result<Self, TokenError> {
166 if config.enabled {
167 warn_no_policy_configured();
168 }
169
170 config.resolve_server_id();
171
172 #[cfg(feature = "openapi-code-mode")]
173 let operation_registry = OperationRegistry::from_entries(&config.operations);
174
175 Ok(Self {
176 graphql_validator: GraphQLValidator::default(),
177 #[cfg(feature = "openapi-code-mode")]
178 javascript_validator: JavaScriptValidator::default()
179 .with_sdk_operations(config.sdk_operations.clone()),
180 #[cfg(feature = "openapi-code-mode")]
181 operation_registry,
182 token_generator: HmacTokenGenerator::new_from_bytes(token_secret)?,
183 explanation_generator: TemplateExplanationGenerator::new(),
184 policy_evaluator: None,
185 config,
186 })
187 }
188
189 pub fn from_token_secret(
208 config: CodeModeConfig,
209 secret: &TokenSecret,
210 ) -> Result<Self, TokenError> {
211 Self::new(config, secret.expose_secret().to_vec())
212 }
213
214 pub fn with_policy_evaluator(
221 mut config: CodeModeConfig,
222 token_secret: impl Into<Vec<u8>>,
223 evaluator: Arc<dyn PolicyEvaluator>,
224 ) -> Result<Self, TokenError> {
225 config.resolve_server_id();
226 if config.server_id.is_none() {
227 tracing::warn!(
228 target: "code_mode",
229 "CodeModeConfig.server_id is not set — AVP/Cedar authorization will use 'unknown' \
230 as the resource entity ID and will likely default-deny silently. \
231 Set server_id in config.toml, or the PMCP_SERVER_ID or AWS_LAMBDA_FUNCTION_NAME env var."
232 );
233 }
234
235 #[cfg(feature = "openapi-code-mode")]
236 let operation_registry = OperationRegistry::from_entries(&config.operations);
237
238 Ok(Self {
239 graphql_validator: GraphQLValidator::default(),
240 #[cfg(feature = "openapi-code-mode")]
241 javascript_validator: JavaScriptValidator::default()
242 .with_sdk_operations(config.sdk_operations.clone()),
243 #[cfg(feature = "openapi-code-mode")]
244 operation_registry,
245 token_generator: HmacTokenGenerator::new_from_bytes(token_secret)?,
246 explanation_generator: TemplateExplanationGenerator::new(),
247 policy_evaluator: Some(evaluator),
248 config,
249 })
250 }
251
252 pub fn from_token_secret_with_policy(
262 config: CodeModeConfig,
263 secret: &TokenSecret,
264 evaluator: Arc<dyn PolicyEvaluator>,
265 ) -> Result<Self, TokenError> {
266 Self::with_policy_evaluator(config, secret.expose_secret().to_vec(), evaluator)
267 }
268}
269
270impl<T: TokenGenerator, E: ExplanationGenerator> ValidationPipeline<T, E> {
271 pub fn with_generators(
273 mut config: CodeModeConfig,
274 token_generator: T,
275 explanation_generator: E,
276 ) -> Self {
277 config.resolve_server_id();
278
279 #[cfg(feature = "openapi-code-mode")]
280 let operation_registry = OperationRegistry::from_entries(&config.operations);
281
282 Self {
283 graphql_validator: GraphQLValidator::default(),
284 #[cfg(feature = "openapi-code-mode")]
285 javascript_validator: JavaScriptValidator::default()
286 .with_sdk_operations(config.sdk_operations.clone()),
287 #[cfg(feature = "openapi-code-mode")]
288 operation_registry,
289 token_generator,
290 explanation_generator,
291 policy_evaluator: None,
292 config,
293 }
294 }
295
296 pub fn set_policy_evaluator(&mut self, evaluator: Arc<dyn PolicyEvaluator>) {
298 self.policy_evaluator = Some(evaluator);
299 }
300
301 pub fn has_policy_evaluator(&self) -> bool {
303 self.policy_evaluator.is_some()
304 }
305
306 fn check_config_authorization(
311 &self,
312 query_info: &GraphQLQueryInfo,
313 start: Instant,
314 ) -> Option<ValidationResult> {
315 if !query_info.operation_type.is_read_only() {
317 let mutation_name = query_info.root_fields.first().cloned().unwrap_or_default();
318
319 if !self.config.blocked_mutations.is_empty()
320 && self.config.blocked_mutations.contains(&mutation_name)
321 {
322 return Some(ValidationResult::failure(
323 vec![PolicyViolation::new(
324 "code_mode",
325 "blocked_mutation",
326 &format!("Mutation '{}' is blocked for this server", mutation_name),
327 )
328 .with_suggestion("This mutation is in the blocklist and cannot be executed")],
329 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
330 ));
331 }
332
333 if !self.config.allowed_mutations.is_empty() {
334 if !self.config.allowed_mutations.contains(&mutation_name) {
335 return Some(ValidationResult::failure(
336 vec![PolicyViolation::new(
337 "code_mode",
338 "mutation_not_allowed",
339 &format!("Mutation '{}' is not in the allowlist", mutation_name),
340 )
341 .with_suggestion(&format!(
342 "Only these mutations are allowed: {}",
343 self.config
344 .allowed_mutations
345 .iter()
346 .cloned()
347 .collect::<Vec<_>>()
348 .join(", ")
349 ))],
350 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
351 ));
352 }
353 } else if !self.config.allow_mutations {
354 return Some(ValidationResult::failure(
355 vec![PolicyViolation::new(
356 "code_mode",
357 "allow_mutations",
358 "Mutations are not enabled for this server",
359 )
360 .with_suggestion("Only read-only queries are allowed")],
361 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
362 ));
363 }
364 }
365
366 if query_info.operation_type.is_read_only() {
368 let query_name = query_info.root_fields.first().cloned().unwrap_or_default();
369
370 if !self.config.blocked_queries.is_empty()
371 && self.config.blocked_queries.contains(&query_name)
372 {
373 return Some(ValidationResult::failure(
374 vec![PolicyViolation::new(
375 "code_mode",
376 "blocked_query",
377 &format!("Query '{}' is blocked for this server", query_name),
378 )
379 .with_suggestion("This query is in the blocklist and cannot be executed")],
380 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
381 ));
382 }
383
384 if !self.config.allowed_queries.is_empty()
385 && !self.config.allowed_queries.contains(&query_name)
386 {
387 return Some(ValidationResult::failure(
388 vec![PolicyViolation::new(
389 "code_mode",
390 "query_not_allowed",
391 &format!("Query '{}' is not in the allowlist", query_name),
392 )
393 .with_suggestion(&format!(
394 "Only these queries are allowed: {}",
395 self.config
396 .allowed_queries
397 .iter()
398 .cloned()
399 .collect::<Vec<_>>()
400 .join(", ")
401 ))],
402 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
403 ));
404 }
405 }
406
407 None
408 }
409
410 pub fn validate_graphql_query(
412 &self,
413 query: &str,
414 context: &ValidationContext,
415 ) -> Result<ValidationResult, ValidationError> {
416 let start = Instant::now();
417
418 if !self.config.enabled {
419 return Err(ValidationError::ConfigError(
420 "Code Mode is not enabled for this server".into(),
421 ));
422 }
423
424 if query.len() > self.config.max_query_length {
425 return Err(ValidationError::SecurityError {
426 message: format!(
427 "Query length {} exceeds maximum {}",
428 query.len(),
429 self.config.max_query_length
430 ),
431 issue: crate::types::SecurityIssueType::HighComplexity,
432 });
433 }
434
435 let query_info = self.graphql_validator.validate(query)?;
436
437 if let Some(failure) = self.check_config_authorization(&query_info, start) {
439 return Ok(failure);
440 }
441
442 self.complete_validation(query, &query_info, context, start)
443 }
444
445 pub async fn validate_graphql_query_async(
447 &self,
448 query: &str,
449 context: &ValidationContext,
450 ) -> Result<ValidationResult, ValidationError> {
451 let start = Instant::now();
452
453 if !self.config.enabled {
454 return Err(ValidationError::ConfigError(
455 "Code Mode is not enabled for this server".into(),
456 ));
457 }
458
459 if query.len() > self.config.max_query_length {
460 return Err(ValidationError::SecurityError {
461 message: format!(
462 "Query length {} exceeds maximum {}",
463 query.len(),
464 self.config.max_query_length
465 ),
466 issue: crate::types::SecurityIssueType::HighComplexity,
467 });
468 }
469
470 let query_info = self.graphql_validator.validate(query)?;
471
472 if let Some(ref evaluator) = self.policy_evaluator {
474 let operation_entity = OperationEntity::from_query_info(&query_info);
475 let server_config = self.config.to_server_config_entity();
476
477 let decision = evaluator
478 .evaluate_operation(&operation_entity, &server_config)
479 .await
480 .map_err(|e| {
481 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
482 })?;
483
484 if !decision.allowed {
485 let op_type_str = format!("{:?}", query_info.operation_type);
486 let action =
487 UnifiedAction::from_graphql(&op_type_str, query_info.operation_name.as_deref());
488 let violations = build_policy_violations(
489 &decision,
490 self.config.server_id(),
491 action,
492 "operation",
493 );
494
495 return Ok(ValidationResult::failure(
496 violations,
497 self.build_metadata(&query_info, start.elapsed().as_millis() as u64),
498 ));
499 }
500 } else {
501 warn_no_policy_configured();
502 tracing::debug!(
503 target: "code_mode",
504 "Falling back to basic config checks (no policy evaluator configured)"
505 );
506 if let Some(failure) = self.check_config_authorization(&query_info, start) {
508 return Ok(failure);
509 }
510 }
511
512 self.complete_validation(query, &query_info, context, start)
513 }
514
515 fn complete_validation(
517 &self,
518 query: &str,
519 query_info: &GraphQLQueryInfo,
520 context: &ValidationContext,
521 start: Instant,
522 ) -> Result<ValidationResult, ValidationError> {
523 let security_analysis = self.graphql_validator.analyze_security(query_info);
524 let risk_level = security_analysis.assess_risk();
525
526 if security_analysis
527 .potential_issues
528 .iter()
529 .any(|i| i.is_critical())
530 {
531 let violations: Vec<PolicyViolation> = security_analysis
532 .potential_issues
533 .iter()
534 .filter(|i| i.is_critical())
535 .map(|i| {
536 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
537 })
538 .collect();
539
540 return Ok(ValidationResult::failure(
541 violations,
542 self.build_metadata(query_info, start.elapsed().as_millis() as u64),
543 ));
544 }
545
546 let explanation = self
547 .explanation_generator
548 .explain_graphql(query_info, &security_analysis);
549
550 let context_hash = context.context_hash();
551 let token = self.token_generator.generate(
552 query,
553 &context.user_id,
554 &context.session_id,
555 self.config.server_id(),
556 &context_hash,
557 risk_level,
558 self.config.token_ttl_seconds,
559 );
560
561 let token_string = token.encode().map_err(|e| {
562 ValidationError::InternalError(format!("Failed to encode token: {}", e))
563 })?;
564
565 let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
566 let mutation_name = query_info.operation_name.as_deref();
567 let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
568 let action = UnifiedAction::resolve(
569 inferred_action,
570 &self.config.action_tags,
571 query_info.operation_name.as_deref().unwrap_or(""),
572 );
573
574 let metadata = ValidationMetadata {
575 is_read_only: query_info.operation_type.is_read_only(),
576 estimated_rows: security_analysis.estimated_rows,
577 accessed_types: security_analysis.tables_accessed.iter().cloned().collect(),
578 accessed_fields: security_analysis.fields_accessed.iter().cloned().collect(),
579 has_aggregation: security_analysis.has_aggregation,
580 code_type: Some(self.graphql_validator.to_code_type(query_info)),
581 action: Some(action),
582 validation_time_ms: start.elapsed().as_millis() as u64,
583 };
584
585 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
586
587 for issue in &security_analysis.potential_issues {
588 if !issue.is_critical() {
589 result.warnings.push(issue.message.clone());
590 }
591 }
592
593 Ok(result)
594 }
595
596 fn build_metadata(
598 &self,
599 query_info: &GraphQLQueryInfo,
600 validation_time_ms: u64,
601 ) -> ValidationMetadata {
602 let operation_type_str = format!("{:?}", query_info.operation_type).to_lowercase();
603 let mutation_name = query_info.operation_name.as_deref();
604 let inferred_action = UnifiedAction::from_graphql(&operation_type_str, mutation_name);
605 let action = UnifiedAction::resolve(
606 inferred_action,
607 &self.config.action_tags,
608 query_info.operation_name.as_deref().unwrap_or(""),
609 );
610
611 ValidationMetadata {
612 is_read_only: query_info.operation_type.is_read_only(),
613 estimated_rows: None,
614 accessed_types: query_info.types_accessed.iter().cloned().collect(),
615 accessed_fields: query_info.fields_accessed.iter().cloned().collect(),
616 has_aggregation: false,
617 code_type: Some(self.graphql_validator.to_code_type(query_info)),
618 action: Some(action),
619 validation_time_ms,
620 }
621 }
622
623 #[cfg(feature = "openapi-code-mode")]
629 pub fn validate_javascript_code(
630 &self,
631 code: &str,
632 context: &ValidationContext,
633 ) -> Result<ValidationResult, ValidationError> {
634 let start = Instant::now();
635 let code_info = self.validate_js_preamble(code)?;
636 if let Some(failure) = self.check_js_config_authorization(&code_info, start) {
637 return Ok(failure);
638 }
639 self.complete_js_validation(code, &code_info, context, start)
640 }
641
642 #[cfg(feature = "openapi-code-mode")]
651 pub async fn validate_javascript_code_async(
652 &self,
653 code: &str,
654 context: &ValidationContext,
655 ) -> Result<ValidationResult, ValidationError> {
656 use crate::policy::types::ScriptEntity;
657
658 let start = Instant::now();
659 let code_info = self.validate_js_preamble(code)?;
660 if let Some(failure) = self.check_js_config_authorization(&code_info, start) {
661 return Ok(failure);
662 }
663
664 if let Some(ref evaluator) = self.policy_evaluator {
666 let sensitive_patterns: Vec<String> =
667 self.config.openapi_blocked_paths.iter().cloned().collect();
668 let registry_ref = if self.operation_registry.is_empty() {
669 None
670 } else {
671 Some(&self.operation_registry)
672 };
673 let script_entity =
674 ScriptEntity::from_javascript_info(&code_info, &sensitive_patterns, registry_ref);
675 let server_entity = self.config.to_openapi_server_entity();
676
677 let decision = evaluator
678 .evaluate_script(&script_entity, &server_entity)
679 .await
680 .map_err(|e| {
681 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
682 })?;
683
684 if !decision.allowed {
685 let violations = build_policy_violations(
686 &decision,
687 self.config.server_id(),
688 script_entity.action(),
689 "script",
690 );
691
692 return Ok(ValidationResult::failure(
693 violations,
694 self.build_js_metadata(&code_info, start.elapsed().as_millis() as u64),
695 ));
696 }
697 }
698
699 self.complete_js_validation(code, &code_info, context, start)
700 }
701
702 #[cfg(feature = "openapi-code-mode")]
704 fn validate_js_preamble(&self, code: &str) -> Result<JavaScriptCodeInfo, ValidationError> {
705 if !self.config.enabled {
706 return Err(ValidationError::ConfigError(
707 "Code Mode is not enabled for this server".into(),
708 ));
709 }
710
711 if code.len() > self.config.max_query_length {
712 return Err(ValidationError::SecurityError {
713 message: format!(
714 "Code length {} exceeds maximum {}",
715 code.len(),
716 self.config.max_query_length
717 ),
718 issue: crate::types::SecurityIssueType::HighComplexity,
719 });
720 }
721
722 self.javascript_validator.validate(code)
723 }
724
725 #[cfg(feature = "openapi-code-mode")]
731 fn check_js_config_authorization(
732 &self,
733 code_info: &JavaScriptCodeInfo,
734 start: Instant,
735 ) -> Option<ValidationResult> {
736 if code_info.is_read_only {
737 return None;
738 }
739
740 for method in &code_info.methods_used {
741 if !self.config.openapi_blocked_writes.is_empty()
742 && self.config.openapi_blocked_writes.contains(method)
743 {
744 return Some(ValidationResult::failure(
745 vec![PolicyViolation::new(
746 "code_mode",
747 "blocked_method",
748 &format!("HTTP method '{}' is blocked for this server", method),
749 )
750 .with_suggestion("This method is in the blocklist and cannot be used")],
751 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
752 ));
753 }
754 }
755
756 if !self.config.openapi_allowed_writes.is_empty() {
757 tracing::debug!(
758 target: "code_mode",
759 "Skipping method-level check - policy evaluator will check operation allowlist ({} entries)",
760 self.config.openapi_allowed_writes.len()
761 );
762 } else if !self.config.openapi_allow_writes {
763 return Some(ValidationResult::failure(
764 vec![PolicyViolation::new(
765 "code_mode",
766 "allow_mutations",
767 "Write HTTP methods (POST, PUT, DELETE, PATCH) are not enabled for this server",
768 )
769 .with_suggestion("Only read-only methods (GET, HEAD, OPTIONS) are allowed. Contact your administrator to enable write operations.")],
770 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
771 ));
772 }
773
774 None
775 }
776
777 #[cfg(feature = "openapi-code-mode")]
779 fn complete_js_validation(
780 &self,
781 code: &str,
782 code_info: &JavaScriptCodeInfo,
783 context: &ValidationContext,
784 start: Instant,
785 ) -> Result<ValidationResult, ValidationError> {
786 let security_analysis = self.javascript_validator.analyze_security(code_info);
787 let risk_level = security_analysis.assess_risk();
788
789 if security_analysis
790 .potential_issues
791 .iter()
792 .any(|i| i.is_critical())
793 {
794 let violations: Vec<PolicyViolation> = security_analysis
795 .potential_issues
796 .iter()
797 .filter(|i| i.is_critical())
798 .map(|i| {
799 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
800 })
801 .collect();
802
803 return Ok(ValidationResult::failure(
804 violations,
805 self.build_js_metadata(code_info, start.elapsed().as_millis() as u64),
806 ));
807 }
808
809 let explanation = self.generate_js_explanation(code_info, &security_analysis);
810
811 let context_hash = context.context_hash();
812 let token = self.token_generator.generate(
813 code,
814 &context.user_id,
815 &context.session_id,
816 self.config.server_id(),
817 &context_hash,
818 risk_level,
819 self.config.token_ttl_seconds,
820 );
821
822 let token_string = token.encode().map_err(|e| {
823 ValidationError::InternalError(format!("Failed to encode token: {}", e))
824 })?;
825
826 let metadata = self.build_js_metadata(code_info, start.elapsed().as_millis() as u64);
827
828 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
829
830 for issue in &security_analysis.potential_issues {
831 if !issue.is_critical() {
832 result.warnings.push(issue.message.clone());
833 }
834 }
835
836 Ok(result)
837 }
838
839 #[cfg(feature = "openapi-code-mode")]
841 fn build_js_metadata(
842 &self,
843 code_info: &JavaScriptCodeInfo,
844 validation_time_ms: u64,
845 ) -> ValidationMetadata {
846 let action = if !code_info.api_calls.is_empty() {
847 let mut max_action = UnifiedAction::Read;
848 for call in &code_info.api_calls {
849 let method_str = format!("{:?}", call.method);
850 let inferred = UnifiedAction::from_http_method(&method_str);
851 match (&max_action, &inferred) {
852 (UnifiedAction::Read, _) => max_action = inferred,
853 (UnifiedAction::Write, UnifiedAction::Delete | UnifiedAction::Admin) => {
854 max_action = inferred
855 },
856 (UnifiedAction::Delete, UnifiedAction::Admin) => max_action = inferred,
857 _ => {},
858 }
859 }
860 Some(max_action)
861 } else if code_info.is_read_only {
862 Some(UnifiedAction::Read)
863 } else {
864 Some(UnifiedAction::Write)
865 };
866
867 ValidationMetadata {
868 is_read_only: code_info.is_read_only,
869 estimated_rows: None,
870 accessed_types: code_info.endpoints_accessed.iter().cloned().collect(),
871 accessed_fields: code_info.methods_used.iter().cloned().collect(),
872 has_aggregation: false,
873 code_type: Some(self.javascript_validator.to_code_type(code_info)),
874 action,
875 validation_time_ms,
876 }
877 }
878
879 #[cfg(feature = "openapi-code-mode")]
881 fn generate_js_explanation(
882 &self,
883 code_info: &JavaScriptCodeInfo,
884 security_analysis: &crate::types::SecurityAnalysis,
885 ) -> String {
886 let mut parts = Vec::new();
887
888 if code_info.is_read_only {
889 parts.push("This code will perform read-only API requests.".to_string());
890 } else {
891 parts.push("This code will perform API requests that may modify data.".to_string());
892 }
893
894 if !code_info.api_calls.is_empty() {
895 let call_descriptions: Vec<String> = code_info
896 .api_calls
897 .iter()
898 .map(|call| format!("{:?} {}", call.method, call.path))
899 .collect();
900
901 if call_descriptions.len() <= 3 {
902 parts.push(format!("API calls: {}", call_descriptions.join(", ")));
903 } else {
904 parts.push(format!(
905 "API calls: {} and {} more",
906 call_descriptions[..2].join(", "),
907 call_descriptions.len() - 2
908 ));
909 }
910 }
911
912 if code_info.loop_count > 0 {
913 if code_info.all_loops_bounded {
914 parts.push(format!(
915 "Contains {} bounded loop(s).",
916 code_info.loop_count
917 ));
918 } else {
919 parts.push(format!(
920 "Contains {} loop(s) - ensure they are properly bounded.",
921 code_info.loop_count
922 ));
923 }
924 }
925
926 let risk = security_analysis.assess_risk();
927 parts.push(format!("Risk: {}", risk));
928
929 parts.join(" ")
930 }
931
932 #[cfg(feature = "sql-code-mode")]
936 pub fn validate_sql_query(
937 &self,
938 sql: &str,
939 context: &ValidationContext,
940 ) -> Result<ValidationResult, ValidationError> {
941 let start = Instant::now();
942 let info = self.validate_sql_preamble(sql)?;
943 if let Some(failure) = self.check_sql_config_authorization(&info, start) {
944 return Ok(failure);
945 }
946 self.complete_sql_validation(sql, &info, context, start)
947 }
948
949 #[cfg(feature = "sql-code-mode")]
958 pub async fn validate_sql_query_async(
959 &self,
960 sql: &str,
961 context: &ValidationContext,
962 ) -> Result<ValidationResult, ValidationError> {
963 use crate::policy::StatementEntity;
964
965 let start = Instant::now();
966 let info = self.validate_sql_preamble(sql)?;
967 if let Some(failure) = self.check_sql_config_authorization(&info, start) {
968 return Ok(failure);
969 }
970
971 if let Some(ref evaluator) = self.policy_evaluator {
972 let statement_entity = StatementEntity::from_sql_info(&info);
973 let server_entity = self.config.to_sql_server_entity();
974
975 let decision = evaluator
976 .evaluate_statement(&statement_entity, &server_entity)
977 .await
978 .map_err(|e| {
979 ValidationError::InternalError(format!("Policy evaluation error: {}", e))
980 })?;
981
982 if !decision.allowed {
983 let violations = build_policy_violations(
984 &decision,
985 self.config.server_id(),
986 statement_entity.action(),
987 "SQL statement",
988 );
989
990 return Ok(ValidationResult::failure(
991 violations,
992 self.build_sql_metadata(&info, start.elapsed().as_millis() as u64),
993 ));
994 }
995 } else {
996 warn_no_policy_configured();
997 }
998
999 self.complete_sql_validation(sql, &info, context, start)
1000 }
1001
1002 #[cfg(feature = "sql-code-mode")]
1004 fn validate_sql_preamble(
1005 &self,
1006 sql: &str,
1007 ) -> Result<crate::sql::SqlStatementInfo, ValidationError> {
1008 if !self.config.enabled {
1009 return Err(ValidationError::ConfigError(
1010 "Code Mode is not enabled for this server".into(),
1011 ));
1012 }
1013
1014 if sql.len() > self.config.max_query_length {
1015 return Err(ValidationError::SecurityError {
1016 message: format!(
1017 "SQL length {} exceeds maximum {}",
1018 sql.len(),
1019 self.config.max_query_length
1020 ),
1021 issue: crate::types::SecurityIssueType::HighComplexity,
1022 });
1023 }
1024
1025 let validator = crate::sql::SqlValidator::new();
1026 validator.validate(sql)
1027 }
1028
1029 #[cfg(feature = "sql-code-mode")]
1034 fn check_sql_config_authorization(
1035 &self,
1036 info: &crate::sql::SqlStatementInfo,
1037 start: Instant,
1038 ) -> Option<ValidationResult> {
1039 use crate::sql::SqlStatementType;
1040
1041 let stype = info.statement_type.as_str();
1042
1043 if self.config.sql_blocked_statements.contains(stype) {
1045 return Some(ValidationResult::failure(
1046 vec![PolicyViolation::new(
1047 "code_mode",
1048 "blocked_statement",
1049 format!("Statement type '{}' is blocked for this server", stype),
1050 )],
1051 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1052 ));
1053 }
1054
1055 if !self.config.sql_allowed_statements.is_empty()
1057 && !self.config.sql_allowed_statements.contains(stype)
1058 {
1059 return Some(ValidationResult::failure(
1060 vec![PolicyViolation::new(
1061 "code_mode",
1062 "statement_not_allowed",
1063 format!("Statement type '{}' is not in the allowlist", stype),
1064 )],
1065 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1066 ));
1067 }
1068
1069 match info.statement_type {
1071 SqlStatementType::Select => {
1072 if !self.config.sql_reads_enabled {
1073 return Some(ValidationResult::failure(
1074 vec![PolicyViolation::new(
1075 "code_mode",
1076 "reads_disabled",
1077 "SELECT statements are not enabled for this server",
1078 )],
1079 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1080 ));
1081 }
1082 if self.config.sql_require_limit && !info.has_limit {
1086 return Some(ValidationResult::failure(
1087 vec![PolicyViolation::new(
1088 "code_mode",
1089 "missing_limit",
1090 "SELECT statements must declare a LIMIT for this server",
1091 )
1092 .with_suggestion("Add a LIMIT clause (e.g. `LIMIT 100`).")],
1093 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1094 ));
1095 }
1096 },
1097 SqlStatementType::Insert | SqlStatementType::Update => {
1098 if !self.config.sql_allow_writes {
1099 return Some(ValidationResult::failure(
1100 vec![PolicyViolation::new(
1101 "code_mode",
1102 "writes_disabled",
1103 "INSERT/UPDATE statements are not enabled for this server",
1104 )
1105 .with_suggestion("Contact your administrator to enable sql_allow_writes.")],
1106 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1107 ));
1108 }
1109 if matches!(info.statement_type, SqlStatementType::Update)
1111 && self.config.sql_require_where_on_writes
1112 && !info.has_where
1113 {
1114 return Some(ValidationResult::failure(
1115 vec![PolicyViolation::new(
1116 "code_mode",
1117 "missing_where",
1118 format!("{} without WHERE clause is not allowed", info.verb),
1119 )],
1120 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1121 ));
1122 }
1123 },
1124 SqlStatementType::Delete => {
1125 if !self.config.sql_allow_deletes {
1126 return Some(ValidationResult::failure(
1127 vec![PolicyViolation::new(
1128 "code_mode",
1129 "deletes_disabled",
1130 "DELETE statements are not enabled for this server",
1131 )],
1132 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1133 ));
1134 }
1135 if self.config.sql_require_where_on_writes && !info.has_where {
1136 return Some(ValidationResult::failure(
1137 vec![PolicyViolation::new(
1138 "code_mode",
1139 "missing_where",
1140 "DELETE without WHERE clause is not allowed",
1141 )],
1142 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1143 ));
1144 }
1145 },
1146 SqlStatementType::Ddl => {
1147 if !self.config.sql_allow_ddl {
1148 return Some(ValidationResult::failure(
1149 vec![PolicyViolation::new(
1150 "code_mode",
1151 "ddl_disabled",
1152 "DDL (CREATE/ALTER/DROP/GRANT/REVOKE) is not enabled for this server",
1153 )],
1154 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1155 ));
1156 }
1157 },
1158 SqlStatementType::Other => {
1159 return Some(ValidationResult::failure(
1160 vec![PolicyViolation::new(
1161 "code_mode",
1162 "unsupported_statement",
1163 format!("Statement type '{}' is not supported", info.verb),
1164 )],
1165 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1166 ));
1167 },
1168 }
1169
1170 if !self.config.sql_blocked_tables.is_empty() {
1172 for table in &info.tables {
1173 if self.config.sql_blocked_tables.contains(table) {
1174 return Some(ValidationResult::failure(
1175 vec![PolicyViolation::new(
1176 "code_mode",
1177 "blocked_table",
1178 format!("Table '{}' is blocked for this server", table),
1179 )],
1180 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1181 ));
1182 }
1183 }
1184 }
1185
1186 if !self.config.sql_allowed_tables.is_empty() {
1188 for table in &info.tables {
1189 if !self.config.sql_allowed_tables.contains(table) {
1190 return Some(ValidationResult::failure(
1191 vec![PolicyViolation::new(
1192 "code_mode",
1193 "table_not_allowed",
1194 format!("Table '{}' is not in the allowlist", table),
1195 )],
1196 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1197 ));
1198 }
1199 }
1200 }
1201
1202 if !self.config.sql_blocked_columns.is_empty() {
1204 for col in &info.columns {
1205 if self.config.sql_blocked_columns.contains(col) {
1206 return Some(ValidationResult::failure(
1207 vec![PolicyViolation::new(
1208 "code_mode",
1209 "blocked_column",
1210 format!("Column '{}' is blocked for this server", col),
1211 )],
1212 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1213 ));
1214 }
1215 }
1216 }
1217
1218 if info.join_count > self.config.sql_max_joins {
1220 return Some(ValidationResult::failure(
1221 vec![PolicyViolation::new(
1222 "code_mode",
1223 "excessive_joins",
1224 format!(
1225 "Query has {} JOINs, exceeds limit of {}",
1226 info.join_count, self.config.sql_max_joins
1227 ),
1228 )],
1229 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1230 ));
1231 }
1232
1233 if info.estimated_rows > self.config.sql_max_rows {
1234 return Some(ValidationResult::failure(
1235 vec![PolicyViolation::new(
1236 "code_mode",
1237 "excessive_rows",
1238 format!(
1239 "Estimated rows {} exceeds limit of {}",
1240 info.estimated_rows, self.config.sql_max_rows
1241 ),
1242 )],
1243 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1244 ));
1245 }
1246
1247 None
1248 }
1249
1250 #[cfg(feature = "sql-code-mode")]
1252 fn complete_sql_validation(
1253 &self,
1254 sql: &str,
1255 info: &crate::sql::SqlStatementInfo,
1256 context: &ValidationContext,
1257 start: Instant,
1258 ) -> Result<ValidationResult, ValidationError> {
1259 let validator = crate::sql::SqlValidator::new();
1260 let security_analysis = validator.analyze_security(info);
1261 let risk_level = security_analysis.assess_risk();
1262
1263 if security_analysis
1264 .potential_issues
1265 .iter()
1266 .any(|i| i.is_critical())
1267 {
1268 let violations: Vec<PolicyViolation> = security_analysis
1269 .potential_issues
1270 .iter()
1271 .filter(|i| i.is_critical())
1272 .map(|i| {
1273 PolicyViolation::new("security", format!("{:?}", i.issue_type), &i.message)
1274 })
1275 .collect();
1276
1277 return Ok(ValidationResult::failure(
1278 violations,
1279 self.build_sql_metadata(info, start.elapsed().as_millis() as u64),
1280 ));
1281 }
1282
1283 let context_hash = context.context_hash();
1284 let token = self.token_generator.generate(
1285 sql,
1286 &context.user_id,
1287 &context.session_id,
1288 self.config.server_id(),
1289 &context_hash,
1290 risk_level,
1291 self.config.token_ttl_seconds,
1292 );
1293
1294 let token_string = token.encode().map_err(|e| {
1295 ValidationError::InternalError(format!("Failed to encode token: {}", e))
1296 })?;
1297
1298 let explanation = self.generate_sql_explanation(info, &security_analysis);
1299 let metadata = self.build_sql_metadata(info, start.elapsed().as_millis() as u64);
1300
1301 let mut result = ValidationResult::success(explanation, risk_level, token_string, metadata);
1302
1303 for issue in &security_analysis.potential_issues {
1304 if !issue.is_critical() {
1305 result.warnings.push(issue.message.clone());
1306 }
1307 }
1308
1309 Ok(result)
1310 }
1311
1312 #[cfg(feature = "sql-code-mode")]
1314 fn build_sql_metadata(
1315 &self,
1316 info: &crate::sql::SqlStatementInfo,
1317 validation_time_ms: u64,
1318 ) -> ValidationMetadata {
1319 let inferred = UnifiedAction::from_sql(info.statement_type.as_str());
1320 let action = UnifiedAction::resolve(inferred, &self.config.action_tags, &info.verb);
1321
1322 ValidationMetadata {
1323 is_read_only: info.statement_type.is_read_only(),
1324 estimated_rows: Some(info.estimated_rows),
1325 accessed_types: info.tables.iter().cloned().collect(),
1326 accessed_fields: info.columns.iter().cloned().collect(),
1327 has_aggregation: info.has_aggregation,
1328 code_type: Some(if info.statement_type.is_read_only() {
1329 crate::types::CodeType::SqlQuery
1330 } else {
1331 crate::types::CodeType::SqlMutation
1332 }),
1333 action: Some(action),
1334 validation_time_ms,
1335 }
1336 }
1337
1338 #[cfg(feature = "sql-code-mode")]
1340 fn generate_sql_explanation(
1341 &self,
1342 info: &crate::sql::SqlStatementInfo,
1343 security_analysis: &crate::types::SecurityAnalysis,
1344 ) -> String {
1345 let mut parts = Vec::new();
1346
1347 let verb_phrase = match info.statement_type.as_str() {
1348 "SELECT" => "This query reads data",
1349 "INSERT" => "This statement inserts rows",
1350 "UPDATE" => "This statement updates rows",
1351 "DELETE" => "This statement deletes rows",
1352 "DDL" => "This statement changes schema or permissions",
1353 _ => "This statement",
1354 };
1355
1356 let tables_phrase = if info.tables.is_empty() {
1357 String::new()
1358 } else {
1359 let mut ts: Vec<&String> = info.tables.iter().collect();
1360 ts.sort();
1361 format!(
1362 " in table(s): {}",
1363 ts.into_iter().cloned().collect::<Vec<_>>().join(", ")
1364 )
1365 };
1366
1367 parts.push(format!("{}{}.", verb_phrase, tables_phrase));
1368
1369 if info.has_where {
1370 parts.push("Filtered with WHERE clause.".to_string());
1371 }
1372 if info.has_limit {
1373 parts.push(format!("Limited to {} rows.", info.estimated_rows));
1374 }
1375 if info.join_count > 0 {
1376 parts.push(format!("Uses {} JOIN(s).", info.join_count));
1377 }
1378 if info.subquery_count > 0 {
1379 parts.push(format!("Contains {} subquer(ies).", info.subquery_count));
1380 }
1381
1382 let risk = security_analysis.assess_risk();
1383 parts.push(format!("Risk: {}", risk));
1384
1385 parts.join(" ")
1386 }
1387
1388 pub fn should_auto_approve(&self, result: &ValidationResult) -> bool {
1390 result.is_valid && self.config.should_auto_approve(result.risk_level)
1391 }
1392
1393 pub fn config(&self) -> &CodeModeConfig {
1395 &self.config
1396 }
1397
1398 pub fn token_generator(&self) -> &T {
1400 &self.token_generator
1401 }
1402}
1403
1404#[cfg(test)]
1405mod tests {
1406 use super::*;
1407 use crate::types::RiskLevel;
1408
1409 fn test_pipeline() -> ValidationPipeline {
1410 ValidationPipeline::new(CodeModeConfig::enabled(), b"test-secret-key!".to_vec()).unwrap()
1411 }
1412
1413 fn test_context() -> ValidationContext {
1414 ValidationContext::new("user-123", "session-456", "schema-hash", "perms-hash")
1415 }
1416
1417 #[test]
1418 fn test_simple_query_validation() {
1419 let pipeline = test_pipeline();
1420 let ctx = test_context();
1421
1422 let result = pipeline
1423 .validate_graphql_query("query { users { id name } }", &ctx)
1424 .unwrap();
1425
1426 assert!(result.is_valid);
1427 assert!(result.approval_token.is_some());
1428 assert_eq!(result.risk_level, RiskLevel::Low);
1429 assert!(result.explanation.contains("read"));
1430 }
1431
1432 #[test]
1433 fn test_mutation_blocked() {
1434 let mut config = CodeModeConfig::enabled();
1435 config.allow_mutations = false;
1436
1437 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1438 let ctx = test_context();
1439
1440 let result = pipeline
1441 .validate_graphql_query("mutation { createUser(name: \"test\") { id } }", &ctx)
1442 .unwrap();
1443
1444 assert!(!result.is_valid);
1445 assert!(result
1446 .violations
1447 .iter()
1448 .any(|v| v.rule == "allow_mutations"));
1449 }
1450
1451 #[test]
1452 fn test_disabled_code_mode() {
1453 let config = CodeModeConfig::default();
1454 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1455 let ctx = test_context();
1456
1457 let result = pipeline.validate_graphql_query("query { users { id } }", &ctx);
1458
1459 assert!(matches!(result, Err(ValidationError::ConfigError(_))));
1460 }
1461
1462 #[test]
1463 fn test_auto_approve_low_risk() {
1464 let pipeline = test_pipeline();
1465 let ctx = test_context();
1466
1467 let result = pipeline
1468 .validate_graphql_query("query { users { id } }", &ctx)
1469 .unwrap();
1470
1471 assert!(pipeline.should_auto_approve(&result));
1472 }
1473
1474 #[test]
1475 fn test_context_hash() {
1476 let ctx = test_context();
1477 let hash1 = ctx.context_hash();
1478
1479 let ctx2 =
1480 ValidationContext::new("user-123", "session-456", "different-schema", "perms-hash");
1481 let hash2 = ctx2.context_hash();
1482
1483 assert_ne!(hash1, hash2);
1484 }
1485
1486 #[test]
1487 fn test_blocked_query_rejected() {
1488 let mut config = CodeModeConfig::enabled();
1489 config.blocked_queries.insert("users".to_string());
1490
1491 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1492 let ctx = test_context();
1493
1494 let result = pipeline
1495 .validate_graphql_query("query { users { id } }", &ctx)
1496 .unwrap();
1497
1498 assert!(!result.is_valid);
1499 assert!(result.violations.iter().any(|v| v.rule == "blocked_query"));
1500 }
1501
1502 #[test]
1503 fn test_allowed_queries_enforced() {
1504 let mut config = CodeModeConfig::enabled();
1505 config.allowed_queries.insert("orders".to_string());
1506
1507 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1508 let ctx = test_context();
1509
1510 let result = pipeline
1512 .validate_graphql_query("query { users { id } }", &ctx)
1513 .unwrap();
1514
1515 assert!(!result.is_valid);
1516 assert!(result
1517 .violations
1518 .iter()
1519 .any(|v| v.rule == "query_not_allowed"));
1520 }
1521
1522 #[cfg(feature = "sql-code-mode")]
1527 mod sql_tests {
1528 use super::*;
1529
1530 fn sql_pipeline() -> ValidationPipeline {
1531 ValidationPipeline::new(CodeModeConfig::enabled(), b"test-secret-key!".to_vec())
1532 .unwrap()
1533 }
1534
1535 #[test]
1536 fn validates_select() {
1537 let pipeline = sql_pipeline();
1538 let ctx = test_context();
1539
1540 let result = pipeline
1541 .validate_sql_query("SELECT id, name FROM users LIMIT 10", &ctx)
1542 .unwrap();
1543
1544 assert!(result.is_valid);
1545 assert!(result.approval_token.is_some());
1546 }
1547
1548 #[test]
1549 fn rejects_insert_when_writes_disabled() {
1550 let pipeline = sql_pipeline();
1551 let ctx = test_context();
1552
1553 let result = pipeline
1554 .validate_sql_query("INSERT INTO users (id, name) VALUES (1, 'Alice')", &ctx)
1555 .unwrap();
1556
1557 assert!(!result.is_valid);
1558 assert!(result
1559 .violations
1560 .iter()
1561 .any(|v| v.rule == "writes_disabled"));
1562 }
1563
1564 #[test]
1565 fn permits_insert_when_writes_enabled() {
1566 let mut config = CodeModeConfig::enabled();
1567 config.sql_allow_writes = true;
1568 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1569 let ctx = test_context();
1570
1571 let result = pipeline
1572 .validate_sql_query("INSERT INTO users (id, name) VALUES (1, 'Alice')", &ctx)
1573 .unwrap();
1574
1575 assert!(result.is_valid);
1576 }
1577
1578 #[test]
1579 fn rejects_update_without_where_by_default() {
1580 let mut config = CodeModeConfig::enabled();
1581 config.sql_allow_writes = true;
1582 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1583 let ctx = test_context();
1584
1585 let result = pipeline
1586 .validate_sql_query("UPDATE users SET active = 0", &ctx)
1587 .unwrap();
1588
1589 assert!(!result.is_valid);
1590 assert!(result.violations.iter().any(|v| v.rule == "missing_where"));
1591 }
1592
1593 #[test]
1594 fn rejects_blocked_table() {
1595 let mut config = CodeModeConfig::enabled();
1596 config.sql_blocked_tables.insert("secrets".to_string());
1597 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1598 let ctx = test_context();
1599
1600 let result = pipeline
1601 .validate_sql_query("SELECT * FROM secrets LIMIT 10", &ctx)
1602 .unwrap();
1603
1604 assert!(!result.is_valid);
1605 assert!(result.violations.iter().any(|v| v.rule == "blocked_table"));
1606 }
1607
1608 #[test]
1609 fn rejects_non_allowlisted_table() {
1610 let mut config = CodeModeConfig::enabled();
1611 config.sql_allowed_tables.insert("users".to_string());
1612 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1613 let ctx = test_context();
1614
1615 let result = pipeline
1617 .validate_sql_query("SELECT id FROM orders LIMIT 10", &ctx)
1618 .unwrap();
1619
1620 assert!(!result.is_valid);
1621 assert!(result
1622 .violations
1623 .iter()
1624 .any(|v| v.rule == "table_not_allowed"));
1625 }
1626
1627 #[test]
1628 fn rejects_blocked_column() {
1629 let mut config = CodeModeConfig::enabled();
1630 config.sql_blocked_columns.insert("password".to_string());
1631 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1632 let ctx = test_context();
1633
1634 let result = pipeline
1635 .validate_sql_query("SELECT id, password FROM users LIMIT 10", &ctx)
1636 .unwrap();
1637
1638 assert!(!result.is_valid);
1639 assert!(result.violations.iter().any(|v| v.rule == "blocked_column"));
1640 }
1641
1642 #[test]
1643 fn rejects_ddl_by_default() {
1644 let pipeline = sql_pipeline();
1645 let ctx = test_context();
1646
1647 let result = pipeline
1648 .validate_sql_query("CREATE TABLE foo (id INT)", &ctx)
1649 .unwrap();
1650
1651 assert!(!result.is_valid);
1652 assert!(result.violations.iter().any(|v| v.rule == "ddl_disabled"));
1653 }
1654
1655 #[test]
1656 fn rejects_syntax_error() {
1657 let pipeline = sql_pipeline();
1658 let ctx = test_context();
1659
1660 let result = pipeline.validate_sql_query("SELEC id FRM users", &ctx);
1661
1662 assert!(matches!(result, Err(ValidationError::ParseError { .. })));
1663 }
1664
1665 fn require_limit_pipeline() -> ValidationPipeline {
1670 let mut config = CodeModeConfig::enabled();
1671 config.sql_require_limit = true;
1672 ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap()
1673 }
1674
1675 #[test]
1676 fn require_limit_rejects_select_without_limit() {
1677 let pipeline = require_limit_pipeline();
1678 let ctx = test_context();
1679
1680 let result = pipeline
1683 .validate_sql_query("SELECT * FROM Artist", &ctx)
1684 .unwrap();
1685
1686 assert!(!result.is_valid);
1687 assert!(result.violations.iter().any(|v| v.rule == "missing_limit"));
1688 }
1689
1690 #[test]
1691 fn require_limit_accepts_select_with_limit() {
1692 let pipeline = require_limit_pipeline();
1693 let ctx = test_context();
1694
1695 let result = pipeline
1696 .validate_sql_query("SELECT * FROM Artist LIMIT 25", &ctx)
1697 .unwrap();
1698
1699 assert!(result.is_valid);
1700 assert!(!result.violations.iter().any(|v| v.rule == "missing_limit"));
1701 }
1702
1703 #[test]
1704 fn require_limit_default_accepts_bare_select() {
1705 let pipeline = sql_pipeline();
1707 let ctx = test_context();
1708
1709 let result = pipeline
1710 .validate_sql_query("SELECT * FROM Artist", &ctx)
1711 .unwrap();
1712
1713 assert!(result.is_valid);
1714 assert!(!result.violations.iter().any(|v| v.rule == "missing_limit"));
1715 }
1716
1717 #[test]
1718 fn require_limit_does_not_affect_writes() {
1719 let mut config = CodeModeConfig::enabled();
1722 config.sql_require_limit = true;
1723 config.sql_allow_writes = true;
1724 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1725 let ctx = test_context();
1726
1727 let result = pipeline
1728 .validate_sql_query("INSERT INTO Artist (Name) VALUES ('AC/DC')", &ctx)
1729 .unwrap();
1730
1731 assert!(result.is_valid);
1732 assert!(!result.violations.iter().any(|v| v.rule == "missing_limit"));
1733 }
1734
1735 #[test]
1736 fn require_limit_serde_round_trip() {
1737 let without: CodeModeConfig =
1739 toml::from_str("enabled = true\n").expect("parse without require_limit");
1740 assert!(!without.sql_require_limit);
1741
1742 let with: CodeModeConfig = toml::from_str("enabled = true\nrequire_limit = true\n")
1743 .expect("parse with require_limit");
1744 assert!(with.sql_require_limit);
1745 }
1746
1747 struct FixedDenyEvaluator {
1748 errors: Vec<String>,
1749 }
1750
1751 #[async_trait::async_trait]
1752 impl PolicyEvaluator for FixedDenyEvaluator {
1753 async fn evaluate_operation(
1754 &self,
1755 _op: &crate::policy::OperationEntity,
1756 _cfg: &crate::policy::ServerConfigEntity,
1757 ) -> Result<crate::policy::AuthorizationDecision, crate::policy::PolicyEvaluationError>
1758 {
1759 Ok(crate::policy::AuthorizationDecision {
1760 allowed: false,
1761 determining_policies: vec![],
1762 errors: self.errors.clone(),
1763 })
1764 }
1765
1766 #[cfg(feature = "sql-code-mode")]
1767 async fn evaluate_statement(
1768 &self,
1769 _stmt: &crate::policy::StatementEntity,
1770 _server: &crate::policy::SqlServerEntity,
1771 ) -> Result<crate::policy::AuthorizationDecision, crate::policy::PolicyEvaluationError>
1772 {
1773 Ok(crate::policy::AuthorizationDecision {
1774 allowed: false,
1775 determining_policies: vec![],
1776 errors: self.errors.clone(),
1777 })
1778 }
1779
1780 fn name(&self) -> &str {
1781 "fixed-deny-test"
1782 }
1783 }
1784
1785 fn sql_pipeline_with_evaluator(evaluator: Arc<dyn PolicyEvaluator>) -> ValidationPipeline {
1786 let mut config = CodeModeConfig::enabled();
1787 config.server_id = Some("test-server".to_string());
1788 ValidationPipeline::with_policy_evaluator(
1789 config,
1790 b"test-secret-key!".to_vec(),
1791 evaluator,
1792 )
1793 .unwrap()
1794 }
1795
1796 #[tokio::test]
1797 async fn default_deny_produces_synthetic_violation() {
1798 let evaluator =
1799 Arc::new(FixedDenyEvaluator { errors: vec![] }) as Arc<dyn PolicyEvaluator>;
1800 let pipeline = sql_pipeline_with_evaluator(evaluator);
1801 let ctx = test_context();
1802
1803 let result = pipeline
1804 .validate_sql_query_async("SELECT id FROM users LIMIT 10", &ctx)
1805 .await
1806 .unwrap();
1807
1808 assert!(!result.is_valid);
1809 let default_deny = result
1810 .violations
1811 .iter()
1812 .find(|v| v.rule == "default_deny")
1813 .expect("expected a synthetic default_deny violation");
1814 assert!(default_deny.message.contains("test-server"));
1815 assert!(default_deny.message.contains("Read"));
1816 }
1817
1818 #[tokio::test]
1819 async fn policy_errors_flow_to_violations() {
1820 let evaluator = Arc::new(FixedDenyEvaluator {
1821 errors: vec!["schema validation: missing required attribute X".to_string()],
1822 }) as Arc<dyn PolicyEvaluator>;
1823 let pipeline = sql_pipeline_with_evaluator(evaluator);
1824 let ctx = test_context();
1825
1826 let result = pipeline
1827 .validate_sql_query_async("SELECT id FROM users LIMIT 10", &ctx)
1828 .await
1829 .unwrap();
1830
1831 assert!(!result.is_valid);
1832 let policy_error = result
1833 .violations
1834 .iter()
1835 .find(|v| v.rule == "evaluation_error")
1836 .expect("expected a policy_error violation");
1837 assert!(policy_error.message.contains("schema validation"));
1838 }
1839
1840 #[test]
1841 fn rejects_excessive_joins() {
1842 let mut config = CodeModeConfig::enabled();
1843 config.sql_max_joins = 1;
1844 let pipeline = ValidationPipeline::new(config, b"test-secret-key!".to_vec()).unwrap();
1845 let ctx = test_context();
1846
1847 let result = pipeline
1848 .validate_sql_query(
1849 "SELECT u.id FROM users u \
1850 JOIN orders o ON u.id = o.user_id \
1851 JOIN items i ON o.id = i.order_id LIMIT 10",
1852 &ctx,
1853 )
1854 .unwrap();
1855
1856 assert!(!result.is_valid);
1857 assert!(result
1858 .violations
1859 .iter()
1860 .any(|v| v.rule == "excessive_joins"));
1861 }
1862 }
1863}