use serde::Deserialize;
use serde::Serialize;
use crate::cc::config::StateUpdateConfig;
use crate::cc::transition::StateTransition;
use crate::cc::transition::DEFAULT_STATE_TRANSITIONS;
use crate::metadata::SerializedType;
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct SaveEngineConfig {
pub path: std::path::PathBuf,
pub ser_type: SerializedType,
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct EngineUpdateConfig {
pub n_iters: usize,
#[serde(default)]
pub save_config: Option<SaveEngineConfig>,
pub transitions: Vec<StateTransition>,
#[serde(default)]
pub checkpoint: Option<usize>,
}
impl EngineUpdateConfig {
pub fn new() -> Self {
Self {
n_iters: 1,
transitions: Vec::new(),
save_config: None,
checkpoint: None,
}
}
pub fn with_default_transitions() -> Self {
Self::new().default_transitions()
}
pub fn default_transitions(mut self) -> Self {
self.transitions = DEFAULT_STATE_TRANSITIONS.into();
self
}
pub fn transitions(mut self, transitions: Vec<StateTransition>) -> Self {
self.transitions.extend(transitions);
self
}
pub fn transition(mut self, transition: StateTransition) -> Self {
self.transitions.push(transition);
self
}
pub fn state_config(&self) -> StateUpdateConfig {
StateUpdateConfig {
n_iters: self.n_iters,
transitions: self.transitions.clone(),
}
}
pub fn n_iters(mut self, n_iters: usize) -> Self {
self.n_iters = n_iters;
self
}
pub fn checkpoint(mut self, checkpoint: Option<usize>) -> Self {
self.checkpoint = checkpoint;
self
}
}
impl Default for EngineUpdateConfig {
fn default() -> Self {
Self::new()
}
}