use cobre_core::TrainingEvent;
#[derive(Debug)]
pub struct TrainingConfig {
pub forward_passes: u32,
pub max_iterations: u64,
pub checkpoint_interval: Option<u64>,
pub warm_start_cuts: u32,
pub event_sender: Option<std::sync::mpsc::Sender<TrainingEvent>>,
}
#[cfg(test)]
mod tests {
use super::TrainingConfig;
use cobre_core::TrainingEvent;
#[test]
fn field_access_forward_passes_and_max_iterations() {
let config = TrainingConfig {
forward_passes: 10,
max_iterations: 100,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
assert_eq!(config.forward_passes, 10);
assert_eq!(config.max_iterations, 100);
}
#[test]
fn checkpoint_interval_none_and_some() {
let config_none = TrainingConfig {
forward_passes: 5,
max_iterations: 50,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
assert!(config_none.checkpoint_interval.is_none());
let config_some = TrainingConfig {
forward_passes: 5,
max_iterations: 50,
checkpoint_interval: Some(10),
warm_start_cuts: 0,
event_sender: None,
};
assert_eq!(config_some.checkpoint_interval, Some(10));
}
#[test]
fn warm_start_cuts_field_accessible() {
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 500,
event_sender: None,
};
assert_eq!(config.warm_start_cuts, 500);
}
#[test]
fn event_sender_none() {
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 1,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
assert!(config.event_sender.is_none());
}
#[test]
fn event_sender_some_can_send_training_event() {
let (tx, rx) = std::sync::mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 4,
max_iterations: 200,
checkpoint_interval: Some(50),
warm_start_cuts: 100,
event_sender: Some(tx),
};
assert!(config.event_sender.is_some());
if let Some(sender) = &config.event_sender {
sender
.send(TrainingEvent::TrainingFinished {
reason: "test".to_string(),
iterations: 1,
final_lb: 0.0,
final_ub: 1.0,
total_time_ms: 100,
total_cuts: 4,
})
.unwrap();
}
let received = rx.recv().unwrap();
assert!(matches!(received, TrainingEvent::TrainingFinished { .. }));
}
#[test]
fn debug_output_non_empty() {
let config = TrainingConfig {
forward_passes: 8,
max_iterations: 500,
checkpoint_interval: Some(100),
warm_start_cuts: 0,
event_sender: None,
};
let debug = format!("{config:?}");
assert!(!debug.is_empty());
assert!(
debug.contains("forward_passes"),
"debug must contain field name: {debug}"
);
assert!(
debug.contains("max_iterations"),
"debug must contain field name: {debug}"
);
}
}