nuts_rs/
math_base.rs

1use std::{error::Error, fmt::Debug};
2
3/// Errors that happen when we evaluate the logp and gradient function
4pub trait LogpError: std::error::Error + Send {
5    /// Unrecoverable errors during logp computation stop sampling,
6    /// recoverable errors are seen as divergences.
7    fn is_recoverable(&self) -> bool;
8}
9
10pub trait Math {
11    type Vector: Debug;
12    type EigVectors: Debug;
13    type EigValues: Debug;
14    type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
15    type Err: Debug + Send + Sync + Error + 'static;
16    type TransformParams;
17
18    fn new_array(&mut self) -> Self::Vector;
19
20    fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
21        let mut copy = self.new_array();
22        self.copy_into(array, &mut copy);
23        copy
24    }
25
26    fn new_eig_vectors<'a>(
27        &'a mut self,
28        vals: impl ExactSizeIterator<Item = &'a [f64]>,
29    ) -> Self::EigVectors;
30    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
31
32    /// Compute the unnormalized log probability density of the posterior
33    ///
34    /// This needs to be implemnted by users of the library to define
35    /// what distribution the users wants to sample from.
36    ///
37    /// Errors during that computation can be recoverable or non-recoverable.
38    /// If a non-recoverable error occurs during sampling, the sampler will
39    /// stop and return an error.
40    fn logp_array(
41        &mut self,
42        position: &Self::Vector,
43        gradient: &mut Self::Vector,
44    ) -> Result<f64, Self::LogpErr>;
45
46    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
47
48    fn dim(&self) -> usize;
49
50    fn scalar_prods3(
51        &mut self,
52        positive1: &Self::Vector,
53        negative1: &Self::Vector,
54        positive2: &Self::Vector,
55        x: &Self::Vector,
56        y: &Self::Vector,
57    ) -> (f64, f64);
58
59    fn scalar_prods2(
60        &mut self,
61        positive1: &Self::Vector,
62        positive2: &Self::Vector,
63        x: &Self::Vector,
64        y: &Self::Vector,
65    ) -> (f64, f64);
66
67    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
68
69    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
70    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
71    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
72    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
73    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
74    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
75
76    fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
77        let mut data = vec![0f64; self.dim()];
78        self.write_to_slice(array, &mut data);
79        data.into()
80    }
81
82    fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
83
84    fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
85    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
86    fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
87    fn array_mult_eigs(
88        &mut self,
89        stds: &Self::Vector,
90        rhs: &Self::Vector,
91        dest: &mut Self::Vector,
92        vecs: &Self::EigVectors,
93        vals: &Self::EigValues,
94    );
95    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
96    fn array_gaussian<R: rand::Rng + ?Sized>(
97        &mut self,
98        rng: &mut R,
99        dest: &mut Self::Vector,
100        stds: &Self::Vector,
101    );
102    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
103        &mut self,
104        rng: &mut R,
105        dest: &mut Self::Vector,
106        scale: &Self::Vector,
107        vals: &Self::EigValues,
108        vecs: &Self::EigVectors,
109    );
110    fn array_update_variance(
111        &mut self,
112        mean: &mut Self::Vector,
113        variance: &mut Self::Vector,
114        value: &Self::Vector,
115        diff_scale: f64,
116    );
117    fn array_update_var_inv_std_draw(
118        &mut self,
119        variance_out: &mut Self::Vector,
120        inv_std: &mut Self::Vector,
121        draw_var: &Self::Vector,
122        scale: f64,
123        fill_invalid: Option<f64>,
124        clamp: (f64, f64),
125    );
126    fn array_update_var_inv_std_draw_grad(
127        &mut self,
128        variance_out: &mut Self::Vector,
129        inv_std: &mut Self::Vector,
130        draw_var: &Self::Vector,
131        grad_var: &Self::Vector,
132        fill_invalid: Option<f64>,
133        clamp: (f64, f64),
134    );
135
136    fn array_update_var_inv_std_grad(
137        &mut self,
138        variance_out: &mut Self::Vector,
139        inv_std: &mut Self::Vector,
140        gradient: &Self::Vector,
141        fill_invalid: f64,
142        clamp: (f64, f64),
143    );
144
145    fn inv_transform_normalize(
146        &mut self,
147        params: &Self::TransformParams,
148        untransformed_position: &Self::Vector,
149        untransofrmed_gradient: &Self::Vector,
150        transformed_position: &mut Self::Vector,
151        transformed_gradient: &mut Self::Vector,
152    ) -> Result<f64, Self::LogpErr>;
153
154    fn init_from_untransformed_position(
155        &mut self,
156        params: &Self::TransformParams,
157        untransformed_position: &Self::Vector,
158        untransformed_gradient: &mut Self::Vector,
159        transformed_position: &mut Self::Vector,
160        transformed_gradient: &mut Self::Vector,
161    ) -> Result<(f64, f64), Self::LogpErr>;
162
163    fn init_from_transformed_position(
164        &mut self,
165        params: &Self::TransformParams,
166        untransformed_position: &mut Self::Vector,
167        untransformed_gradient: &mut Self::Vector,
168        transformed_position: &Self::Vector,
169        transformed_gradient: &mut Self::Vector,
170    ) -> Result<(f64, f64), Self::LogpErr>;
171
172    fn update_transformation<'a, R: rand::Rng + ?Sized>(
173        &'a mut self,
174        rng: &mut R,
175        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
176        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
177        untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
178        params: &'a mut Self::TransformParams,
179    ) -> Result<(), Self::LogpErr>;
180
181    fn new_transformation<R: rand::Rng + ?Sized>(
182        &mut self,
183        rng: &mut R,
184        untransformed_position: &Self::Vector,
185        untransfogmed_gradient: &Self::Vector,
186        chain: u64,
187    ) -> Result<Self::TransformParams, Self::LogpErr>;
188
189    fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr>;
190}