use crate::cdt::action::ActionConfig;
use crate::cdt::metropolis::MetropolisConfig;
use clap::Parser;
use dirs::home_dir;
use std::path::{Component, Path, PathBuf};
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
pub struct CdtConfig {
#[arg(short, long, value_parser = clap::value_parser!(u8).range(2..4))]
pub dimension: Option<u8>,
#[arg(short, long, value_parser = clap::value_parser!(u32).range(3..))]
pub vertices: u32,
#[arg(short, long, value_parser = clap::value_parser!(u32).range(1..))]
pub timeslices: u32,
#[arg(long, default_value = "1.0")]
pub temperature: f64,
#[arg(long, default_value = "1000")]
pub steps: u32,
#[arg(long, default_value = "100")]
pub thermalization_steps: u32,
#[arg(long, default_value = "10", value_parser = clap::value_parser!(u32).range(1..))]
pub measurement_frequency: u32,
#[arg(long, default_value = "1.0")]
pub coupling_0: f64,
#[arg(long, default_value = "1.0")]
pub coupling_2: f64,
#[arg(long, default_value = "0.1")]
pub cosmological_constant: f64,
#[arg(long, default_value_t = true)]
pub simulate: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DimensionOverride {
Value(u8),
Clear,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct CdtConfigOverrides {
pub dimension: Option<DimensionOverride>,
pub vertices: Option<u32>,
pub timeslices: Option<u32>,
pub temperature: Option<f64>,
pub steps: Option<u32>,
pub thermalization_steps: Option<u32>,
pub measurement_frequency: Option<u32>,
pub coupling_0: Option<f64>,
pub coupling_2: Option<f64>,
pub cosmological_constant: Option<f64>,
pub simulate: Option<bool>,
}
impl CdtConfig {
#[must_use]
pub fn merge_with_override(&self, overrides: &CdtConfigOverrides) -> Self {
let mut merged = self.clone();
if let Some(dimension_override) = overrides.dimension {
match dimension_override {
DimensionOverride::Value(value) => {
merged.dimension = Some(value);
}
DimensionOverride::Clear => {
merged.dimension = None;
}
}
}
if let Some(vertices) = overrides.vertices {
merged.vertices = vertices;
}
if let Some(timeslices) = overrides.timeslices {
merged.timeslices = timeslices;
}
if let Some(temperature) = overrides.temperature {
merged.temperature = temperature;
}
if let Some(steps) = overrides.steps {
merged.steps = steps;
}
if let Some(thermalization_steps) = overrides.thermalization_steps {
merged.thermalization_steps = thermalization_steps;
}
if let Some(measurement_frequency) = overrides.measurement_frequency {
merged.measurement_frequency = measurement_frequency;
}
if let Some(coupling_0) = overrides.coupling_0 {
merged.coupling_0 = coupling_0;
}
if let Some(coupling_2) = overrides.coupling_2 {
merged.coupling_2 = coupling_2;
}
if let Some(cosmological_constant) = overrides.cosmological_constant {
merged.cosmological_constant = cosmological_constant;
}
if let Some(simulate) = overrides.simulate {
merged.simulate = simulate;
}
merged
}
#[must_use]
pub fn resolve_path(base_dir: impl AsRef<Path>, candidate: impl AsRef<Path>) -> PathBuf {
let candidate = candidate.as_ref();
if candidate.is_absolute() {
return normalize_components(candidate);
}
if let Some(candidate_str) = candidate.to_str()
&& let Some(rest) = candidate_str.strip_prefix('~')
{
if rest.is_empty() {
if let Some(home) = home_dir() {
return normalize_components(&home);
}
} else if matches!(rest.chars().next(), Some('/' | '\\'))
&& let Some(home) = home_dir()
{
let stripped = rest.trim_start_matches(['/', '\\']);
let path = if stripped.is_empty() {
home
} else {
home.join(stripped)
};
return normalize_components(&path);
}
}
let joined = base_dir.as_ref().join(candidate);
normalize_components(&joined)
}
}
fn normalize_components(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::CurDir => {}
Component::ParentDir => {
let mut components = normalized.components();
let at_root = components.next().is_some_and(|first| {
components.next().is_none()
&& matches!(first, Component::RootDir | Component::Prefix(_))
});
if !normalized.as_os_str().is_empty() && !at_root {
normalized.pop();
}
}
Component::RootDir | Component::Prefix(_) => {
normalized.push(component.as_os_str());
}
Component::Normal(segment) => {
normalized.push(segment);
}
}
}
if normalized.as_os_str().is_empty() {
PathBuf::from(Component::CurDir.as_os_str())
} else {
normalized
}
}
impl CdtConfig {
#[must_use]
pub fn from_args() -> Self {
Self::parse()
}
#[must_use]
pub const fn new(vertices: u32, timeslices: u32) -> Self {
Self {
dimension: Some(2),
vertices,
timeslices,
temperature: 1.0,
steps: 1000,
thermalization_steps: 100,
measurement_frequency: 10,
coupling_0: 1.0,
coupling_2: 1.0,
cosmological_constant: 0.1,
simulate: true,
}
}
#[must_use]
pub const fn to_metropolis_config(&self) -> MetropolisConfig {
MetropolisConfig::new(
self.temperature,
self.steps,
self.thermalization_steps,
self.measurement_frequency,
)
}
#[must_use]
pub const fn to_action_config(&self) -> ActionConfig {
ActionConfig::new(self.coupling_0, self.coupling_2, self.cosmological_constant)
}
#[must_use]
pub const fn dimension(&self) -> u8 {
match self.dimension {
Some(d) => d,
None => 2,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.vertices < 3 {
return Err("Number of vertices must be at least 3".to_string());
}
if self.timeslices == 0 {
return Err("Number of timeslices must be at least 1".to_string());
}
if let Some(dim) = self.dimension
&& !(2..=3).contains(&dim)
{
return Err(format!(
"Unsupported dimension: {dim}. Only 2D and 3D are supported."
));
}
if self.temperature <= 0.0 {
return Err("Temperature must be positive".to_string());
}
if self.steps == 0 {
return Err("Number of steps must be positive".to_string());
}
if self.measurement_frequency == 0 {
return Err("Measurement frequency must be positive".to_string());
}
if self.measurement_frequency > self.steps {
return Err("Measurement frequency cannot be greater than total steps".to_string());
}
if self.thermalization_steps > self.steps {
return Err("Thermalization steps cannot exceed total steps".to_string());
}
let measurement_steps = self.steps.saturating_sub(self.thermalization_steps);
if measurement_steps > 0 && measurement_steps < self.measurement_frequency {
return Err(format!(
"Only {measurement_steps} steps remain after thermalization, but measurement_frequency is {}. \
No measurements will be taken. Increase steps or decrease measurement_frequency.",
self.measurement_frequency
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TestConfig;
impl TestConfig {
#[must_use]
pub const fn small() -> CdtConfig {
CdtConfig {
dimension: Some(2),
vertices: 16,
timeslices: 2,
temperature: 1.0,
steps: 10,
thermalization_steps: 2,
measurement_frequency: 2,
coupling_0: 1.0,
coupling_2: 1.0,
cosmological_constant: 0.1,
simulate: true,
}
}
#[must_use]
pub const fn medium() -> CdtConfig {
CdtConfig {
dimension: Some(2),
vertices: 64,
timeslices: 4,
temperature: 1.0,
steps: 100,
thermalization_steps: 20,
measurement_frequency: 5,
coupling_0: 1.0,
coupling_2: 1.0,
cosmological_constant: 0.1,
simulate: true,
}
}
#[must_use]
pub const fn large() -> CdtConfig {
CdtConfig {
dimension: Some(2),
vertices: 256,
timeslices: 8,
temperature: 1.0,
steps: 1000,
thermalization_steps: 100,
measurement_frequency: 10,
coupling_0: 1.0,
coupling_2: 1.0,
cosmological_constant: 0.1,
simulate: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use dirs::home_dir;
use std::path::PathBuf;
#[test]
fn test_config_new() {
let config = CdtConfig::new(32, 3);
assert_eq!(config.vertices, 32);
assert_eq!(config.timeslices, 3);
assert_eq!(config.dimension(), 2);
assert!(config.simulate);
}
#[test]
fn test_config_conversions() {
let config = CdtConfig::new(64, 4);
let metropolis_config = config.to_metropolis_config();
assert_relative_eq!(metropolis_config.temperature, 1.0);
assert_eq!(metropolis_config.steps, 1000);
let action_config = config.to_action_config();
assert_relative_eq!(action_config.coupling_0, 1.0);
assert_relative_eq!(action_config.coupling_2, 1.0);
assert_relative_eq!(action_config.cosmological_constant, 0.1);
}
#[test]
fn test_config_validation() {
let valid_config = CdtConfig::new(32, 3);
assert!(valid_config.validate().is_ok());
let invalid_vertices = CdtConfig {
vertices: 2,
..CdtConfig::new(32, 3)
};
assert!(invalid_vertices.validate().is_err());
let invalid_timeslices = CdtConfig {
timeslices: 0,
..CdtConfig::new(32, 3)
};
assert!(invalid_timeslices.validate().is_err());
let invalid_temperature = CdtConfig {
temperature: -1.0,
..CdtConfig::new(32, 3)
};
assert!(invalid_temperature.validate().is_err());
let invalid_measurement_frequency = CdtConfig {
measurement_frequency: 0,
..CdtConfig::new(32, 3)
};
assert!(invalid_measurement_frequency.validate().is_err());
let invalid_steps = CdtConfig {
steps: 0,
..CdtConfig::new(32, 3)
};
assert!(invalid_steps.validate().is_err());
let invalid_dimension = CdtConfig {
dimension: Some(4),
..CdtConfig::new(32, 3)
};
let error = invalid_dimension.validate().unwrap_err();
assert!(
error.contains("Unsupported dimension"),
"unexpected validation error: {error}"
);
let measurement_frequency_exceeds_steps = CdtConfig {
measurement_frequency: 2_000,
..CdtConfig::new(32, 3)
};
assert!(measurement_frequency_exceeds_steps.validate().is_err());
let zero_measurement_steps = CdtConfig {
steps: 10,
thermalization_steps: 10,
measurement_frequency: 5,
..CdtConfig::new(32, 3)
};
assert!(
zero_measurement_steps.validate().is_ok(),
"Configurations with zero post-thermalization steps but valid frequencies should pass validation"
);
let insufficient_measurements = CdtConfig {
steps: 20,
thermalization_steps: 15,
measurement_frequency: 10,
..CdtConfig::new(32, 3)
};
let error = insufficient_measurements.validate().unwrap_err();
assert!(
error.contains("No measurements will be taken"),
"Unexpected validation error: {error}"
);
}
#[test]
fn test_dimension_defaults_to_two_when_unspecified() {
let config = CdtConfig {
dimension: None,
..CdtConfig::new(32, 3)
};
assert_eq!(config.dimension(), 2);
}
#[test]
fn test_preset_configs() {
let small = TestConfig::small();
assert!(small.validate().is_ok());
assert_eq!(small.vertices, 16);
assert_eq!(small.steps, 10);
let medium = TestConfig::medium();
assert!(medium.validate().is_ok());
assert_eq!(medium.vertices, 64);
assert_eq!(medium.steps, 100);
let large = TestConfig::large();
assert!(large.validate().is_ok());
assert_eq!(large.vertices, 256);
assert_eq!(large.steps, 1000);
}
#[test]
fn test_merge_with_override_updates_specified_fields() {
let base = CdtConfig::new(10, 2);
let overrides = CdtConfigOverrides {
dimension: Some(DimensionOverride::Value(3)),
vertices: Some(42),
temperature: Some(2.5),
simulate: Some(false),
..CdtConfigOverrides::default()
};
let merged = base.merge_with_override(&overrides);
assert_eq!(merged.dimension(), 3);
assert_eq!(merged.vertices, 42);
assert_relative_eq!(merged.temperature, 2.5);
assert!(!merged.simulate);
assert_eq!(merged.timeslices, base.timeslices);
assert_eq!(merged.steps, base.steps);
}
#[test]
fn test_merge_with_override_can_clear_dimension() {
let base = CdtConfig::new(10, 2);
let overrides = CdtConfigOverrides {
dimension: Some(DimensionOverride::Clear),
..CdtConfigOverrides::default()
};
let merged = base.merge_with_override(&overrides);
assert_eq!(merged.dimension, None);
assert_eq!(merged.dimension(), 2); }
#[test]
fn test_resolve_path_with_absolute_path() {
let abs = PathBuf::from("/tmp/example");
let resolved = CdtConfig::resolve_path("/does/not/matter", &abs);
assert_eq!(resolved, PathBuf::from("/tmp/example"));
}
#[test]
fn test_resolve_path_with_relative_path() {
let base = PathBuf::from("/tmp/base");
let candidate = PathBuf::from("config/settings.toml");
let resolved = CdtConfig::resolve_path(&base, &candidate);
assert_eq!(resolved, PathBuf::from("/tmp/base/config/settings.toml"));
}
#[test]
fn test_resolve_path_with_home_expansion() {
let home = home_dir().expect("Home directory must be resolvable for this test");
let resolved = CdtConfig::resolve_path("/tmp", PathBuf::from("~/config.toml"));
assert_eq!(resolved, home.join("config.toml"));
}
#[test]
fn test_resolve_path_normalizes_navigation_components() {
let base = PathBuf::from("/tmp/base");
let candidate = PathBuf::from("configs/../settings.toml");
let resolved = CdtConfig::resolve_path(&base, candidate);
assert_eq!(resolved, PathBuf::from("/tmp/base/settings.toml"));
}
#[test]
fn test_resolve_path_cannot_escape_root() {
let candidate = PathBuf::from("/../etc/passwd");
let resolved = CdtConfig::resolve_path("/tmp", candidate);
assert_eq!(resolved, PathBuf::from("/etc/passwd"));
}
}