1use std::{error::Error, fmt::Debug};
4
5use nuts_storable::{HasDims, Storable, Value};
6use rand::Rng;
7
8pub trait LogpError: std::error::Error + Send {
10 fn is_recoverable(&self) -> bool;
13}
14
15pub trait Math: HasDims {
16 type Vector: Debug;
17 type EigVectors: Debug;
18 type EigValues: Debug;
19 type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
20 type Err: Debug + Send + Sync + Error + 'static;
21 type FlowParameters;
22 type ExpandedVector: Storable<Self>;
23
24 fn new_array(&mut self) -> Self::Vector;
25
26 fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
27 let mut copy = self.new_array();
28 self.copy_into(array, &mut copy);
29 copy
30 }
31
32 fn new_eig_vectors<'a>(
33 &'a mut self,
34 vals: impl ExactSizeIterator<Item = &'a [f64]>,
35 ) -> Self::EigVectors;
36 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
37
38 fn logp_array(
47 &mut self,
48 position: &Self::Vector,
49 gradient: &mut Self::Vector,
50 ) -> Result<f64, Self::LogpErr>;
51
52 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
53
54 fn init_position<R: Rng + ?Sized>(
55 &mut self,
56 rng: &mut R,
57 position: &mut Self::Vector,
58 gradient: &mut Self::Vector,
59 ) -> Result<f64, Self::LogpErr>;
60
61 fn expand_vector<R: Rng + ?Sized>(
64 &mut self,
65 rng: &mut R,
66 array: &Self::Vector,
67 ) -> Result<Self::ExpandedVector, Self::Err>;
68
69 fn dim(&self) -> usize;
70
71 fn vector_coord(&self) -> Option<Value> {
72 None
73 }
74
75 fn scalar_prods3(
76 &mut self,
77 positive1: &Self::Vector,
78 negative1: &Self::Vector,
79 positive2: &Self::Vector,
80 x: &Self::Vector,
81 y: &Self::Vector,
82 ) -> (f64, f64);
83
84 fn scalar_prods2(
85 &mut self,
86 positive1: &Self::Vector,
87 positive2: &Self::Vector,
88 x: &Self::Vector,
89 y: &Self::Vector,
90 ) -> (f64, f64);
91
92 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
93
94 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
95 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
96 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
97 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
98 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
99 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
100
101 fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
102 let mut data = vec![0f64; self.dim()];
103 self.write_to_slice(array, &mut data);
104 data.into()
105 }
106
107 fn array_sum_ln(&mut self, array: &Self::Vector) -> f64 {
114 let mut data = vec![0f64; self.dim()];
115 self.write_to_slice(array, &mut data);
116 data.iter().map(|x| x.ln()).sum()
117 }
118
119 fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
120
121 fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
122 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
123 fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
124 fn array_mult_inplace(&mut self, array1: &mut Self::Vector, array2: &Self::Vector);
125 fn array_recip(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
126
127 fn apply_lowrank_transform(
132 &mut self,
133 vecs: &Self::EigVectors,
134 vals: &Self::EigValues,
135 rhs: &Self::Vector,
136 dest: &mut Self::Vector,
137 );
138
139 fn apply_lowrank_transform_inplace(
140 &mut self,
141 vecs: &Self::EigVectors,
142 vals: &Self::EigValues,
143 rhs_and_dest: &mut Self::Vector,
144 );
145
146 fn array_mult_eigs(
147 &mut self,
148 stds: &Self::Vector,
149 rhs: &Self::Vector,
150 dest: &mut Self::Vector,
151 vecs: &Self::EigVectors,
152 vals: &Self::EigValues,
153 );
154
155 fn std_norm_flow(
156 &mut self,
157 pos: &Self::Vector,
158 pos_out: &mut Self::Vector,
159 vel: &mut Self::Vector,
160 epsilon: f64,
161 );
162 fn std_norm_grad_flow(
163 &mut self,
164 pos: &Self::Vector,
165 grad: &Self::Vector,
166 vel: &Self::Vector,
167 vel_out: &mut Self::Vector,
168 epsilon: f64,
169 );
170 fn std_norm_grad_flow_inplace(
171 &mut self,
172 pos: &Self::Vector,
173 grad: &Self::Vector,
174 vel: &mut Self::Vector,
175 epsilon: f64,
176 );
177
178 fn array_normalize(&mut self, v: &mut Self::Vector);
182
183 fn esh_momentum_update(
206 &mut self,
207 grad: &Self::Vector,
208 mom: &mut Self::Vector,
209 step: f64,
210 ) -> f64;
211
212 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
213 fn array_gaussian<R: rand::Rng + ?Sized>(
214 &mut self,
215 rng: &mut R,
216 dest: &mut Self::Vector,
217 stds: &Self::Vector,
218 );
219 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
220 &mut self,
221 rng: &mut R,
222 dest: &mut Self::Vector,
223 scale: &Self::Vector,
224 vals: &Self::EigValues,
225 vecs: &Self::EigVectors,
226 );
227 fn array_update_variance(
228 &mut self,
229 mean: &mut Self::Vector,
230 variance: &mut Self::Vector,
231 value: &Self::Vector,
232 diff_scale: f64,
233 );
234 fn array_update_var_inv_std_draw(
235 &mut self,
236 inv_std: &mut Self::Vector,
237 std: &mut Self::Vector,
238 draw_var: &Self::Vector,
239 scale: f64,
240 fill_invalid: Option<f64>,
241 clamp: (f64, f64),
242 );
243 fn array_update_var_inv_std_draw_grad(
244 &mut self,
245 inv_std: &mut Self::Vector,
246 std: &mut Self::Vector,
247 draw_var: &Self::Vector,
248 grad_var: &Self::Vector,
249 fill_invalid: Option<f64>,
250 clamp: (f64, f64),
251 );
252
253 fn array_update_var_inv_std_grad(
254 &mut self,
255 inv_std: &mut Self::Vector,
256 std: &mut Self::Vector,
257 gradient: &Self::Vector,
258 fill_invalid: f64,
259 clamp: (f64, f64),
260 );
261
262 fn inv_transform_normalize(
263 &mut self,
264 params: &Self::FlowParameters,
265 untransformed_position: &Self::Vector,
266 untransofrmed_gradient: &Self::Vector,
267 transformed_position: &mut Self::Vector,
268 transformed_gradient: &mut Self::Vector,
269 ) -> Result<f64, Self::LogpErr>;
270
271 fn init_from_untransformed_position(
272 &mut self,
273 params: &Self::FlowParameters,
274 untransformed_position: &Self::Vector,
275 untransformed_gradient: &mut Self::Vector,
276 transformed_position: &mut Self::Vector,
277 transformed_gradient: &mut Self::Vector,
278 ) -> Result<(f64, f64), Self::LogpErr>;
279
280 fn init_from_transformed_position(
281 &mut self,
282 params: &Self::FlowParameters,
283 untransformed_position: &mut Self::Vector,
284 untransformed_gradient: &mut Self::Vector,
285 transformed_position: &Self::Vector,
286 transformed_gradient: &mut Self::Vector,
287 ) -> Result<(f64, f64), Self::LogpErr>;
288
289 fn update_transformation<'a, R: rand::Rng + ?Sized>(
290 &'a mut self,
291 rng: &mut R,
292 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
293 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
294 untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
295 params: &'a mut Self::FlowParameters,
296 ) -> Result<(), Self::LogpErr>;
297
298 fn new_transformation<R: rand::Rng + ?Sized>(
299 &mut self,
300 rng: &mut R,
301 dim: usize,
302 chain: u64,
303 ) -> Result<Self::FlowParameters, Self::LogpErr>;
304
305 fn init_transformation<R: rand::Rng + ?Sized>(
306 &mut self,
307 rng: &mut R,
308 untransformed_position: &Self::Vector,
309 untransfogmed_gradient: &Self::Vector,
310 chain: u64,
311 ) -> Result<Self::FlowParameters, Self::LogpErr>;
312
313 fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr>;
314}