use std::path::Path;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::director::DirectorConfig;
use crate::error::ConfigError;
use crate::phase::PhaseConfig;
use crate::termination::TerminationConfig;
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub struct SolverConfig {
#[serde(default)]
pub environment_mode: EnvironmentMode,
#[serde(default)]
pub random_seed: Option<u64>,
#[serde(default)]
pub move_thread_count: MoveThreadCount,
#[serde(default)]
pub termination: Option<TerminationConfig>,
#[serde(default)]
pub score_director: Option<DirectorConfig>,
#[serde(default)]
pub phases: Vec<PhaseConfig>,
}
impl SolverConfig {
pub fn new() -> Self {
Self::default()
}
pub fn load(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
Self::from_toml_file(path)
}
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path)?;
Self::from_toml_str(&contents)
}
pub fn from_toml_str(s: &str) -> Result<Self, ConfigError> {
Ok(toml::from_str(s)?)
}
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path)?;
Self::from_yaml_str(&contents)
}
pub fn from_yaml_str(s: &str) -> Result<Self, ConfigError> {
Ok(serde_yaml::from_str(s)?)
}
pub fn with_termination_seconds(mut self, seconds: u64) -> Self {
self.termination = Some(TerminationConfig {
seconds_spent_limit: Some(seconds),
..self.termination.unwrap_or_default()
});
self
}
pub fn with_random_seed(mut self, seed: u64) -> Self {
self.random_seed = Some(seed);
self
}
pub fn with_phase(mut self, phase: PhaseConfig) -> Self {
self.phases.push(phase);
self
}
pub fn time_limit(&self) -> Option<Duration> {
self.termination.as_ref().and_then(|t| t.time_limit())
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum EnvironmentMode {
#[default]
NonReproducible,
Reproducible,
FastAssert,
FullAssert,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MoveThreadCount {
#[default]
Auto,
None,
Count(usize),
}
#[derive(Debug, Clone, Default)]
pub struct SolverConfigOverride {
pub termination: Option<TerminationConfig>,
}
impl SolverConfigOverride {
pub fn with_termination(termination: TerminationConfig) -> Self {
SolverConfigOverride {
termination: Some(termination),
}
}
}