nuts_rs/
math_base.rs

1use std::{error::Error, fmt::Debug};
2
3use nuts_storable::{HasDims, Storable, Value};
4use rand::Rng;
5
6/// Errors that happen when we evaluate the logp and gradient function
7pub trait LogpError: std::error::Error + Send {
8    /// Unrecoverable errors during logp computation stop sampling,
9    /// recoverable errors are seen as divergences.
10    fn is_recoverable(&self) -> bool;
11}
12
13pub trait Math: HasDims {
14    type Vector: Debug;
15    type EigVectors: Debug;
16    type EigValues: Debug;
17    type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
18    type Err: Debug + Send + Sync + Error + 'static;
19    type FlowParameters;
20    type ExpandedVector: Storable<Self>;
21
22    fn new_array(&mut self) -> Self::Vector;
23
24    fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
25        let mut copy = self.new_array();
26        self.copy_into(array, &mut copy);
27        copy
28    }
29
30    fn new_eig_vectors<'a>(
31        &'a mut self,
32        vals: impl ExactSizeIterator<Item = &'a [f64]>,
33    ) -> Self::EigVectors;
34    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
35
36    /// Compute the unnormalized log probability density of the posterior
37    ///
38    /// This needs to be implemnted by users of the library to define
39    /// what distribution the users wants to sample from.
40    ///
41    /// Errors during that computation can be recoverable or non-recoverable.
42    /// If a non-recoverable error occurs during sampling, the sampler will
43    /// stop and return an error.
44    fn logp_array(
45        &mut self,
46        position: &Self::Vector,
47        gradient: &mut Self::Vector,
48    ) -> Result<f64, Self::LogpErr>;
49
50    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
51
52    fn init_position<R: Rng + ?Sized>(
53        &mut self,
54        rng: &mut R,
55        position: &mut Self::Vector,
56        gradient: &mut Self::Vector,
57    ) -> Result<f64, Self::LogpErr>;
58
59    /// Expand a vector into a larger representation, to for instance
60    /// compute deterministic values that are to be stored in the trace.
61    fn expand_vector<R: Rng + ?Sized>(
62        &mut self,
63        rng: &mut R,
64        array: &Self::Vector,
65    ) -> Result<Self::ExpandedVector, Self::Err>;
66
67    fn dim(&self) -> usize;
68
69    fn vector_coord(&self) -> Option<Value> {
70        None
71    }
72
73    fn scalar_prods3(
74        &mut self,
75        positive1: &Self::Vector,
76        negative1: &Self::Vector,
77        positive2: &Self::Vector,
78        x: &Self::Vector,
79        y: &Self::Vector,
80    ) -> (f64, f64);
81
82    fn scalar_prods2(
83        &mut self,
84        positive1: &Self::Vector,
85        positive2: &Self::Vector,
86        x: &Self::Vector,
87        y: &Self::Vector,
88    ) -> (f64, f64);
89
90    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
91
92    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
93    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
94    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
95    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
96    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
97    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
98
99    fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
100        let mut data = vec![0f64; self.dim()];
101        self.write_to_slice(array, &mut data);
102        data.into()
103    }
104
105    fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
106
107    fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
108    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
109    fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
110    fn array_mult_eigs(
111        &mut self,
112        stds: &Self::Vector,
113        rhs: &Self::Vector,
114        dest: &mut Self::Vector,
115        vecs: &Self::EigVectors,
116        vals: &Self::EigValues,
117    );
118    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
119    fn array_gaussian<R: rand::Rng + ?Sized>(
120        &mut self,
121        rng: &mut R,
122        dest: &mut Self::Vector,
123        stds: &Self::Vector,
124    );
125    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
126        &mut self,
127        rng: &mut R,
128        dest: &mut Self::Vector,
129        scale: &Self::Vector,
130        vals: &Self::EigValues,
131        vecs: &Self::EigVectors,
132    );
133    fn array_update_variance(
134        &mut self,
135        mean: &mut Self::Vector,
136        variance: &mut Self::Vector,
137        value: &Self::Vector,
138        diff_scale: f64,
139    );
140    fn array_update_var_inv_std_draw(
141        &mut self,
142        variance_out: &mut Self::Vector,
143        inv_std: &mut Self::Vector,
144        draw_var: &Self::Vector,
145        scale: f64,
146        fill_invalid: Option<f64>,
147        clamp: (f64, f64),
148    );
149    fn array_update_var_inv_std_draw_grad(
150        &mut self,
151        variance_out: &mut Self::Vector,
152        inv_std: &mut Self::Vector,
153        draw_var: &Self::Vector,
154        grad_var: &Self::Vector,
155        fill_invalid: Option<f64>,
156        clamp: (f64, f64),
157    );
158
159    fn array_update_var_inv_std_grad(
160        &mut self,
161        variance_out: &mut Self::Vector,
162        inv_std: &mut Self::Vector,
163        gradient: &Self::Vector,
164        fill_invalid: f64,
165        clamp: (f64, f64),
166    );
167
168    fn inv_transform_normalize(
169        &mut self,
170        params: &Self::FlowParameters,
171        untransformed_position: &Self::Vector,
172        untransofrmed_gradient: &Self::Vector,
173        transformed_position: &mut Self::Vector,
174        transformed_gradient: &mut Self::Vector,
175    ) -> Result<f64, Self::LogpErr>;
176
177    fn init_from_untransformed_position(
178        &mut self,
179        params: &Self::FlowParameters,
180        untransformed_position: &Self::Vector,
181        untransformed_gradient: &mut Self::Vector,
182        transformed_position: &mut Self::Vector,
183        transformed_gradient: &mut Self::Vector,
184    ) -> Result<(f64, f64), Self::LogpErr>;
185
186    fn init_from_transformed_position(
187        &mut self,
188        params: &Self::FlowParameters,
189        untransformed_position: &mut Self::Vector,
190        untransformed_gradient: &mut Self::Vector,
191        transformed_position: &Self::Vector,
192        transformed_gradient: &mut Self::Vector,
193    ) -> Result<(f64, f64), Self::LogpErr>;
194
195    fn update_transformation<'a, R: rand::Rng + ?Sized>(
196        &'a mut self,
197        rng: &mut R,
198        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
199        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
200        untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
201        params: &'a mut Self::FlowParameters,
202    ) -> Result<(), Self::LogpErr>;
203
204    fn new_transformation<R: rand::Rng + ?Sized>(
205        &mut self,
206        rng: &mut R,
207        untransformed_position: &Self::Vector,
208        untransfogmed_gradient: &Self::Vector,
209        chain: u64,
210    ) -> Result<Self::FlowParameters, Self::LogpErr>;
211
212    fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr>;
213}