use chrono::NaiveDate;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum BlockMode {
Parallel,
Chronological,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SeasonCycleType {
Monthly,
Weekly,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum NoiseMethod {
Saa,
Lhs,
QmcSobol,
QmcHalton,
Selective,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PolicyGraphType {
FiniteHorizon,
Cyclic,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Block {
pub index: usize,
pub name: String,
pub duration_hours: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StageStateConfig {
pub storage: bool,
pub inflow_lags: bool,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum StageRiskConfig {
Expectation,
CVaR {
alpha: f64,
lambda: f64,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ScenarioSourceConfig {
pub branching_factor: usize,
pub noise_method: NoiseMethod,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Stage {
pub index: usize,
pub id: i32,
pub start_date: NaiveDate,
pub end_date: NaiveDate,
pub season_id: Option<usize>,
pub blocks: Vec<Block>,
pub block_mode: BlockMode,
pub state_config: StageStateConfig,
pub risk_config: StageRiskConfig,
pub scenario_config: ScenarioSourceConfig,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SeasonDefinition {
pub id: usize,
pub label: String,
pub month_start: u32,
pub day_start: Option<u32>,
pub month_end: Option<u32>,
pub day_end: Option<u32>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SeasonMap {
pub cycle_type: SeasonCycleType,
pub seasons: Vec<SeasonDefinition>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Transition {
pub source_id: i32,
pub target_id: i32,
pub probability: f64,
pub annual_discount_rate_override: Option<f64>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PolicyGraph {
pub graph_type: PolicyGraphType,
pub annual_discount_rate: f64,
pub transitions: Vec<Transition>,
pub season_map: Option<SeasonMap>,
}
impl Default for PolicyGraph {
fn default() -> Self {
Self {
graph_type: PolicyGraphType::FiniteHorizon,
annual_discount_rate: 0.0,
transitions: Vec::new(),
season_map: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_mode_copy() {
let original = BlockMode::Parallel;
let copied = original;
assert_eq!(original, BlockMode::Parallel);
assert_eq!(copied, BlockMode::Parallel);
let chrono = BlockMode::Chronological;
let copied_chrono = chrono;
assert_eq!(chrono, BlockMode::Chronological);
assert_eq!(copied_chrono, BlockMode::Chronological);
}
#[test]
fn test_stage_duration() {
let stage = Stage {
index: 0,
id: 1,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "SINGLE".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: 50,
noise_method: NoiseMethod::Saa,
},
};
assert_eq!(
stage.end_date - stage.start_date,
chrono::TimeDelta::days(31)
);
}
#[test]
fn test_policy_graph_construction() {
let transitions = vec![
Transition {
source_id: 1,
target_id: 2,
probability: 1.0,
annual_discount_rate_override: None,
},
Transition {
source_id: 2,
target_id: 3,
probability: 1.0,
annual_discount_rate_override: Some(0.08),
},
Transition {
source_id: 3,
target_id: 4,
probability: 1.0,
annual_discount_rate_override: None,
},
];
let graph = PolicyGraph {
graph_type: PolicyGraphType::FiniteHorizon,
annual_discount_rate: 0.06,
transitions,
season_map: None,
};
assert_eq!(graph.graph_type, PolicyGraphType::FiniteHorizon);
assert!((graph.annual_discount_rate - 0.06).abs() < f64::EPSILON);
assert_eq!(graph.transitions.len(), 3);
assert_eq!(
graph.transitions[1].annual_discount_rate_override,
Some(0.08)
);
assert!(graph.season_map.is_none());
}
#[test]
fn test_season_map_construction() {
let months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
];
let seasons: Vec<SeasonDefinition> = months
.iter()
.enumerate()
.map(|(i, &label)| SeasonDefinition {
id: i,
label: label.to_string(),
month_start: u32::try_from(i + 1).unwrap(),
day_start: None,
month_end: None,
day_end: None,
})
.collect();
let season_map = SeasonMap {
cycle_type: SeasonCycleType::Monthly,
seasons,
};
assert_eq!(season_map.cycle_type, SeasonCycleType::Monthly);
assert_eq!(season_map.seasons.len(), 12);
assert_eq!(season_map.seasons[0].label, "January");
assert_eq!(season_map.seasons[11].label, "December");
assert_eq!(season_map.seasons[0].month_start, 1);
assert_eq!(season_map.seasons[11].month_start, 12);
}
#[cfg(feature = "serde")]
#[test]
fn test_policy_graph_serde_roundtrip() {
let graph = PolicyGraph {
graph_type: PolicyGraphType::FiniteHorizon,
annual_discount_rate: 0.06,
transitions: vec![
Transition {
source_id: 1,
target_id: 2,
probability: 1.0,
annual_discount_rate_override: None,
},
Transition {
source_id: 2,
target_id: 3,
probability: 1.0,
annual_discount_rate_override: None,
},
],
season_map: None,
};
let json = serde_json::to_string(&graph).unwrap();
assert!(
json.contains("\"graph_type\":\"FiniteHorizon\""),
"JSON did not contain expected graph_type: {json}"
);
assert!(
json.contains("\"annual_discount_rate\":0.06"),
"JSON did not contain expected annual_discount_rate: {json}"
);
let deserialized: PolicyGraph = serde_json::from_str(&json).unwrap();
assert_eq!(graph, deserialized);
}
}