use std::{process::Command, sync::Arc};
use anyhow::{Context, Result};
use crate::{config::hooks::HookCondition, git::GitRepo};
pub struct ConditionEvaluator {
repo: Option<Arc<GitRepo>>,
}
impl Default for ConditionEvaluator {
fn default() -> Self {
Self::new()
}
}
impl ConditionEvaluator {
pub fn with_repo(repo: Arc<GitRepo>) -> Self {
Self { repo: Some(repo) }
}
pub fn new() -> Self {
Self {
repo: GitRepo::discover().ok().map(Arc::new),
}
}
pub fn evaluate(&self, condition: &HookCondition) -> Result<bool> {
match condition {
HookCondition::Bool(value) => Ok(*value),
HookCondition::Array(conditions) => {
for cond in conditions {
if self.evaluate_single_condition(cond)? {
return Ok(true);
}
}
Ok(false)
}
}
}
fn evaluate_single_condition(&self, condition: &str) -> Result<bool> {
let parts: Vec<&str> = condition.splitn(2, ':').collect();
if parts.len() != 2 {
return match condition {
"true" | "always" => Ok(true),
"false" | "never" => Ok(false),
_ => Ok(false), };
}
let (cond_type, cond_value) = (parts[0], parts[1]);
match cond_type {
"branch" => self.check_branch_condition(cond_value),
"ref" => self.check_ref_condition(cond_value),
"files" => self.check_files_condition(cond_value),
"merge" => self.check_merge_condition(),
"rebase" => self.check_rebase_condition(),
"ci" => self.check_ci_condition(),
_ => Ok(false), }
}
fn check_branch_condition(&self, pattern: &str) -> Result<bool> {
let Some(repo) = &self.repo else {
return Ok(false);
};
let current_branch = repo
.get_current_branch()
.context("Failed to get current branch")?;
let glob = globset::Glob::new(pattern).context("Invalid branch pattern")?;
let matcher = glob.compile_matcher();
Ok(matcher.is_match(¤t_branch))
}
fn check_ref_condition(&self, pattern: &str) -> Result<bool> {
let Some(_repo) = &self.repo else {
return Ok(false);
};
let output = Command::new("git")
.args(["symbolic-ref", "-q", "HEAD"])
.output();
let current_ref = if let Ok(output) = output {
if output.status.success() {
String::from_utf8_lossy(&output.stdout).trim().to_string()
} else {
let commit_output = Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.context("Failed to get current commit")?;
String::from_utf8_lossy(&commit_output.stdout)
.trim()
.to_string()
}
} else {
return Ok(false);
};
let glob = globset::Glob::new(pattern).context("Invalid ref pattern")?;
let matcher = glob.compile_matcher();
Ok(matcher.is_match(¤t_ref))
}
fn check_files_condition(&self, pattern: &str) -> Result<bool> {
let Some(repo) = &self.repo else {
return Ok(false);
};
let staged = repo.get_staged_files()?;
let modified = repo.get_modified_files()?;
let glob = globset::Glob::new(pattern).context("Invalid file pattern")?;
let matcher = glob.compile_matcher();
for file in staged.iter().chain(modified.iter()) {
if matcher.is_match(file) {
return Ok(true);
}
}
Ok(false)
}
fn check_merge_condition(&self) -> Result<bool> {
Ok(std::path::Path::new(".git/MERGE_HEAD").exists())
}
fn check_rebase_condition(&self) -> Result<bool> {
Ok(std::path::Path::new(".git/rebase-merge").exists()
|| std::path::Path::new(".git/rebase-apply").exists())
}
fn check_ci_condition(&self) -> Result<bool> {
Ok(std::env::var("CI").is_ok()
|| std::env::var("CONTINUOUS_INTEGRATION").is_ok()
|| std::env::var("GITHUB_ACTIONS").is_ok()
|| std::env::var("GITLAB_CI").is_ok()
|| std::env::var("CIRCLECI").is_ok()
|| std::env::var("TRAVIS").is_ok()
|| std::env::var("JENKINS_URL").is_ok())
}
}
pub fn should_skip(
skip: &HookCondition,
only: &HookCondition,
evaluator: &ConditionEvaluator,
) -> Result<bool> {
if evaluator.evaluate(skip)? {
return Ok(true);
}
match only {
HookCondition::Bool(false) => Ok(false), HookCondition::Bool(true) => Ok(false), HookCondition::Array(conditions) if conditions.is_empty() => Ok(false),
_ => {
Ok(!evaluator.evaluate(only)?)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bool_conditions() {
let evaluator = ConditionEvaluator::new();
assert!(evaluator.evaluate(&HookCondition::Bool(true)).unwrap());
assert!(!evaluator.evaluate(&HookCondition::Bool(false)).unwrap());
}
#[test]
fn test_array_conditions() {
let evaluator = ConditionEvaluator::new();
let conditions = HookCondition::Array(vec!["true".to_string()]);
assert!(evaluator.evaluate(&conditions).unwrap());
let conditions = HookCondition::Array(vec!["false".to_string()]);
assert!(!evaluator.evaluate(&conditions).unwrap());
let conditions = HookCondition::Array(vec!["false".to_string(), "true".to_string()]);
assert!(evaluator.evaluate(&conditions).unwrap());
}
#[test]
fn test_ci_condition() {
let evaluator = ConditionEvaluator::new();
let conditions = HookCondition::Array(vec!["ci".to_string()]);
let result = evaluator.evaluate(&conditions).unwrap();
if std::env::var("CI").is_err() {
assert!(!result);
}
}
#[test]
fn test_should_skip_logic() {
let evaluator = ConditionEvaluator::new();
assert!(
should_skip(
&HookCondition::Bool(true),
&HookCondition::Bool(false),
&evaluator
)
.unwrap()
);
assert!(
!should_skip(
&HookCondition::Bool(false),
&HookCondition::Bool(false),
&evaluator
)
.unwrap()
);
assert!(
!should_skip(
&HookCondition::Bool(false),
&HookCondition::Bool(true),
&evaluator
)
.unwrap()
);
assert!(
should_skip(
&HookCondition::Bool(false),
&HookCondition::Array(vec!["branch:nonexistent".to_string()]),
&evaluator
)
.unwrap()
);
}
}