1use once_cell::sync::Lazy;
71use regex::Regex;
72use serde::{Deserialize, Serialize};
73use std::collections::HashSet;
74
75use crate::identity::Requirement as IdentityRequirement;
76use crate::predicates::{CommandPredicate, SensitivePath};
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
83#[serde(rename_all = "PascalCase")]
84pub enum Severity {
85 Low = 1,
86 Medium = 2,
87 High = 3,
88 Critical = 4,
89}
90
91impl Severity {
92 pub fn rank(self) -> u8 {
93 self as u8
94 }
95
96 pub fn as_str(self) -> &'static str {
97 match self {
98 Severity::Critical => "Critical",
99 Severity::High => "High",
100 Severity::Medium => "Medium",
101 Severity::Low => "Low",
102 }
103 }
104
105 pub fn bumped(self) -> Self {
107 match self {
108 Severity::Low => Severity::Medium,
109 Severity::Medium => Severity::High,
110 Severity::High => Severity::Critical,
111 Severity::Critical => Severity::Critical,
112 }
113 }
114
115 pub fn demoted(self) -> Self {
117 match self {
118 Severity::Critical => Severity::High,
119 Severity::High => Severity::Medium,
120 Severity::Medium => Severity::Low,
121 Severity::Low => Severity::Low,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
127pub enum Decision {
128 Allow,
129 Warn {
130 rule_id: String,
131 severity: Severity,
132 banner: String,
133 safer_alternative: Option<String>,
134 },
135 Approval {
136 rule_id: String,
137 severity: Severity,
138 reason: String,
139 safer_alternative: Option<String>,
140 contributing_rules: Vec<String>,
141 },
142 IdentityVerification {
148 rule_id: String,
149 severity: Severity,
150 reason: String,
151 safer_alternative: Option<String>,
152 contributing_rules: Vec<String>,
153 requirement: IdentityRequirement,
154 },
155 Block {
156 rule_id: String,
157 severity: Severity,
158 reason: String,
159 safer_alternative: Option<String>,
160 contributing_rules: Vec<String>,
161 },
162}
163
164impl Decision {
165 pub fn is_blocking(&self) -> bool {
166 matches!(
167 self,
168 Decision::Block { .. } | Decision::Approval { .. } | Decision::IdentityVerification { .. }
169 )
170 }
171
172 pub fn label(&self) -> &'static str {
173 match self {
174 Decision::Allow => "allow",
175 Decision::Warn { .. } => "warn",
176 Decision::Approval { .. } => "approval",
177 Decision::IdentityVerification { .. } => "identity_verification",
178 Decision::Block { .. } => "block",
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, Default)]
188pub struct Adjustments {
189 pub workspace_is_prod: bool,
190 pub fingerprint_recently_denied: bool,
191 pub fingerprint_repeatedly_approved: bool,
192 pub burst_in_progress: bool,
193}
194
195#[derive(Debug, Deserialize)]
200pub struct Root {
201 pub shieldset: Shieldset,
202}
203
204#[derive(Debug, Deserialize)]
205pub struct Shieldset {
206 #[serde(default)]
207 pub version: u32,
208 #[serde(default)]
209 pub policy: Policy,
210 #[serde(default)]
211 pub rules: Vec<YamlRule>,
212}
213
214#[derive(Debug, Default, Deserialize, Clone)]
215pub struct Policy {
216 #[serde(default)]
217 pub workspace_probe: WorkspaceProbeCfg,
218 #[serde(default)]
219 pub decision_memory: DecisionMemoryCfg,
220 #[serde(default)]
221 pub burst_detector: BurstDetectorCfg,
222 #[serde(default)]
223 pub composite_scoring: CompositeScoringCfg,
224 #[serde(default)]
225 pub supply_chain: SupplyChainCfg,
226}
227
228#[derive(Debug, Deserialize, Clone)]
232pub struct SupplyChainCfg {
233 #[serde(default = "default_true")]
236 pub pinning: bool,
237 #[serde(default = "default_on_changed")]
240 pub on_changed_tool: String,
241 #[serde(default = "default_on_new")]
244 pub on_new_tool: String,
245}
246impl Default for SupplyChainCfg {
247 fn default() -> Self {
248 Self {
249 pinning: true,
250 on_changed_tool: default_on_changed(),
251 on_new_tool: default_on_new(),
252 }
253 }
254}
255fn default_on_changed() -> String { "block".into() }
256fn default_on_new() -> String { "warn".into() }
257
258#[derive(Debug, Deserialize, Clone)]
259pub struct WorkspaceProbeCfg {
260 #[serde(default = "default_true")]
261 pub enabled: bool,
262 #[serde(default = "default_prod_signals")]
263 pub prod_signals: Vec<String>,
264 #[serde(default = "one")]
265 pub severity_bump: u8,
266}
267impl Default for WorkspaceProbeCfg {
268 fn default() -> Self {
269 Self {
270 enabled: true,
271 prod_signals: default_prod_signals(),
272 severity_bump: 1,
273 }
274 }
275}
276
277#[derive(Debug, Deserialize, Clone)]
278pub struct DecisionMemoryCfg {
279 #[serde(default = "default_true")]
280 pub enabled: bool,
281 #[serde(default = "default_three")]
282 pub demote_after_approvals: u32,
283 #[serde(default = "default_seven")]
284 pub escalate_on_deny_days: u32,
285}
286impl Default for DecisionMemoryCfg {
287 fn default() -> Self {
288 Self {
289 enabled: true,
290 demote_after_approvals: 3,
291 escalate_on_deny_days: 7,
292 }
293 }
294}
295
296#[derive(Debug, Deserialize, Clone)]
297pub struct BurstDetectorCfg {
298 #[serde(default = "default_true")]
299 pub enabled: bool,
300 #[serde(default = "default_300")]
301 pub window_seconds: u32,
302 #[serde(default = "default_five")]
303 pub threshold: u32,
304}
305impl Default for BurstDetectorCfg {
306 fn default() -> Self {
307 Self {
308 enabled: true,
309 window_seconds: 300,
310 threshold: 5,
311 }
312 }
313}
314
315#[derive(Debug, Deserialize, Clone)]
316pub struct CompositeScoringCfg {
317 #[serde(default = "default_true")]
318 pub enabled: bool,
319 #[serde(default)]
320 pub thresholds: CompositeThresholds,
321}
322impl Default for CompositeScoringCfg {
323 fn default() -> Self {
324 Self {
325 enabled: true,
326 thresholds: CompositeThresholds::default(),
327 }
328 }
329}
330
331#[derive(Debug, Deserialize, Clone)]
332pub struct CompositeThresholds {
333 #[serde(default = "default_two")]
334 pub medium: u32,
335 #[serde(default = "default_five")]
336 pub high: u32,
337 #[serde(default = "default_nine")]
338 pub critical: u32,
339}
340impl Default for CompositeThresholds {
341 fn default() -> Self {
342 Self { medium: 2, high: 5, critical: 9 }
343 }
344}
345
346fn default_true() -> bool { true }
347fn one() -> u8 { 1 }
348fn default_three() -> u32 { 3 }
349fn default_seven() -> u32 { 7 }
350fn default_300() -> u32 { 300 }
351fn default_five() -> u32 { 5 }
352fn default_two() -> u32 { 2 }
353fn default_nine() -> u32 { 9 }
354fn default_prod_signals() -> Vec<String> {
355 vec![
356 ".env.production".into(),
357 "prod/".into(),
358 "Procfile".into(),
359 ".terraform/terraform.tfstate".into(),
360 "kubeconfig".into(),
361 ".kube/config".into(),
362 "production.yml".into(),
363 "production.yaml".into(),
364 ]
365}
366
367#[derive(Debug, Deserialize)]
368pub struct YamlRule {
369 pub id: String,
370 pub severity: Severity,
371 #[serde(default)]
372 pub points: Option<u32>,
373 #[serde(rename = "where")]
374 pub where_: String,
375 #[serde(default)]
376 pub r#match: Option<YamlMatch>,
377 #[serde(default)]
378 pub reason: String,
379 #[serde(default)]
380 pub safer_alternative: Option<String>,
381 #[serde(default)]
387 pub identity: Option<IdentityRequirement>,
388}
389
390#[derive(Debug, Default, Deserialize)]
391pub struct YamlMatch {
392 #[serde(default)]
393 pub tool: Option<Vec<String>>,
394 #[serde(default)]
395 pub any_param_matches: Vec<String>,
396 #[serde(default)]
397 pub sql_matches: Vec<String>,
398 #[serde(default)]
399 pub sql_predicates: Vec<String>,
400 #[serde(default)]
401 pub text_matches: Vec<String>,
402 #[serde(default)]
403 pub command_predicates: Vec<String>,
404 #[serde(default)]
405 pub sensitive_paths: Vec<String>,
406}
407
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum Scope {
414 ToolCall,
415 LlmResponse,
416 ToolDescription,
421 ToolResult,
424}
425
426#[derive(Debug, Clone, Copy)]
427pub enum SqlPredicate { UnscopedUpdate, UnscopedDelete }
428
429#[derive(Debug)]
430struct Match {
431 tool_whitelist: Option<HashSet<String>>,
432 any_param_re: Vec<Regex>,
433 sql_re: Vec<Regex>,
434 sql_predicates: Vec<SqlPredicate>,
435 text_re: Vec<Regex>,
436 command_predicates: Vec<CommandPredicate>,
437 sensitive_paths: Vec<SensitivePath>,
438}
439
440#[derive(Debug)]
441pub struct CompiledRule {
442 pub id: String,
443 pub severity: Severity,
444 pub points: u32,
445 pub scope: Scope,
446 pub reason: String,
447 pub safer_alternative: Option<String>,
448 pub identity: Option<IdentityRequirement>,
452 matcher: Option<Match>,
453}
454
455impl CompiledRule {
456 pub fn matches_tool_call(&self, tool: &str, params: &serde_json::Value) -> bool {
457 let m = match &self.matcher { Some(m) => m, None => return false };
458 if let Some(allow) = &m.tool_whitelist {
459 if !allow.contains(tool) {
460 return false;
461 }
462 }
463
464 if !m.sql_re.is_empty() || !m.sql_predicates.is_empty() {
466 let sqls = extract_sql(params);
467 for s in &sqls {
468 for re in &m.sql_re {
469 if re.is_match(s) { return true; }
470 }
471 for p in &m.sql_predicates {
472 if matches_sql_predicate(*p, s) { return true; }
473 }
474 }
475 }
476
477 if !m.any_param_re.is_empty() {
479 let mut hit = false;
480 walk_strings(params, &mut |s| {
481 if hit { return; }
482 for re in &m.any_param_re {
483 if re.is_match(s) { hit = true; return; }
484 }
485 });
486 if hit { return true; }
487 }
488
489 if !m.command_predicates.is_empty() {
493 let mut hit = false;
494 walk_strings(params, &mut |s| {
495 if hit { return; }
496 for p in &m.command_predicates {
497 if p.matches(s) { hit = true; return; }
498 }
499 });
500 if hit { return true; }
501 }
502
503 if !m.sensitive_paths.is_empty() {
514 let mut hit = false;
515 walk_strings(params, &mut |s| {
516 if hit { return; }
517 if !crate::predicates::command_writes(s) { return; }
518 for sp in &m.sensitive_paths {
519 if sp.touches(s) { hit = true; return; }
520 }
521 });
522 if hit { return true; }
523 }
524
525 false
526 }
527
528 pub fn matches_text(&self, text: &str) -> bool {
529 let m = match &self.matcher { Some(m) => m, None => return false };
530 for re in &m.text_re {
531 if re.is_match(text) { return true; }
532 }
533 false
534 }
535
536 pub fn tool_whitelist(&self) -> Option<&HashSet<String>> {
540 self.matcher.as_ref().and_then(|m| m.tool_whitelist.as_ref())
541 }
542}
543
544#[derive(Debug)]
549pub struct Engine {
550 pub rules: Vec<CompiledRule>,
551 pub policy: Policy,
552}
553
554#[derive(Debug, Clone)]
556pub struct Evaluation {
557 pub matches: Vec<MatchInfo>,
558 pub composite_points: u32,
559 pub raw_severity: Severity,
560 pub composite_severity: Severity,
561 pub final_severity: Severity,
562 pub adjustments_applied: Vec<&'static str>,
563}
564
565#[derive(Debug, Clone)]
566pub struct MatchInfo {
567 pub rule_id: String,
568 pub severity: Severity,
569 pub points: u32,
570 pub reason: String,
571 pub safer_alternative: Option<String>,
572 pub identity: Option<IdentityRequirement>,
576}
577
578impl Engine {
579 pub fn from_yaml(raw: &str) -> anyhow::Result<Self> {
582 let root: Root = serde_yaml::from_str(raw)?;
583 let policy = root.shieldset.policy.clone();
584 let rules = Self::compile_yaml_rules(root.shieldset.rules)?;
585 Ok(Engine { rules, policy })
586 }
587
588 pub fn extend_from_yaml(&mut self, raw: &str) -> anyhow::Result<()> {
594 let root: Root = serde_yaml::from_str(raw)?;
595 let extra = Self::compile_yaml_rules(root.shieldset.rules)?;
596 for r in &extra {
597 if self.rules.iter().any(|e| e.id == r.id) {
598 anyhow::bail!("rule pack defines duplicate rule id '{}'", r.id);
599 }
600 }
601 self.rules.extend(extra);
602 Ok(())
603 }
604
605 fn compile_yaml_rules(yaml_rules: Vec<YamlRule>) -> anyhow::Result<Vec<CompiledRule>> {
606 let mut rules = Vec::with_capacity(yaml_rules.len());
607 for y in yaml_rules {
608 let scope = match y.where_.as_str() {
609 "tool_call" => Scope::ToolCall,
610 "llm_response" => Scope::LlmResponse,
611 "tool_description" => Scope::ToolDescription,
612 "tool_result" => Scope::ToolResult,
613 other => anyhow::bail!("rule '{}' has unknown where '{}'", y.id, other),
614 };
615 let matcher = if let Some(m) = y.r#match {
616 let mut sql_preds = Vec::new();
617 for n in m.sql_predicates {
618 let p = match n.to_ascii_lowercase().as_str() {
619 "unscoped_update" => SqlPredicate::UnscopedUpdate,
620 "unscoped_delete" => SqlPredicate::UnscopedDelete,
621 other => anyhow::bail!("rule '{}'.sql_predicates: unknown '{}'", y.id, other),
622 };
623 sql_preds.push(p);
624 }
625 let mut cmd_preds = Vec::new();
626 for n in m.command_predicates {
627 let p = CommandPredicate::parse(&n).ok_or_else(|| {
628 anyhow::anyhow!("rule '{}'.command_predicates: unknown '{}'", y.id, n)
629 })?;
630 cmd_preds.push(p);
631 }
632 let mut paths = Vec::new();
633 for n in m.sensitive_paths {
634 paths.push(SensitivePath::compile(&n)?);
635 }
636 Some(Match {
637 tool_whitelist: m.tool.map(|v| v.into_iter().collect()),
638 any_param_re: compile_regexes(&y.id, "any_param_matches", m.any_param_matches)?,
639 sql_re: compile_regexes(&y.id, "sql_matches", m.sql_matches)?,
640 sql_predicates: sql_preds,
641 text_re: compile_regexes(&y.id, "text_matches", m.text_matches)?,
642 command_predicates: cmd_preds,
643 sensitive_paths: paths,
644 })
645 } else {
646 None
647 };
648 let points = y.points.unwrap_or(y.severity.rank() as u32);
651 rules.push(CompiledRule {
652 id: y.id,
653 severity: y.severity,
654 points,
655 scope,
656 reason: y.reason,
657 safer_alternative: y.safer_alternative,
658 identity: y.identity,
659 matcher,
660 });
661 }
662 Ok(rules)
663 }
664
665 pub fn builtin_default() -> Self {
668 let yaml = include_str!("../config/shieldset.yaml");
669 Self::from_yaml(yaml).expect("bundled shieldset.yaml must parse")
670 }
671
672 pub fn evaluate(&self, tool: &str, params: &serde_json::Value, adj: Adjustments) -> Evaluation {
676 let mut matches = Vec::new();
677 let mut composite_points = 0u32;
678 for r in self.rules.iter().filter(|r| r.scope == Scope::ToolCall) {
679 if r.matches_tool_call(tool, params) {
680 composite_points = composite_points.saturating_add(r.points);
681 matches.push(MatchInfo {
682 rule_id: r.id.clone(),
683 severity: r.severity,
684 points: r.points,
685 reason: r.reason.clone(),
686 safer_alternative: r.safer_alternative.clone(),
687 identity: r.identity.clone(),
688 });
689 }
690 }
691 self.resolve(matches, composite_points, adj)
692 }
693
694 pub fn evaluate_text(&self, text: &str, adj: Adjustments) -> Evaluation {
696 self.evaluate_scoped_text(Scope::LlmResponse, None, text, adj)
697 }
698
699 pub fn evaluate_scoped_text(
704 &self,
705 scope: Scope,
706 tool: Option<&str>,
707 text: &str,
708 adj: Adjustments,
709 ) -> Evaluation {
710 let mut matches = Vec::new();
711 let mut composite_points = 0u32;
712 for r in self.rules.iter().filter(|r| r.scope == scope) {
713 if let (Some(t), Some(allow)) = (tool, r.tool_whitelist()) {
714 if !allow.contains(t) {
715 continue;
716 }
717 }
718 if r.matches_text(text) {
719 composite_points = composite_points.saturating_add(r.points);
720 matches.push(MatchInfo {
721 rule_id: r.id.clone(),
722 severity: r.severity,
723 points: r.points,
724 reason: r.reason.clone(),
725 safer_alternative: r.safer_alternative.clone(),
726 identity: r.identity.clone(),
727 });
728 }
729 }
730 self.resolve(matches, composite_points, adj)
731 }
732
733 fn resolve(&self, matches: Vec<MatchInfo>, composite_points: u32, adj: Adjustments) -> Evaluation {
734 let raw_severity = matches
735 .iter()
736 .map(|m| m.severity)
737 .max()
738 .unwrap_or(Severity::Low);
739
740 let composite_severity = if self.policy.composite_scoring.enabled {
741 severity_from_points(composite_points, &self.policy.composite_scoring.thresholds)
742 } else {
743 Severity::Low
744 };
745
746 let mut final_severity = raw_severity.max(composite_severity);
747 let mut adjustments_applied = Vec::new();
748
749 if adj.workspace_is_prod && !matches.is_empty() {
750 final_severity = final_severity.bumped();
751 adjustments_applied.push("workspace_is_prod");
752 }
753 if adj.fingerprint_recently_denied && !matches.is_empty() {
754 final_severity = final_severity.bumped();
755 adjustments_applied.push("fingerprint_recently_denied");
756 }
757 if adj.burst_in_progress && !matches.is_empty() {
758 final_severity = final_severity.bumped();
759 adjustments_applied.push("burst_in_progress");
760 }
761 if adj.fingerprint_repeatedly_approved
764 && !matches.is_empty()
765 && !adj.workspace_is_prod
766 && !adj.fingerprint_recently_denied
767 && !adj.burst_in_progress
768 {
769 final_severity = final_severity.demoted();
770 adjustments_applied.push("fingerprint_repeatedly_approved");
771 }
772
773 Evaluation {
774 matches,
775 composite_points,
776 raw_severity,
777 composite_severity,
778 final_severity,
779 adjustments_applied,
780 }
781 }
782}
783
784pub fn decide(eval: &Evaluation) -> Decision {
788 if eval.matches.is_empty() {
789 return Decision::Allow;
790 }
791 let primary = eval
792 .matches
793 .iter()
794 .max_by(|a, b| {
795 a.severity.cmp(&b.severity)
796 .then(a.points.cmp(&b.points))
797 .then(b.rule_id.cmp(&a.rule_id))
798 })
799 .expect("non-empty");
800
801 let contributing: Vec<String> = eval
802 .matches
803 .iter()
804 .filter(|m| m.rule_id != primary.rule_id)
805 .map(|m| m.rule_id.clone())
806 .collect();
807
808 match eval.final_severity {
809 Severity::Critical => {
810 if let Some(req) = primary.identity.clone() {
815 Decision::IdentityVerification {
816 rule_id: primary.rule_id.clone(),
817 severity: eval.final_severity,
818 reason: primary.reason.clone(),
819 safer_alternative: primary.safer_alternative.clone(),
820 contributing_rules: contributing,
821 requirement: req,
822 }
823 } else {
824 Decision::Block {
825 rule_id: primary.rule_id.clone(),
826 severity: eval.final_severity,
827 reason: primary.reason.clone(),
828 safer_alternative: primary.safer_alternative.clone(),
829 contributing_rules: contributing,
830 }
831 }
832 }
833 Severity::High => {
834 if let Some(req) = primary.identity.clone() {
835 Decision::IdentityVerification {
836 rule_id: primary.rule_id.clone(),
837 severity: eval.final_severity,
838 reason: primary.reason.clone(),
839 safer_alternative: primary.safer_alternative.clone(),
840 contributing_rules: contributing,
841 requirement: req,
842 }
843 } else {
844 Decision::Approval {
845 rule_id: primary.rule_id.clone(),
846 severity: eval.final_severity,
847 reason: primary.reason.clone(),
848 safer_alternative: primary.safer_alternative.clone(),
849 contributing_rules: contributing,
850 }
851 }
852 }
853 Severity::Medium => Decision::Warn {
854 rule_id: primary.rule_id.clone(),
855 severity: eval.final_severity,
856 banner: primary.reason.clone(),
857 safer_alternative: primary.safer_alternative.clone(),
858 },
859 Severity::Low => Decision::Allow,
860 }
861}
862
863fn severity_from_points(points: u32, t: &CompositeThresholds) -> Severity {
864 if points >= t.critical { Severity::Critical }
865 else if points >= t.high { Severity::High }
866 else if points >= t.medium { Severity::Medium }
867 else { Severity::Low }
868}
869
870fn compile_regexes(rule_id: &str, field: &str, ps: Vec<String>) -> anyhow::Result<Vec<Regex>> {
875 let mut out = Vec::with_capacity(ps.len());
876 for p in ps {
877 out.push(Regex::new(&p).map_err(|e| anyhow::anyhow!("rule '{}'.{}: bad regex '{}': {}", rule_id, field, p, e))?);
878 }
879 Ok(out)
880}
881
882const SQL_KEYS: &[&str] = &["query", "sql", "statement", "command", "stmt", "ddl", "dml"];
883
884fn extract_sql(v: &serde_json::Value) -> Vec<String> {
885 let mut out = Vec::new();
886 walk_sql(v, &mut out);
887 out
888}
889
890fn walk_sql(v: &serde_json::Value, out: &mut Vec<String>) {
891 match v {
892 serde_json::Value::Object(map) => {
893 for (k, val) in map {
894 if SQL_KEYS.iter().any(|sk| sk.eq_ignore_ascii_case(k)) {
895 if let Some(s) = val.as_str() { out.push(s.to_string()); }
896 }
897 walk_sql(val, out);
898 }
899 }
900 serde_json::Value::Array(arr) => {
901 for item in arr { walk_sql(item, out); }
902 }
903 _ => {}
904 }
905}
906
907pub(crate) fn walk_strings<F: FnMut(&str)>(v: &serde_json::Value, f: &mut F) {
908 match v {
909 serde_json::Value::String(s) => f(s),
910 serde_json::Value::Array(arr) => for item in arr { walk_strings(item, f); },
911 serde_json::Value::Object(map) => for (_, val) in map { walk_strings(val, f); },
912 _ => {}
913 }
914}
915
916static UPDATE_HEAD: Lazy<Regex> = Lazy::new(|| {
917 Regex::new(r"(?i)\bUPDATE\s+[A-Za-z_][A-Za-z0-9_\.]*\s+SET\b").expect("static")
918});
919static DELETE_HEAD: Lazy<Regex> = Lazy::new(|| {
920 Regex::new(r"(?i)\bDELETE\s+FROM\s+[A-Za-z_][A-Za-z0-9_\.]*").expect("static")
921});
922static WHERE_CLAUSE: Lazy<Regex> = Lazy::new(|| {
923 Regex::new(r"(?i)\bWHERE\b").expect("static")
924});
925
926fn matches_sql_predicate(p: SqlPredicate, sql: &str) -> bool {
927 for frag in sql.split(';') {
928 let f = frag.trim();
929 if f.is_empty() { continue; }
930 match p {
931 SqlPredicate::UnscopedUpdate => {
932 if !UPDATE_HEAD.is_match(f) { continue; }
933 if !WHERE_CLAUSE.is_match(f) { return true; }
935 if where_is_tautological_for_update(f) { return true; }
945 }
946 SqlPredicate::UnscopedDelete => {
947 if DELETE_HEAD.is_match(f) && !WHERE_CLAUSE.is_match(f) { return true; }
948 }
949 }
950 }
951 false
952}
953
954static SET_AND_WHERE_RE: Lazy<Regex> = Lazy::new(|| {
988 Regex::new(r"(?is)\bSET\b\s+(.+?)\s+\bWHERE\b\s+(.+?)(?:\s+\b(?:LIMIT|RETURNING|ORDER\s+BY|GROUP\s+BY)\b.*)?$")
991 .expect("static")
992});
993
994fn where_is_tautological_for_update(sql: &str) -> bool {
995 let caps = match SET_AND_WHERE_RE.captures(sql) {
996 Some(c) => c,
997 None => return false,
998 };
999 let set_part = match caps.get(1) { Some(m) => m.as_str(), None => return false };
1000 let where_part = match caps.get(2) { Some(m) => m.as_str(), None => return false };
1001
1002 let set_pairs = parse_set_pairs(set_part);
1003 if set_pairs.is_empty() { return false; }
1004
1005 let conjuncts = split_where_on_and(where_part);
1006 if conjuncts.is_empty() { return false; }
1007
1008 for conjunct in &conjuncts {
1009 let trimmed = conjunct.trim_matches(|c: char| c.is_whitespace() || c == '(' || c == ')');
1010 if trimmed.is_empty() { continue; }
1011 let mut matched = false;
1012 for (col, val) in &set_pairs {
1013 if predicate_is_tautological(col, val, trimmed) {
1014 matched = true;
1015 break;
1016 }
1017 }
1018 if !matched { return false; }
1019 }
1020 true
1021}
1022
1023fn parse_set_pairs(set_part: &str) -> Vec<(String, String)> {
1028 let mut out = Vec::new();
1029 for raw in set_part.split(',') {
1030 let mut halves = raw.splitn(2, '=');
1031 let col = match halves.next() { Some(c) => c.trim(), None => continue };
1032 let val = match halves.next() { Some(v) => v.trim(), None => continue };
1033 if col.is_empty() || val.is_empty() { continue; }
1034 let col_norm = col.trim_matches(|c: char| c == '"' || c == '`').to_string();
1035 let val_norm = val.trim_matches(|c: char| c == '\'' || c == '"').to_string();
1036 out.push((col_norm, val_norm));
1037 }
1038 out
1039}
1040
1041fn split_where_on_and(where_part: &str) -> Vec<&str> {
1046 static AND_SPLIT: Lazy<Regex> = Lazy::new(|| {
1047 Regex::new(r"(?i)\s+AND\s+").expect("static")
1048 });
1049 AND_SPLIT.split(where_part).collect()
1050}
1051
1052fn predicate_is_tautological(col: &str, set_val: &str, predicate: &str) -> bool {
1053 let col_esc = regex::escape(col);
1054 let set_val_lower = set_val.to_ascii_lowercase();
1055 let val_esc = regex::escape(set_val);
1056 let q = r#"['"]?"#;
1061
1062 if regex_match(
1064 &format!(r"(?i)^\s*{}\s*(?:!=|<>)\s*{}{}{}\s*$", col_esc, q, val_esc, q),
1065 predicate,
1066 ) {
1067 return true;
1068 }
1069
1070 if regex_match(
1072 &format!(r"(?i)^\s*{}\s+IS\s+(?:NOT|DISTINCT\s+FROM)\s+{}{}{}\s*$", col_esc, q, val_esc, q),
1073 predicate,
1074 ) {
1075 return true;
1076 }
1077
1078 if is_bool_literal(&set_val_lower) {
1082 let opposite_pat = bool_opposite_regex_alt(&set_val_lower);
1083 if regex_match(
1084 &format!(r"(?i)^\s*{}\s*=\s*{}(?:{}){}\s*$", col_esc, q, opposite_pat, q),
1085 predicate,
1086 ) {
1087 return true;
1088 }
1089 }
1090
1091 if set_val_lower == "true" || set_val_lower == "t" || set_val_lower == "1" {
1095 if regex_match(
1096 &format!(r"(?i)^\s*{}\s+IS\s+NULL\s*$", col_esc),
1097 predicate,
1098 ) {
1099 return true;
1100 }
1101 if regex_match(&format!(r"(?i)^\s*NOT\s+{}\s*$", col_esc), predicate) {
1104 return true;
1105 }
1106 if regex_match(
1108 &format!(r"(?i)^\s*{}\s+IS\s+NOT\s+TRUE\s*$", col_esc),
1109 predicate,
1110 ) {
1111 return true;
1112 }
1113 }
1114
1115 false
1116}
1117
1118fn is_bool_literal(s: &str) -> bool {
1119 matches!(s, "true" | "false" | "t" | "f" | "1" | "0")
1120}
1121
1122fn bool_opposite_regex_alt(lit: &str) -> &'static str {
1127 match lit {
1128 "true" | "t" | "1" => "false|f|0",
1129 "false" | "f" | "0" => "true|t|1",
1130 _ => "",
1131 }
1132}
1133
1134fn regex_match(pattern: &str, haystack: &str) -> bool {
1135 Regex::new(pattern)
1136 .map(|re| re.is_match(haystack))
1137 .unwrap_or(false)
1138}
1139
1140pub fn fingerprint(rule_id: &str, params: &serde_json::Value) -> String {
1149 use sha2::{Digest, Sha256};
1150 let mut h = Sha256::new();
1151 h.update(rule_id.as_bytes());
1152 h.update(b"\x00");
1153 if let Ok(s) = serde_json::to_string(params) {
1158 h.update(s.as_bytes());
1159 }
1160 let out = h.finalize();
1161 let mut hex = String::with_capacity(16);
1162 for b in &out[..8] {
1163 hex.push_str(&format!("{:02x}", b));
1164 }
1165 hex
1166}
1167
1168#[cfg(test)]
1173mod tests {
1174 use super::*;
1175 use serde_json::json;
1176
1177 fn engine() -> Engine { Engine::builtin_default() }
1178
1179 #[test]
1180 fn bundled_default_loads_with_many_rules() {
1181 let e = engine();
1182 assert!(e.rules.len() >= 30, "expected >= 30 default rules, got {}", e.rules.len());
1183 }
1184
1185 #[test]
1186 fn severity_ord_is_monotonic() {
1187 assert!(Severity::Critical > Severity::High);
1188 assert!(Severity::High > Severity::Medium);
1189 assert!(Severity::Medium > Severity::Low);
1190 assert_eq!(Severity::Critical.bumped(), Severity::Critical);
1191 assert_eq!(Severity::Low.demoted(), Severity::Low);
1192 assert_eq!(Severity::Medium.bumped(), Severity::High);
1193 assert_eq!(Severity::High.demoted(), Severity::Medium);
1194 }
1195
1196 #[test]
1197 fn drop_database_blocked() {
1198 let e = engine();
1199 let p = json!({"arguments": {"query": "DROP DATABASE prod;"}});
1200 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1201 assert!(ev.matches.iter().any(|m| m.rule_id == "sql.drop_database"));
1202 match decide(&ev) {
1203 Decision::Block { rule_id, .. } => assert_eq!(rule_id, "sql.drop_database"),
1204 other => panic!("expected Block, got {}", other.label()),
1205 }
1206 }
1207
1208 #[test]
1209 fn unscoped_update_approval() {
1210 let e = engine();
1211 let p = json!({"arguments": {"query": "UPDATE users SET banned = true"}});
1212 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1213 match decide(&ev) {
1214 Decision::Approval { rule_id, .. } => assert_eq!(rule_id, "sql.unscoped_update"),
1215 other => panic!("expected Approval, got {}", other.label()),
1216 }
1217 }
1218
1219 #[test]
1220 fn tautological_where_email_verified_boolean_opposite() {
1221 let e = engine();
1225 let p = json!({"arguments": {"query":
1226 "UPDATE users SET email_verified = TRUE WHERE email_verified = FALSE"}});
1227 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1228 match decide(&ev) {
1229 Decision::Approval { rule_id, .. } => assert_eq!(rule_id, "sql.unscoped_update"),
1230 other => panic!("expected Approval on tautological WHERE, got {}", other.label()),
1231 }
1232 }
1233
1234 #[test]
1235 fn tautological_where_inequality_fires() {
1236 let e = engine();
1237 let p = json!({"arguments": {"query":
1238 "UPDATE users SET status = 'active' WHERE status != 'active'"}});
1239 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1240 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1241 "expected Approval on `WHERE col != X` tautology");
1242 }
1243
1244 #[test]
1245 fn tautological_where_ne_operator_fires() {
1246 let e = engine();
1247 let p = json!({"arguments": {"query":
1248 "UPDATE users SET status = 'active' WHERE status <> 'active'"}});
1249 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1250 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1251 "expected Approval on `WHERE col <> X` tautology");
1252 }
1253
1254 #[test]
1255 fn tautological_where_is_null_with_set_true_fires() {
1256 let e = engine();
1257 let p = json!({"arguments": {"query":
1258 "UPDATE users SET verified = TRUE WHERE verified IS NULL"}});
1259 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1260 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1261 "expected Approval on `WHERE col IS NULL` + `SET col = TRUE` tautology");
1262 }
1263
1264 #[test]
1265 fn tautological_where_not_col_fires() {
1266 let e = engine();
1267 let p = json!({"arguments": {"query":
1268 "UPDATE users SET banned = TRUE WHERE NOT banned"}});
1269 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1270 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1271 "expected Approval on `WHERE NOT col` + `SET col = TRUE` tautology");
1272 }
1273
1274 #[test]
1275 fn tautological_where_is_not_true_fires() {
1276 let e = engine();
1277 let p = json!({"arguments": {"query":
1278 "UPDATE users SET email_verified = TRUE WHERE email_verified IS NOT TRUE"}});
1279 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1280 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1281 "expected Approval on `WHERE col IS NOT TRUE` + `SET col = TRUE` tautology");
1282 }
1283
1284 #[test]
1285 fn tautological_where_handles_1_0_spellings() {
1286 let e = engine();
1288 let p = json!({"arguments": {"query":
1289 "UPDATE users SET email_verified = 1 WHERE email_verified = 0"}});
1290 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1291 assert!(matches!(decide(&ev), Decision::Approval { .. }),
1292 "expected Approval on 1/0 boolean opposites");
1293 }
1294
1295 #[test]
1296 fn real_scope_narrowing_with_and_does_not_fire() {
1297 let e = engine();
1301 let p = json!({"arguments": {"query":
1302 "UPDATE users SET email_verified = TRUE WHERE email_verified = FALSE AND created_at > NOW() - INTERVAL '7 days'"}});
1303 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1304 assert!(matches!(decide(&ev), Decision::Allow { .. } | Decision::Warn { .. }),
1305 "expected Allow/Warn on real time-window scope; got {}", decide(&ev).label());
1306 }
1307
1308 #[test]
1309 fn scoped_update_by_id_does_not_fire() {
1310 let e = engine();
1312 let p = json!({"arguments": {"query":
1313 "UPDATE users SET email_verified = TRUE WHERE id = 7"}});
1314 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1315 assert!(matches!(decide(&ev), Decision::Allow { .. } | Decision::Warn { .. }),
1316 "expected Allow/Warn on scoped UPDATE by id; got {}", decide(&ev).label());
1317 }
1318
1319 #[test]
1320 fn scoped_update_allow() {
1321 let e = engine();
1322 let p = json!({"arguments": {"query": "UPDATE users SET banned = true WHERE id = 7"}});
1323 let ev = e.evaluate("execute_sql", &p, Adjustments::default());
1324 assert!(matches!(decide(&ev), Decision::Allow));
1325 }
1326
1327 #[test]
1328 fn workspace_prod_bumps_severity() {
1329 let e = engine();
1330 let p = json!({"arguments": {"query": "GRANT ALL ON foo TO bar"}});
1332 let mut adj = Adjustments::default();
1333 adj.workspace_is_prod = true;
1334 let ev = e.evaluate("execute_sql", &p, adj);
1335 match decide(&ev) {
1336 Decision::Approval { .. } => {},
1337 other => panic!("expected Approval from prod bump, got {}", other.label()),
1338 }
1339 }
1340
1341 #[test]
1342 fn repeated_approval_demotes() {
1343 let e = engine();
1344 let p = json!({"arguments": {"query": "GRANT ALL ON foo TO bar"}});
1345 let mut adj = Adjustments::default();
1346 adj.fingerprint_repeatedly_approved = true;
1347 let ev = e.evaluate("execute_sql", &p, adj);
1348 assert!(matches!(decide(&ev), Decision::Allow));
1350 }
1351
1352 #[test]
1353 fn deny_history_escalates() {
1354 let e = engine();
1355 let p = json!({"arguments": {"query": "GRANT ALL ON foo TO bar"}});
1356 let mut adj = Adjustments::default();
1357 adj.fingerprint_recently_denied = true;
1358 let ev = e.evaluate("execute_sql", &p, adj);
1359 match decide(&ev) {
1360 Decision::Approval { .. } => {},
1361 other => panic!("expected Approval from deny escalation, got {}", other.label()),
1362 }
1363 }
1364
1365 #[test]
1366 fn composite_scoring_promotes_weak_signals() {
1367 let e = engine();
1369 let p = json!({"arguments": {
1370 "command": "git branch -D feature/legacy",
1371 "query": "GRANT ALL ON foo TO bar"
1372 }});
1373 let ev = e.evaluate("run_terminal", &p, Adjustments::default());
1374 assert!(ev.matches.len() >= 1);
1377 assert!(ev.composite_points >= ev.matches[0].points);
1378 }
1379
1380 #[test]
1381 fn fingerprint_is_stable_for_same_input() {
1382 let p = json!({"arguments": {"query": "DROP DATABASE prod"}});
1383 let a = fingerprint("sql.drop_database", &p);
1384 let b = fingerprint("sql.drop_database", &p);
1385 assert_eq!(a, b);
1386 assert_eq!(a.len(), 16);
1387 }
1388
1389 #[test]
1390 fn fingerprint_differs_per_rule() {
1391 let p = json!({"arguments": {"query": "DROP DATABASE prod"}});
1392 let a = fingerprint("sql.drop_database", &p);
1393 let b = fingerprint("sql.drop_table_or_schema", &p);
1394 assert_ne!(a, b);
1395 }
1396}