use std::sync::OnceLock;
use crate::error::Result;
use regex::Regex;
use serde_derive::{Deserialize, Serialize};
use serde_regex;
use tracing::{debug, warn};
use crate::{
blast_radius::{self, BlastRadiusInfo, BlastScope},
config::{Challenge, Settings},
context::{self, RuntimeContext},
env::Environment,
policy::{self, MergedPolicy},
prompt::{AlternativeInfo, ChallengeResult, DisplayContext, Prompter},
};
const ALL_CHECKS: &str = include_str!(concat!(env!("OUT_DIR"), "/all-checks.yaml"));
#[derive(Debug, Default, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Severity {
Info,
Low,
#[default]
Medium,
High,
Critical,
}
impl std::fmt::Display for Severity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Info => write!(f, "INFO"),
Self::Low => write!(f, "LOW"),
Self::Medium => write!(f, "MEDIUM"),
Self::High => write!(f, "HIGH"),
Self::Critical => write!(f, "CRITICAL"),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "type", content = "value")]
pub enum Filter {
PathExists(usize),
NotContains(String),
Contains(String),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Check {
pub id: String,
#[serde(with = "serde_regex")]
pub test: Regex,
pub description: String,
pub from: String,
#[serde(default)]
pub challenge: Challenge,
#[serde(default)]
pub filters: Vec<Filter>,
#[serde(default)]
pub alternative: Option<String>,
#[serde(default)]
pub alternative_info: Option<String>,
#[serde(default)]
pub severity: Severity,
}
pub(crate) fn all_checks_cached() -> &'static [Check] {
static CHECKS: OnceLock<Vec<Check>> = OnceLock::new();
CHECKS.get_or_init(|| serde_yaml::from_str(ALL_CHECKS).expect("built-in checks are valid YAML"))
}
pub fn get_all() -> Result<Vec<Check>> {
Ok(all_checks_cached().to_vec())
}
pub fn load_custom_checks(checks_dir: &std::path::Path) -> Result<Vec<Check>> {
let mut custom_checks = Vec::new();
if !checks_dir.is_dir() {
return Ok(custom_checks);
}
let entries = std::fs::read_dir(checks_dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
let content = std::fs::read_to_string(&path)?;
let checks: Vec<Check> = serde_yaml::from_str(&content)?;
custom_checks.extend(checks);
}
}
Ok(custom_checks)
}
#[must_use]
pub fn validate_checks(checks: &[Check]) -> Vec<String> {
let mut warnings = Vec::new();
for check in checks {
let num_captures = check.test.captures_len(); for filter in &check.filters {
if let Filter::PathExists(group_idx) = filter {
if *group_idx >= num_captures {
warnings.push(format!(
"check {:?}: PathExists({}) references a capture group that \
does not exist (regex has {} groups including group 0)",
check.id, group_idx, num_captures
));
}
}
}
}
warnings
}
pub fn challenge_with_context(
settings: &Settings,
checks: &[&Check],
context: &RuntimeContext,
merged_policy: &MergedPolicy,
prompter: &dyn Prompter,
blast_radii: &[(String, BlastRadiusInfo)],
) -> Result<ChallengeResult> {
let mut descriptions: Vec<String> = Vec::new();
let mut alternatives: Vec<AlternativeInfo> = Vec::new();
let mut should_deny_command = false;
debug!(
"list of denied pattern ids {:?}",
settings.deny_patterns_ids
);
for check in checks {
if !descriptions.contains(&check.description) {
descriptions.push(check.description.clone());
}
if !should_deny_command && settings.deny_patterns_ids.contains(&check.id) {
should_deny_command = true;
}
if !should_deny_command && merged_policy.is_denied(&check.id) {
should_deny_command = true;
}
if let Some(ref alt) = check.alternative {
let already_has = alternatives.iter().any(|a| a.suggestion == *alt);
if !already_has {
alternatives.push(AlternativeInfo {
suggestion: alt.clone(),
explanation: check.alternative_info.clone(),
});
}
}
}
let base_challenge = settings.challenge;
let mut effective = base_challenge;
let max_severity = checks.iter().map(|c| c.severity).max();
if let Some(sev) = max_severity {
if let Some(sev_floor) = settings.severity_escalation.challenge_for_severity(sev) {
effective = max_challenge(effective, sev_floor);
}
}
for check in checks {
if let Some(&group_floor) = settings.group_escalation.get(&check.from) {
effective = max_challenge(effective, group_floor);
}
}
for check in checks {
if let Some(&check_floor) = settings.check_escalation.get(&check.id) {
effective = max_challenge(effective, check_floor);
}
}
effective = max_challenge(
effective,
context::escalate_challenge(
&base_challenge,
context.risk_level,
&settings.context.escalation,
),
);
for check in checks {
let policy_effective = merged_policy.effective_challenge(&check.id, &effective);
effective = max_challenge(effective, policy_effective);
}
let escalation_note = if effective == base_challenge {
None
} else {
Some(format!("{base_challenge} -> {effective}"))
};
let severity_label = max_severity.map(|s| format!("{s}"));
let blast_radius_label = blast_radii
.iter()
.max_by_key(|(_, br)| br.scope)
.map(|(_, br)| format!("[{}] — {}", br.scope, br.description));
let display = DisplayContext {
is_denied: should_deny_command,
descriptions,
alternatives,
context_labels: context.labels.clone(),
effective_challenge: effective,
escalation_note,
severity_label,
blast_radius_label,
};
Ok(prompter.run_challenge(&display))
}
#[must_use]
pub fn run_check_on_command<'a>(checks: &'a [Check], command: &str) -> Vec<&'a Check> {
run_check_on_command_with_env(checks, command, &crate::env::RealEnvironment)
}
#[must_use]
pub fn run_check_on_command_with_env<'a>(
checks: &'a [Check],
command: &str,
env: &dyn Environment,
) -> Vec<&'a Check> {
checks
.iter()
.filter(|v| v.test.is_match(command))
.filter(|v| check_custom_filter_with_env(v, command, env))
.collect()
}
fn check_custom_filter_with_env(check: &Check, command: &str, env: &dyn Environment) -> bool {
if check.filters.is_empty() {
return true;
}
let caps = check.test.captures(command);
for filter in &check.filters {
debug!("filter information: command {command:?} filter: {filter:?}");
let keep = match filter {
Filter::PathExists(group_idx) => {
let file_path = caps
.as_ref()
.and_then(|c| c.get(*group_idx))
.map_or_else(
|| {
warn!(
"check {:?}: PathExists references capture group {} which does not exist in regex",
check.id, group_idx
);
""
},
|m| m.as_str(),
);
if file_path.is_empty() {
false
} else {
filter_path_exists_with_env(file_path, env)
}
}
Filter::NotContains(ref s) => !command.contains(s.as_str()),
Filter::Contains(ref s) => command.contains(s.as_str()),
};
if !keep {
return false;
}
}
true
}
fn filter_path_exists_with_env(file_path: &str, env: &dyn Environment) -> bool {
use std::borrow::Cow;
let trimmed = file_path.trim();
let file_path: Cow<'_, str> = if trimmed.starts_with('~') {
match env.home_dir() {
Some(home) => Cow::Owned(trimmed.replacen('~', &home.display().to_string(), 1)),
None => return true,
}
} else {
Cow::Borrowed(trimmed)
};
if file_path.contains('*') {
return true;
}
let full_path = match env.current_dir() {
Ok(cwd) => cwd.join(&*file_path),
Err(err) => {
debug!("could not get current dir. err: {err:?}");
return true;
}
};
debug!("check if {} path exists", full_path.display());
env.path_exists(&full_path)
}
pub(crate) fn max_challenge(a: Challenge, b: Challenge) -> Challenge {
a.stricter(b)
}
fn dedup_check_matches<'a>(primary: Vec<&'a Check>, secondary: Vec<&'a Check>) -> Vec<&'a Check> {
let mut seen_ids = std::collections::HashSet::new();
let mut result = Vec::with_capacity(primary.len() + secondary.len());
for m in primary.into_iter().chain(secondary) {
if seen_ids.insert(m.id.as_str()) {
result.push(m);
}
}
result
}
#[derive(Debug, Clone)]
pub struct PipelineResult {
pub stripped_command: String,
pub command_parts: Vec<String>,
pub active_matches: Vec<Check>,
pub skipped_matches: Vec<Check>,
pub context: RuntimeContext,
pub relevant_context: RuntimeContext,
pub max_severity: Severity,
pub is_denied: bool,
pub alternatives: Vec<AlternativeInfo>,
pub merged_policy: MergedPolicy,
pub blast_radii: Vec<(String, BlastRadiusInfo)>,
pub max_blast_scope: Option<BlastScope>,
}
#[allow(clippy::too_many_lines)]
pub fn analyze_command(
command: &str,
settings: &Settings,
checks: &[Check],
env: &dyn Environment,
strip_quotes_re: &Regex,
) -> Result<PipelineResult> {
let stripped = strip_quotes_re.replace_all(command, "").into_owned();
let command_parts = split_command(&stripped);
debug!("analyze_command: parts={command_parts:?}");
let segment_matches: Vec<&Check> = command_parts
.iter()
.flat_map(|c| run_check_on_command_with_env(checks, c, env))
.collect();
let matches = if command_parts.len() > 1 {
let full_matches = run_check_on_command_with_env(checks, &stripped, env);
dedup_check_matches(segment_matches, full_matches)
} else {
segment_matches
};
debug!("analyze_command: {} matches found", matches.len());
let cwd = env.current_dir().unwrap_or_default();
let project_policy = policy::discover(env, &cwd);
let policy_has_extra_checks = project_policy
.as_ref()
.is_some_and(|pp| !pp.checks.is_empty());
if matches.is_empty() && !policy_has_extra_checks {
debug!("analyze_command: no matches and no policy checks, skipping context detection");
return Ok(PipelineResult {
stripped_command: stripped,
command_parts,
active_matches: Vec::new(),
skipped_matches: Vec::new(),
context: RuntimeContext::default(),
relevant_context: RuntimeContext::default(),
max_severity: Severity::default(),
is_denied: false,
alternatives: Vec::new(),
merged_policy: MergedPolicy::default(),
blast_radii: Vec::new(),
max_blast_scope: None,
});
}
let runtime_context = context::detect(env, &settings.context);
let merged_policy = if let Some(ref pp) = project_policy {
policy::merge_into_settings(settings, pp, runtime_context.git_branch.as_deref())
} else {
MergedPolicy::default()
};
let mut all_matches: Vec<&Check> = matches;
if !merged_policy.extra_checks.is_empty() {
let extra_segment: Vec<&Check> = command_parts
.iter()
.flat_map(|c| run_check_on_command_with_env(&merged_policy.extra_checks, c, env))
.collect();
let extra = if command_parts.len() > 1 {
let extra_full =
run_check_on_command_with_env(&merged_policy.extra_checks, &stripped, env);
dedup_check_matches(extra_segment, extra_full)
} else {
extra_segment
};
all_matches.extend(extra);
}
let (active_refs, skipped_refs): (Vec<&Check>, Vec<&Check>) =
if let Some(min_sev) = settings.min_severity {
all_matches.into_iter().partition(|c| c.severity >= min_sev)
} else {
(all_matches, Vec::new())
};
let active_matches: Vec<Check> = active_refs.into_iter().cloned().collect();
let skipped_matches: Vec<Check> = skipped_refs.into_iter().cloned().collect();
let max_severity = active_matches
.iter()
.chain(skipped_matches.iter())
.map(|c| c.severity)
.max()
.unwrap_or_default();
let is_denied = active_matches
.iter()
.any(|c| settings.deny_patterns_ids.contains(&c.id) || merged_policy.is_denied(&c.id));
let mut alternatives = Vec::new();
for check in &active_matches {
if let Some(ref alt) = check.alternative {
let already_has = alternatives
.iter()
.any(|a: &AlternativeInfo| a.suggestion == *alt);
if !already_has {
alternatives.push(AlternativeInfo {
suggestion: alt.clone(),
explanation: check.alternative_info.clone(),
});
}
}
}
let blast_radii = if settings.blast_radius {
blast_radius::compute_for_matches(&active_matches, &command_parts, &stripped, env)
} else {
Vec::new()
};
let max_blast_scope = blast_radii.iter().map(|(_, br)| br.scope).max();
let matched_groups: std::collections::HashSet<&str> =
active_matches.iter().map(|c| c.from.as_str()).collect();
let relevant_context = runtime_context.filter_for_groups(&matched_groups, &settings.context);
Ok(PipelineResult {
stripped_command: stripped,
command_parts,
active_matches,
skipped_matches,
context: runtime_context,
relevant_context,
max_severity,
is_denied,
alternatives,
merged_policy,
blast_radii,
max_blast_scope,
})
}
#[must_use]
pub fn split_command(command: &str) -> Vec<String> {
let mut parts = Vec::new();
let bytes = command.as_bytes();
let len = bytes.len();
let mut i = 0;
let mut start = i;
let mut in_single_quote = false;
let mut in_double_quote = false;
while i < len {
let b = bytes[i];
if b == b'\'' && !in_double_quote {
in_single_quote = !in_single_quote;
i += 1;
} else if b == b'"' && !in_single_quote {
in_double_quote = !in_double_quote;
i += 1;
} else if !in_single_quote && !in_double_quote {
let (is_split, advance) = if i + 1 < len
&& ((b == b'&' && bytes[i + 1] == b'&') || (b == b'|' && bytes[i + 1] == b'|'))
{
(true, 2)
} else if b == b'|' || b == b';' {
(true, 1)
} else {
(false, 1)
};
if is_split {
parts.push(command[start..i].to_string());
i += advance;
start = i;
} else {
i += 1;
}
} else {
i += 1;
}
}
if start < len {
parts.push(command[start..].to_string());
}
parts
}
#[cfg(test)]
mod test_checks {
use insta::{assert_debug_snapshot, with_settings};
use super::*;
const CHECKS: &str = r###"
- from: test-1
test: test-(1)
enable: true
description: ""
id: ""
- from: test-2
test: test-(1|2)
enable: true
description: ""
id: ""
- from: test-disabled
test: test-disabled
enable: true
description: ""
id: ""
"###;
#[test]
fn can_run_check_on_command() {
let checks: Vec<Check> = serde_yaml::from_str(CHECKS).unwrap();
with_settings!({filters => vec![
(r#"(?s)test:\s*Regex\(\s*"([^"]+)",?\s*\)"#, "test: $1"),
]}, {
assert_debug_snapshot!(run_check_on_command(&checks, "test-1"));
assert_debug_snapshot!(run_check_on_command(&checks, "unknown command"));
});
}
#[test]
fn can_check_custom_filter_with_file_exists() {
use std::collections::HashSet;
let filters = vec![Filter::PathExists(1)];
let check = Check {
id: "id".to_string(),
test: Regex::new("(?:^|[^>])>([^>].*)").unwrap(),
description: "some description".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters,
alternative: None,
alternative_info: None,
severity: Severity::default(),
};
let env_no_file = crate::env::MockEnvironment {
cwd: "/mock".into(),
..Default::default()
};
let command = "cat 'write message' > /mock/app/message.txt";
assert_debug_snapshot!(check_custom_filter_with_env(&check, command, &env_no_file));
let mut existing = HashSet::new();
existing.insert(std::path::PathBuf::from("/mock/app/message.txt"));
let env_with_file = crate::env::MockEnvironment {
cwd: "/mock".into(),
existing_paths: existing,
..Default::default()
};
assert_debug_snapshot!(check_custom_filter_with_env(
&check,
command,
&env_with_file
));
}
#[test]
fn can_check_custom_filter_with_str_contains() {
let filters = vec![Filter::NotContains("--dry-run".to_string())];
let check = Check {
id: "id".to_string(),
test: Regex::new("(delete)").unwrap(),
description: "some description".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters,
alternative: None,
alternative_info: None,
severity: Severity::default(),
};
let env = crate::env::MockEnvironment::default();
assert_debug_snapshot!(check_custom_filter_with_env(&check, "delete", &env));
assert_debug_snapshot!(check_custom_filter_with_env(
&check,
"delete --dry-run",
&env
));
}
#[test]
fn can_check_custom_filter_with_contains() {
let filters = vec![Filter::Contains("--force".to_string())];
let check = Check {
id: "id".to_string(),
test: Regex::new("(push)").unwrap(),
description: "some description".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters,
alternative: None,
alternative_info: None,
severity: Severity::default(),
};
let env = crate::env::MockEnvironment::default();
assert!(!check_custom_filter_with_env(
&check,
"git push origin main",
&env
));
assert!(check_custom_filter_with_env(
&check,
"git push --force origin main",
&env
));
}
#[test]
fn can_check_custom_filter_with_missing_capture_group() {
let filters = vec![Filter::PathExists(5)];
let check = Check {
id: "test-missing-group".to_string(),
test: Regex::new("rm (.*)").unwrap(),
description: "some description".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters,
alternative: None,
alternative_info: None,
severity: Severity::default(),
};
let env = crate::env::MockEnvironment::default();
assert!(!check_custom_filter_with_env(&check, "rm /tmp/foo", &env));
}
#[test]
fn can_check_multiple_filters_all_must_pass() {
let filters = vec![
Filter::NotContains("--dry-run".to_string()),
Filter::NotContains("--check".to_string()),
];
let check = Check {
id: "id".to_string(),
test: Regex::new("(delete)").unwrap(),
description: "some description".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters,
alternative: None,
alternative_info: None,
severity: Severity::default(),
};
let env = crate::env::MockEnvironment::default();
assert!(check_custom_filter_with_env(&check, "delete", &env));
assert!(!check_custom_filter_with_env(
&check,
"delete --dry-run",
&env
));
assert!(!check_custom_filter_with_env(
&check,
"delete --check",
&env
));
assert!(!check_custom_filter_with_env(
&check,
"delete --dry-run --check",
&env
));
}
#[test]
fn can_get_all_checks() {
assert_debug_snapshot!(get_all().is_ok());
}
#[test]
fn test_split_command_and_and() {
let parts = split_command("ls && rm -rf /");
assert_eq!(parts, vec!["ls ", " rm -rf /"]);
}
#[test]
fn test_split_command_pipe() {
let parts = split_command("cat foo | grep bar");
assert_eq!(parts, vec!["cat foo ", " grep bar"]);
}
#[test]
fn test_split_command_mixed() {
let parts = split_command("a && b || c; d");
assert_eq!(parts, vec!["a ", " b ", " c", " d"]);
}
#[test]
fn test_split_command_single() {
let parts = split_command("git push -f");
assert_eq!(parts, vec!["git push -f"]);
}
#[test]
fn test_split_command_double_quoted_operator() {
let parts = split_command(r#"echo "hello && world""#);
assert_eq!(parts, vec![r#"echo "hello && world""#]);
}
#[test]
fn test_split_command_single_quoted_pipe() {
let parts = split_command("echo 'a | b'");
assert_eq!(parts, vec!["echo 'a | b'"]);
}
#[test]
fn test_split_command_quoted_then_operator() {
let parts = split_command(r#"echo "safe" && rm -rf /"#);
assert_eq!(parts, vec![r#"echo "safe" "#, " rm -rf /"]);
}
#[test]
fn test_all_builtin_checks_pass_validation() {
let checks = get_all().unwrap();
let warnings = validate_checks(&checks);
assert!(
warnings.is_empty(),
"Built-in checks have validation warnings:\n{}",
warnings.join("\n")
);
}
#[test]
fn test_validate_catches_bad_capture_group() {
let checks = vec![Check {
id: "bad".to_string(),
test: Regex::new("rm (.*)").unwrap(),
description: "test".to_string(),
from: "test".to_string(),
challenge: Challenge::default(),
filters: vec![Filter::PathExists(5)], alternative: None,
alternative_info: None,
severity: Severity::default(),
}];
let warnings = validate_checks(&checks);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].contains("PathExists(5)"));
}
}