1use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{Itertools, izip};
5use nuts_storable::{HasDims, Storable, Value};
6use rand::RngExt;
7use thiserror::Error;
8
9use crate::{
10 math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
11 math_base::{LogpError, Math},
12};
13
14#[derive(Debug)]
15pub struct CpuMath<F: CpuLogpFunc> {
16 logp_func: F,
17 arch: pulp::Arch,
18}
19
20impl<F: CpuLogpFunc> CpuMath<F> {
21 pub fn new(logp_func: F) -> Self {
22 let arch = pulp::Arch::new();
23 Self { logp_func, arch }
24 }
25}
26
27#[non_exhaustive]
28#[derive(Error, Debug)]
29pub enum CpuMathError {
30 #[error("Error during array operation")]
31 ArrayError(),
32 #[error("Error during point expansion: {0}")]
33 ExpandError(String),
34}
35
36impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
37 fn dim_sizes(&self) -> HashMap<String, u64> {
38 self.logp_func.dim_sizes()
39 }
40
41 fn coords(&self) -> HashMap<String, nuts_storable::Value> {
42 self.logp_func.coords()
43 }
44}
45
46pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
47
48impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
49 fn names(parent: &CpuMath<F>) -> Vec<&str> {
50 F::ExpandedVector::names(&parent.logp_func)
51 }
52
53 fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
54 F::ExpandedVector::item_type(&parent.logp_func, item)
55 }
56
57 fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
58 F::ExpandedVector::dims(&parent.logp_func, item)
59 }
60
61 fn get_all<'a>(
62 &'a mut self,
63 parent: &'a CpuMath<F>,
64 ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
65 self.0.get_all(&parent.logp_func)
66 }
67}
68
69impl<F: CpuLogpFunc> Math for CpuMath<F> {
70 type Vector = Col<f64>;
71 type EigVectors = Mat<f64>;
72 type EigValues = Col<f64>;
73 type LogpErr = F::LogpError;
74 type Err = CpuMathError;
75 type FlowParameters = F::FlowParameters;
76 type ExpandedVector = ExpandedVectorWrapper<F>;
77
78 fn new_array(&mut self) -> Self::Vector {
79 Col::zeros(self.dim())
80 }
81
82 fn new_eig_vectors<'a>(
83 &'a mut self,
84 vals: impl ExactSizeIterator<Item = &'a [f64]>,
85 ) -> Self::EigVectors {
86 let ndim = self.dim();
87 let nvecs = vals.len();
88
89 let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
90 vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
91 col.try_as_col_major_mut()
92 .expect("Array is not contiguous")
93 .as_slice_mut()
94 .copy_from_slice(vals)
95 });
96
97 vectors
98 }
99
100 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
101 let mut values: Col<f64> = Col::zeros(vals.len());
102 values
103 .try_as_col_major_mut()
104 .expect("Array is not contiguous")
105 .as_slice_mut()
106 .copy_from_slice(vals);
107 values
108 }
109
110 fn logp_array(
111 &mut self,
112 position: &Self::Vector,
113 gradient: &mut Self::Vector,
114 ) -> Result<f64, Self::LogpErr> {
115 self.logp_func.logp(
116 position
117 .try_as_col_major()
118 .expect("Array is not contiguous")
119 .as_slice(),
120 gradient
121 .try_as_col_major_mut()
122 .expect("Array is not contiguous")
123 .as_slice_mut(),
124 )
125 }
126
127 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
128 self.logp_func.logp(position, gradient)
129 }
130
131 fn dim(&self) -> usize {
132 self.logp_func.dim()
133 }
134
135 fn expand_vector<R: rand::Rng + ?Sized>(
136 &mut self,
137 rng: &mut R,
138 array: &Self::Vector,
139 ) -> Result<Self::ExpandedVector, Self::Err> {
140 Ok(ExpandedVectorWrapper(
141 self.logp_func.expand_vector(
142 rng,
143 array
144 .try_as_col_major()
145 .ok_or_else(|| {
146 CpuMathError::ExpandError("Internal vector was not col major".into())
147 })?
148 .as_slice(),
149 )?,
150 ))
151 }
152
153 fn vector_coord(&self) -> Option<Value> {
154 self.logp_func.vector_coord()
155 }
156
157 fn init_position<R: rand::Rng + ?Sized>(
158 &mut self,
159 rng: &mut R,
160 position: &mut Self::Vector,
161 gradient: &mut Self::Vector,
162 ) -> Result<f64, Self::LogpErr> {
163 let pos = position
164 .try_as_col_major_mut()
165 .expect("Array is not contiguous")
166 .as_slice_mut();
167
168 pos.iter_mut().for_each(|x| {
169 let val: f64 = rng.random();
170 *x = val * 2f64 - 1f64
171 });
172
173 self.logp_func.logp(
174 position
175 .try_as_col_major()
176 .expect("Array is not contiguous")
177 .as_slice(),
178 gradient
179 .try_as_col_major_mut()
180 .expect("Array is not contiguous")
181 .as_slice_mut(),
182 )
183 }
184
185 fn scalar_prods3(
186 &mut self,
187 positive1: &Self::Vector,
188 negative1: &Self::Vector,
189 positive2: &Self::Vector,
190 x: &Self::Vector,
191 y: &Self::Vector,
192 ) -> (f64, f64) {
193 scalar_prods3(
194 self.arch,
195 positive1.try_as_col_major().unwrap().as_slice(),
196 negative1.try_as_col_major().unwrap().as_slice(),
197 positive2.try_as_col_major().unwrap().as_slice(),
198 x.try_as_col_major().unwrap().as_slice(),
199 y.try_as_col_major().unwrap().as_slice(),
200 )
201 }
202
203 fn scalar_prods2(
204 &mut self,
205 positive1: &Self::Vector,
206 positive2: &Self::Vector,
207 x: &Self::Vector,
208 y: &Self::Vector,
209 ) -> (f64, f64) {
210 scalar_prods2(
211 self.arch,
212 positive1.try_as_col_major().unwrap().as_slice(),
213 positive2.try_as_col_major().unwrap().as_slice(),
214 x.try_as_col_major().unwrap().as_slice(),
215 y.try_as_col_major().unwrap().as_slice(),
216 )
217 }
218
219 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
220 x.try_as_col_major()
221 .unwrap()
222 .as_slice()
223 .iter()
224 .zip(y.try_as_col_major().unwrap().as_slice())
225 .map(|(&x, &y)| (x + y) * (x + y))
226 .sum()
227 }
228
229 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
230 dest.try_as_col_major_mut()
231 .unwrap()
232 .as_slice_mut()
233 .copy_from_slice(source);
234 }
235
236 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
237 dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
238 }
239
240 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
241 dest.clone_from(array)
242 }
243
244 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
245 axpy_out(
246 self.arch,
247 x.try_as_col_major().unwrap().as_slice(),
248 y.try_as_col_major().unwrap().as_slice(),
249 a,
250 out.try_as_col_major_mut().unwrap().as_slice_mut(),
251 );
252 }
253
254 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
255 axpy(
256 self.arch,
257 x.try_as_col_major().unwrap().as_slice(),
258 y.try_as_col_major_mut().unwrap().as_slice_mut(),
259 a,
260 );
261 }
262
263 fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
264 faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
265 }
266
267 fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
268 let mut ok = true;
269 faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
270 ok
271 }
272
273 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
274 self.arch.dispatch(|| {
275 array
276 .try_as_col_major()
277 .unwrap()
278 .as_slice()
279 .iter()
280 .all(|&x| x.is_finite() & (x != 0f64))
281 })
282 }
283
284 fn array_mult(
285 &mut self,
286 array1: &Self::Vector,
287 array2: &Self::Vector,
288 dest: &mut Self::Vector,
289 ) {
290 multiply(
291 self.arch,
292 array1.try_as_col_major().unwrap().as_slice(),
293 array2.try_as_col_major().unwrap().as_slice(),
294 dest.try_as_col_major_mut().unwrap().as_slice_mut(),
295 )
296 }
297
298 fn array_mult_eigs(
299 &mut self,
300 stds: &Self::Vector,
301 rhs: &Self::Vector,
302 dest: &mut Self::Vector,
303 vecs: &Self::EigVectors,
304 vals: &Self::EigValues,
305 ) {
306 let rhs = stds.as_diagonal() * rhs;
307 let trafo = vecs.transpose() * (&rhs);
308 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
309 let scaled = stds.as_diagonal() * inner_prod;
310
311 let _ = replace(dest, scaled);
312 }
313
314 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
315 vector_dot(
316 self.arch,
317 array1.try_as_col_major().unwrap().as_slice(),
318 array2.try_as_col_major().unwrap().as_slice(),
319 )
320 }
321
322 fn array_gaussian<R: rand::Rng + ?Sized>(
323 &mut self,
324 rng: &mut R,
325 dest: &mut Self::Vector,
326 stds: &Self::Vector,
327 ) {
328 let dist = rand_distr::StandardNormal;
329 dest.try_as_col_major_mut()
330 .unwrap()
331 .as_slice_mut()
332 .iter_mut()
333 .zip(stds.try_as_col_major().unwrap().as_slice().iter())
334 .for_each(|(p, &s)| {
335 let norm: f64 = rng.sample(dist);
336 *p = s * norm;
337 });
338 }
339
340 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
341 &mut self,
342 rng: &mut R,
343 dest: &mut Self::Vector,
344 scale: &Self::Vector,
345 vals: &Self::EigValues,
346 vecs: &Self::EigVectors,
347 ) {
348 let mut draw: Col<f64> = Col::zeros(self.dim());
349 let dist = rand_distr::StandardNormal;
350 draw.try_as_col_major_mut()
351 .unwrap()
352 .as_slice_mut()
353 .iter_mut()
354 .for_each(|p| {
355 *p = rng.sample(dist);
356 });
357
358 let trafo = vecs.transpose() * (&draw);
359 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
360
361 let scaled = scale.as_diagonal() * inner_prod;
362
363 let _ = replace(dest, scaled);
364 }
365
366 fn array_update_variance(
367 &mut self,
368 mean: &mut Self::Vector,
369 variance: &mut Self::Vector,
370 value: &Self::Vector,
371 diff_scale: f64, ) {
373 self.arch.dispatch(|| {
374 izip!(
375 mean.try_as_col_major_mut()
376 .unwrap()
377 .as_slice_mut()
378 .iter_mut(),
379 variance
380 .try_as_col_major_mut()
381 .unwrap()
382 .as_slice_mut()
383 .iter_mut(),
384 value.try_as_col_major().unwrap().as_slice()
385 )
386 .for_each(|(mean, var, x)| {
387 let diff = x - *mean;
388 *mean += diff * diff_scale;
389 *var += diff * diff;
390 });
391 })
392 }
393
394 fn array_update_var_inv_std_draw(
395 &mut self,
396 variance_out: &mut Self::Vector,
397 inv_std: &mut Self::Vector,
398 draw_var: &Self::Vector,
399 scale: f64,
400 fill_invalid: Option<f64>,
401 clamp: (f64, f64),
402 ) {
403 self.arch.dispatch(|| {
404 izip!(
405 variance_out
406 .try_as_col_major_mut()
407 .unwrap()
408 .as_slice_mut()
409 .iter_mut(),
410 inv_std
411 .try_as_col_major_mut()
412 .unwrap()
413 .as_slice_mut()
414 .iter_mut(),
415 draw_var.try_as_col_major().unwrap().as_slice().iter(),
416 )
417 .for_each(|(var_out, inv_std_out, &draw_var)| {
418 let draw_var = draw_var * scale;
419 if (!draw_var.is_finite()) | (draw_var == 0f64) {
420 if let Some(fill_val) = fill_invalid {
421 *var_out = fill_val;
422 *inv_std_out = fill_val.recip().sqrt();
423 }
424 } else {
425 let val = draw_var.clamp(clamp.0, clamp.1);
426 *var_out = val;
427 *inv_std_out = val.recip().sqrt();
428 }
429 });
430 });
431 }
432
433 fn array_update_var_inv_std_draw_grad(
434 &mut self,
435 variance_out: &mut Self::Vector,
436 inv_std: &mut Self::Vector,
437 draw_var: &Self::Vector,
438 grad_var: &Self::Vector,
439 fill_invalid: Option<f64>,
440 clamp: (f64, f64),
441 ) {
442 self.arch.dispatch(|| {
443 izip!(
444 variance_out
445 .try_as_col_major_mut()
446 .unwrap()
447 .as_slice_mut()
448 .iter_mut(),
449 inv_std
450 .try_as_col_major_mut()
451 .unwrap()
452 .as_slice_mut()
453 .iter_mut(),
454 draw_var.try_as_col_major().unwrap().as_slice().iter(),
455 grad_var.try_as_col_major().unwrap().as_slice().iter(),
456 )
457 .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
458 let val = (draw_var / grad_var).sqrt();
459 if (!val.is_finite()) | (val == 0f64) {
460 if let Some(fill_val) = fill_invalid {
461 *var_out = fill_val;
462 *inv_std_out = fill_val.recip().sqrt();
463 }
464 } else {
465 let val = val.clamp(clamp.0, clamp.1);
466 *var_out = val;
467 *inv_std_out = val.recip().sqrt();
468 }
469 });
470 });
471 }
472
473 fn array_update_var_inv_std_grad(
474 &mut self,
475 variance_out: &mut Self::Vector,
476 inv_std: &mut Self::Vector,
477 gradient: &Self::Vector,
478 fill_invalid: f64,
479 clamp: (f64, f64),
480 ) {
481 self.arch.dispatch(|| {
482 izip!(
483 variance_out
484 .try_as_col_major_mut()
485 .unwrap()
486 .as_slice_mut()
487 .iter_mut(),
488 inv_std
489 .try_as_col_major_mut()
490 .unwrap()
491 .as_slice_mut()
492 .iter_mut(),
493 gradient.try_as_col_major().unwrap().as_slice().iter(),
494 )
495 .for_each(|(var_out, inv_std_out, &grad_var)| {
496 let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
497 let val = if val.is_finite() { val } else { fill_invalid };
498 *var_out = val;
499 *inv_std_out = val.recip().sqrt();
500 });
501 });
502 }
503
504 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
505 source
506 .try_as_col_major()
507 .unwrap()
508 .as_slice()
509 .to_vec()
510 .into()
511 }
512
513 fn inv_transform_normalize(
514 &mut self,
515 params: &Self::FlowParameters,
516 untransformed_position: &Self::Vector,
517 untransofrmed_gradient: &Self::Vector,
518 transformed_position: &mut Self::Vector,
519 transformed_gradient: &mut Self::Vector,
520 ) -> Result<f64, Self::LogpErr> {
521 self.logp_func.inv_transform_normalize(
522 params,
523 untransformed_position
524 .try_as_col_major()
525 .unwrap()
526 .as_slice(),
527 untransofrmed_gradient
528 .try_as_col_major()
529 .unwrap()
530 .as_slice(),
531 transformed_position
532 .try_as_col_major_mut()
533 .unwrap()
534 .as_slice_mut(),
535 transformed_gradient
536 .try_as_col_major_mut()
537 .unwrap()
538 .as_slice_mut(),
539 )
540 }
541
542 fn init_from_untransformed_position(
543 &mut self,
544 params: &Self::FlowParameters,
545 untransformed_position: &Self::Vector,
546 untransformed_gradient: &mut Self::Vector,
547 transformed_position: &mut Self::Vector,
548 transformed_gradient: &mut Self::Vector,
549 ) -> Result<(f64, f64), Self::LogpErr> {
550 self.logp_func.init_from_untransformed_position(
551 params,
552 untransformed_position
553 .try_as_col_major()
554 .unwrap()
555 .as_slice(),
556 untransformed_gradient
557 .try_as_col_major_mut()
558 .unwrap()
559 .as_slice_mut(),
560 transformed_position
561 .try_as_col_major_mut()
562 .unwrap()
563 .as_slice_mut(),
564 transformed_gradient
565 .try_as_col_major_mut()
566 .unwrap()
567 .as_slice_mut(),
568 )
569 }
570
571 fn init_from_transformed_position(
572 &mut self,
573 params: &Self::FlowParameters,
574 untransformed_position: &mut Self::Vector,
575 untransformed_gradient: &mut Self::Vector,
576 transformed_position: &Self::Vector,
577 transformed_gradient: &mut Self::Vector,
578 ) -> Result<(f64, f64), Self::LogpErr> {
579 self.logp_func.init_from_transformed_position(
580 params,
581 untransformed_position
582 .try_as_col_major_mut()
583 .unwrap()
584 .as_slice_mut(),
585 untransformed_gradient
586 .try_as_col_major_mut()
587 .unwrap()
588 .as_slice_mut(),
589 transformed_position.try_as_col_major().unwrap().as_slice(),
590 transformed_gradient
591 .try_as_col_major_mut()
592 .unwrap()
593 .as_slice_mut(),
594 )
595 }
596
597 fn update_transformation<'a, R: rand::Rng + ?Sized>(
598 &'a mut self,
599 rng: &mut R,
600 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
601 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
602 untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
603 params: &'a mut Self::FlowParameters,
604 ) -> Result<(), Self::LogpErr> {
605 self.logp_func.update_transformation(
606 rng,
607 untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
608 untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
609 untransformed_logp,
610 params,
611 )
612 }
613
614 fn new_transformation<R: rand::Rng + ?Sized>(
615 &mut self,
616 rng: &mut R,
617 untransformed_position: &Self::Vector,
618 untransfogmed_gradient: &Self::Vector,
619 chain: u64,
620 ) -> Result<Self::FlowParameters, Self::LogpErr> {
621 self.logp_func.new_transformation(
622 rng,
623 untransformed_position
624 .try_as_col_major()
625 .unwrap()
626 .as_slice(),
627 untransfogmed_gradient
628 .try_as_col_major()
629 .unwrap()
630 .as_slice(),
631 chain,
632 )
633 }
634
635 fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
636 self.logp_func.transformation_id(params)
637 }
638}
639
640pub trait CpuLogpFunc: HasDims {
641 type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
642 type FlowParameters;
643 type ExpandedVector: Storable<Self>;
644
645 fn dim(&self) -> usize;
646 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
647 fn expand_vector<R>(
648 &mut self,
649 rng: &mut R,
650 array: &[f64],
651 ) -> Result<Self::ExpandedVector, CpuMathError>
652 where
653 R: rand::Rng + ?Sized;
654
655 fn vector_coord(&self) -> Option<Value> {
656 None
657 }
658
659 fn inv_transform_normalize(
660 &mut self,
661 _params: &Self::FlowParameters,
662 _untransformed_position: &[f64],
663 _untransformed_gradient: &[f64],
664 _transformed_position: &mut [f64],
665 _transformed_gradient: &mut [f64],
666 ) -> Result<f64, Self::LogpError> {
667 unimplemented!()
668 }
669
670 fn init_from_untransformed_position(
671 &mut self,
672 _params: &Self::FlowParameters,
673 _untransformed_position: &[f64],
674 _untransformed_gradient: &mut [f64],
675 _transformed_position: &mut [f64],
676 _transformed_gradient: &mut [f64],
677 ) -> Result<(f64, f64), Self::LogpError> {
678 unimplemented!()
679 }
680
681 fn init_from_transformed_position(
682 &mut self,
683 _params: &Self::FlowParameters,
684 _untransformed_position: &mut [f64],
685 _untransformed_gradient: &mut [f64],
686 _transformed_position: &[f64],
687 _transformed_gradient: &mut [f64],
688 ) -> Result<(f64, f64), Self::LogpError> {
689 unimplemented!()
690 }
691
692 fn update_transformation<'a, R: rand::Rng + ?Sized>(
693 &'a mut self,
694 _rng: &mut R,
695 _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
696 _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
697 _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
698 _params: &'a mut Self::FlowParameters,
699 ) -> Result<(), Self::LogpError> {
700 unimplemented!()
701 }
702
703 fn new_transformation<R: rand::Rng + ?Sized>(
704 &mut self,
705 _rng: &mut R,
706 _untransformed_position: &[f64],
707 _untransformed_gradient: &[f64],
708 _chain: u64,
709 ) -> Result<Self::FlowParameters, Self::LogpError> {
710 unimplemented!()
711 }
712
713 fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
714 unimplemented!()
715 }
716}
717
718impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
719 fn clone(&self) -> Self {
720 Self {
721 logp_func: self.logp_func.clone(),
722 arch: self.arch,
723 }
724 }
725}