use std::path::{Path, PathBuf};
use zenith_core::{
BrandContract, DiagnosticPolicy, PolicyEntry, PolicyVerb, parse_brand_contract,
parse_diagnostic_policy,
};
const LOCAL_CONFIG_NAME: &str = ".zenith.kdl";
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct CliPolicyFlags {
pub allow: Vec<String>,
pub warn: Vec<String>,
pub deny: Vec<String>,
}
impl CliPolicyFlags {
pub fn is_empty(&self) -> bool {
self.allow.is_empty() && self.warn.is_empty() && self.deny.is_empty()
}
fn entries(&self) -> Vec<PolicyEntry> {
let mut entries = Vec::with_capacity(self.allow.len() + self.warn.len() + self.deny.len());
for code in &self.allow {
entries.push(PolicyEntry {
verb: PolicyVerb::Allow,
code: code.clone(),
subjects: Vec::new(),
source_span: None,
});
}
for code in &self.warn {
entries.push(PolicyEntry {
verb: PolicyVerb::Warn,
code: code.clone(),
subjects: Vec::new(),
source_span: None,
});
}
for code in &self.deny {
entries.push(PolicyEntry {
verb: PolicyVerb::Deny,
code: code.clone(),
subjects: Vec::new(),
source_span: None,
});
}
entries
}
}
pub fn merge_policy(
global: &DiagnosticPolicy,
local: &DiagnosticPolicy,
in_file: &DiagnosticPolicy,
flags: &CliPolicyFlags,
) -> DiagnosticPolicy {
let cli_entries = flags.entries();
let mut entries = Vec::with_capacity(
global.entries.len() + local.entries.len() + in_file.entries.len() + cli_entries.len(),
);
entries.extend(global.entries.iter().cloned());
entries.extend(local.entries.iter().cloned());
entries.extend(in_file.entries.iter().cloned());
entries.extend(cli_entries);
DiagnosticPolicy { entries }
}
pub fn load_policy_file(path: &Path) -> Result<DiagnosticPolicy, String> {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Ok(DiagnosticPolicy::default());
}
Err(e) => return Err(format!("cannot read config '{}': {e}", path.display())),
};
parse_diagnostic_policy(&bytes)
.map_err(|e| format!("invalid config '{}': {}", path.display(), e.message))
}
pub fn find_local_policy(start_dir: &Path) -> Result<DiagnosticPolicy, String> {
let mut dir: Option<&Path> = Some(start_dir);
while let Some(current) = dir {
let candidate = current.join(LOCAL_CONFIG_NAME);
if candidate.is_file() {
return load_policy_file(&candidate);
}
dir = current.parent();
}
Ok(DiagnosticPolicy::default())
}
pub fn load_global_policy_in(config_dir: &Path) -> Result<DiagnosticPolicy, String> {
let path = config_dir.join("zenith").join("config.kdl");
load_policy_file(&path)
}
pub fn load_global_policy() -> Result<DiagnosticPolicy, String> {
match std::env::var_os("HOME").map(PathBuf::from) {
Some(home) => load_global_policy_in(&home.join(".config")),
None => Ok(DiagnosticPolicy::default()),
}
}
pub fn load_brand_file(path: &Path) -> Result<BrandContract, String> {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Ok(BrandContract::default());
}
Err(e) => return Err(format!("cannot read config '{}': {e}", path.display())),
};
parse_brand_contract(&bytes)
.map_err(|e| format!("invalid config '{}': {}", path.display(), e.message))
}
pub fn find_local_brand(start_dir: &Path) -> Result<BrandContract, String> {
let mut dir: Option<&Path> = Some(start_dir);
while let Some(current) = dir {
let candidate = current.join(LOCAL_CONFIG_NAME);
if candidate.is_file() {
return load_brand_file(&candidate);
}
dir = current.parent();
}
Ok(BrandContract::default())
}
pub fn load_global_brand_in(config_dir: &Path) -> Result<BrandContract, String> {
let path = config_dir.join("zenith").join("config.kdl");
load_brand_file(&path)
}
pub fn load_global_brand() -> Result<BrandContract, String> {
match std::env::var_os("HOME").map(PathBuf::from) {
Some(home) => load_global_brand_in(&home.join(".config")),
None => Ok(BrandContract::default()),
}
}
pub fn load_global_and_local(
start_dir: Option<&Path>,
) -> Result<
(
DiagnosticPolicy,
DiagnosticPolicy,
BrandContract,
BrandContract,
),
String,
> {
let global = load_global_policy()?;
let global_brand = load_global_brand()?;
let (local, local_brand) = match start_dir {
Some(dir) => (find_local_policy(dir)?, find_local_brand(dir)?),
None => (DiagnosticPolicy::default(), BrandContract::default()),
};
Ok((global, local, global_brand, local_brand))
}
#[cfg(test)]
mod tests {
use super::*;
fn deny(code: &str) -> DiagnosticPolicy {
DiagnosticPolicy {
entries: vec![PolicyEntry {
verb: PolicyVerb::Deny,
code: code.to_owned(),
subjects: Vec::new(),
source_span: None,
}],
}
}
fn allow(code: &str) -> DiagnosticPolicy {
DiagnosticPolicy {
entries: vec![PolicyEntry {
verb: PolicyVerb::Allow,
code: code.to_owned(),
subjects: Vec::new(),
source_span: None,
}],
}
}
#[test]
fn empty_everything_is_identity() {
let merged = merge_policy(
&DiagnosticPolicy::default(),
&DiagnosticPolicy::default(),
&DiagnosticPolicy::default(),
&CliPolicyFlags::default(),
);
assert!(merged.entries.is_empty());
}
#[test]
fn cli_beats_in_file_beats_local_beats_global() {
let global = allow("a");
let local = deny("a");
let in_file = allow("a");
let flags = CliPolicyFlags {
deny: vec!["a".to_owned()],
..Default::default()
};
let merged = merge_policy(&global, &local, &in_file, &flags);
assert_eq!(merged.verb_for("a", None), Some(&PolicyVerb::Deny));
}
#[test]
fn in_file_beats_config_when_no_flag() {
let global = deny("a");
let local = deny("a");
let in_file = allow("a");
let merged = merge_policy(&global, &local, &in_file, &CliPolicyFlags::default());
assert_eq!(merged.verb_for("a", None), Some(&PolicyVerb::Allow));
}
#[test]
fn missing_file_is_default() {
let policy = load_policy_file(Path::new("/no/such/zenith/config.kdl"))
.expect("missing file must be ok");
assert!(policy.entries.is_empty());
}
#[test]
fn find_local_handles_root_without_panic() {
let policy = find_local_policy(Path::new("/")).expect("root walk must be ok");
assert!(policy.entries.is_empty());
}
}