1use std::{error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{izip, Itertools};
5use thiserror::Error;
6
7use crate::{
8 math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
9 math_base::{LogpError, Math},
10};
11
12#[derive(Debug)]
13pub struct CpuMath<F: CpuLogpFunc> {
14 logp_func: F,
15 arch: pulp::Arch,
16}
17
18impl<F: CpuLogpFunc> CpuMath<F> {
19 pub fn new(logp_func: F) -> Self {
20 let arch = pulp::Arch::new();
21 Self { logp_func, arch }
22 }
23}
24
25#[non_exhaustive]
26#[derive(Error, Debug)]
27pub enum CpuMathError {
28 #[error("Error during array operation")]
29 ArrayError(),
30}
31
32impl<F: CpuLogpFunc> Math for CpuMath<F> {
33 type Vector = Col<f64>;
34 type EigVectors = Mat<f64>;
35 type EigValues = Col<f64>;
36 type LogpErr = F::LogpError;
37 type Err = CpuMathError;
38 type TransformParams = F::TransformParams;
39
40 fn new_array(&mut self) -> Self::Vector {
41 Col::zeros(self.dim())
42 }
43
44 fn new_eig_vectors<'a>(
45 &'a mut self,
46 vals: impl ExactSizeIterator<Item = &'a [f64]>,
47 ) -> Self::EigVectors {
48 let ndim = self.dim();
49 let nvecs = vals.len();
50
51 let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
52 vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
53 col.try_as_col_major_mut()
54 .expect("Array is not contiguous")
55 .as_slice_mut()
56 .copy_from_slice(vals)
57 });
58
59 vectors
60 }
61
62 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
63 let mut values: Col<f64> = Col::zeros(vals.len());
64 values
65 .try_as_col_major_mut()
66 .expect("Array is not contiguous")
67 .as_slice_mut()
68 .copy_from_slice(vals);
69 values
70 }
71
72 fn logp_array(
73 &mut self,
74 position: &Self::Vector,
75 gradient: &mut Self::Vector,
76 ) -> Result<f64, Self::LogpErr> {
77 self.logp_func.logp(
78 position
79 .try_as_col_major()
80 .expect("Array is not contiguous")
81 .as_slice(),
82 gradient
83 .try_as_col_major_mut()
84 .expect("Array is not contiguous")
85 .as_slice_mut(),
86 )
87 }
88
89 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
90 self.logp_func.logp(position, gradient)
91 }
92
93 fn dim(&self) -> usize {
94 self.logp_func.dim()
95 }
96
97 fn scalar_prods3(
98 &mut self,
99 positive1: &Self::Vector,
100 negative1: &Self::Vector,
101 positive2: &Self::Vector,
102 x: &Self::Vector,
103 y: &Self::Vector,
104 ) -> (f64, f64) {
105 scalar_prods3(
106 self.arch,
107 positive1.try_as_col_major().unwrap().as_slice(),
108 negative1.try_as_col_major().unwrap().as_slice(),
109 positive2.try_as_col_major().unwrap().as_slice(),
110 x.try_as_col_major().unwrap().as_slice(),
111 y.try_as_col_major().unwrap().as_slice(),
112 )
113 }
114
115 fn scalar_prods2(
116 &mut self,
117 positive1: &Self::Vector,
118 positive2: &Self::Vector,
119 x: &Self::Vector,
120 y: &Self::Vector,
121 ) -> (f64, f64) {
122 scalar_prods2(
123 self.arch,
124 positive1.try_as_col_major().unwrap().as_slice(),
125 positive2.try_as_col_major().unwrap().as_slice(),
126 x.try_as_col_major().unwrap().as_slice(),
127 y.try_as_col_major().unwrap().as_slice(),
128 )
129 }
130
131 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
132 x.try_as_col_major()
133 .unwrap()
134 .as_slice()
135 .iter()
136 .zip(y.try_as_col_major().unwrap().as_slice())
137 .map(|(&x, &y)| (x + y) * (x + y))
138 .sum()
139 }
140
141 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
142 dest.try_as_col_major_mut()
143 .unwrap()
144 .as_slice_mut()
145 .copy_from_slice(source);
146 }
147
148 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
149 dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
150 }
151
152 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
153 dest.clone_from(array)
154 }
155
156 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
157 axpy_out(
158 self.arch,
159 x.try_as_col_major().unwrap().as_slice(),
160 y.try_as_col_major().unwrap().as_slice(),
161 a,
162 out.try_as_col_major_mut().unwrap().as_slice_mut(),
163 );
164 }
165
166 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
167 axpy(
168 self.arch,
169 x.try_as_col_major().unwrap().as_slice(),
170 y.try_as_col_major_mut().unwrap().as_slice_mut(),
171 a,
172 );
173 }
174
175 fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
176 faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
177 }
178
179 fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
180 let mut ok = true;
181 faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
182 ok
183 }
184
185 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
186 self.arch.dispatch(|| {
187 array
188 .try_as_col_major()
189 .unwrap()
190 .as_slice()
191 .iter()
192 .all(|&x| x.is_finite() & (x != 0f64))
193 })
194 }
195
196 fn array_mult(
197 &mut self,
198 array1: &Self::Vector,
199 array2: &Self::Vector,
200 dest: &mut Self::Vector,
201 ) {
202 multiply(
203 self.arch,
204 array1.try_as_col_major().unwrap().as_slice(),
205 array2.try_as_col_major().unwrap().as_slice(),
206 dest.try_as_col_major_mut().unwrap().as_slice_mut(),
207 )
208 }
209
210 fn array_mult_eigs(
211 &mut self,
212 stds: &Self::Vector,
213 rhs: &Self::Vector,
214 dest: &mut Self::Vector,
215 vecs: &Self::EigVectors,
216 vals: &Self::EigValues,
217 ) {
218 let rhs = stds.as_diagonal() * rhs;
219 let trafo = vecs.transpose() * (&rhs);
220 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
221 let scaled = stds.as_diagonal() * inner_prod;
222
223 let _ = replace(dest, scaled);
224 }
225
226 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
227 vector_dot(
228 self.arch,
229 array1.try_as_col_major().unwrap().as_slice(),
230 array2.try_as_col_major().unwrap().as_slice(),
231 )
232 }
233
234 fn array_gaussian<R: rand::Rng + ?Sized>(
235 &mut self,
236 rng: &mut R,
237 dest: &mut Self::Vector,
238 stds: &Self::Vector,
239 ) {
240 let dist = rand_distr::StandardNormal;
241 dest.try_as_col_major_mut()
242 .unwrap()
243 .as_slice_mut()
244 .iter_mut()
245 .zip(stds.try_as_col_major().unwrap().as_slice().iter())
246 .for_each(|(p, &s)| {
247 let norm: f64 = rng.sample(dist);
248 *p = s * norm;
249 });
250 }
251
252 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
253 &mut self,
254 rng: &mut R,
255 dest: &mut Self::Vector,
256 scale: &Self::Vector,
257 vals: &Self::EigValues,
258 vecs: &Self::EigVectors,
259 ) {
260 let mut draw: Col<f64> = Col::zeros(self.dim());
261 let dist = rand_distr::StandardNormal;
262 draw.try_as_col_major_mut()
263 .unwrap()
264 .as_slice_mut()
265 .iter_mut()
266 .for_each(|p| {
267 *p = rng.sample(dist);
268 });
269
270 let trafo = vecs.transpose() * (&draw);
271 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
272
273 let scaled = scale.as_diagonal() * inner_prod;
274
275 let _ = replace(dest, scaled);
276 }
277
278 fn array_update_variance(
279 &mut self,
280 mean: &mut Self::Vector,
281 variance: &mut Self::Vector,
282 value: &Self::Vector,
283 diff_scale: f64, ) {
285 self.arch.dispatch(|| {
286 izip!(
287 mean.try_as_col_major_mut()
288 .unwrap()
289 .as_slice_mut()
290 .iter_mut(),
291 variance
292 .try_as_col_major_mut()
293 .unwrap()
294 .as_slice_mut()
295 .iter_mut(),
296 value.try_as_col_major().unwrap().as_slice()
297 )
298 .for_each(|(mean, var, x)| {
299 let diff = x - *mean;
300 *mean += diff * diff_scale;
301 *var += diff * diff;
302 });
303 })
304 }
305
306 fn array_update_var_inv_std_draw(
307 &mut self,
308 variance_out: &mut Self::Vector,
309 inv_std: &mut Self::Vector,
310 draw_var: &Self::Vector,
311 scale: f64,
312 fill_invalid: Option<f64>,
313 clamp: (f64, f64),
314 ) {
315 self.arch.dispatch(|| {
316 izip!(
317 variance_out
318 .try_as_col_major_mut()
319 .unwrap()
320 .as_slice_mut()
321 .iter_mut(),
322 inv_std
323 .try_as_col_major_mut()
324 .unwrap()
325 .as_slice_mut()
326 .iter_mut(),
327 draw_var.try_as_col_major().unwrap().as_slice().iter(),
328 )
329 .for_each(|(var_out, inv_std_out, &draw_var)| {
330 let draw_var = draw_var * scale;
331 if (!draw_var.is_finite()) | (draw_var == 0f64) {
332 if let Some(fill_val) = fill_invalid {
333 *var_out = fill_val;
334 *inv_std_out = fill_val.recip().sqrt();
335 }
336 } else {
337 let val = draw_var.clamp(clamp.0, clamp.1);
338 *var_out = val;
339 *inv_std_out = val.recip().sqrt();
340 }
341 });
342 });
343 }
344
345 fn array_update_var_inv_std_draw_grad(
346 &mut self,
347 variance_out: &mut Self::Vector,
348 inv_std: &mut Self::Vector,
349 draw_var: &Self::Vector,
350 grad_var: &Self::Vector,
351 fill_invalid: Option<f64>,
352 clamp: (f64, f64),
353 ) {
354 self.arch.dispatch(|| {
355 izip!(
356 variance_out
357 .try_as_col_major_mut()
358 .unwrap()
359 .as_slice_mut()
360 .iter_mut(),
361 inv_std
362 .try_as_col_major_mut()
363 .unwrap()
364 .as_slice_mut()
365 .iter_mut(),
366 draw_var.try_as_col_major().unwrap().as_slice().iter(),
367 grad_var.try_as_col_major().unwrap().as_slice().iter(),
368 )
369 .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
370 let val = (draw_var / grad_var).sqrt();
371 if (!val.is_finite()) | (val == 0f64) {
372 if let Some(fill_val) = fill_invalid {
373 *var_out = fill_val;
374 *inv_std_out = fill_val.recip().sqrt();
375 }
376 } else {
377 let val = val.clamp(clamp.0, clamp.1);
378 *var_out = val;
379 *inv_std_out = val.recip().sqrt();
380 }
381 });
382 });
383 }
384
385 fn array_update_var_inv_std_grad(
386 &mut self,
387 variance_out: &mut Self::Vector,
388 inv_std: &mut Self::Vector,
389 gradient: &Self::Vector,
390 fill_invalid: f64,
391 clamp: (f64, f64),
392 ) {
393 self.arch.dispatch(|| {
394 izip!(
395 variance_out
396 .try_as_col_major_mut()
397 .unwrap()
398 .as_slice_mut()
399 .iter_mut(),
400 inv_std
401 .try_as_col_major_mut()
402 .unwrap()
403 .as_slice_mut()
404 .iter_mut(),
405 gradient.try_as_col_major().unwrap().as_slice().iter(),
406 )
407 .for_each(|(var_out, inv_std_out, &grad_var)| {
408 let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
409 let val = if val.is_finite() { val } else { fill_invalid };
410 *var_out = val;
411 *inv_std_out = val.recip().sqrt();
412 });
413 });
414 }
415
416 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
417 source
418 .try_as_col_major()
419 .unwrap()
420 .as_slice()
421 .to_vec()
422 .into()
423 }
424
425 fn inv_transform_normalize(
426 &mut self,
427 params: &Self::TransformParams,
428 untransformed_position: &Self::Vector,
429 untransofrmed_gradient: &Self::Vector,
430 transformed_position: &mut Self::Vector,
431 transformed_gradient: &mut Self::Vector,
432 ) -> Result<f64, Self::LogpErr> {
433 self.logp_func.inv_transform_normalize(
434 params,
435 untransformed_position
436 .try_as_col_major()
437 .unwrap()
438 .as_slice(),
439 untransofrmed_gradient
440 .try_as_col_major()
441 .unwrap()
442 .as_slice(),
443 transformed_position
444 .try_as_col_major_mut()
445 .unwrap()
446 .as_slice_mut(),
447 transformed_gradient
448 .try_as_col_major_mut()
449 .unwrap()
450 .as_slice_mut(),
451 )
452 }
453
454 fn init_from_untransformed_position(
455 &mut self,
456 params: &Self::TransformParams,
457 untransformed_position: &Self::Vector,
458 untransformed_gradient: &mut Self::Vector,
459 transformed_position: &mut Self::Vector,
460 transformed_gradient: &mut Self::Vector,
461 ) -> Result<(f64, f64), Self::LogpErr> {
462 self.logp_func.init_from_untransformed_position(
463 params,
464 untransformed_position
465 .try_as_col_major()
466 .unwrap()
467 .as_slice(),
468 untransformed_gradient
469 .try_as_col_major_mut()
470 .unwrap()
471 .as_slice_mut(),
472 transformed_position
473 .try_as_col_major_mut()
474 .unwrap()
475 .as_slice_mut(),
476 transformed_gradient
477 .try_as_col_major_mut()
478 .unwrap()
479 .as_slice_mut(),
480 )
481 }
482
483 fn init_from_transformed_position(
484 &mut self,
485 params: &Self::TransformParams,
486 untransformed_position: &mut Self::Vector,
487 untransformed_gradient: &mut Self::Vector,
488 transformed_position: &Self::Vector,
489 transformed_gradient: &mut Self::Vector,
490 ) -> Result<(f64, f64), Self::LogpErr> {
491 self.logp_func.init_from_transformed_position(
492 params,
493 untransformed_position
494 .try_as_col_major_mut()
495 .unwrap()
496 .as_slice_mut(),
497 untransformed_gradient
498 .try_as_col_major_mut()
499 .unwrap()
500 .as_slice_mut(),
501 transformed_position.try_as_col_major().unwrap().as_slice(),
502 transformed_gradient
503 .try_as_col_major_mut()
504 .unwrap()
505 .as_slice_mut(),
506 )
507 }
508
509 fn update_transformation<'a, R: rand::Rng + ?Sized>(
510 &'a mut self,
511 rng: &mut R,
512 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
513 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
514 untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
515 params: &'a mut Self::TransformParams,
516 ) -> Result<(), Self::LogpErr> {
517 self.logp_func.update_transformation(
518 rng,
519 untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
520 untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
521 untransformed_logp,
522 params,
523 )
524 }
525
526 fn new_transformation<R: rand::Rng + ?Sized>(
527 &mut self,
528 rng: &mut R,
529 untransformed_position: &Self::Vector,
530 untransfogmed_gradient: &Self::Vector,
531 chain: u64,
532 ) -> Result<Self::TransformParams, Self::LogpErr> {
533 self.logp_func.new_transformation(
534 rng,
535 untransformed_position
536 .try_as_col_major()
537 .unwrap()
538 .as_slice(),
539 untransfogmed_gradient
540 .try_as_col_major()
541 .unwrap()
542 .as_slice(),
543 chain,
544 )
545 }
546
547 fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
548 self.logp_func.transformation_id(params)
549 }
550}
551
552pub trait CpuLogpFunc {
553 type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
554 type TransformParams;
555
556 fn dim(&self) -> usize;
557 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
558
559 fn inv_transform_normalize(
560 &mut self,
561 _params: &Self::TransformParams,
562 _untransformed_position: &[f64],
563 _untransformed_gradient: &[f64],
564 _transformed_position: &mut [f64],
565 _transformed_gradient: &mut [f64],
566 ) -> Result<f64, Self::LogpError> {
567 unimplemented!()
568 }
569
570 fn init_from_untransformed_position(
571 &mut self,
572 _params: &Self::TransformParams,
573 _untransformed_position: &[f64],
574 _untransformed_gradient: &mut [f64],
575 _transformed_position: &mut [f64],
576 _transformed_gradient: &mut [f64],
577 ) -> Result<(f64, f64), Self::LogpError> {
578 unimplemented!()
579 }
580
581 fn init_from_transformed_position(
582 &mut self,
583 _params: &Self::TransformParams,
584 _untransformed_position: &mut [f64],
585 _untransformed_gradient: &mut [f64],
586 _transformed_position: &[f64],
587 _transformed_gradient: &mut [f64],
588 ) -> Result<(f64, f64), Self::LogpError> {
589 unimplemented!()
590 }
591
592 fn update_transformation<'a, R: rand::Rng + ?Sized>(
593 &'a mut self,
594 _rng: &mut R,
595 _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
596 _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
597 _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
598 _params: &'a mut Self::TransformParams,
599 ) -> Result<(), Self::LogpError> {
600 unimplemented!()
601 }
602
603 fn new_transformation<R: rand::Rng + ?Sized>(
604 &mut self,
605 _rng: &mut R,
606 _untransformed_position: &[f64],
607 _untransformed_gradient: &[f64],
608 _chain: u64,
609 ) -> Result<Self::TransformParams, Self::LogpError> {
610 unimplemented!()
611 }
612
613 fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
614 unimplemented!()
615 }
616}