1use std::{error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{izip, Itertools};
5use thiserror::Error;
6
7use crate::{
8 math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
9 math_base::{LogpError, Math},
10};
11
12#[derive(Debug)]
13pub struct CpuMath<F: CpuLogpFunc> {
14 logp_func: F,
15 arch: pulp::Arch,
16}
17
18impl<F: CpuLogpFunc> CpuMath<F> {
19 pub fn new(logp_func: F) -> Self {
20 let arch = pulp::Arch::new();
21 Self { logp_func, arch }
22 }
23}
24
25#[non_exhaustive]
26#[derive(Error, Debug)]
27pub enum CpuMathError {
28 #[error("Error during array operation")]
29 ArrayError(),
30}
31
32impl<F: CpuLogpFunc> Math for CpuMath<F> {
33 type Vector = Col<f64>;
34 type EigVectors = Mat<f64>;
35 type EigValues = Col<f64>;
36 type LogpErr = F::LogpError;
37 type Err = CpuMathError;
38 type TransformParams = F::TransformParams;
39
40 fn new_array(&mut self) -> Self::Vector {
41 Col::zeros(self.dim())
42 }
43
44 fn new_eig_vectors<'a>(
45 &'a mut self,
46 vals: impl ExactSizeIterator<Item = &'a [f64]>,
47 ) -> Self::EigVectors {
48 let ndim = self.dim();
49 let nvecs = vals.len();
50
51 let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
52 vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
53 col.try_as_col_major_mut()
54 .expect("Array is not contiguous")
55 .as_slice_mut()
56 .copy_from_slice(vals)
57 });
58
59 vectors
60 }
61
62 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
63 let mut values: Col<f64> = Col::zeros(vals.len());
64 values
65 .try_as_col_major_mut()
66 .expect("Array is not contiguous")
67 .as_slice_mut()
68 .copy_from_slice(vals);
69 values
70 }
71
72 fn logp_array(
73 &mut self,
74 position: &Self::Vector,
75 gradient: &mut Self::Vector,
76 ) -> Result<f64, Self::LogpErr> {
77 self.logp_func.logp(
78 position
79 .try_as_col_major()
80 .expect("Array is not contiguous")
81 .as_slice(),
82 gradient
83 .try_as_col_major_mut()
84 .expect("Array is not contiguous")
85 .as_slice_mut(),
86 )
87 }
88
89 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
90 self.logp_func.logp(position, gradient)
91 }
92
93 fn dim(&self) -> usize {
94 self.logp_func.dim()
95 }
96
97 fn scalar_prods3(
98 &mut self,
99 positive1: &Self::Vector,
100 negative1: &Self::Vector,
101 positive2: &Self::Vector,
102 x: &Self::Vector,
103 y: &Self::Vector,
104 ) -> (f64, f64) {
105 scalar_prods3(
106 positive1.try_as_col_major().unwrap().as_slice(),
107 negative1.try_as_col_major().unwrap().as_slice(),
108 positive2.try_as_col_major().unwrap().as_slice(),
109 x.try_as_col_major().unwrap().as_slice(),
110 y.try_as_col_major().unwrap().as_slice(),
111 )
112 }
113
114 fn scalar_prods2(
115 &mut self,
116 positive1: &Self::Vector,
117 positive2: &Self::Vector,
118 x: &Self::Vector,
119 y: &Self::Vector,
120 ) -> (f64, f64) {
121 scalar_prods2(
122 positive1.try_as_col_major().unwrap().as_slice(),
123 positive2.try_as_col_major().unwrap().as_slice(),
124 x.try_as_col_major().unwrap().as_slice(),
125 y.try_as_col_major().unwrap().as_slice(),
126 )
127 }
128
129 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
130 x.try_as_col_major()
131 .unwrap()
132 .as_slice()
133 .iter()
134 .zip(y.try_as_col_major().unwrap().as_slice())
135 .map(|(&x, &y)| (x + y) * (x + y))
136 .sum()
137 }
138
139 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
140 dest.try_as_col_major_mut()
141 .unwrap()
142 .as_slice_mut()
143 .copy_from_slice(source);
144 }
145
146 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
147 dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
148 }
149
150 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
151 dest.clone_from(array)
152 }
153
154 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
155 axpy_out(
156 x.try_as_col_major().unwrap().as_slice(),
157 y.try_as_col_major().unwrap().as_slice(),
158 a,
159 out.try_as_col_major_mut().unwrap().as_slice_mut(),
160 );
161 }
162
163 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
164 axpy(
165 x.try_as_col_major().unwrap().as_slice(),
166 y.try_as_col_major_mut().unwrap().as_slice_mut(),
167 a,
168 );
169 }
170
171 fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
172 faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
173 }
174
175 fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
176 let mut ok = true;
177 faer::zip!(array).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
178 ok
179 }
180
181 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
182 self.arch.dispatch(|| {
183 array
184 .try_as_col_major()
185 .unwrap()
186 .as_slice()
187 .iter()
188 .all(|&x| x.is_finite() & (x != 0f64))
189 })
190 }
191
192 fn array_mult(
193 &mut self,
194 array1: &Self::Vector,
195 array2: &Self::Vector,
196 dest: &mut Self::Vector,
197 ) {
198 multiply(
199 array1.try_as_col_major().unwrap().as_slice(),
200 array2.try_as_col_major().unwrap().as_slice(),
201 dest.try_as_col_major_mut().unwrap().as_slice_mut(),
202 )
203 }
204
205 fn array_mult_eigs(
206 &mut self,
207 stds: &Self::Vector,
208 rhs: &Self::Vector,
209 dest: &mut Self::Vector,
210 vecs: &Self::EigVectors,
211 vals: &Self::EigValues,
212 ) {
213 let rhs = stds.as_diagonal() * rhs;
214 let trafo = vecs.transpose() * (&rhs);
215 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
216 let scaled = stds.as_diagonal() * inner_prod;
217
218 let _ = replace(dest, scaled);
219 }
220
221 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
222 vector_dot(
223 array1.try_as_col_major().unwrap().as_slice(),
224 array2.try_as_col_major().unwrap().as_slice(),
225 )
226 }
227
228 fn array_gaussian<R: rand::Rng + ?Sized>(
229 &mut self,
230 rng: &mut R,
231 dest: &mut Self::Vector,
232 stds: &Self::Vector,
233 ) {
234 let dist = rand_distr::StandardNormal;
235 dest.try_as_col_major_mut()
236 .unwrap()
237 .as_slice_mut()
238 .iter_mut()
239 .zip(stds.try_as_col_major().unwrap().as_slice().iter())
240 .for_each(|(p, &s)| {
241 let norm: f64 = rng.sample(dist);
242 *p = s * norm;
243 });
244 }
245
246 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
247 &mut self,
248 rng: &mut R,
249 dest: &mut Self::Vector,
250 scale: &Self::Vector,
251 vals: &Self::EigValues,
252 vecs: &Self::EigVectors,
253 ) {
254 let mut draw: Col<f64> = Col::zeros(self.dim());
255 let dist = rand_distr::StandardNormal;
256 draw.try_as_col_major_mut()
257 .unwrap()
258 .as_slice_mut()
259 .iter_mut()
260 .for_each(|p| {
261 *p = rng.sample(dist);
262 });
263
264 let trafo = vecs.transpose() * (&draw);
265 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
266
267 let scaled = scale.as_diagonal() * inner_prod;
268
269 let _ = replace(dest, scaled);
270 }
271
272 fn array_update_variance(
273 &mut self,
274 mean: &mut Self::Vector,
275 variance: &mut Self::Vector,
276 value: &Self::Vector,
277 diff_scale: f64, ) {
279 self.arch.dispatch(|| {
280 izip!(
281 mean.try_as_col_major_mut()
282 .unwrap()
283 .as_slice_mut()
284 .iter_mut(),
285 variance
286 .try_as_col_major_mut()
287 .unwrap()
288 .as_slice_mut()
289 .iter_mut(),
290 value.try_as_col_major().unwrap().as_slice()
291 )
292 .for_each(|(mean, var, x)| {
293 let diff = x - *mean;
294 *mean += diff * diff_scale;
295 *var += diff * diff;
296 });
297 })
298 }
299
300 fn array_update_var_inv_std_draw(
301 &mut self,
302 variance_out: &mut Self::Vector,
303 inv_std: &mut Self::Vector,
304 draw_var: &Self::Vector,
305 scale: f64,
306 fill_invalid: Option<f64>,
307 clamp: (f64, f64),
308 ) {
309 self.arch.dispatch(|| {
310 izip!(
311 variance_out
312 .try_as_col_major_mut()
313 .unwrap()
314 .as_slice_mut()
315 .iter_mut(),
316 inv_std
317 .try_as_col_major_mut()
318 .unwrap()
319 .as_slice_mut()
320 .iter_mut(),
321 draw_var.try_as_col_major().unwrap().as_slice().iter(),
322 )
323 .for_each(|(var_out, inv_std_out, &draw_var)| {
324 let draw_var = draw_var * scale;
325 if (!draw_var.is_finite()) | (draw_var == 0f64) {
326 if let Some(fill_val) = fill_invalid {
327 *var_out = fill_val;
328 *inv_std_out = fill_val.recip().sqrt();
329 }
330 } else {
331 let val = draw_var.clamp(clamp.0, clamp.1);
332 *var_out = val;
333 *inv_std_out = val.recip().sqrt();
334 }
335 });
336 });
337 }
338
339 fn array_update_var_inv_std_draw_grad(
340 &mut self,
341 variance_out: &mut Self::Vector,
342 inv_std: &mut Self::Vector,
343 draw_var: &Self::Vector,
344 grad_var: &Self::Vector,
345 fill_invalid: Option<f64>,
346 clamp: (f64, f64),
347 ) {
348 self.arch.dispatch(|| {
349 izip!(
350 variance_out
351 .try_as_col_major_mut()
352 .unwrap()
353 .as_slice_mut()
354 .iter_mut(),
355 inv_std
356 .try_as_col_major_mut()
357 .unwrap()
358 .as_slice_mut()
359 .iter_mut(),
360 draw_var.try_as_col_major().unwrap().as_slice().iter(),
361 grad_var.try_as_col_major().unwrap().as_slice().iter(),
362 )
363 .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
364 let val = (draw_var / grad_var).sqrt();
365 if (!val.is_finite()) | (val == 0f64) {
366 if let Some(fill_val) = fill_invalid {
367 *var_out = fill_val;
368 *inv_std_out = fill_val.recip().sqrt();
369 }
370 } else {
371 let val = val.clamp(clamp.0, clamp.1);
372 *var_out = val;
373 *inv_std_out = val.recip().sqrt();
374 }
375 });
376 });
377 }
378
379 fn array_update_var_inv_std_grad(
380 &mut self,
381 variance_out: &mut Self::Vector,
382 inv_std: &mut Self::Vector,
383 gradient: &Self::Vector,
384 fill_invalid: f64,
385 clamp: (f64, f64),
386 ) {
387 self.arch.dispatch(|| {
388 izip!(
389 variance_out
390 .try_as_col_major_mut()
391 .unwrap()
392 .as_slice_mut()
393 .iter_mut(),
394 inv_std
395 .try_as_col_major_mut()
396 .unwrap()
397 .as_slice_mut()
398 .iter_mut(),
399 gradient.try_as_col_major().unwrap().as_slice().iter(),
400 )
401 .for_each(|(var_out, inv_std_out, &grad_var)| {
402 let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
403 let val = if val.is_finite() { val } else { fill_invalid };
404 *var_out = val;
405 *inv_std_out = val.recip().sqrt();
406 });
407 });
408 }
409
410 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
411 source
412 .try_as_col_major()
413 .unwrap()
414 .as_slice()
415 .to_vec()
416 .into()
417 }
418
419 fn inv_transform_normalize(
420 &mut self,
421 params: &Self::TransformParams,
422 untransformed_position: &Self::Vector,
423 untransofrmed_gradient: &Self::Vector,
424 transformed_position: &mut Self::Vector,
425 transformed_gradient: &mut Self::Vector,
426 ) -> Result<f64, Self::LogpErr> {
427 self.logp_func.inv_transform_normalize(
428 params,
429 untransformed_position
430 .try_as_col_major()
431 .unwrap()
432 .as_slice(),
433 untransofrmed_gradient
434 .try_as_col_major()
435 .unwrap()
436 .as_slice(),
437 transformed_position
438 .try_as_col_major_mut()
439 .unwrap()
440 .as_slice_mut(),
441 transformed_gradient
442 .try_as_col_major_mut()
443 .unwrap()
444 .as_slice_mut(),
445 )
446 }
447
448 fn init_from_untransformed_position(
449 &mut self,
450 params: &Self::TransformParams,
451 untransformed_position: &Self::Vector,
452 untransformed_gradient: &mut Self::Vector,
453 transformed_position: &mut Self::Vector,
454 transformed_gradient: &mut Self::Vector,
455 ) -> Result<(f64, f64), Self::LogpErr> {
456 self.logp_func.init_from_untransformed_position(
457 params,
458 untransformed_position
459 .try_as_col_major()
460 .unwrap()
461 .as_slice(),
462 untransformed_gradient
463 .try_as_col_major_mut()
464 .unwrap()
465 .as_slice_mut(),
466 transformed_position
467 .try_as_col_major_mut()
468 .unwrap()
469 .as_slice_mut(),
470 transformed_gradient
471 .try_as_col_major_mut()
472 .unwrap()
473 .as_slice_mut(),
474 )
475 }
476
477 fn init_from_transformed_position(
478 &mut self,
479 params: &Self::TransformParams,
480 untransformed_position: &mut Self::Vector,
481 untransformed_gradient: &mut Self::Vector,
482 transformed_position: &Self::Vector,
483 transformed_gradient: &mut Self::Vector,
484 ) -> Result<(f64, f64), Self::LogpErr> {
485 self.logp_func.init_from_transformed_position(
486 params,
487 untransformed_position
488 .try_as_col_major_mut()
489 .unwrap()
490 .as_slice_mut(),
491 untransformed_gradient
492 .try_as_col_major_mut()
493 .unwrap()
494 .as_slice_mut(),
495 transformed_position.try_as_col_major().unwrap().as_slice(),
496 transformed_gradient
497 .try_as_col_major_mut()
498 .unwrap()
499 .as_slice_mut(),
500 )
501 }
502
503 fn update_transformation<'a, R: rand::Rng + ?Sized>(
504 &'a mut self,
505 rng: &mut R,
506 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
507 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
508 untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
509 params: &'a mut Self::TransformParams,
510 ) -> Result<(), Self::LogpErr> {
511 self.logp_func.update_transformation(
512 rng,
513 untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
514 untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
515 untransformed_logp,
516 params,
517 )
518 }
519
520 fn new_transformation<R: rand::Rng + ?Sized>(
521 &mut self,
522 rng: &mut R,
523 untransformed_position: &Self::Vector,
524 untransfogmed_gradient: &Self::Vector,
525 chain: u64,
526 ) -> Result<Self::TransformParams, Self::LogpErr> {
527 self.logp_func.new_transformation(
528 rng,
529 untransformed_position
530 .try_as_col_major()
531 .unwrap()
532 .as_slice(),
533 untransfogmed_gradient
534 .try_as_col_major()
535 .unwrap()
536 .as_slice(),
537 chain,
538 )
539 }
540
541 fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
542 self.logp_func.transformation_id(params)
543 }
544}
545
546pub trait CpuLogpFunc {
547 type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
548 type TransformParams;
549
550 fn dim(&self) -> usize;
551 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
552
553 fn inv_transform_normalize(
554 &mut self,
555 _params: &Self::TransformParams,
556 _untransformed_position: &[f64],
557 _untransformed_gradient: &[f64],
558 _transformed_position: &mut [f64],
559 _transformed_gradient: &mut [f64],
560 ) -> Result<f64, Self::LogpError> {
561 unimplemented!()
562 }
563
564 fn init_from_untransformed_position(
565 &mut self,
566 _params: &Self::TransformParams,
567 _untransformed_position: &[f64],
568 _untransformed_gradient: &mut [f64],
569 _transformed_position: &mut [f64],
570 _transformed_gradient: &mut [f64],
571 ) -> Result<(f64, f64), Self::LogpError> {
572 unimplemented!()
573 }
574
575 fn init_from_transformed_position(
576 &mut self,
577 _params: &Self::TransformParams,
578 _untransformed_position: &mut [f64],
579 _untransformed_gradient: &mut [f64],
580 _transformed_position: &[f64],
581 _transformed_gradient: &mut [f64],
582 ) -> Result<(f64, f64), Self::LogpError> {
583 unimplemented!()
584 }
585
586 fn update_transformation<'a, R: rand::Rng + ?Sized>(
587 &'a mut self,
588 _rng: &mut R,
589 _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
590 _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
591 _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
592 _params: &'a mut Self::TransformParams,
593 ) -> Result<(), Self::LogpError> {
594 unimplemented!()
595 }
596
597 fn new_transformation<R: rand::Rng + ?Sized>(
598 &mut self,
599 _rng: &mut R,
600 _untransformed_position: &[f64],
601 _untransformed_gradient: &[f64],
602 _chain: u64,
603 ) -> Result<Self::TransformParams, Self::LogpError> {
604 unimplemented!()
605 }
606
607 fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
608 unimplemented!()
609 }
610}