#![forbid(unsafe_code)]
use crate::cdt::action::ActionConfig;
use crate::cdt::ergodic_moves::{ErgodicsSystem, MoveResult, MoveType, proposal_site_count};
use crate::errors::{CdtError, CdtResult, MetropolisMoveApplicationFailure};
use crate::geometry::CdtTriangulation2D;
use markov_chain_monte_carlo::{
Chain, ChainCheckpoint, DelayedProposal, DiscreteProposalRatio, McmcError, Target,
};
use rand::Rng;
use std::error::Error;
use std::fmt;
use std::hint::cold_path;
use super::helpers::{action_for, proposed_delta_action, simplex_counts, validate_temperature};
use super::telemetry::{CdtProposalSiteRejection, ProposalStatistics};
pub struct CdtTarget {
action_config: ActionConfig,
temperature: f64,
}
impl CdtTarget {
pub fn new(action_config: ActionConfig, temperature: f64) -> CdtResult<Self> {
action_config.validate();
validate_temperature(temperature)?;
Ok(Self {
action_config,
temperature,
})
}
}
impl Target<CdtTriangulation2D> for CdtTarget {
fn log_prob(&self, state: &CdtTriangulation2D) -> f64 {
let counts = simplex_counts(state);
let action =
self.action_config
.calculate_action(counts.vertices, counts.edges, counts.triangles);
-action / self.temperature
}
}
#[derive(Debug, Clone)]
pub struct CdtProposalPlan {
pub(crate) move_type: MoveType,
pub(crate) action_before: f64,
pub(crate) action_after: Option<f64>,
pub(crate) delta_action: Option<f64>,
pub(crate) forward_site_count: usize,
pub(crate) reverse_site_count: usize,
pub(crate) proposed_state: CdtTriangulation2D,
}
impl CdtProposalPlan {
#[must_use]
pub const fn move_type(&self) -> MoveType {
self.move_type
}
#[must_use]
pub const fn action_before(&self) -> f64 {
self.action_before
}
#[must_use]
pub const fn action_after(&self) -> Option<f64> {
self.action_after
}
#[must_use]
pub const fn delta_action(&self) -> Option<f64> {
self.delta_action
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CdtProposalInfo {
pub move_type: MoveType,
pub action_before: f64,
pub action_after: Option<f64>,
pub delta_action: Option<f64>,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum CdtProposalError {
ApplicationFailed {
move_type: MoveType,
attempt: usize,
source: CdtError,
},
}
impl fmt::Display for CdtProposalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ApplicationFailed {
move_type,
attempt,
source,
} => write!(
f,
"failed to apply {move_type:?} on attempt {attempt}: {source}"
),
}
}
}
impl Error for CdtProposalError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::ApplicationFailed { source, .. } => Some(source),
}
}
}
impl From<CdtProposalError> for CdtError {
fn from(error: CdtProposalError) -> Self {
match error {
CdtProposalError::ApplicationFailed {
move_type,
attempt,
source,
} => Self::ProposalApplicationFailed {
move_type,
attempt,
source: MetropolisMoveApplicationFailure::from(source),
},
}
}
}
pub struct CdtProposal {
action_config: ActionConfig,
moves: ErgodicsSystem,
last_step_info: Option<CdtProposalInfo>,
last_no_plan_info: Option<CdtProposalInfo>,
last_proposal_stats: ProposalStatistics,
}
impl CdtProposal {
#[must_use]
pub fn new(action_config: ActionConfig) -> Self {
action_config.validate();
Self {
action_config,
moves: ErgodicsSystem::new(),
last_step_info: None,
last_no_plan_info: None,
last_proposal_stats: ProposalStatistics::new(),
}
}
#[must_use]
pub fn with_seed(action_config: ActionConfig, seed: u64) -> Self {
action_config.validate();
Self {
action_config,
moves: ErgodicsSystem::with_seed(seed),
last_step_info: None,
last_no_plan_info: None,
last_proposal_stats: ProposalStatistics::new(),
}
}
pub(crate) fn from_ergodics(action_config: ActionConfig, moves: ErgodicsSystem) -> Self {
action_config.validate();
Self {
action_config,
moves,
last_step_info: None,
last_no_plan_info: None,
last_proposal_stats: ProposalStatistics::new(),
}
}
pub(crate) fn into_ergodics(self) -> ErgodicsSystem {
self.moves
}
pub(crate) const fn last_proposal_stats(&self) -> &ProposalStatistics {
&self.last_proposal_stats
}
pub(crate) const fn last_step_info(&self) -> Option<CdtProposalInfo> {
self.last_step_info
}
}
impl DelayedProposal<CdtTriangulation2D> for CdtProposal {
type Plan = CdtProposalPlan;
type Info = CdtProposalInfo;
type Error = CdtProposalError;
fn propose_plan<R: Rng + ?Sized>(
&mut self,
state: &CdtTriangulation2D,
_rng: &mut R,
) -> Result<Option<Self::Plan>, Self::Error> {
let move_type = self.moves.select_random_move();
let action_before = action_for(&self.action_config, state);
let no_plan_info = CdtProposalInfo {
move_type,
action_before,
action_after: None,
delta_action: None,
};
let mut proposal_stats = ProposalStatistics::new();
let plan = match propose_concrete_plan(
state,
&mut self.moves,
&mut proposal_stats,
&self.action_config,
move_type,
action_before,
) {
Ok(Some(plan)) => plan,
Ok(None) => {
self.last_step_info = Some(no_plan_info);
self.last_no_plan_info = Some(no_plan_info);
self.last_proposal_stats = proposal_stats;
cold_path();
return Ok(None);
}
Err(err) => {
self.last_step_info = Some(no_plan_info);
self.last_no_plan_info = None;
proposal_stats.record_hard_failure();
self.last_proposal_stats = proposal_stats;
cold_path();
return Err(CdtProposalError::ApplicationFailed {
move_type,
attempt: err.attempt,
source: err.source,
});
}
};
self.last_step_info = Some(CdtProposalInfo {
move_type: plan.move_type,
action_before: plan.action_before,
action_after: plan.action_after,
delta_action: plan.delta_action,
});
self.last_no_plan_info = None;
self.last_proposal_stats = proposal_stats;
Ok(Some(plan))
}
fn no_plan_info(&mut self) -> Option<Self::Info> {
self.last_no_plan_info.take()
}
fn proposed_log_prob<T: Target<CdtTriangulation2D>>(
&self,
_state: &CdtTriangulation2D,
plan: &Self::Plan,
target: &T,
) -> Result<f64, Self::Error> {
Ok(plan
.action_after
.map_or(f64::NEG_INFINITY, |_| target.log_prob(&plan.proposed_state)))
}
fn log_q_ratio(
&self,
state: &CdtTriangulation2D,
plan: &Self::Plan,
) -> Result<f64, Self::Error> {
Ok(concrete_log_q_ratio(state, plan))
}
fn info(&self, plan: &Self::Plan) -> Self::Info {
CdtProposalInfo {
move_type: plan.move_type,
action_before: plan.action_before,
action_after: plan.action_after,
delta_action: plan.delta_action,
}
}
fn commit<R: Rng + ?Sized>(
&mut self,
state: &mut CdtTriangulation2D,
plan: Self::Plan,
_rng: &mut R,
) -> Result<(), Self::Error> {
*state = plan.proposed_state;
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct MoveApplicationError {
pub(crate) attempt: usize,
pub(crate) source: CdtError,
}
pub(crate) fn propose_concrete_plan(
state: &CdtTriangulation2D,
moves: &mut ErgodicsSystem,
proposal_stats: &mut ProposalStatistics,
action_config: &ActionConfig,
move_type: MoveType,
action_before: f64,
) -> Result<Option<CdtProposalPlan>, MoveApplicationError> {
if proposed_delta_action(action_config, simplex_counts(state), move_type).is_none() {
proposal_stats.record_move_family(0);
proposal_stats.record_no_site();
return Ok(None);
}
let selection = moves.select_proposal_site(state, move_type);
let forward_site_count = selection.site_count;
proposal_stats.record_move_family(forward_site_count);
let Some(site) = selection.site else {
proposal_stats.record_no_site();
return Ok(None);
};
let mut proposed_state = state.clone();
let move_stats_before = moves.stats().clone();
let result = moves.apply_proposal_site(&mut proposed_state, move_type, site);
moves.replace_stats(move_stats_before);
let action_after = match result {
MoveResult::Success => action_for(action_config, &proposed_state),
MoveResult::HardFailure(err) => {
return Err(MoveApplicationError {
attempt: 1,
source: err,
});
}
MoveResult::CausalityViolation => {
proposal_stats.record_site_rejection(&CdtProposalSiteRejection::CausalityViolation);
return Ok(None);
}
MoveResult::GeometricViolation => {
proposal_stats.record_site_rejection(&CdtProposalSiteRejection::GeometricViolation);
return Ok(None);
}
MoveResult::Rejected(err) => {
proposal_stats.record_site_rejection(&CdtProposalSiteRejection::Kernel(err));
return Ok(None);
}
};
let delta_action = action_after - action_before;
let reverse_site_count = proposal_site_count(&proposed_state, reverse_move_type(move_type));
Ok(Some(CdtProposalPlan {
move_type,
action_before,
action_after: Some(action_after),
delta_action: Some(delta_action),
forward_site_count,
reverse_site_count,
proposed_state,
}))
}
pub(crate) fn concrete_log_q_ratio(_state: &CdtTriangulation2D, plan: &CdtProposalPlan) -> f64 {
DiscreteProposalRatio::from_counts(plan.forward_site_count, plan.reverse_site_count)
.map_or(f64::NEG_INFINITY, DiscreteProposalRatio::log_q_ratio)
}
pub(crate) fn restore_checkpoint_state(
checkpoint: ChainCheckpoint<CdtTriangulation2D>,
target: &CdtTarget,
) -> Result<CdtTriangulation2D, McmcError> {
Chain::from_checkpoint(checkpoint, target).map(Chain::into_state)
}
const fn reverse_move_type(move_type: MoveType) -> MoveType {
match move_type {
MoveType::Move22 => MoveType::Move22,
MoveType::Move13Add => MoveType::Move31Remove,
MoveType::Move31Remove => MoveType::Move13Add,
MoveType::EdgeFlip => MoveType::EdgeFlip,
}
}