use crate::error::{PolicyError, Result};
use crate::expr::{self, Expr};
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct Policy {
pub requirements: Vec<Requirement>,
pub overrides: Vec<Override>,
pub aliases: Vec<Alias>,
}
#[derive(Debug, Clone)]
pub struct Requirement {
pub name: String,
pub expr: Expr,
pub default: bool,
}
#[derive(Debug, Clone)]
pub struct Override {
pub matcher: SubjectMatcher,
pub op: OverrideOp,
}
#[derive(Debug, Clone, Default)]
pub struct SubjectMatcher {
pub registry: Option<String>,
pub package: Option<String>,
pub version: Option<String>,
pub variant: Option<String>,
pub hash: Option<String>,
}
#[derive(Debug, Clone)]
pub enum OverrideOp {
Replace(Vec<String>),
Patch {
add: Vec<String>,
remove: Vec<String>,
},
}
#[derive(Debug, Clone)]
pub struct Alias {
pub canonical: String,
pub mappings: Vec<(String, String)>,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct WirePolicy {
#[serde(default)]
requirement: HashMap<String, WireRequirement>,
#[serde(default, rename = "override")]
overrides: Vec<WireOverride>,
#[serde(default)]
alias: HashMap<String, Vec<String>>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum WireRequirement {
Bare(String),
Full {
condition: String,
#[serde(default = "default_true")]
default: bool,
},
}
fn default_true() -> bool {
true
}
#[derive(Deserialize)]
struct WireOverride {
#[serde(default)]
registry: Option<String>,
#[serde(default)]
package: Option<String>,
#[serde(default)]
version: Option<String>,
#[serde(default)]
variant: Option<String>,
#[serde(default)]
hash: Option<String>,
requirements: WireRequirementsOp,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum WireRequirementsOp {
Replace(Vec<String>),
Patch {
#[serde(default)]
add: Vec<String>,
#[serde(default)]
remove: Vec<String>,
},
}
pub fn parse_str(toml_str: &str) -> Result<Policy> {
let wire: WirePolicy = toml::from_str(toml_str)?;
build(wire)
}
impl<'de> serde::Deserialize<'de> for Policy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let wire = WirePolicy::deserialize(deserializer)?;
build(wire).map_err(serde::de::Error::custom)
}
}
pub fn parse(path: &std::path::Path) -> Result<Policy> {
let s = std::fs::read_to_string(path)?;
parse_str(&s)
}
fn build(wire: WirePolicy) -> Result<Policy> {
let mut requirements: Vec<Requirement> = Vec::new();
for (name, wr) in wire.requirement {
let (condition, default) = match wr {
WireRequirement::Bare(c) => (c, true),
WireRequirement::Full { condition, default } => (condition, default),
};
let expr =
expr::parse(&condition).map_err(|source| PolicyError::RequirementExpression {
name: name.clone(),
source: Box::new(source),
})?;
requirements.push(Requirement {
name,
expr,
default,
});
}
requirements.sort_by(|a, b| a.name.cmp(&b.name));
let known: HashSet<&str> = requirements.iter().map(|r| r.name.as_str()).collect();
let mut overrides: Vec<Override> = Vec::with_capacity(wire.overrides.len());
for o in wire.overrides {
let op = match o.requirements {
WireRequirementsOp::Replace(names) => {
check_names_known(&names, &known)?;
OverrideOp::Replace(names)
}
WireRequirementsOp::Patch { add, remove } => {
check_names_known(&add, &known)?;
check_names_known(&remove, &known)?;
OverrideOp::Patch { add, remove }
}
};
overrides.push(Override {
matcher: SubjectMatcher {
registry: o.registry,
package: o.package,
version: o.version,
variant: o.variant,
hash: o.hash,
},
op,
});
}
let mut aliases: Vec<Alias> = Vec::with_capacity(wire.alias.len());
for (canonical, mappings) in wire.alias {
let mut parsed = Vec::with_capacity(mappings.len());
for entry in mappings {
let (log, name) =
entry
.split_once(':')
.ok_or_else(|| PolicyError::InvalidAliasEntry {
canonical: canonical.clone(),
entry: entry.clone(),
})?;
parsed.push((log.to_string(), name.to_string()));
}
aliases.push(Alias {
canonical,
mappings: parsed,
});
}
aliases.sort_by(|a, b| a.canonical.cmp(&b.canonical));
Ok(Policy {
requirements,
overrides,
aliases,
})
}
fn check_names_known(names: &[String], known: &HashSet<&str>) -> Result<()> {
for n in names {
if !known.contains(n.as_str()) {
return Err(PolicyError::UnknownRequirement { name: n.clone() });
}
}
Ok(())
}
impl Policy {
pub fn requirement(&self, name: &str) -> Option<&Requirement> {
self.requirements.iter().find(|r| r.name == name)
}
pub fn alias_for<'a>(&'a self, canonical: &str) -> Option<&'a Alias> {
self.aliases.iter().find(|a| a.canonical == canonical)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_policy_parses() {
let p = parse_str("").unwrap();
assert!(p.requirements.is_empty());
assert!(p.overrides.is_empty());
assert!(p.aliases.is_empty());
}
#[test]
fn bare_requirement_defaults_to_on() {
let p = parse_str(
r#"
[requirement]
safe-to-deploy = "safe-to-deploy"
"#,
)
.unwrap();
assert_eq!(p.requirements.len(), 1);
assert_eq!(p.requirements[0].name, "safe-to-deploy");
assert!(p.requirements[0].default);
}
#[test]
fn full_form_can_disable_default() {
let p = parse_str(
r#"
[requirement]
sandbox = { condition = "sandboxed", default = false }
"#,
)
.unwrap();
assert!(!p.requirements[0].default);
}
#[test]
fn full_form_default_true_is_kept() {
let p = parse_str(
r#"
[requirement]
x = { condition = "a" }
"#,
)
.unwrap();
assert!(p.requirements[0].default);
}
#[test]
fn override_replace_form() {
let p = parse_str(
r#"
[requirement]
r1 = "a"
r2 = "b"
[[override]]
package = "serde"
requirements = ["r2"]
"#,
)
.unwrap();
assert_eq!(p.overrides.len(), 1);
assert!(matches!(p.overrides[0].op, OverrideOp::Replace(ref v) if v == &["r2"]));
}
#[test]
fn override_patch_form() {
let p = parse_str(
r#"
[requirement]
r1 = "a"
r2 = "b"
[[override]]
package = "libc"
requirements = { add = ["r2"], remove = ["r1"] }
"#,
)
.unwrap();
match &p.overrides[0].op {
OverrideOp::Patch { add, remove } => {
assert_eq!(add, &["r2"]);
assert_eq!(remove, &["r1"]);
}
_ => panic!("expected Patch"),
}
}
#[test]
fn override_to_unknown_requirement_errors() {
let err = parse_str(
r#"
[requirement]
r1 = "a"
[[override]]
package = "x"
requirements = ["nope"]
"#,
)
.unwrap_err();
assert!(matches!(err, PolicyError::UnknownRequirement { .. }));
}
#[test]
fn aliases_parse() {
let p = parse_str(
r#"
[alias]
safe-to-run = ["google:safe-to-run", "mozilla:runtime-safe"]
"#,
)
.unwrap();
assert_eq!(p.aliases.len(), 1);
assert_eq!(p.aliases[0].canonical, "safe-to-run");
assert_eq!(p.aliases[0].mappings.len(), 2);
assert_eq!(
p.aliases[0].mappings[0],
("google".into(), "safe-to-run".into())
);
}
#[test]
fn alias_without_colon_errors() {
let err = parse_str(
r#"
[alias]
x = ["nocolonhere"]
"#,
)
.unwrap_err();
assert!(matches!(err, PolicyError::InvalidAliasEntry { .. }));
}
#[test]
fn bad_expression_surfaces_requirement_name() {
let err = parse_str(
r#"
[requirement]
broken = "a and"
"#,
)
.unwrap_err();
if let PolicyError::RequirementExpression { name, .. } = err {
assert_eq!(name, "broken");
} else {
panic!("expected RequirementExpression");
}
}
}