Skip to main content

nuts_rs/math/
math.rs

1//! Define the backend interface that decouples the sampler from any particular hardware or logp implementation.
2
3use std::{error::Error, fmt::Debug};
4
5use nuts_storable::{HasDims, Storable, Value};
6use rand::Rng;
7
8/// Errors that happen when we evaluate the logp and gradient function
9pub trait LogpError: std::error::Error + Send {
10    /// Unrecoverable errors during logp computation stop sampling,
11    /// recoverable errors are seen as divergences.
12    fn is_recoverable(&self) -> bool;
13}
14
15pub trait Math: HasDims {
16    type Vector: Debug;
17    type EigVectors: Debug;
18    type EigValues: Debug;
19    type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
20    type Err: Debug + Send + Sync + Error + 'static;
21    type FlowParameters;
22    type ExpandedVector: Storable<Self>;
23
24    fn new_array(&mut self) -> Self::Vector;
25
26    fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
27        let mut copy = self.new_array();
28        self.copy_into(array, &mut copy);
29        copy
30    }
31
32    fn new_eig_vectors<'a>(
33        &'a mut self,
34        vals: impl ExactSizeIterator<Item = &'a [f64]>,
35    ) -> Self::EigVectors;
36    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
37
38    /// Compute the unnormalized log probability density of the posterior
39    ///
40    /// This needs to be implemnted by users of the library to define
41    /// what distribution the users wants to sample from.
42    ///
43    /// Errors during that computation can be recoverable or non-recoverable.
44    /// If a non-recoverable error occurs during sampling, the sampler will
45    /// stop and return an error.
46    fn logp_array(
47        &mut self,
48        position: &Self::Vector,
49        gradient: &mut Self::Vector,
50    ) -> Result<f64, Self::LogpErr>;
51
52    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
53
54    fn init_position<R: Rng + ?Sized>(
55        &mut self,
56        rng: &mut R,
57        position: &mut Self::Vector,
58        gradient: &mut Self::Vector,
59    ) -> Result<f64, Self::LogpErr>;
60
61    /// Expand a vector into a larger representation, to for instance
62    /// compute deterministic values that are to be stored in the trace.
63    fn expand_vector<R: Rng + ?Sized>(
64        &mut self,
65        rng: &mut R,
66        array: &Self::Vector,
67    ) -> Result<Self::ExpandedVector, Self::Err>;
68
69    fn dim(&self) -> usize;
70
71    fn vector_coord(&self) -> Option<Value> {
72        None
73    }
74
75    fn scalar_prods3(
76        &mut self,
77        positive1: &Self::Vector,
78        negative1: &Self::Vector,
79        positive2: &Self::Vector,
80        x: &Self::Vector,
81        y: &Self::Vector,
82    ) -> (f64, f64);
83
84    fn scalar_prods2(
85        &mut self,
86        positive1: &Self::Vector,
87        positive2: &Self::Vector,
88        x: &Self::Vector,
89        y: &Self::Vector,
90    ) -> (f64, f64);
91
92    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
93
94    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
95    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
96    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
97    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
98    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
99    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
100
101    fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
102        let mut data = vec![0f64; self.dim()];
103        self.write_to_slice(array, &mut data);
104        data.into()
105    }
106
107    /// Compute the sum of the natural logarithms of all elements in `array`,
108    /// i.e. `Σ ln(array[i])`.
109    ///
110    /// The default implementation copies into a temporary allocation via
111    /// [`write_to_slice`]; backends may override this with a zero-allocation
112    /// version.
113    fn array_sum_ln(&mut self, array: &Self::Vector) -> f64 {
114        let mut data = vec![0f64; self.dim()];
115        self.write_to_slice(array, &mut data);
116        data.iter().map(|x| x.ln()).sum()
117    }
118
119    fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
120
121    fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
122    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
123    fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
124    fn array_mult_inplace(&mut self, array1: &mut Self::Vector, array2: &Self::Vector);
125    fn array_recip(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
126
127    /// Apply the low-rank linear map `(I + U * (diag(vals) - I) * U^T) * rhs` into `dest`.
128    ///
129    /// `vecs` is `U` (d × r, orthonormal columns), `vals` is the diagonal vector (length r).
130    /// When `vecs` has zero columns the result is just a copy of `rhs`.
131    fn apply_lowrank_transform(
132        &mut self,
133        vecs: &Self::EigVectors,
134        vals: &Self::EigValues,
135        rhs: &Self::Vector,
136        dest: &mut Self::Vector,
137    );
138
139    fn apply_lowrank_transform_inplace(
140        &mut self,
141        vecs: &Self::EigVectors,
142        vals: &Self::EigValues,
143        rhs_and_dest: &mut Self::Vector,
144    );
145
146    fn array_mult_eigs(
147        &mut self,
148        stds: &Self::Vector,
149        rhs: &Self::Vector,
150        dest: &mut Self::Vector,
151        vecs: &Self::EigVectors,
152        vals: &Self::EigValues,
153    );
154
155    fn std_norm_flow(
156        &mut self,
157        pos: &Self::Vector,
158        pos_out: &mut Self::Vector,
159        vel: &mut Self::Vector,
160        epsilon: f64,
161    );
162    fn std_norm_grad_flow(
163        &mut self,
164        pos: &Self::Vector,
165        grad: &Self::Vector,
166        vel: &Self::Vector,
167        vel_out: &mut Self::Vector,
168        epsilon: f64,
169    );
170    fn std_norm_grad_flow_inplace(
171        &mut self,
172        pos: &Self::Vector,
173        grad: &Self::Vector,
174        vel: &mut Self::Vector,
175        epsilon: f64,
176    );
177
178    /// Normalise `v` to unit length in-place: `v := v / ‖v‖`.
179    ///
180    /// If `‖v‖ < 1e-300` the vector is left unchanged.
181    fn array_normalize(&mut self, v: &mut Self::Vector);
182
183    /// Perform one ESH (Extended Stochastic Hamiltonian) momentum half-step.
184    ///
185    /// Updates `mom` in-place so that it remains on the unit sphere, and
186    /// returns the new cumulative kinetic-energy change `prev_delta_ke + ΔKE`.
187    ///
188    /// # Algorithm
189    ///
190    /// Given momentum `p` on the unit sphere, log-density gradient `g`,
191    /// half-step size `step`, and dimension `n`:
192    ///
193    /// ```text
194    /// ĝ      = g / ‖g‖
195    /// α      = p · ĝ
196    /// Δ      = step · ‖g‖ / (n − 1)
197    /// ζ      = exp(−Δ)
198    /// p_raw  = ĝ · (1 − ζ)(1 + ζ + α(1 − ζ))  +  2ζ p
199    /// p'     = p_raw / ‖p_raw‖
200    /// ΔKE    = (Δ − log 2 + log(1 + α + (1 − α)ζ²)) · (n − 1)
201    /// ```
202    ///
203    /// Reference: Steeg & Gallagher, arXiv:2111.02434 (2021), ported from the
204    /// [BlackJAX implementation](https://github.com/blackjax-devs/blackjax/blob/main/blackjax/mcmc/integrators.py#L314).
205    fn esh_momentum_update(
206        &mut self,
207        grad: &Self::Vector,
208        mom: &mut Self::Vector,
209        step: f64,
210    ) -> f64;
211
212    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
213    fn array_gaussian<R: rand::Rng + ?Sized>(
214        &mut self,
215        rng: &mut R,
216        dest: &mut Self::Vector,
217        stds: &Self::Vector,
218    );
219    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
220        &mut self,
221        rng: &mut R,
222        dest: &mut Self::Vector,
223        scale: &Self::Vector,
224        vals: &Self::EigValues,
225        vecs: &Self::EigVectors,
226    );
227    fn array_update_variance(
228        &mut self,
229        mean: &mut Self::Vector,
230        variance: &mut Self::Vector,
231        value: &Self::Vector,
232        diff_scale: f64,
233    );
234    fn array_update_var_inv_std_draw(
235        &mut self,
236        inv_std: &mut Self::Vector,
237        std: &mut Self::Vector,
238        draw_var: &Self::Vector,
239        scale: f64,
240        fill_invalid: Option<f64>,
241        clamp: (f64, f64),
242    );
243    fn array_update_var_inv_std_draw_grad(
244        &mut self,
245        inv_std: &mut Self::Vector,
246        std: &mut Self::Vector,
247        draw_var: &Self::Vector,
248        grad_var: &Self::Vector,
249        fill_invalid: Option<f64>,
250        clamp: (f64, f64),
251    );
252
253    fn array_update_var_inv_std_grad(
254        &mut self,
255        inv_std: &mut Self::Vector,
256        std: &mut Self::Vector,
257        gradient: &Self::Vector,
258        fill_invalid: f64,
259        clamp: (f64, f64),
260    );
261
262    fn inv_transform_normalize(
263        &mut self,
264        params: &Self::FlowParameters,
265        untransformed_position: &Self::Vector,
266        untransofrmed_gradient: &Self::Vector,
267        transformed_position: &mut Self::Vector,
268        transformed_gradient: &mut Self::Vector,
269    ) -> Result<f64, Self::LogpErr>;
270
271    fn init_from_untransformed_position(
272        &mut self,
273        params: &Self::FlowParameters,
274        untransformed_position: &Self::Vector,
275        untransformed_gradient: &mut Self::Vector,
276        transformed_position: &mut Self::Vector,
277        transformed_gradient: &mut Self::Vector,
278    ) -> Result<(f64, f64), Self::LogpErr>;
279
280    fn init_from_transformed_position(
281        &mut self,
282        params: &Self::FlowParameters,
283        untransformed_position: &mut Self::Vector,
284        untransformed_gradient: &mut Self::Vector,
285        transformed_position: &Self::Vector,
286        transformed_gradient: &mut Self::Vector,
287    ) -> Result<(f64, f64), Self::LogpErr>;
288
289    fn update_transformation<'a, R: rand::Rng + ?Sized>(
290        &'a mut self,
291        rng: &mut R,
292        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
293        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
294        untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
295        params: &'a mut Self::FlowParameters,
296    ) -> Result<(), Self::LogpErr>;
297
298    fn new_transformation<R: rand::Rng + ?Sized>(
299        &mut self,
300        rng: &mut R,
301        dim: usize,
302        chain: u64,
303    ) -> Result<Self::FlowParameters, Self::LogpErr>;
304
305    fn init_transformation<R: rand::Rng + ?Sized>(
306        &mut self,
307        rng: &mut R,
308        untransformed_position: &Self::Vector,
309        untransfogmed_gradient: &Self::Vector,
310        chain: u64,
311    ) -> Result<Self::FlowParameters, Self::LogpErr>;
312
313    fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr>;
314}