use crate::asset::check_capacity_valid_for_asset;
use crate::input::{
deserialise_proportion_nonzero, input_err_msg, is_sorted_and_unique, read_toml,
};
use crate::units::{Capacity, Dimensionless, Flow, MoneyPerFlow};
use anyhow::{Context, Result, ensure};
use log::warn;
use serde::Deserialize;
use std::path::Path;
use std::sync::OnceLock;
const MODEL_PARAMETERS_FILE_NAME: &str = "model.toml";
pub const ALLOW_DANGEROUS_OPTION_NAME: &str = "please_give_me_broken_results";
static DANGEROUS_OPTIONS_ENABLED: OnceLock<bool> = OnceLock::new();
pub fn dangerous_model_options_enabled() -> bool {
*DANGEROUS_OPTIONS_ENABLED
.get()
.expect("Dangerous options flag not set")
}
fn set_dangerous_model_options_flag(enabled: bool) {
let result = DANGEROUS_OPTIONS_ENABLED.set(enabled);
if result.is_err() {
if cfg!(test) {
assert_eq!(enabled, dangerous_model_options_enabled());
} else {
panic!("Attempted to set DANGEROUS_OPTIONS_ENABLED twice");
}
}
}
macro_rules! define_unit_param_default {
($name:ident, $type: ty, $value: expr) => {
fn $name() -> $type {
<$type>::new($value)
}
};
}
macro_rules! define_param_default {
($name:ident, $type: ty, $value: expr) => {
fn $name() -> $type {
$value
}
};
}
define_unit_param_default!(default_candidate_asset_capacity, Capacity, 0.0001);
define_unit_param_default!(default_capacity_limit_factor, Dimensionless, 0.1);
define_unit_param_default!(default_value_of_lost_load, MoneyPerFlow, 1e9);
define_unit_param_default!(default_price_tolerance, Dimensionless, 1e-6);
define_unit_param_default!(default_remaining_demand_absolute_tolerance, Flow, 1e-12);
define_param_default!(default_max_ironing_out_iterations, u32, 1);
define_param_default!(default_capacity_margin, f64, 0.2);
define_param_default!(default_mothball_years, u32, 0);
#[derive(Debug, Deserialize, PartialEq)]
pub struct ModelParameters {
pub milestone_years: Vec<u32>,
#[serde(default, rename = "please_give_me_broken_results")] pub allow_dangerous_options: bool,
#[serde(default = "default_candidate_asset_capacity")]
pub candidate_asset_capacity: Capacity,
#[serde(default = "default_capacity_limit_factor")]
#[serde(deserialize_with = "deserialise_proportion_nonzero")]
pub capacity_limit_factor: Dimensionless,
#[serde(default = "default_value_of_lost_load")]
pub value_of_lost_load: MoneyPerFlow,
#[serde(default = "default_max_ironing_out_iterations")]
pub max_ironing_out_iterations: u32,
#[serde(default = "default_price_tolerance")]
pub price_tolerance: Dimensionless,
#[serde(default = "default_capacity_margin")]
pub capacity_margin: f64,
#[serde(default = "default_mothball_years")]
pub mothball_years: u32,
#[serde(default = "default_remaining_demand_absolute_tolerance")]
pub remaining_demand_absolute_tolerance: Flow,
}
fn check_milestone_years(years: &[u32]) -> Result<()> {
ensure!(!years.is_empty(), "`milestone_years` is empty");
ensure!(
is_sorted_and_unique(years),
"`milestone_years` must be composed of unique values in order"
);
Ok(())
}
fn check_value_of_lost_load(value: MoneyPerFlow) -> Result<()> {
ensure!(
value.is_finite() && value > MoneyPerFlow(0.0),
"value_of_lost_load must be a finite number greater than zero"
);
Ok(())
}
fn check_max_ironing_out_iterations(value: u32) -> Result<()> {
ensure!(value > 0, "max_ironing_out_iterations cannot be zero");
Ok(())
}
fn check_price_tolerance(value: Dimensionless) -> Result<()> {
ensure!(
value.is_finite() && value >= Dimensionless(0.0),
"price_tolerance must be a finite number greater than or equal to zero"
);
Ok(())
}
fn check_remaining_demand_absolute_tolerance(
dangerous_options_enabled: bool,
value: Flow,
) -> Result<()> {
ensure!(
value.is_finite() && value >= Flow(0.0),
"remaining_demand_absolute_tolerance must be a finite number greater than or equal to zero"
);
let default_value = default_remaining_demand_absolute_tolerance();
if !dangerous_options_enabled {
ensure!(
value == default_value,
"Setting a remaining_demand_absolute_tolerance different from the default value of \
{:e} is potentially dangerous, set {ALLOW_DANGEROUS_OPTION_NAME} to true if you want \
to allow this.",
default_value.0
);
}
Ok(())
}
fn check_capacity_margin(value: f64) -> Result<()> {
ensure!(
value.is_finite() && value >= 0.0,
"capacity_margin must be a finite number greater than or equal to zero"
);
Ok(())
}
impl ModelParameters {
pub fn from_path<P: AsRef<Path>>(model_dir: P) -> Result<ModelParameters> {
let file_path = model_dir.as_ref().join(MODEL_PARAMETERS_FILE_NAME);
let model_params: ModelParameters = read_toml(&file_path)?;
set_dangerous_model_options_flag(model_params.allow_dangerous_options);
model_params
.validate()
.with_context(|| input_err_msg(file_path))?;
Ok(model_params)
}
fn validate(&self) -> Result<()> {
if self.allow_dangerous_options {
warn!(
"!!! You've enabled the {ALLOW_DANGEROUS_OPTION_NAME} option. !!!\n\
I see you like to live dangerously 😈. This option should ONLY be used by \
developers as it can cause peculiar behaviour that breaks things. NEVER enable it \
for results you actually care about or want to publish. You have been warned!"
);
}
check_milestone_years(&self.milestone_years)?;
check_capacity_valid_for_asset(self.candidate_asset_capacity)
.context("Invalid value for candidate_asset_capacity")?;
check_value_of_lost_load(self.value_of_lost_load)?;
check_max_ironing_out_iterations(self.max_ironing_out_iterations)?;
check_price_tolerance(self.price_tolerance)?;
check_capacity_margin(self.capacity_margin)?;
check_remaining_demand_absolute_tolerance(
self.allow_dangerous_options,
self.remaining_demand_absolute_tolerance,
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::fmt::Display;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
fn assert_validation_result<T, U: Display>(
result: Result<T>,
expected_valid: bool,
value: U,
expected_error_fragment: &str,
) {
if expected_valid {
assert!(
result.is_ok(),
"Expected value {} to be valid, but got error: {:?}",
value,
result.err()
);
} else {
assert!(
result.is_err(),
"Expected value {value} to be invalid, but it was accepted",
);
let error_message = result.err().unwrap().to_string();
assert!(
error_message.contains(expected_error_fragment),
"Error message should mention the validation constraint, got: {error_message}",
);
}
}
#[test]
fn check_milestone_years_works() {
check_milestone_years(&[1]).unwrap();
check_milestone_years(&[1, 2]).unwrap();
assert!(check_milestone_years(&[]).is_err());
assert!(check_milestone_years(&[1, 1]).is_err());
assert!(check_milestone_years(&[2, 1]).is_err());
}
#[test]
fn model_params_from_path() {
let dir = tempdir().unwrap();
{
let mut file = File::create(dir.path().join(MODEL_PARAMETERS_FILE_NAME)).unwrap();
writeln!(file, "milestone_years = [2020, 2100]").unwrap();
}
let model_params = ModelParameters::from_path(dir.path()).unwrap();
assert_eq!(model_params.milestone_years, [2020, 2100]);
}
#[rstest]
#[case(1.0, true)] #[case(1e-10, true)] #[case(1e9, true)] #[case(f64::MAX, true)] #[case(0.0, false)] #[case(-1.0, false)] #[case(-1e-10, false)] #[case(f64::INFINITY, false)] #[case(f64::NEG_INFINITY, false)] #[case(f64::NAN, false)] fn check_value_of_lost_load_works(#[case] value: f64, #[case] expected_valid: bool) {
let money_per_flow = MoneyPerFlow::new(value);
let result = check_value_of_lost_load(money_per_flow);
assert_validation_result(
result,
expected_valid,
value,
"value_of_lost_load must be a finite number greater than zero",
);
}
#[rstest]
#[case(1, true)] #[case(10, true)] #[case(100, true)] #[case(u32::MAX, true)] #[case(0, false)] fn check_max_ironing_out_iterations_works(#[case] value: u32, #[case] expected_valid: bool) {
let result = check_max_ironing_out_iterations(value);
assert_validation_result(
result,
expected_valid,
value,
"max_ironing_out_iterations cannot be zero",
);
}
#[rstest]
#[case(0.0, true)] #[case(1e-10, true)] #[case(1e-6, true)] #[case(1.0, true)] #[case(f64::MAX, true)] #[case(-1e-10, false)] #[case(-1.0, false)] #[case(f64::INFINITY, false)] #[case(f64::NEG_INFINITY, false)] #[case(f64::NAN, false)] fn check_price_tolerance_works(#[case] value: f64, #[case] expected_valid: bool) {
let dimensionless = Dimensionless::new(value);
let result = check_price_tolerance(dimensionless);
assert_validation_result(
result,
expected_valid,
value,
"price_tolerance must be a finite number greater than or equal to zero",
);
}
#[rstest]
#[case(true, 0.0, true)] #[case(true, 1e-10, true)] #[case(true, 1e-15, true)] #[case(false, 1e-12, true)] #[case(true, 1.0, true)] #[case(true, f64::MAX, true)] #[case(true, -1e-10, false)] #[case(true, f64::INFINITY, false)] #[case(true, f64::NEG_INFINITY, false)] #[case(true, f64::NAN, false)] #[case(false, -1e-10, false)] #[case(false, f64::INFINITY, false)] #[case(false, f64::NEG_INFINITY, false)] #[case(false, f64::NAN, false)] fn check_remaining_demand_absolute_tolerance_works(
#[case] allow_dangerous_options: bool,
#[case] value: f64,
#[case] expected_valid: bool,
) {
let flow = Flow::new(value);
let result = check_remaining_demand_absolute_tolerance(allow_dangerous_options, flow);
assert_validation_result(
result,
expected_valid,
value,
"remaining_demand_absolute_tolerance must be a finite number greater than or equal to zero",
);
}
#[rstest]
#[case(0.0)] #[case(1e-10)] #[case(1.0)] #[case(f64::MAX)] fn check_remaining_demand_absolute_tolerance_requires_dangerous_options_if_non_default(
#[case] value: f64,
) {
let flow = Flow::new(value);
let result = check_remaining_demand_absolute_tolerance(false, flow);
assert_validation_result(
result,
false,
value,
"Setting a remaining_demand_absolute_tolerance different from the default value of \
1e-12 is potentially dangerous, set please_give_me_broken_results to true if you want \
to allow this.",
);
}
#[rstest]
#[case(0.0, true)] #[case(0.2, true)] #[case(10.0, true)] #[case(-1e-6, false)] #[case(f64::INFINITY, false)] #[case(f64::NEG_INFINITY, false)] #[case(f64::NAN, false)] fn check_capacity_margin_works(#[case] value: f64, #[case] expected_valid: bool) {
let result = check_capacity_margin(value);
assert_validation_result(
result,
expected_valid,
value,
"capacity_margin must be a finite number greater than or equal to zero",
);
}
}