non_convex_opt/algorithms/parallel_tempering/
swap_manager.rs

1use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, DimSub, OMatrix, OVector, RealField};
2use rand::{rngs::StdRng, Rng, SeedableRng};
3
4use crate::algorithms::parallel_tempering::metropolis_hastings::MetropolisHastings;
5use crate::utils::opt_prob::FloatNumber as FloatNum;
6
7/// Manages replica exchange (swapping) operations in parallel tempering
8pub struct SwapManager<T, N, D>
9where
10    T: FloatNum + RealField + Send + Sync,
11    N: Dim + Send + Sync,
12    D: Dim + Send + Sync + DimSub<nalgebra::Const<1>>,
13    OVector<T, D>: Send + Sync,
14    OVector<T, N>: Send + Sync,
15    OVector<bool, N>: Send + Sync,
16    OMatrix<T, N, D>: Send + Sync,
17    OMatrix<T, D, D>: Send + Sync,
18    DefaultAllocator: Allocator<D>
19        + Allocator<N, D>
20        + Allocator<N>
21        + Allocator<D, D>
22        + Allocator<<D as DimSub<nalgebra::Const<1>>>::Output>,
23{
24    metropolis_hastings: MetropolisHastings<T, D>,
25    swap_acceptance_rates: Vec<f64>,
26    adaptive_swapping: bool,
27    random_swap_probability: f64,
28    swap_rate_smoothing: f64,
29    _phantom: std::marker::PhantomData<N>,
30    rng: StdRng,
31}
32
33impl<T, N, D> SwapManager<T, N, D>
34where
35    T: FloatNum + RealField + Send + Sync,
36    N: Dim + Send + Sync,
37    D: Dim + Send + Sync + DimSub<nalgebra::Const<1>>,
38    OVector<T, D>: Send + Sync,
39    OVector<T, N>: Send + Sync,
40    OVector<bool, N>: Send + Sync,
41    OMatrix<T, N, D>: Send + Sync,
42    OMatrix<T, D, D>: Send + Sync,
43    DefaultAllocator: Allocator<D>
44        + Allocator<N, D>
45        + Allocator<N>
46        + Allocator<D, D>
47        + Allocator<<D as DimSub<nalgebra::Const<1>>>::Output>,
48{
49    pub fn new(
50        metropolis_hastings: MetropolisHastings<T, D>,
51        num_replicas: usize,
52        adaptive_swapping: bool,
53        random_swap_probability: f64,
54        swap_rate_smoothing: f64,
55        seed: u64,
56    ) -> Self {
57        Self {
58            metropolis_hastings,
59            swap_acceptance_rates: vec![0.3; num_replicas.saturating_sub(1)],
60            adaptive_swapping,
61            random_swap_probability,
62            swap_rate_smoothing,
63            _phantom: std::marker::PhantomData,
64            rng: StdRng::seed_from_u64(seed),
65        }
66    }
67
68    pub fn swap_adjacent_replicas(
69        &mut self,
70        populations: &mut [OMatrix<T, N, D>],
71        fitnesses: &mut [OVector<T, N>],
72        constraints: &mut [OVector<bool, N>],
73        step_sizes: &mut [Vec<OMatrix<T, D, D>>],
74        temperatures: &[T],
75    ) {
76        let n = populations.len();
77        if n < 2 {
78            return;
79        }
80
81        for i in 0..n - 1 {
82            let t_i = temperatures[i];
83            let t_j = temperatures[i + 1];
84
85            let swap_accepted = self.metropolis_hastings.accept_replica_exchange::<N>(
86                &fitnesses[i],
87                &fitnesses[i + 1],
88                t_i,
89                t_j,
90            );
91
92            if swap_accepted {
93                populations.swap(i, i + 1);
94                fitnesses.swap(i, i + 1);
95                constraints.swap(i, i + 1);
96                step_sizes.swap(i, i + 1);
97            }
98
99            let current_success = if swap_accepted { 1.0 } else { 0.0 };
100            self.swap_acceptance_rates[i] = self.swap_rate_smoothing * current_success
101                + (1.0 - self.swap_rate_smoothing) * self.swap_acceptance_rates[i];
102        }
103
104        if self.adaptive_swapping && self.rng.random::<f64>() < self.random_swap_probability {
105            self.attempt_random_swap(
106                populations,
107                fitnesses,
108                constraints,
109                step_sizes,
110                temperatures,
111            );
112        }
113    }
114
115    /// Attempt random non-adjacent swaps
116    fn attempt_random_swap(
117        &mut self,
118        populations: &mut [OMatrix<T, N, D>],
119        fitnesses: &mut [OVector<T, N>],
120        constraints: &mut [OVector<bool, N>],
121        step_sizes: &mut [Vec<OMatrix<T, D, D>>],
122        temperatures: &[T],
123    ) {
124        let n = populations.len();
125        if n < 3 {
126            return;
127        }
128
129        let i = self.rng.random_range(0..n);
130        let mut j = self.rng.random_range(0..n);
131
132        while j == i || j == i.wrapping_sub(1) || j == i + 1 {
133            j = self.rng.random_range(0..n);
134        }
135
136        let t_i = temperatures[i];
137        let t_j = temperatures[j];
138
139        if self.metropolis_hastings.accept_replica_exchange::<N>(
140            &fitnesses[i],
141            &fitnesses[j],
142            t_i,
143            t_j,
144        ) {
145            populations.swap(i, j);
146            fitnesses.swap(i, j);
147            constraints.swap(i, j);
148            step_sizes.swap(i, j);
149        }
150    }
151
152    pub fn get_swap_acceptance_rates(&self) -> &[f64] {
153        &self.swap_acceptance_rates
154    }
155}