use cobre_core::ScalarParameter;
use cobre_io::config::StoppingRuleConfig;
use crate::{
InflowNonNegativityMethod, SddpError,
cut_selection::{CutSelectionStrategy, parse_cut_selection_config},
stopping_rule::{StoppingMode, StoppingRule, StoppingRuleSet},
};
pub const DEFAULT_FORWARD_PASSES: u32 = 1;
pub const DEFAULT_MAX_ITERATIONS: u64 = 100;
pub const DEFAULT_SEED: u64 = 42;
#[derive(Debug, Clone)]
pub struct StudyParams {
pub seed: u64,
pub forward_passes: u32,
pub stopping_rule_set: StoppingRuleSet,
pub n_scenarios: u32,
pub io_channel_capacity: usize,
pub policy_path: String,
pub inflow_method: InflowNonNegativityMethod,
pub cut_selection: Option<CutSelectionStrategy>,
pub cut_activity_tolerance: f64,
pub budget: Option<u32>,
}
impl StudyParams {
pub fn from_config(config: &cobre_io::Config) -> Result<Self, SddpError> {
let seed = config
.training
.tree_seed
.map_or(DEFAULT_SEED, i64::unsigned_abs);
let forward_passes = config
.training
.forward_passes
.unwrap_or(DEFAULT_FORWARD_PASSES);
let rule_configs = match &config.training.stopping_rules {
Some(rules) if !rules.is_empty() => rules.clone(),
_ => vec![StoppingRuleConfig::IterationLimit {
limit: u32::try_from(DEFAULT_MAX_ITERATIONS).unwrap_or(u32::MAX),
}],
};
let stopping_rules: Vec<StoppingRule> = rule_configs
.into_iter()
.map(|c| match c {
StoppingRuleConfig::IterationLimit { limit } => Ok(StoppingRule::IterationLimit {
limit: u64::from(limit),
}),
StoppingRuleConfig::TimeLimit { seconds } => {
Ok(StoppingRule::TimeLimit { seconds })
}
StoppingRuleConfig::BoundStalling {
iterations,
tolerance,
} => Ok(StoppingRule::BoundStalling {
iterations: u64::from(iterations),
tolerance,
}),
StoppingRuleConfig::Simulation { .. } => Err(SddpError::Validation(
"simulation-based stopping rule is not yet implemented; \
use iteration_limit, time_limit, or bound_stalling"
.to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
let stopping_mode = if config.training.stopping_mode.eq_ignore_ascii_case("all") {
StoppingMode::All
} else {
StoppingMode::Any
};
let stopping_rule_set = StoppingRuleSet {
rules: stopping_rules,
mode: stopping_mode,
};
let n_scenarios = if config.simulation.enabled {
config.simulation.num_scenarios
} else {
0
};
let io_channel_capacity =
usize::try_from(config.simulation.io_channel_capacity).unwrap_or(64);
let policy_path = config.policy.path.clone();
let inflow_method = InflowNonNegativityMethod::from(&config.modeling.inflow_non_negativity);
let cut_selection = parse_cut_selection_config(&config.training.cut_selection)
.map_err(|msg| SddpError::Validation(format!("cut_selection config error: {msg}")))?;
let cut_activity_tolerance = config
.training
.cut_selection
.row_activity_tolerance
.unwrap_or(0.0);
let budget = config.training.cut_selection.max_active_per_stage;
if let Some(b) = budget {
if u64::from(b) < u64::from(forward_passes) {
tracing::warn!(
"max_active_per_stage ({b}) is less than forward_passes \
({forward_passes}); budget enforcement will evict all \
non-current-iteration cuts every iteration"
);
}
}
Ok(Self {
seed,
forward_passes,
stopping_rule_set,
n_scenarios,
io_channel_capacity,
policy_path,
inflow_method,
cut_selection,
cut_activity_tolerance,
budget,
})
}
#[must_use]
pub fn into_construction_config(self) -> ConstructionConfig {
ConstructionConfig {
seed: self.seed,
forward_passes: self.forward_passes,
stopping_rule_set: self.stopping_rule_set,
n_scenarios: self.n_scenarios,
io_channel_capacity: self.io_channel_capacity,
policy_path: self.policy_path,
inflow_method: self.inflow_method,
cut_selection: self.cut_selection,
cut_activity_tolerance: self.cut_activity_tolerance,
budget: self.budget,
export_states: false,
scalar_parameters: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConstructionConfig {
pub seed: u64,
pub forward_passes: u32,
pub stopping_rule_set: StoppingRuleSet,
pub n_scenarios: u32,
pub io_channel_capacity: usize,
pub policy_path: String,
pub inflow_method: InflowNonNegativityMethod,
pub cut_selection: Option<CutSelectionStrategy>,
pub cut_activity_tolerance: f64,
pub budget: Option<u32>,
pub export_states: bool,
pub scalar_parameters: Vec<ScalarParameter>,
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::sync::{Arc, Mutex};
use cobre_io::config::{
Config, EstimationConfig, ExportsConfig, InflowNonNegativityConfig,
InflowNonNegativityMethod as CfgInflowMethod, ModelingConfig, PolicyConfig,
RowSelectionConfig, SimulationConfig as IoSimulationConfig, StoppingRuleConfig,
TrainingConfig, TrainingSolverConfig, UpperBoundEvaluationConfig,
};
use tracing::{Event, Level, Metadata, Subscriber, span};
use super::StudyParams;
struct WarnRecorder {
messages: Arc<Mutex<Vec<String>>>,
}
impl WarnRecorder {
fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
let messages = Arc::new(Mutex::new(Vec::new()));
(
Self {
messages: Arc::clone(&messages),
},
messages,
)
}
}
impl Subscriber for WarnRecorder {
fn enabled(&self, metadata: &Metadata<'_>) -> bool {
*metadata.level() <= Level::WARN
}
fn new_span(&self, _attrs: &span::Attributes<'_>) -> span::Id {
span::Id::from_u64(1)
}
fn record(&self, _span: &span::Id, _values: &span::Record<'_>) {}
fn record_follows_from(&self, _span: &span::Id, _follows: &span::Id) {}
fn event(&self, event: &Event<'_>) {
if *event.metadata().level() == Level::WARN {
struct MessageVisitor(String);
impl tracing::field::Visit for MessageVisitor {
fn record_debug(
&mut self,
field: &tracing::field::Field,
value: &dyn std::fmt::Debug,
) {
if field.name() == "message" {
self.0 = format!("{value:?}");
}
}
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
if field.name() == "message" {
self.0 = value.to_string();
}
}
}
let mut visitor = MessageVisitor(String::new());
event.record(&mut visitor);
self.messages.lock().unwrap().push(visitor.0);
}
}
fn enter(&self, _span: &span::Id) {}
fn exit(&self, _span: &span::Id) {}
}
fn config_with_budget_below_forward_passes() -> Config {
Config {
schema: None,
modeling: ModelingConfig {
inflow_non_negativity: InflowNonNegativityConfig {
method: CfgInflowMethod::Penalty,
},
},
training: TrainingConfig {
enabled: true,
tree_seed: Some(42),
forward_passes: Some(2),
stopping_rules: Some(vec![StoppingRuleConfig::IterationLimit { limit: 1 }]),
stopping_mode: "any".to_string(),
cut_selection: RowSelectionConfig {
max_active_per_stage: Some(1),
..RowSelectionConfig::default()
},
solver: TrainingSolverConfig::default(),
scenario_source: None,
},
upper_bound_evaluation: UpperBoundEvaluationConfig::default(),
policy: PolicyConfig::default(),
simulation: IoSimulationConfig::default(),
exports: ExportsConfig::default(),
estimation: EstimationConfig::default(),
}
}
fn config_with_simulation_stopping_rule() -> Config {
Config {
schema: None,
modeling: ModelingConfig {
inflow_non_negativity: InflowNonNegativityConfig {
method: CfgInflowMethod::Penalty,
},
},
training: TrainingConfig {
enabled: true,
tree_seed: Some(42),
forward_passes: Some(1),
stopping_rules: Some(vec![StoppingRuleConfig::Simulation {
replications: 100,
period: 12,
bound_window: 10,
distance_tol: 0.05,
bound_tol: 0.01,
}]),
stopping_mode: "any".to_string(),
cut_selection: RowSelectionConfig::default(),
solver: TrainingSolverConfig::default(),
scenario_source: None,
},
upper_bound_evaluation: UpperBoundEvaluationConfig::default(),
policy: PolicyConfig::default(),
simulation: IoSimulationConfig::default(),
exports: ExportsConfig::default(),
estimation: EstimationConfig::default(),
}
}
#[test]
fn from_config_rejects_simulation_stopping_rule() {
use crate::SddpError;
let err = StudyParams::from_config(&config_with_simulation_stopping_rule())
.expect_err("Simulation stopping rule must be rejected");
assert!(
matches!(err, SddpError::Validation(_)),
"expected SddpError::Validation, got: {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("simulation-based stopping rule"),
"error message must mention 'simulation-based stopping rule'; got: {msg}"
);
assert!(
msg.contains("not yet implemented"),
"error message must say 'not yet implemented'; got: {msg}"
);
}
#[test]
fn study_params_warns_when_budget_below_forward_passes() {
let (subscriber, messages) = WarnRecorder::new();
tracing::subscriber::with_default(subscriber, || {
let _params = StudyParams::from_config(&config_with_budget_below_forward_passes())
.expect("config is valid; warning must not prevent construction");
});
let recorded = messages.lock().unwrap();
let relevant: Vec<&str> = recorded
.iter()
.map(std::string::String::as_str)
.filter(|msg| msg.contains("max_active_per_stage"))
.collect();
assert!(
!relevant.is_empty(),
"expected at least one WARN event containing 'max_active_per_stage', got: {recorded:?}"
);
}
}