1use std::{error::Error, fmt::Debug};
2
3pub trait LogpError: std::error::Error + Send {
5 fn is_recoverable(&self) -> bool;
8}
9
10pub trait Math {
11 type Vector: Debug;
12 type EigVectors: Debug;
13 type EigValues: Debug;
14 type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static;
15 type Err: Debug + Send + Sync + Error + 'static;
16 type TransformParams;
17
18 fn new_array(&mut self) -> Self::Vector;
19
20 fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector {
21 let mut copy = self.new_array();
22 self.copy_into(array, &mut copy);
23 copy
24 }
25
26 fn new_eig_vectors<'a>(
27 &'a mut self,
28 vals: impl ExactSizeIterator<Item = &'a [f64]>,
29 ) -> Self::EigVectors;
30 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues;
31
32 fn logp_array(
41 &mut self,
42 position: &Self::Vector,
43 gradient: &mut Self::Vector,
44 ) -> Result<f64, Self::LogpErr>;
45
46 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr>;
47
48 fn dim(&self) -> usize;
49
50 fn scalar_prods3(
51 &mut self,
52 positive1: &Self::Vector,
53 negative1: &Self::Vector,
54 positive2: &Self::Vector,
55 x: &Self::Vector,
56 y: &Self::Vector,
57 ) -> (f64, f64);
58
59 fn scalar_prods2(
60 &mut self,
61 positive1: &Self::Vector,
62 positive2: &Self::Vector,
63 x: &Self::Vector,
64 y: &Self::Vector,
65 ) -> (f64, f64);
66
67 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
68
69 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
70 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
71 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
72 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector);
73 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector);
74 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64);
75
76 fn box_array(&mut self, array: &Self::Vector) -> Box<[f64]> {
77 let mut data = vec![0f64; self.dim()];
78 self.write_to_slice(array, &mut data);
79 data.into()
80 }
81
82 fn fill_array(&mut self, array: &mut Self::Vector, val: f64);
83
84 fn array_all_finite(&mut self, array: &Self::Vector) -> bool;
85 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool;
86 fn array_mult(&mut self, array1: &Self::Vector, array2: &Self::Vector, dest: &mut Self::Vector);
87 fn array_mult_eigs(
88 &mut self,
89 stds: &Self::Vector,
90 rhs: &Self::Vector,
91 dest: &mut Self::Vector,
92 vecs: &Self::EigVectors,
93 vals: &Self::EigValues,
94 );
95 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64;
96 fn array_gaussian<R: rand::Rng + ?Sized>(
97 &mut self,
98 rng: &mut R,
99 dest: &mut Self::Vector,
100 stds: &Self::Vector,
101 );
102 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
103 &mut self,
104 rng: &mut R,
105 dest: &mut Self::Vector,
106 scale: &Self::Vector,
107 vals: &Self::EigValues,
108 vecs: &Self::EigVectors,
109 );
110 fn array_update_variance(
111 &mut self,
112 mean: &mut Self::Vector,
113 variance: &mut Self::Vector,
114 value: &Self::Vector,
115 diff_scale: f64,
116 );
117 fn array_update_var_inv_std_draw(
118 &mut self,
119 variance_out: &mut Self::Vector,
120 inv_std: &mut Self::Vector,
121 draw_var: &Self::Vector,
122 scale: f64,
123 fill_invalid: Option<f64>,
124 clamp: (f64, f64),
125 );
126 fn array_update_var_inv_std_draw_grad(
127 &mut self,
128 variance_out: &mut Self::Vector,
129 inv_std: &mut Self::Vector,
130 draw_var: &Self::Vector,
131 grad_var: &Self::Vector,
132 fill_invalid: Option<f64>,
133 clamp: (f64, f64),
134 );
135
136 fn array_update_var_inv_std_grad(
137 &mut self,
138 variance_out: &mut Self::Vector,
139 inv_std: &mut Self::Vector,
140 gradient: &Self::Vector,
141 fill_invalid: f64,
142 clamp: (f64, f64),
143 );
144
145 fn inv_transform_normalize(
146 &mut self,
147 params: &Self::TransformParams,
148 untransformed_position: &Self::Vector,
149 untransofrmed_gradient: &Self::Vector,
150 transformed_position: &mut Self::Vector,
151 transformed_gradient: &mut Self::Vector,
152 ) -> Result<f64, Self::LogpErr>;
153
154 fn init_from_untransformed_position(
155 &mut self,
156 params: &Self::TransformParams,
157 untransformed_position: &Self::Vector,
158 untransformed_gradient: &mut Self::Vector,
159 transformed_position: &mut Self::Vector,
160 transformed_gradient: &mut Self::Vector,
161 ) -> Result<(f64, f64), Self::LogpErr>;
162
163 fn init_from_transformed_position(
164 &mut self,
165 params: &Self::TransformParams,
166 untransformed_position: &mut Self::Vector,
167 untransformed_gradient: &mut Self::Vector,
168 transformed_position: &Self::Vector,
169 transformed_gradient: &mut Self::Vector,
170 ) -> Result<(f64, f64), Self::LogpErr>;
171
172 fn update_transformation<'a, R: rand::Rng + ?Sized>(
173 &'a mut self,
174 rng: &mut R,
175 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
176 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
177 untransformed_logps: impl ExactSizeIterator<Item = &'a f64>,
178 params: &'a mut Self::TransformParams,
179 ) -> Result<(), Self::LogpErr>;
180
181 fn new_transformation<R: rand::Rng + ?Sized>(
182 &mut self,
183 rng: &mut R,
184 untransformed_position: &Self::Vector,
185 untransfogmed_gradient: &Self::Vector,
186 chain: u64,
187 ) -> Result<Self::TransformParams, Self::LogpErr>;
188
189 fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr>;
190}