#![forbid(unsafe_code)]
use crate::cdt::action::ActionConfig;
use crate::cdt::ergodic_moves::{ErgodicsSystem, MoveStatistics, MoveType};
use crate::cdt::results::{
CdtScalarTraceRow, Measurement, SimulationResultsBackend, SimulationResultsParts,
validate_scalar_trace_rows,
};
use crate::errors::{
CdtError, CdtResult, CheckpointMoveCounter, CheckpointResumeFailure, ProposalTelemetryCounter,
};
use crate::geometry::CdtTriangulation2D;
use markov_chain_monte_carlo::ChainCheckpoint;
use rand::rngs::Xoshiro256PlusPlus;
use serde::de::Error as DeError;
use serde::{Deserialize, Deserializer, Serialize};
use std::num::NonZeroU32;
use std::time::Duration;
use super::helpers::{
action_for, actions_match, expected_measurement_count, expected_measurement_step,
};
use super::runner::{MetropolisAlgorithm, MetropolisConfig};
use super::telemetry::{MonteCarloStep, ProposalStatistics};
pub(crate) struct CdtMcmcCheckpointParts {
pub(crate) triangulation: CdtTriangulation2D,
pub(crate) accepted: usize,
pub(crate) rejected: usize,
pub(crate) config: MetropolisConfig,
pub(crate) action_config: ActionConfig,
pub(crate) current_step: NonZeroU32,
pub(crate) current_action: f64,
pub(crate) move_stats: MoveStatistics,
pub(crate) proposal_stats: ProposalStatistics,
pub(crate) steps: Vec<MonteCarloStep>,
pub(crate) measurements: Vec<Measurement>,
pub(crate) scalar_trace_rows: Vec<CdtScalarTraceRow>,
pub(crate) elapsed_time: Duration,
pub(crate) acceptance_rng: Xoshiro256PlusPlus,
pub(crate) ergodics: ErgodicsSystem,
}
#[derive(Clone, Copy, Debug, PartialEq, Serialize)]
struct CheckpointAction(f64);
impl CheckpointAction {
const fn new(value: f64) -> CdtResult<Self> {
if value.is_finite() {
Ok(Self(value))
} else {
Err(checkpoint_resume_failed(
CheckpointResumeFailure::NonFiniteCheckpointAction { stored: value },
))
}
}
const fn get(self) -> f64 {
self.0
}
}
#[derive(Clone, Serialize)]
pub struct CdtMcmcCheckpoint {
pub(crate) chain: ChainCheckpoint<CdtTriangulation2D>,
pub(crate) config: MetropolisConfig,
pub(crate) action_config: ActionConfig,
pub(crate) current_step: NonZeroU32,
current_action: CheckpointAction,
pub(crate) move_stats: MoveStatistics,
#[serde(default)]
pub(crate) proposal_stats: ProposalStatistics,
pub(crate) steps: Vec<MonteCarloStep>,
pub(crate) measurements: Vec<Measurement>,
pub(crate) scalar_trace_rows: Vec<CdtScalarTraceRow>,
pub(crate) elapsed_time: Duration,
pub(crate) acceptance_rng: Xoshiro256PlusPlus,
pub(crate) ergodics: ErgodicsSystem,
}
#[derive(Deserialize)]
struct CdtMcmcCheckpointWire {
chain: ChainCheckpoint<CdtTriangulation2D>,
config: MetropolisConfig,
action_config: ActionConfig,
current_step: u32,
current_action: f64,
move_stats: MoveStatistics,
#[serde(default)]
proposal_stats: ProposalStatistics,
steps: Vec<MonteCarloStep>,
measurements: Vec<Measurement>,
scalar_trace_rows: Vec<CdtScalarTraceRow>,
elapsed_time: Duration,
acceptance_rng: Xoshiro256PlusPlus,
ergodics: ErgodicsSystem,
}
impl<'de> Deserialize<'de> for CdtMcmcCheckpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire = CdtMcmcCheckpointWire::deserialize(deserializer)?;
let current_step = NonZeroU32::new(wire.current_step)
.ok_or_else(|| DeError::custom("checkpoint current_step must be nonzero"))?;
let current_action = CheckpointAction::new(wire.current_action).map_err(DeError::custom)?;
let checkpoint = Self {
chain: wire.chain,
config: wire.config,
action_config: wire.action_config,
current_step,
current_action,
move_stats: wire.move_stats,
proposal_stats: wire.proposal_stats,
steps: wire.steps,
measurements: wire.measurements,
scalar_trace_rows: wire.scalar_trace_rows,
elapsed_time: wire.elapsed_time,
acceptance_rng: wire.acceptance_rng,
ergodics: wire.ergodics,
};
validate_checkpoint_counters(&checkpoint).map_err(DeError::custom)?;
Ok(checkpoint)
}
}
impl CdtMcmcCheckpoint {
pub(crate) fn from_parts(parts: CdtMcmcCheckpointParts) -> CdtResult<Self> {
let current_action = CheckpointAction::new(parts.current_action)?;
let checkpoint = Self {
chain: ChainCheckpoint::new(parts.triangulation, parts.accepted, parts.rejected),
config: parts.config,
action_config: parts.action_config,
current_step: parts.current_step,
current_action,
move_stats: parts.move_stats,
proposal_stats: parts.proposal_stats,
steps: parts.steps,
measurements: parts.measurements,
scalar_trace_rows: parts.scalar_trace_rows,
elapsed_time: parts.elapsed_time,
acceptance_rng: parts.acceptance_rng,
ergodics: parts.ergodics,
};
validate_checkpoint_counters(&checkpoint)?;
Ok(checkpoint)
}
pub const fn chain(&self) -> &ChainCheckpoint<CdtTriangulation2D> {
&self.chain
}
#[must_use]
pub const fn triangulation(&self) -> &CdtTriangulation2D {
self.chain.state()
}
#[must_use]
pub const fn config(&self) -> &MetropolisConfig {
&self.config
}
#[must_use]
pub const fn action_config(&self) -> &ActionConfig {
&self.action_config
}
#[must_use]
pub const fn current_step(&self) -> NonZeroU32 {
self.current_step
}
#[must_use]
pub const fn current_action(&self) -> f64 {
self.current_action.get()
}
#[must_use]
pub const fn move_stats(&self) -> &MoveStatistics {
&self.move_stats
}
#[must_use]
pub const fn proposal_stats(&self) -> &ProposalStatistics {
&self.proposal_stats
}
#[must_use]
pub fn steps(&self) -> &[MonteCarloStep] {
&self.steps
}
#[must_use]
pub fn measurements(&self) -> &[Measurement] {
&self.measurements
}
#[must_use]
pub fn into_results(self) -> SimulationResultsBackend {
let (triangulation, _, _) = self.chain.into_parts();
SimulationResultsBackend::from_parts(SimulationResultsParts {
config: self.config,
action_config: self.action_config,
move_stats: self.move_stats,
proposal_stats: self.proposal_stats,
steps: self.steps,
measurements: self.measurements,
scalar_trace_rows: self.scalar_trace_rows,
elapsed_time: self.elapsed_time,
triangulation,
})
}
}
pub(crate) const fn checkpoint_resume_failed(failure: CheckpointResumeFailure) -> CdtError {
CdtError::CheckpointResumeFailed { failure }
}
pub(crate) fn validate_resume_compatible(
algorithm: &MetropolisAlgorithm,
checkpoint: &CdtMcmcCheckpoint,
) -> CdtResult<()> {
if !action_configs_match(algorithm.action_config(), &checkpoint.action_config) {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::IncompatibleActionConfiguration,
));
}
if algorithm.config().temperature().to_bits() != checkpoint.config.temperature().to_bits() {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::IncompatibleTemperature,
));
}
if algorithm.config().thermalization_steps() != checkpoint.config.thermalization_steps() {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::IncompatibleThermalizationSchedule,
));
}
if algorithm.config().measurement_frequency() != checkpoint.config.measurement_frequency() {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::IncompatibleMeasurementFrequency,
));
}
validate_checkpoint_counters(checkpoint)
}
fn action_configs_match(left: &ActionConfig, right: &ActionConfig) -> bool {
actions_match(left.coupling_0(), right.coupling_0())
&& actions_match(left.coupling_2(), right.coupling_2())
&& actions_match(left.cosmological_constant(), right.cosmological_constant())
}
pub(crate) fn validate_checkpoint_counters(checkpoint: &CdtMcmcCheckpoint) -> CdtResult<()> {
checkpoint.config.validate();
checkpoint.action_config.validate();
validate_checkpoint_current_action(checkpoint)?;
let (accepted, rejected) = chain_counters(&checkpoint.move_stats)?;
if checkpoint.chain.accepted() != accepted || checkpoint.chain.rejected() != rejected {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ChainCounterMismatch {
chain_accepted: checkpoint.chain.accepted(),
chain_rejected: checkpoint.chain.rejected(),
move_accepted: accepted,
move_rejected: rejected,
},
));
}
let current_step = checkpoint.current_step.get();
let checkpoint_step = usize::try_from(current_step).unwrap_or(usize::MAX);
if checkpoint.chain.total_steps() != checkpoint_step {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ChainStepMismatch {
chain_steps: checkpoint.chain.total_steps(),
checkpoint_step: current_step,
},
));
}
if checkpoint.steps.len() != checkpoint.chain.total_steps() {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::StepTelemetryLengthMismatch {
actual: checkpoint.steps.len(),
expected: checkpoint.chain.total_steps(),
},
));
}
validate_checkpoint_proposal_stats(checkpoint, accepted, rejected)?;
validate_checkpoint_steps(checkpoint)?;
validate_checkpoint_measurements(checkpoint)?;
validate_scalar_trace_rows(
&checkpoint.config,
&checkpoint.proposal_stats,
&checkpoint.steps,
&checkpoint.scalar_trace_rows,
)?;
Ok(())
}
fn validate_checkpoint_current_action(checkpoint: &CdtMcmcCheckpoint) -> CdtResult<()> {
let recomputed = action_for(&checkpoint.action_config, checkpoint.triangulation());
let stored = checkpoint.current_action.get();
if !actions_match(stored, recomputed) {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ActionMismatch { stored, recomputed },
));
}
Ok(())
}
fn validate_checkpoint_proposal_stats(
checkpoint: &CdtMcmcCheckpoint,
accepted: usize,
rejected: usize,
) -> CdtResult<()> {
if checkpoint.proposal_stats.hard_failures() != 0 {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ProposalHardFailures {
actual: checkpoint.proposal_stats.hard_failures(),
},
));
}
let steps = u64::try_from(checkpoint.steps.len()).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::ProposalCounterOverflow {
counter: ProposalTelemetryCounter::MoveFamilyProposals,
})
})?;
if checkpoint.proposal_stats.move_family_proposals() != steps {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ProposalMoveFamilyCountMismatch {
actual: checkpoint.proposal_stats.move_family_proposals(),
expected: steps,
},
));
}
let accepted = u64::try_from(accepted).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::ProposalCounterOverflow {
counter: ProposalTelemetryCounter::AcceptedTransitions,
})
})?;
if checkpoint.proposal_stats.accepted_transitions() != accepted {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ProposalAcceptedCountMismatch {
actual: checkpoint.proposal_stats.accepted_transitions(),
expected: accepted,
},
));
}
let rejected = u64::try_from(rejected).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::ProposalCounterOverflow {
counter: ProposalTelemetryCounter::RejectedTransitions,
})
})?;
let actual_rejected = checkpoint.proposal_stats.rejected_transitions();
if actual_rejected != rejected {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::ProposalRejectedCountMismatch {
actual: actual_rejected,
expected: rejected,
},
));
}
Ok(())
}
fn validate_checkpoint_steps(checkpoint: &CdtMcmcCheckpoint) -> CdtResult<()> {
let accepted_steps = checkpoint
.steps
.iter()
.filter(|step| step.accepted())
.count();
if accepted_steps != checkpoint.chain.accepted() {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::StepTelemetryAcceptedCountMismatch {
actual: accepted_steps,
expected: checkpoint.chain.accepted(),
},
));
}
for (index, step) in checkpoint.steps.iter().enumerate() {
let expected_step = u32::try_from(index + 1).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::StepTelemetryIndexOverflow)
})?;
let step_number = step.step().get();
if step_number != expected_step {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::StepTelemetrySequenceMismatch {
actual: step_number,
expected: expected_step,
},
));
}
}
Ok(())
}
fn validate_checkpoint_measurements(checkpoint: &CdtMcmcCheckpoint) -> CdtResult<()> {
let expected_measurements = expected_measurement_count(
checkpoint.current_step.get(),
checkpoint.config.thermalization_steps(),
checkpoint.config.measurement_frequency(),
)
.ok_or_else(|| checkpoint_resume_failed(CheckpointResumeFailure::MeasurementCountOverflow))?;
if checkpoint.measurements.len() != expected_measurements {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::MeasurementCountMismatch {
actual: checkpoint.measurements.len(),
expected: expected_measurements,
},
));
}
for (index, measurement) in checkpoint.measurements.iter().enumerate() {
let expected_step = expected_measurement_step(
index,
checkpoint.config.thermalization_steps(),
checkpoint.config.measurement_frequency(),
)
.ok_or_else(|| {
checkpoint_resume_failed(CheckpointResumeFailure::MeasurementStepOverflow)
})?;
if measurement.step() != expected_step {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::MeasurementStepMismatch {
actual: measurement.step(),
expected: expected_step,
},
));
}
}
Ok(())
}
fn checked_move_counter_sum(counter: CheckpointMoveCounter, counters: [u64; 4]) -> CdtResult<u64> {
counters.into_iter().try_fold(0_u64, |total, count| {
total.checked_add(count).ok_or_else(|| {
checkpoint_resume_failed(CheckpointResumeFailure::MoveCounterOverflow { counter })
})
})
}
fn validate_move_counter_bounds(move_stats: &MoveStatistics) -> CdtResult<()> {
let counters = [
(
MoveType::Move22,
move_stats.attempted(MoveType::Move22),
move_stats.accepted(MoveType::Move22),
move_stats.hard_failed(MoveType::Move22),
),
(
MoveType::Move13Add,
move_stats.attempted(MoveType::Move13Add),
move_stats.accepted(MoveType::Move13Add),
move_stats.hard_failed(MoveType::Move13Add),
),
(
MoveType::Move31Remove,
move_stats.attempted(MoveType::Move31Remove),
move_stats.accepted(MoveType::Move31Remove),
move_stats.hard_failed(MoveType::Move31Remove),
),
(
MoveType::EdgeFlip,
move_stats.attempted(MoveType::EdgeFlip),
move_stats.accepted(MoveType::EdgeFlip),
move_stats.hard_failed(MoveType::EdgeFlip),
),
];
for (move_type, attempted, accepted, hard_failed) in counters {
if hard_failed != 0 {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::MoveHardFailures { move_type },
));
}
if accepted > attempted {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::MoveAcceptedExceedsAttempted { move_type },
));
}
}
Ok(())
}
pub(crate) fn chain_counters(move_stats: &MoveStatistics) -> CdtResult<(usize, usize)> {
validate_move_counter_bounds(move_stats)?;
let attempted = checked_move_counter_sum(
CheckpointMoveCounter::Attempted,
[
move_stats.attempted(MoveType::Move22),
move_stats.attempted(MoveType::Move13Add),
move_stats.attempted(MoveType::Move31Remove),
move_stats.attempted(MoveType::EdgeFlip),
],
)?;
let accepted = checked_move_counter_sum(
CheckpointMoveCounter::Accepted,
[
move_stats.accepted(MoveType::Move22),
move_stats.accepted(MoveType::Move13Add),
move_stats.accepted(MoveType::Move31Remove),
move_stats.accepted(MoveType::EdgeFlip),
],
)?;
let rejected = attempted.checked_sub(accepted).ok_or_else(|| {
checkpoint_resume_failed(CheckpointResumeFailure::TotalAcceptedExceedsAttempted)
})?;
Ok((
usize::try_from(accepted).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::CounterConversionOverflow {
counter: CheckpointMoveCounter::Accepted,
})
})?,
usize::try_from(rejected).map_err(|_| {
checkpoint_resume_failed(CheckpointResumeFailure::CounterConversionOverflow {
counter: CheckpointMoveCounter::Rejected,
})
})?,
))
}