use crate::invariant::{Category, Context, Invariant, Outcome};
use serde::Deserialize;
use std::fs;
use std::path::{Path, PathBuf};
const USER_DIR: &str = ".koala/invariants";
#[derive(Debug, Deserialize)]
struct File {
#[serde(default)]
rule: Vec<UserRule>,
}
#[derive(Debug, Deserialize, Clone)]
struct UserRule {
id: String,
category: String,
intent: String,
#[serde(default)]
adr: Option<String>,
#[serde(rename = "match")]
match_: MatchSpec,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "kind")]
enum MatchSpec {
#[serde(rename = "forbid-substring")]
ForbidSubstring { glob: String, needle: String },
#[serde(rename = "require-substring")]
RequireSubstring { glob: String, needle: String },
}
#[derive(Debug)]
pub enum LoadError {
Io { path: PathBuf, err: std::io::Error },
Parse { path: PathBuf, err: toml::de::Error },
BadCategory { id: String, value: String },
}
impl std::fmt::Display for LoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io { path, err } => write!(f, "io ({}): {err}", path.display()),
Self::Parse { path, err } => write!(f, "parse ({}): {err}", path.display()),
Self::BadCategory { id, value } => write!(
f,
"rule `{id}`: unknown category `{value}` (expected arch / deps / docs / \
governance / health / security)"
),
}
}
}
impl std::error::Error for LoadError {}
#[derive(Debug)]
pub struct UserDefinedInvariant {
id: String,
category: Category,
intent: String,
adr: Option<String>,
spec: MatchSpec,
}
impl UserDefinedInvariant {
pub fn category_label(&self) -> &str {
self.category.as_str()
}
}
impl Invariant for UserDefinedInvariant {
fn id(&self) -> &'static str {
Box::leak(self.id.clone().into_boxed_str())
}
fn category(&self) -> Category {
self.category
}
fn intent(&self) -> &'static str {
Box::leak(self.intent.clone().into_boxed_str())
}
fn adr(&self) -> Option<&'static str> {
self.adr.clone().map(|s| &*Box::leak(s.into_boxed_str()))
}
fn evaluate(&self, ctx: &Context) -> Outcome {
match &self.spec {
MatchSpec::ForbidSubstring { glob, needle } => evaluate_forbid(ctx, glob, needle),
MatchSpec::RequireSubstring { glob, needle } => evaluate_require(ctx, glob, needle),
}
}
}
fn evaluate_forbid(ctx: &Context, glob: &str, needle: &str) -> Outcome {
let mut hits = Vec::new();
for path in walk_glob(ctx.root(), glob) {
let Ok(text) = fs::read_to_string(&path) else {
continue;
};
if text.contains(needle) {
hits.push(rel_display(&path, ctx.root()));
}
}
if hits.is_empty() {
Outcome::pass()
} else {
Outcome::fail_repro(
format!(
"{n} file(s) contain forbidden substring `{needle}`:\n {body}",
n = hits.len(),
body = hits.join("\n ")
),
format!("rg -F '{needle}' {glob}"),
)
}
}
fn evaluate_require(ctx: &Context, glob: &str, needle: &str) -> Outcome {
let any_present = walk_glob(ctx.root(), glob).into_iter().any(|p| {
fs::read_to_string(&p)
.map(|t| t.contains(needle))
.unwrap_or(false)
});
if any_present {
Outcome::pass()
} else {
Outcome::fail_repro(
format!("no file matching `{glob}` contains required substring `{needle}`"),
format!("rg -F '{needle}' {glob}"),
)
}
}
fn rel_display(p: &Path, root: &Path) -> String {
p.strip_prefix(root)
.unwrap_or(p)
.display()
.to_string()
.replace('\\', "/")
}
fn walk_glob(root: &Path, glob: &str) -> Vec<PathBuf> {
let mut out = Vec::new();
for entry in walkdir::WalkDir::new(root).into_iter().flatten() {
if !entry.file_type().is_file() {
continue;
}
let p = entry.path();
let Some(rel) = p.strip_prefix(root).ok() else {
continue;
};
let rel = rel.to_string_lossy().replace('\\', "/");
if glob_match(glob, &rel) {
out.push(p.to_path_buf());
}
}
out
}
fn glob_match(pattern: &str, text: &str) -> bool {
let segs: Vec<&str> = pattern.split('/').collect();
let parts: Vec<&str> = text.split('/').collect();
glob_segments(&segs, &parts)
}
fn glob_segments(pat: &[&str], text: &[&str]) -> bool {
if pat.is_empty() {
return text.is_empty();
}
let head = pat[0];
let rest_pat = &pat[1..];
if head == "**" {
if glob_segments(rest_pat, text) {
return true;
}
for i in 1..=text.len() {
if glob_segments(rest_pat, &text[i..]) {
return true;
}
}
return false;
}
if text.is_empty() {
return false;
}
if !segment_match(head, text[0]) {
return false;
}
glob_segments(rest_pat, &text[1..])
}
fn segment_match(pat: &str, text: &str) -> bool {
let pb = pat.as_bytes();
let tb = text.as_bytes();
let mut pi = 0usize;
let mut ti = 0usize;
let mut star_pi: Option<usize> = None;
let mut star_ti = 0usize;
while ti < tb.len() {
if pi < pb.len() && pb[pi] == b'*' {
star_pi = Some(pi + 1);
star_ti = ti;
pi += 1;
continue;
}
if pi < pb.len() && pb[pi] == tb[ti] {
pi += 1;
ti += 1;
continue;
}
if let Some(spi) = star_pi {
star_ti += 1;
ti = star_ti;
pi = spi;
continue;
}
return false;
}
while pi < pb.len() && pb[pi] == b'*' {
pi += 1;
}
pi == pb.len()
}
pub fn load_all(repo_root: &Path) -> Result<Vec<UserDefinedInvariant>, LoadError> {
let dir = repo_root.join(USER_DIR);
let Ok(read) = fs::read_dir(&dir) else {
return Ok(Vec::new());
};
let mut out = Vec::new();
for entry in read.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("toml") {
continue;
}
let text = fs::read_to_string(&path).map_err(|err| LoadError::Io {
path: path.clone(),
err,
})?;
let file: File = toml::from_str(&text).map_err(|err| LoadError::Parse {
path: path.clone(),
err,
})?;
for r in file.rule {
let category = parse_category(&r.id, &r.category)?;
out.push(UserDefinedInvariant {
id: r.id,
category,
intent: r.intent,
adr: r.adr,
spec: r.match_,
});
}
}
out.sort_by(|a, b| a.id.cmp(&b.id));
Ok(out)
}
fn parse_category(id: &str, value: &str) -> Result<Category, LoadError> {
Ok(match value {
"arch" => Category::Arch,
"deps" => Category::Deps,
"docs" => Category::Docs,
"governance" => Category::Governance,
"health" => Category::Health,
"security" => Category::Security,
other => {
return Err(LoadError::BadCategory {
id: id.to_string(),
value: other.to_string(),
})
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn write(root: &Path, rel: &str, body: &str) {
let p = root.join(rel);
fs::create_dir_all(p.parent().unwrap()).unwrap();
fs::write(p, body).unwrap();
}
#[test]
fn glob_matches_simple_patterns() {
assert!(glob_match("crates/**/*.rs", "crates/koala-core/src/lib.rs"));
assert!(glob_match(
"crates/*/Cargo.toml",
"crates/koala-core/Cargo.toml"
));
assert!(!glob_match("crates/*/Cargo.toml", "crates/a/b/Cargo.toml"));
assert!(glob_match("**/README.md", "README.md"));
assert!(glob_match("**/README.md", "wiki/README.md"));
}
#[test]
fn user_defined_toml_loaded() {
let tmp = TempDir::new().unwrap();
write(
tmp.path(),
".koala/invariants/biz.toml",
r#"
[[rule]]
id = "biz.no-fixme-in-src"
category = "health"
intent = "Code under crates/ must not ship FIXME markers."
adr = "ADR-0019"
[rule.match]
kind = "forbid-substring"
glob = "crates/**/*.rs"
needle = "FIXME"
"#,
);
let rules = load_all(tmp.path()).unwrap();
assert_eq!(rules.len(), 1);
let r = &rules[0];
assert_eq!(r.id(), "biz.no-fixme-in-src");
assert_eq!(r.category().as_str(), "health");
assert_eq!(r.adr(), Some("ADR-0019"));
let ctx = Context::new(tmp.path().to_path_buf());
assert!(matches!(r.evaluate(&ctx), Outcome::Pass { .. }));
write(
tmp.path(),
"crates/x/src/lib.rs",
"// FIXME: rewrite\npub fn k() {}\n",
);
let out = r.evaluate(&ctx);
assert!(matches!(out, Outcome::Fail { .. }), "{out:?}");
}
#[test]
fn require_substring_rule() {
let tmp = TempDir::new().unwrap();
write(
tmp.path(),
".koala/invariants/docs.toml",
r#"
[[rule]]
id = "biz.readme-mentions-license"
category = "docs"
intent = "README must mention the license."
[rule.match]
kind = "require-substring"
glob = "README.md"
needle = "Apache-2.0"
"#,
);
let rules = load_all(tmp.path()).unwrap();
let r = &rules[0];
let ctx = Context::new(tmp.path().to_path_buf());
assert!(matches!(r.evaluate(&ctx), Outcome::Fail { .. }));
write(
tmp.path(),
"README.md",
"# Project\n\nLicense: Apache-2.0\n",
);
assert!(matches!(r.evaluate(&ctx), Outcome::Pass { .. }));
}
#[test]
fn missing_user_dir_returns_empty() {
let tmp = TempDir::new().unwrap();
let rules = load_all(tmp.path()).unwrap();
assert!(rules.is_empty());
}
#[test]
fn bad_category_is_rejected() {
let tmp = TempDir::new().unwrap();
write(
tmp.path(),
".koala/invariants/bad.toml",
r#"
[[rule]]
id = "biz.x"
category = "nonsense"
intent = "x"
[rule.match]
kind = "forbid-substring"
glob = "**/*"
needle = "x"
"#,
);
let err = load_all(tmp.path()).unwrap_err();
assert!(matches!(err, LoadError::BadCategory { .. }));
}
}