1use std::{error::Error, fmt::Debug};
2
3use nuts_storable::{HasDims, Storable, Value};
4use rand::Rng;
5
6pub trait LogpError: std::error::Error + Send {
8 fn is_recoverable(&self) -> bool;
11}
12
13pub trait Math: HasDims {
14 type Vector: Debug;
15 type EigVectors: Debug;
16 type EigValues: Debug;
17 type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
18 type Err: Debug + Send + Sync + Error + 'static;
19 type FlowParameters;
20 type ExpandedVector: Storable<Self>;
21
22 fn new_array(&mut self) -> Self::Vector;
23
24 fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
25 let mut copy = self.new_array();
26 self.copy_into(array, &mut copy);
27 copy
28 }
29
30 fn new_eig_vectors<'a>(
31 &'a mut self,
32 vals: impl ExactSizeIterator<Item = &'a [f64]>,
33 ) -> Self::EigVectors;
34 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
35
36 fn logp_array(
45 &mut self,
46 position: &Self::Vector,
47 gradient: &mut Self::Vector,
48 ) -> Result<f64, Self::LogpErr>;
49
50 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
51
52 fn init_position<R: Rng + ?Sized>(
53 &mut self,
54 rng: &mut R,
55 position: &mut Self::Vector,
56 gradient: &mut Self::Vector,
57 ) -> Result<f64, Self::LogpErr>;
58
59 fn expand_vector<R: Rng + ?Sized>(
62 &mut self,
63 rng: &mut R,
64 array: &Self::Vector,
65 ) -> Result<Self::ExpandedVector, Self::Err>;
66
67 fn dim(&self) -> usize;
68
69 fn vector_coord(&self) -> Option<Value> {
70 None
71 }
72
73 fn scalar_prods3(
74 &mut self,
75 positive1: &Self::Vector,
76 negative1: &Self::Vector,
77 positive2: &Self::Vector,
78 x: &Self::Vector,
79 y: &Self::Vector,
80 ) -> (f64, f64);
81
82 fn scalar_prods2(
83 &mut self,
84 positive1: &Self::Vector,
85 positive2: &Self::Vector,
86 x: &Self::Vector,
87 y: &Self::Vector,
88 ) -> (f64, f64);
89
90 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
91
92 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
93 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
94 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
95 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
96 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
97 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
98
99 fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
100 let mut data = vec![0f64; self.dim()];
101 self.write_to_slice(array, &mut data);
102 data.into()
103 }
104
105 fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
106
107 fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
108 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
109 fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
110 fn array_mult_eigs(
111 &mut self,
112 stds: &Self::Vector,
113 rhs: &Self::Vector,
114 dest: &mut Self::Vector,
115 vecs: &Self::EigVectors,
116 vals: &Self::EigValues,
117 );
118 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
119 fn array_gaussian<R: rand::Rng + ?Sized>(
120 &mut self,
121 rng: &mut R,
122 dest: &mut Self::Vector,
123 stds: &Self::Vector,
124 );
125 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
126 &mut self,
127 rng: &mut R,
128 dest: &mut Self::Vector,
129 scale: &Self::Vector,
130 vals: &Self::EigValues,
131 vecs: &Self::EigVectors,
132 );
133 fn array_update_variance(
134 &mut self,
135 mean: &mut Self::Vector,
136 variance: &mut Self::Vector,
137 value: &Self::Vector,
138 diff_scale: f64,
139 );
140 fn array_update_var_inv_std_draw(
141 &mut self,
142 variance_out: &mut Self::Vector,
143 inv_std: &mut Self::Vector,
144 draw_var: &Self::Vector,
145 scale: f64,
146 fill_invalid: Option<f64>,
147 clamp: (f64, f64),
148 );
149 fn array_update_var_inv_std_draw_grad(
150 &mut self,
151 variance_out: &mut Self::Vector,
152 inv_std: &mut Self::Vector,
153 draw_var: &Self::Vector,
154 grad_var: &Self::Vector,
155 fill_invalid: Option<f64>,
156 clamp: (f64, f64),
157 );
158
159 fn array_update_var_inv_std_grad(
160 &mut self,
161 variance_out: &mut Self::Vector,
162 inv_std: &mut Self::Vector,
163 gradient: &Self::Vector,
164 fill_invalid: f64,
165 clamp: (f64, f64),
166 );
167
168 fn inv_transform_normalize(
169 &mut self,
170 params: &Self::FlowParameters,
171 untransformed_position: &Self::Vector,
172 untransofrmed_gradient: &Self::Vector,
173 transformed_position: &mut Self::Vector,
174 transformed_gradient: &mut Self::Vector,
175 ) -> Result<f64, Self::LogpErr>;
176
177 fn init_from_untransformed_position(
178 &mut self,
179 params: &Self::FlowParameters,
180 untransformed_position: &Self::Vector,
181 untransformed_gradient: &mut Self::Vector,
182 transformed_position: &mut Self::Vector,
183 transformed_gradient: &mut Self::Vector,
184 ) -> Result<(f64, f64), Self::LogpErr>;
185
186 fn init_from_transformed_position(
187 &mut self,
188 params: &Self::FlowParameters,
189 untransformed_position: &mut Self::Vector,
190 untransformed_gradient: &mut Self::Vector,
191 transformed_position: &Self::Vector,
192 transformed_gradient: &mut Self::Vector,
193 ) -> Result<(f64, f64), Self::LogpErr>;
194
195 fn update_transformation<'a, R: rand::Rng + ?Sized>(
196 &'a mut self,
197 rng: &mut R,
198 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
199 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
200 untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
201 params: &'a mut Self::FlowParameters,
202 ) -> Result<(), Self::LogpErr>;
203
204 fn new_transformation<R: rand::Rng + ?Sized>(
205 &mut self,
206 rng: &mut R,
207 untransformed_position: &Self::Vector,
208 untransfogmed_gradient: &Self::Vector,
209 chain: u64,
210 ) -> Result<Self::FlowParameters, Self::LogpErr>;
211
212 fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr>;
213}