use crate::Result;
use crate::expr::Expression;
use camino::{Utf8Path, Utf8PathBuf};
use core::time::Duration;
use ohno::{IntoAppError, app_err};
use semver::{Version, VersionReq};
use serde::{Deserialize, Serialize};
use std::fs;
use std::io;
pub const DEFAULT_CONFIG_TOML: &str = include_str!("../../default_config.toml");
#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct AllowListEntry {
pub name: String,
pub version: VersionReq,
}
impl AllowListEntry {
#[must_use]
pub fn matches(&self, name: &str, version: &Version) -> bool {
self.name == name && self.version.matches(version)
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
#[serde(default)]
pub allow_list: Vec<AllowListEntry>,
#[serde(default)]
pub high_risk: Vec<Expression>,
#[serde(default)]
pub eval: Vec<Expression>,
#[serde(default = "default_medium_risk_threshold")]
pub medium_risk_threshold: f64,
#[serde(default = "default_low_risk_threshold")]
pub low_risk_threshold: f64,
#[serde(default = "default_cache_ttl", with = "humantime_serde")]
pub crates_cache_ttl: Duration,
#[serde(default = "default_cache_ttl", with = "humantime_serde")]
pub hosting_cache_ttl: Duration,
#[serde(default = "default_cache_ttl", with = "humantime_serde")]
pub codebase_cache_ttl: Duration,
#[serde(default = "default_cache_ttl", with = "humantime_serde")]
pub coverage_cache_ttl: Duration,
#[serde(default = "default_cache_ttl", with = "humantime_serde")]
pub advisories_cache_ttl: Duration,
}
const fn default_medium_risk_threshold() -> f64 {
30.0
}
const fn default_low_risk_threshold() -> f64 {
70.0
}
const fn default_cache_ttl() -> Duration {
Duration::from_hours(24 * 7)
}
impl Config {
#[must_use]
pub fn is_allowed(&self, name: &str, version: &Version) -> bool {
self.allow_list.iter().any(|entry| entry.matches(name, version))
}
pub fn load(workspace_root: &Utf8Path, config_path: Option<&Utf8PathBuf>) -> Result<Self> {
let (final_path, text) = if let Some(path) = config_path {
let text = fs::read_to_string(path).into_app_err_with(|| format!("reading cargo-aprz configuration file '{path}'"))?;
(path.clone(), text)
} else {
let path = workspace_root.join("aprz.toml");
match fs::read_to_string(&path) {
Ok(text) => (path, text),
Err(e) if e.kind() == io::ErrorKind::NotFound => {
return Ok(Self::default());
}
Err(e) => return Err(e).into_app_err_with(|| format!("reading cargo-aprz configuration file '{path}'")),
}
};
let config: Self = toml::from_str(&text).into_app_err_with(|| format!("parsing configuration file '{final_path}'"))?;
config.validate()?;
Ok(config)
}
pub fn save_default(output_path: &Utf8Path) -> Result<()> {
fs::write(output_path, DEFAULT_CONFIG_TOML).into_app_err_with(|| format!("writing default configuration to {output_path}"))?;
Ok(())
}
fn validate(&self) -> Result<()> {
if !(0.0..=100.0).contains(&self.medium_risk_threshold) {
return Err(app_err!(
"medium_risk_threshold must be between 0 and 100, got {}",
self.medium_risk_threshold
));
}
if !(0.0..=100.0).contains(&self.low_risk_threshold) {
return Err(app_err!(
"low_risk_threshold must be between 0 and 100, got {}",
self.low_risk_threshold
));
}
if self.medium_risk_threshold >= self.low_risk_threshold {
return Err(app_err!(
"medium_risk_threshold ({}) must be less than low_risk_threshold ({})",
self.medium_risk_threshold,
self.low_risk_threshold
));
}
Ok(())
}
}
impl Default for Config {
fn default() -> Self {
toml::from_str(DEFAULT_CONFIG_TOML).expect("default_config.toml should be valid TOML that deserializes to Config")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_is_valid() {
let config = Config::default();
config.validate().unwrap();
}
#[test]
fn test_validate_medium_risk_out_of_range_low() {
let config = Config { medium_risk_threshold: -1.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_medium_risk_out_of_range_high() {
let config = Config { medium_risk_threshold: 101.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_low_risk_out_of_range_low() {
let config = Config { low_risk_threshold: -1.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_low_risk_out_of_range_high() {
let config = Config { low_risk_threshold: 101.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_medium_ge_low_risk() {
let config = Config { medium_risk_threshold: 80.0, low_risk_threshold: 70.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_medium_equals_low_risk() {
let config = Config { medium_risk_threshold: 70.0, low_risk_threshold: 70.0, ..Config::default() };
assert!(config.validate().is_err());
}
#[test]
fn test_validate_valid_thresholds() {
let config = Config { medium_risk_threshold: 30.0, low_risk_threshold: 70.0, ..Config::default() };
config.validate().unwrap();
}
#[test]
fn test_validate_boundary_values() {
let config = Config { medium_risk_threshold: 0.0, low_risk_threshold: 100.0, ..Config::default() };
config.validate().unwrap();
}
#[test]
#[cfg_attr(miri, ignore = "Miri cannot call GetTempPathW")]
fn test_save_default_and_load() {
let tmp = tempfile::tempdir().unwrap();
let output_path = Utf8PathBuf::try_from(tmp.path().join("aprz.toml")).unwrap();
Config::save_default(&output_path).unwrap();
let loaded = Config::load(&Utf8PathBuf::try_from(tmp.path().to_path_buf()).unwrap(), Some(&output_path)).unwrap();
loaded.validate().unwrap();
}
#[test]
#[cfg_attr(miri, ignore = "Miri cannot call GetTempPathW")]
fn test_load_missing_config_uses_defaults() {
let tmp = tempfile::tempdir().unwrap();
let workspace_root = Utf8PathBuf::try_from(tmp.path().to_path_buf()).unwrap();
let config = Config::load(&workspace_root, None).unwrap();
config.validate().unwrap();
}
#[test]
fn test_default_config_toml_is_not_empty() {
assert!(!DEFAULT_CONFIG_TOML.is_empty());
}
#[test]
fn test_default_config_has_empty_allow_list() {
let config = Config::default();
assert!(config.allow_list.is_empty());
}
#[test]
fn test_allow_list_entry_matches_exact_version() {
let entry = AllowListEntry {
name: "foo".to_string(),
version: VersionReq::parse("=1.2.3").unwrap(),
};
assert!(entry.matches("foo", &Version::new(1, 2, 3)));
assert!(!entry.matches("foo", &Version::new(1, 2, 4)));
assert!(!entry.matches("bar", &Version::new(1, 2, 3)));
}
#[test]
fn test_allow_list_entry_matches_caret_range() {
let entry = AllowListEntry {
name: "foo".to_string(),
version: VersionReq::parse("^1.0").unwrap(),
};
assert!(entry.matches("foo", &Version::new(1, 0, 0)));
assert!(entry.matches("foo", &Version::new(1, 9, 9)));
assert!(!entry.matches("foo", &Version::new(2, 0, 0)));
}
#[test]
fn test_allow_list_entry_matches_wildcard() {
let entry = AllowListEntry {
name: "foo".to_string(),
version: VersionReq::parse("*").unwrap(),
};
assert!(entry.matches("foo", &Version::new(0, 0, 1)));
assert!(entry.matches("foo", &Version::new(99, 99, 99)));
assert!(!entry.matches("bar", &Version::new(1, 0, 0)));
}
#[test]
fn test_is_allowed_matches() {
let mut config = Config::default();
config.allow_list.push(AllowListEntry {
name: "foo".to_string(),
version: VersionReq::parse("^1.0").unwrap(),
});
assert!(config.is_allowed("foo", &Version::new(1, 2, 3)));
assert!(!config.is_allowed("foo", &Version::new(2, 0, 0)));
assert!(!config.is_allowed("bar", &Version::new(1, 0, 0)));
}
#[test]
fn test_is_allowed_empty_list() {
let config = Config::default();
assert!(!config.is_allowed("foo", &Version::new(1, 0, 0)));
}
#[test]
#[cfg_attr(miri, ignore = "Miri cannot call GetTempPathW")]
fn test_load_config_with_allow_list() {
let tmp = tempfile::tempdir().unwrap();
let config_path = Utf8PathBuf::try_from(tmp.path().join("aprz.toml")).unwrap();
let toml_content = r#"
medium_risk_threshold = 30.0
low_risk_threshold = 70.0
[[allow_list]]
name = "some-crate"
version = "=1.2.3"
[[allow_list]]
name = "another-crate"
version = "^2.0"
"#;
fs::write(&config_path, toml_content).unwrap();
let workspace_root = Utf8PathBuf::try_from(tmp.path().to_path_buf()).unwrap();
let config = Config::load(&workspace_root, Some(&config_path)).unwrap();
assert_eq!(config.allow_list.len(), 2);
assert!(config.is_allowed("some-crate", &Version::new(1, 2, 3)));
assert!(!config.is_allowed("some-crate", &Version::new(1, 2, 4)));
assert!(config.is_allowed("another-crate", &Version::new(2, 5, 0)));
assert!(!config.is_allowed("another-crate", &Version::new(1, 0, 0)));
}
}