use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use cobre_core::TrainingEvent;
use crate::cut_selection::CutSelectionStrategy;
use crate::risk_measure::RiskMeasure;
use crate::stopping_rule::{StoppingMode, StoppingRule, StoppingRuleSet};
#[derive(Debug)]
pub struct LoopParams {
pub seed: u64,
pub forward_passes: u32,
pub max_iterations: u64,
pub(crate) start_iteration: u64,
pub(crate) max_blocks: usize,
pub(crate) stopping_rules: StoppingRuleSet,
}
#[derive(Debug)]
pub struct LoopConfig {
pub forward_passes: u32,
pub max_iterations: u64,
pub start_iteration: u64,
pub n_fwd_threads: usize,
pub max_blocks: usize,
pub stopping_rules: StoppingRuleSet,
}
impl Default for LoopConfig {
fn default() -> Self {
Self {
forward_passes: 1,
max_iterations: 1,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: StoppingRuleSet {
rules: vec![StoppingRule::IterationLimit { limit: 1 }],
mode: StoppingMode::Any,
},
}
}
}
#[derive(Debug)]
pub struct CutManagementConfig {
pub cut_selection: Option<CutSelectionStrategy>,
pub budget: Option<u32>,
pub cut_activity_tolerance: f64,
pub warm_start_cuts: u32,
pub risk_measures: Vec<RiskMeasure>,
}
impl Default for CutManagementConfig {
fn default() -> Self {
Self {
cut_selection: None,
budget: None,
cut_activity_tolerance: 1e-6,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation],
}
}
}
#[derive(Debug, Default)]
pub struct EventConfig {
pub event_sender: Option<std::sync::mpsc::Sender<TrainingEvent>>,
pub checkpoint_interval: Option<u64>,
pub shutdown_flag: Option<Arc<AtomicBool>>,
pub export_states: bool,
}
#[derive(Debug)]
pub(crate) struct EventParams {
pub(crate) export_states: bool,
}
#[derive(Debug)]
pub struct TrainingConfig {
pub loop_config: LoopConfig,
pub cut_management: CutManagementConfig,
pub events: EventConfig,
}
#[cfg(test)]
mod tests {
use super::{CutManagementConfig, EventConfig, LoopConfig, TrainingConfig};
use cobre_core::TrainingEvent;
#[test]
fn field_access_forward_passes_and_max_iterations() {
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 10,
max_iterations: 100,
..LoopConfig::default()
},
cut_management: CutManagementConfig::default(),
events: EventConfig::default(),
};
assert_eq!(config.loop_config.forward_passes, 10);
assert_eq!(config.loop_config.max_iterations, 100);
}
#[test]
fn checkpoint_interval_none_and_some() {
let config_none = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 5,
max_iterations: 50,
..LoopConfig::default()
},
cut_management: CutManagementConfig::default(),
events: EventConfig::default(),
};
assert!(config_none.events.checkpoint_interval.is_none());
let config_some = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 5,
max_iterations: 50,
..LoopConfig::default()
},
cut_management: CutManagementConfig::default(),
events: EventConfig {
checkpoint_interval: Some(10),
..EventConfig::default()
},
};
assert_eq!(config_some.events.checkpoint_interval, Some(10));
}
#[test]
fn warm_start_cuts_field_accessible() {
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
..LoopConfig::default()
},
cut_management: CutManagementConfig {
warm_start_cuts: 500,
..CutManagementConfig::default()
},
events: EventConfig::default(),
};
assert_eq!(config.cut_management.warm_start_cuts, 500);
}
#[test]
fn event_sender_none() {
let config = TrainingConfig {
loop_config: LoopConfig::default(),
cut_management: CutManagementConfig::default(),
events: EventConfig::default(),
};
assert!(config.events.event_sender.is_none());
}
#[test]
fn event_sender_some_can_send_training_event() {
let (tx, rx) = std::sync::mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 4,
max_iterations: 200,
..LoopConfig::default()
},
cut_management: CutManagementConfig {
warm_start_cuts: 100,
cut_activity_tolerance: 1e-6,
..CutManagementConfig::default()
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: Some(50),
..EventConfig::default()
},
};
assert!(config.events.event_sender.is_some());
if let Some(sender) = &config.events.event_sender {
sender
.send(TrainingEvent::TrainingFinished {
reason: "test".to_string(),
iterations: 1,
final_lb: 0.0,
final_ub: 1.0,
total_time_ms: 100,
total_rows: 4,
})
.unwrap();
}
let received = rx.recv().unwrap();
assert!(matches!(received, TrainingEvent::TrainingFinished { .. }));
}
#[test]
fn debug_output_non_empty() {
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 8,
max_iterations: 500,
..LoopConfig::default()
},
cut_management: CutManagementConfig::default(),
events: EventConfig {
checkpoint_interval: Some(100),
..EventConfig::default()
},
};
let debug = format!("{config:?}");
assert!(!debug.is_empty());
assert!(
debug.contains("forward_passes"),
"debug must contain field name: {debug}"
);
assert!(
debug.contains("max_iterations"),
"debug must contain field name: {debug}"
);
}
}