#![forbid(unsafe_code)]
use crate::cdt::ergodic_moves::MoveType;
use crate::errors::{CdtError, CdtResult, CheckpointResumeFailure};
use serde::de::Error as DeError;
use serde::{Deserialize, Deserializer, Serialize};
use std::error::Error;
use std::fmt;
use std::num::NonZeroU32;
use super::helpers::actions_match;
#[derive(Debug, Clone, Serialize)]
pub struct MonteCarloStep {
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
outcome: MonteCarloStepOutcome,
}
#[derive(Deserialize)]
struct MonteCarloStepWire {
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
outcome: MonteCarloStepOutcomeWire,
}
impl<'de> Deserialize<'de> for MonteCarloStep {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire = MonteCarloStepWire::deserialize(deserializer)?;
let outcome = MonteCarloStepOutcome::from_wire(wire.step, wire.action_before, wire.outcome)
.map_err(DeError::custom)?;
Self::new(wire.step, wire.move_type, wire.action_before, outcome).map_err(DeError::custom)
}
}
impl MonteCarloStep {
pub fn new(
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
outcome: MonteCarloStepOutcome,
) -> CdtResult<Self> {
validate_action_before(step, action_before)?;
outcome.validate_for_step(step, action_before)?;
Ok(Self {
step,
move_type,
action_before,
outcome,
})
}
pub fn accepted_step(
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
action_after: f64,
delta_action: f64,
) -> CdtResult<Self> {
Self::new(
step,
move_type,
action_before,
MonteCarloStepOutcome::accepted_transition(
step,
action_before,
action_after,
delta_action,
)?,
)
}
pub fn rejected_proposal(
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
delta_action: Option<f64>,
) -> CdtResult<Self> {
Self::new(
step,
move_type,
action_before,
MonteCarloStepOutcome::rejected_proposal(step, delta_action)?,
)
}
pub fn no_proposal(
step: NonZeroU32,
move_type: MoveType,
action_before: f64,
) -> CdtResult<Self> {
Self::new(
step,
move_type,
action_before,
MonteCarloStepOutcome::NoProposal,
)
}
#[must_use]
pub const fn step(&self) -> NonZeroU32 {
self.step
}
#[must_use]
pub const fn move_type(&self) -> MoveType {
self.move_type
}
#[must_use]
pub const fn action_before(&self) -> f64 {
self.action_before
}
#[must_use]
pub const fn outcome(&self) -> &MonteCarloStepOutcome {
&self.outcome
}
#[must_use]
pub const fn accepted(&self) -> bool {
matches!(self.outcome, MonteCarloStepOutcome::Accepted(_))
}
#[must_use]
pub const fn action_after(&self) -> Option<f64> {
self.outcome.action_after()
}
#[must_use]
pub const fn delta_action(&self) -> Option<f64> {
self.outcome.delta_action()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize)]
pub struct AcceptedStepTelemetry {
action_after: f64,
delta_action: f64,
}
impl AcceptedStepTelemetry {
#[must_use]
pub const fn action_after(self) -> f64 {
self.action_after
}
#[must_use]
pub const fn delta_action(self) -> f64 {
self.delta_action
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize)]
pub struct RejectedProposalStepTelemetry {
delta_action: Option<f64>,
}
impl RejectedProposalStepTelemetry {
#[must_use]
pub const fn delta_action(self) -> Option<f64> {
self.delta_action
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize)]
pub enum MonteCarloStepOutcome {
Accepted(AcceptedStepTelemetry),
RejectedProposal(RejectedProposalStepTelemetry),
NoProposal,
}
#[derive(Clone, Copy, Deserialize)]
#[serde(rename_all = "PascalCase")]
enum MonteCarloStepOutcomeWire {
Accepted {
action_after: f64,
delta_action: f64,
},
RejectedProposal {
delta_action: Option<f64>,
},
NoProposal,
}
impl MonteCarloStepOutcome {
pub fn accepted_transition(
step: NonZeroU32,
action_before: f64,
action_after: f64,
delta_action: f64,
) -> CdtResult<Self> {
validate_action_after(step, action_after)?;
validate_delta_action(step, delta_action)?;
if !actions_match(action_after, action_before + delta_action) {
return Err(checkpoint_resume_failed(
CheckpointResumeFailure::StepActionAfterDeltaMismatch { step: step.get() },
));
}
Ok(Self::Accepted(AcceptedStepTelemetry {
action_after,
delta_action,
}))
}
pub fn rejected_proposal(step: NonZeroU32, delta_action: Option<f64>) -> CdtResult<Self> {
if let Some(delta_action) = delta_action {
validate_delta_action(step, delta_action)?;
}
Ok(Self::RejectedProposal(RejectedProposalStepTelemetry {
delta_action,
}))
}
#[must_use]
pub const fn accepted(self) -> bool {
matches!(self, Self::Accepted(_))
}
#[must_use]
pub const fn action_after(self) -> Option<f64> {
match self {
Self::Accepted(payload) => Some(payload.action_after()),
Self::RejectedProposal(_) | Self::NoProposal => None,
}
}
#[must_use]
pub const fn delta_action(self) -> Option<f64> {
match self {
Self::Accepted(payload) => Some(payload.delta_action()),
Self::RejectedProposal(payload) => payload.delta_action(),
Self::NoProposal => None,
}
}
fn validate_for_step(self, step: NonZeroU32, action_before: f64) -> CdtResult<()> {
match self {
Self::Accepted(payload) => {
Self::accepted_transition(
step,
action_before,
payload.action_after(),
payload.delta_action(),
)?;
}
Self::RejectedProposal(payload) => {
Self::rejected_proposal(step, payload.delta_action())?;
}
Self::NoProposal => {}
}
Ok(())
}
fn from_wire(
step: NonZeroU32,
action_before: f64,
wire: MonteCarloStepOutcomeWire,
) -> CdtResult<Self> {
match wire {
MonteCarloStepOutcomeWire::Accepted {
action_after,
delta_action,
} => Self::accepted_transition(step, action_before, action_after, delta_action),
MonteCarloStepOutcomeWire::RejectedProposal { delta_action } => {
Self::rejected_proposal(step, delta_action)
}
MonteCarloStepOutcomeWire::NoProposal => Ok(Self::NoProposal),
}
}
}
const fn checkpoint_resume_failed(failure: CheckpointResumeFailure) -> CdtError {
CdtError::CheckpointResumeFailed { failure }
}
const fn validate_action_before(step: NonZeroU32, action_before: f64) -> CdtResult<()> {
if action_before.is_finite() {
Ok(())
} else {
Err(checkpoint_resume_failed(
CheckpointResumeFailure::NonFiniteStepActionBefore { step: step.get() },
))
}
}
const fn validate_action_after(step: NonZeroU32, action_after: f64) -> CdtResult<()> {
if action_after.is_finite() {
Ok(())
} else {
Err(checkpoint_resume_failed(
CheckpointResumeFailure::NonFiniteStepActionAfter { step: step.get() },
))
}
}
const fn validate_delta_action(step: NonZeroU32, delta_action: f64) -> CdtResult<()> {
if delta_action.is_finite() {
Ok(())
} else {
Err(checkpoint_resume_failed(
CheckpointResumeFailure::NonFiniteStepDeltaAction { step: step.get() },
))
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum CdtProposalSiteRejection {
CausalityViolation,
GeometricViolation,
Kernel(CdtError),
}
impl fmt::Display for CdtProposalSiteRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CausalityViolation => {
f.write_str("causality violation at selected application site")
}
Self::GeometricViolation => {
f.write_str("geometric violation at selected application site")
}
Self::Kernel(err) => err.fmt(f),
}
}
}
impl Error for CdtProposalSiteRejection {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Kernel(err) => Some(err),
Self::CausalityViolation | Self::GeometricViolation => None,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
#[non_exhaustive]
pub struct ProposalStatistics {
move_family_proposals: u64,
observed_forward_sites: u64,
no_site_proposals: u64,
site_causality_rejections: u64,
site_geometric_rejections: u64,
site_backend_rejections: u64,
metropolis_rejections: u64,
accepted_transitions: u64,
hard_failures: u64,
}
#[derive(Deserialize)]
struct ProposalStatisticsWire {
move_family_proposals: u64,
observed_forward_sites: u64,
no_site_proposals: u64,
site_causality_rejections: u64,
site_geometric_rejections: u64,
site_backend_rejections: u64,
metropolis_rejections: u64,
accepted_transitions: u64,
hard_failures: u64,
}
impl<'de> Deserialize<'de> for ProposalStatistics {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire = ProposalStatisticsWire::deserialize(deserializer)?;
Self::from_wire(&wire).map_err(DeError::custom)
}
}
impl ProposalStatistics {
#[must_use]
pub const fn new() -> Self {
Self {
move_family_proposals: 0,
observed_forward_sites: 0,
no_site_proposals: 0,
site_causality_rejections: 0,
site_geometric_rejections: 0,
site_backend_rejections: 0,
metropolis_rejections: 0,
accepted_transitions: 0,
hard_failures: 0,
}
}
#[cfg(test)]
#[expect(
clippy::too_many_arguments,
reason = "test and serde helpers need to preserve the flat telemetry wire shape"
)]
pub(crate) const fn from_validated_parts(
move_family_proposals: u64,
observed_forward_sites: u64,
no_site_proposals: u64,
site_causality_rejections: u64,
site_geometric_rejections: u64,
site_backend_rejections: u64,
metropolis_rejections: u64,
accepted_transitions: u64,
hard_failures: u64,
) -> Self {
Self {
move_family_proposals,
observed_forward_sites,
no_site_proposals,
site_causality_rejections,
site_geometric_rejections,
site_backend_rejections,
metropolis_rejections,
accepted_transitions,
hard_failures,
}
}
fn from_wire(wire: &ProposalStatisticsWire) -> Result<Self, String> {
let terminal_outcomes = [
wire.no_site_proposals,
wire.site_causality_rejections,
wire.site_geometric_rejections,
wire.site_backend_rejections,
wire.metropolis_rejections,
wire.accepted_transitions,
wire.hard_failures,
]
.into_iter()
.try_fold(0_u64, |total, count| {
total
.checked_add(count)
.ok_or_else(|| "proposal terminal outcome counters exceed u64::MAX".to_string())
})?;
if terminal_outcomes != wire.move_family_proposals {
return Err(format!(
"proposal terminal outcomes ({terminal_outcomes}) do not match move-family proposals ({})",
wire.move_family_proposals
));
}
if wire.move_family_proposals == 0 && wire.observed_forward_sites != 0 {
return Err(
"observed forward-site count must be zero when no move families were proposed"
.to_string(),
);
}
Ok(Self {
move_family_proposals: wire.move_family_proposals,
observed_forward_sites: wire.observed_forward_sites,
no_site_proposals: wire.no_site_proposals,
site_causality_rejections: wire.site_causality_rejections,
site_geometric_rejections: wire.site_geometric_rejections,
site_backend_rejections: wire.site_backend_rejections,
metropolis_rejections: wire.metropolis_rejections,
accepted_transitions: wire.accepted_transitions,
hard_failures: wire.hard_failures,
})
}
#[must_use]
pub const fn move_family_proposals(&self) -> u64 {
self.move_family_proposals
}
#[must_use]
pub const fn observed_forward_sites(&self) -> u64 {
self.observed_forward_sites
}
#[must_use]
pub const fn no_site_proposals(&self) -> u64 {
self.no_site_proposals
}
#[must_use]
pub const fn site_causality_rejections(&self) -> u64 {
self.site_causality_rejections
}
#[must_use]
pub const fn site_geometric_rejections(&self) -> u64 {
self.site_geometric_rejections
}
#[must_use]
pub const fn site_backend_rejections(&self) -> u64 {
self.site_backend_rejections
}
#[must_use]
pub const fn metropolis_rejections(&self) -> u64 {
self.metropolis_rejections
}
#[must_use]
pub const fn accepted_transitions(&self) -> u64 {
self.accepted_transitions
}
#[must_use]
pub const fn hard_failures(&self) -> u64 {
self.hard_failures
}
#[must_use]
pub const fn rejected_transitions(&self) -> u64 {
self.no_site_proposals
.saturating_add(self.site_causality_rejections)
.saturating_add(self.site_geometric_rejections)
.saturating_add(self.site_backend_rejections)
.saturating_add(self.metropolis_rejections)
}
pub(crate) fn record_move_family(&mut self, forward_sites: usize) {
self.move_family_proposals = self.move_family_proposals.saturating_add(1);
self.observed_forward_sites = self
.observed_forward_sites
.saturating_add(u64::try_from(forward_sites).unwrap_or(u64::MAX));
}
pub(crate) const fn record_no_site(&mut self) {
self.no_site_proposals = self.no_site_proposals.saturating_add(1);
}
pub(crate) const fn record_site_rejection(&mut self, rejection: &CdtProposalSiteRejection) {
match rejection {
CdtProposalSiteRejection::CausalityViolation => {
self.site_causality_rejections = self.site_causality_rejections.saturating_add(1);
}
CdtProposalSiteRejection::GeometricViolation => {
self.site_geometric_rejections = self.site_geometric_rejections.saturating_add(1);
}
CdtProposalSiteRejection::Kernel(_) => {
self.site_backend_rejections = self.site_backend_rejections.saturating_add(1);
}
}
}
pub(crate) const fn record_metropolis_rejection(&mut self) {
self.metropolis_rejections = self.metropolis_rejections.saturating_add(1);
}
pub(crate) const fn record_accepted_transition(&mut self) {
self.accepted_transitions = self.accepted_transitions.saturating_add(1);
}
pub(crate) const fn record_hard_failure(&mut self) {
self.hard_failures = self.hard_failures.saturating_add(1);
}
pub(crate) const fn extend(&mut self, other: &Self) {
self.move_family_proposals = self
.move_family_proposals
.saturating_add(other.move_family_proposals);
self.observed_forward_sites = self
.observed_forward_sites
.saturating_add(other.observed_forward_sites);
self.no_site_proposals = self
.no_site_proposals
.saturating_add(other.no_site_proposals);
self.site_causality_rejections = self
.site_causality_rejections
.saturating_add(other.site_causality_rejections);
self.site_geometric_rejections = self
.site_geometric_rejections
.saturating_add(other.site_geometric_rejections);
self.site_backend_rejections = self
.site_backend_rejections
.saturating_add(other.site_backend_rejections);
self.metropolis_rejections = self
.metropolis_rejections
.saturating_add(other.metropolis_rejections);
self.accepted_transitions = self
.accepted_transitions
.saturating_add(other.accepted_transitions);
self.hard_failures = self.hard_failures.saturating_add(other.hard_failures);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::assert_matches;
fn step_number(step: u32) -> NonZeroU32 {
NonZeroU32::new(step).expect("test step number should be nonzero")
}
fn assert_checkpoint_failure(
error: CdtError,
matches_failure: impl FnOnce(&CheckpointResumeFailure) -> bool,
) {
match error {
CdtError::CheckpointResumeFailed { failure } => assert!(
matches_failure(&failure),
"unexpected checkpoint failure: {failure:?}"
),
other => panic!("expected checkpoint resume failure, got {other:?}"),
}
}
fn assert_optional_actions_match(actual: Option<f64>, expected: Option<f64>) {
match (actual, expected) {
(Some(actual), Some(expected)) => assert!(
actions_match(actual, expected),
"expected {actual} to match {expected}"
),
(None, None) => {}
(actual, expected) => panic!("expected {actual:?} to match {expected:?}"),
}
}
#[test]
fn monte_carlo_step_new_revalidates_outcome_against_action_before() {
let outcome = MonteCarloStepOutcome::accepted_transition(step_number(1), 4.0, 3.5, -0.5)
.expect("test outcome should satisfy its original action context");
let error = MonteCarloStep::new(step_number(1), MoveType::Move22, 10.0, outcome)
.expect_err("step constructor should reject outcome inconsistent with action_before");
assert_checkpoint_failure(error, |failure| {
matches!(
failure,
CheckpointResumeFailure::StepActionAfterDeltaMismatch { step: 1 }
)
});
}
#[test]
fn monte_carlo_step_serde_round_trips_valid_outcome_variants() {
let steps = [
MonteCarloStep::accepted_step(step_number(1), MoveType::Move22, 4.0, 3.5, -0.5)
.expect("test accepted step should satisfy action invariants"),
MonteCarloStep::rejected_proposal(step_number(2), MoveType::Move13Add, 3.5, Some(0.25))
.expect("test rejected-proposal step should satisfy action invariants"),
MonteCarloStep::no_proposal(step_number(3), MoveType::EdgeFlip, 3.5)
.expect("test no-proposal step should satisfy action invariants"),
];
for step in steps {
let value = serde_json::to_value(&step).expect("step telemetry should serialize");
let round_tripped: MonteCarloStep =
serde_json::from_value(value).expect("valid step telemetry should deserialize");
assert_eq!(round_tripped.step(), step.step());
assert_eq!(round_tripped.move_type(), step.move_type());
assert!(
actions_match(round_tripped.action_before(), step.action_before()),
"round-tripped action_before should match original"
);
assert_eq!(round_tripped.accepted(), step.accepted());
assert_optional_actions_match(round_tripped.action_after(), step.action_after());
assert_optional_actions_match(round_tripped.delta_action(), step.delta_action());
}
}
#[test]
fn monte_carlo_step_deserialization_rejects_accepted_delta_mismatch() {
let payload = r#"{
"step": 1,
"move_type": "Move22",
"action_before": 4.0,
"outcome": {
"Accepted": {
"action_after": 3.5,
"delta_action": 0.0
}
}
}"#;
let error = serde_json::from_str::<MonteCarloStep>(payload)
.expect_err("accepted action-after/delta mismatch should be rejected");
assert!(
error
.to_string()
.contains("action_after does not match delta_action"),
"serde error should explain accepted-step action invariant, got {error}"
);
}
#[test]
fn monte_carlo_step_deserialization_preserves_rejected_proposal_kind() {
let payload = r#"{
"step": 2,
"move_type": "Move13Add",
"action_before": 3.5,
"outcome": {
"RejectedProposal": {
"delta_action": null
}
}
}"#;
let step = serde_json::from_str::<MonteCarloStep>(payload)
.expect("rejected proposal without delta should deserialize");
assert_eq!(step.step().get(), 2);
assert_eq!(step.move_type(), MoveType::Move13Add);
assert_matches!(
step.outcome(),
MonteCarloStepOutcome::RejectedProposal(payload) if payload.delta_action().is_none()
);
assert!(!step.accepted());
assert_eq!(step.action_after(), None);
assert_eq!(step.delta_action(), None);
}
#[test]
fn proposal_statistics_deserialization_rejects_terminal_outcomes_above_proposals() {
let payload = r#"{
"move_family_proposals": 1,
"observed_forward_sites": 1,
"no_site_proposals": 1,
"site_causality_rejections": 0,
"site_geometric_rejections": 0,
"site_backend_rejections": 0,
"metropolis_rejections": 0,
"accepted_transitions": 1,
"hard_failures": 0
}"#;
let error = serde_json::from_str::<ProposalStatistics>(payload)
.expect_err("terminal outcomes above move-family proposals should be rejected");
assert!(
error.to_string().contains("terminal outcomes"),
"serde error should explain proposal telemetry invariant, got {error}"
);
}
#[test]
fn proposal_statistics_deserialization_rejects_under_classified_proposals() {
let payload = r#"{
"move_family_proposals": 2,
"observed_forward_sites": 1,
"no_site_proposals": 1,
"site_causality_rejections": 0,
"site_geometric_rejections": 0,
"site_backend_rejections": 0,
"metropolis_rejections": 0,
"accepted_transitions": 0,
"hard_failures": 0
}"#;
let error = serde_json::from_str::<ProposalStatistics>(payload)
.expect_err("under-classified move-family proposals should be rejected");
assert!(
error.to_string().contains("do not match"),
"serde error should explain exact proposal telemetry invariant, got {error}"
);
}
#[test]
fn proposal_statistics_deserialization_rejects_forward_sites_without_proposals() {
let payload = r#"{
"move_family_proposals": 0,
"observed_forward_sites": 1,
"no_site_proposals": 0,
"site_causality_rejections": 0,
"site_geometric_rejections": 0,
"site_backend_rejections": 0,
"metropolis_rejections": 0,
"accepted_transitions": 0,
"hard_failures": 0
}"#;
let error = serde_json::from_str::<ProposalStatistics>(payload)
.expect_err("forward sites without proposals should be rejected");
assert!(
error.to_string().contains("forward-site"),
"serde error should explain forward-site invariant, got {error}"
);
}
#[test]
fn proposal_statistics_deserialization_rejects_terminal_outcome_counter_overflow() {
let mut stats = ProposalStatistics::from_validated_parts(
u64::MAX,
u64::MAX,
u64::MAX,
0,
0,
0,
0,
0,
0,
);
let other = ProposalStatistics::from_validated_parts(1, 1, 0, 0, 0, 0, 0, 1, 0);
stats.extend(&other);
assert_eq!(stats.move_family_proposals(), u64::MAX);
assert_eq!(stats.no_site_proposals(), u64::MAX);
assert_eq!(stats.accepted_transitions(), 1);
let serialized =
serde_json::to_string(&stats).expect("saturated telemetry should serialize");
let error = serde_json::from_str::<ProposalStatistics>(&serialized)
.expect_err("overflowed terminal outcome partition should be rejected");
assert!(
error.to_string().contains("exceed u64::MAX"),
"serde error should explain terminal outcome counter overflow, got {error}"
);
}
}