nuts-rs 0.18.0

Sample from unnormalized densities using Hamiltonian MCMC
Documentation
//! Adaptation strategy for when the coordinate transformation is learned from data rather than computed analytically.

use nuts_derive::Storable;
use nuts_storable::{HasDims, Storable};
use serde::{Deserialize, Serialize};

use crate::adapt_strategy::CombinedCollector;
use crate::chain::AdaptStrategy;
use crate::dynamics::{Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint};
use crate::nuts::{Collector, NutsOptions, SampleInfo};
use crate::sampler_stats::{SamplerStats, StatsDims};
use crate::stepsize::AcceptanceRateCollector;
use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
use crate::transform::ExternalTransformation;
use crate::{Math, NutsError};

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct FlowSettings {
    pub step_size_window: f64,
    pub transform_update_freq: u64,
    pub use_orbit_for_training: bool,
    pub step_size_settings: StepSizeSettings,
    pub transform_train_max_energy_error: f64,
}

/// Backwards-compatible alias for [`FlowSettings`].
#[deprecated(since = "0.0.0", note = "Use FlowSettings instead")]
pub type TransformedSettings = FlowSettings;

impl Default for FlowSettings {
    fn default() -> Self {
        Self {
            step_size_window: 0.07f64,
            transform_update_freq: 128,
            use_orbit_for_training: false,
            transform_train_max_energy_error: 20f64,
            step_size_settings: Default::default(),
        }
    }
}

pub struct ExternalTransformAdaptation {
    step_size: StepSizeStrategy,
    options: FlowSettings,
    num_tune: u64,
    final_window_size: u64,
    tuning: bool,
    chain: u64,
}

#[derive(Debug, Storable)]
pub struct Stats<P: HasDims, S: Storable<P>> {
    tuning: bool,
    #[storable(flatten)]
    pub step_size: S,
    #[storable(ignore)]
    _phantom: std::marker::PhantomData<fn() -> P>,
}

impl<M: Math> SamplerStats<M> for ExternalTransformAdaptation {
    type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
    type StatsOptions = ();

    fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
        Stats {
            tuning: self.tuning,
            step_size: { self.step_size.extract_stats(math, ()) },
            _phantom: std::marker::PhantomData,
        }
    }
}

pub struct DrawCollector<M: Math> {
    draws: Vec<M::Vector>,
    grads: Vec<M::Vector>,
    logps: Vec<f64>,
    collect_orbit: bool,
    max_energy_error: f64,
}

impl<M: Math> DrawCollector<M> {
    fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self {
        Self {
            draws: vec![],
            grads: vec![],
            logps: vec![],
            collect_orbit,
            max_energy_error,
        }
    }
}

impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
    fn register_leapfrog(
        &mut self,
        math: &mut M,
        _start: &State<M, P>,
        end: &State<M, P>,
        divergence_info: Option<&crate::DivergenceInfo>,
    ) {
        if divergence_info.is_some() {
            return;
        }

        if self.collect_orbit {
            let point = end.point();
            let energy_error = point.energy_error();
            if !energy_error.is_finite() {
                return;
            }

            if energy_error > self.max_energy_error {
                return;
            }

            if !math.array_all_finite(point.position()) {
                return;
            }
            if !math.array_all_finite(point.gradient()) {
                return;
            }

            self.draws.push(math.copy_array(point.position()));
            self.grads.push(math.copy_array(point.gradient()));
            self.logps.push(point.logp());
        }
    }

    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
        if !self.collect_orbit {
            let point = state.point();
            let energy_error = point.energy_error();
            if !energy_error.is_finite() {
                return;
            }

            if energy_error > self.max_energy_error {
                return;
            }

            if !math.array_all_finite(point.position()) {
                return;
            }
            if !math.array_all_finite(point.gradient()) {
                return;
            }

            self.draws.push(math.copy_array(point.position()));
            self.grads.push(math.copy_array(point.gradient()));
            self.logps.push(point.logp());
        }
    }
}

impl<M: Math> AdaptStrategy<M> for ExternalTransformAdaptation {
    type Hamiltonian = TransformedHamiltonian<M, ExternalTransformation<M>>;

    type Collector =
        CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, DrawCollector<M>>;

    type Options = FlowSettings;

    fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
        let step_size = StepSizeStrategy::new(options.step_size_settings);
        let final_window_size =
            ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64;
        Self {
            step_size,
            options,
            num_tune,
            final_window_size,
            tuning: true,
            chain,
        }
    }

    fn init<R: rand::Rng + ?Sized>(
        &mut self,
        math: &mut M,
        options: &mut NutsOptions,
        hamiltonian: &mut Self::Hamiltonian,
        position: &[f64],
        rng: &mut R,
    ) -> Result<(), NutsError> {
        hamiltonian.init_transformation(rng, math, position, self.chain)?;
        self.step_size
            .init(math, options, hamiltonian, position, rng)?;
        Ok(())
    }

    fn adapt<R: rand::Rng + ?Sized>(
        &mut self,
        math: &mut M,
        _options: &mut NutsOptions,
        hamiltonian: &mut Self::Hamiltonian,
        draw: u64,
        collector: &Self::Collector,
        _state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
        rng: &mut R,
    ) -> Result<(), NutsError> {
        self.step_size.update(&collector.collector1);

        if draw >= self.num_tune {
            // Needed for step size jitter
            self.step_size.update_stepsize(rng, hamiltonian, true);
            self.tuning = false;
            return Ok(());
        }

        if draw < self.final_window_size {
            if draw < 100 {
                if (draw > 0) && draw.is_multiple_of(10) {
                    hamiltonian.update_params(
                        math,
                        rng,
                        collector.collector2.draws.iter(),
                        collector.collector2.grads.iter(),
                        collector.collector2.logps.iter(),
                    )?;
                }
            } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) {
                hamiltonian.update_params(
                    math,
                    rng,
                    collector.collector2.draws.iter(),
                    collector.collector2.grads.iter(),
                    collector.collector2.logps.iter(),
                )?;
            }
            self.step_size.update_estimator_early();
            self.step_size.update_stepsize(rng, hamiltonian, false);
            return Ok(());
        }

        self.step_size.update_estimator_late();
        let is_last = draw == self.num_tune - 1;
        self.step_size.update_stepsize(rng, hamiltonian, is_last);
        Ok(())
    }

    fn new_collector(&self, math: &mut M) -> Self::Collector {
        Self::Collector::new(
            self.step_size.new_collector(),
            DrawCollector::new(
                math,
                self.options.use_orbit_for_training,
                self.options.transform_train_max_energy_error,
            ),
        )
    }

    fn is_tuning(&self) -> bool {
        self.tuning
    }

    fn last_num_steps(&self) -> u64 {
        self.step_size.last_n_steps
    }
}