1use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{Itertools, izip};
5use nuts_storable::{HasDims, Storable, Value};
6use thiserror::Error;
7
8use crate::{
9 math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
10 math_base::{LogpError, Math},
11};
12
13#[derive(Debug)]
14pub struct CpuMath<F: CpuLogpFunc> {
15 logp_func: F,
16 arch: pulp::Arch,
17}
18
19impl<F: CpuLogpFunc> CpuMath<F> {
20 pub fn new(logp_func: F) -> Self {
21 let arch = pulp::Arch::new();
22 Self { logp_func, arch }
23 }
24}
25
26#[non_exhaustive]
27#[derive(Error, Debug)]
28pub enum CpuMathError {
29 #[error("Error during array operation")]
30 ArrayError(),
31 #[error("Error during point expansion: {0}")]
32 ExpandError(String),
33}
34
35impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
36 fn dim_sizes(&self) -> HashMap<String, u64> {
37 self.logp_func.dim_sizes()
38 }
39
40 fn coords(&self) -> HashMap<String, nuts_storable::Value> {
41 self.logp_func.coords()
42 }
43}
44
45pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
46
47impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
48 fn names(parent: &CpuMath<F>) -> Vec<&str> {
49 F::ExpandedVector::names(&parent.logp_func)
50 }
51
52 fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
53 F::ExpandedVector::item_type(&parent.logp_func, item)
54 }
55
56 fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
57 F::ExpandedVector::dims(&parent.logp_func, item)
58 }
59
60 fn get_all<'a>(
61 &'a mut self,
62 parent: &'a CpuMath<F>,
63 ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
64 self.0.get_all(&parent.logp_func)
65 }
66}
67
68impl<F: CpuLogpFunc> Math for CpuMath<F> {
69 type Vector = Col<f64>;
70 type EigVectors = Mat<f64>;
71 type EigValues = Col<f64>;
72 type LogpErr = F::LogpError;
73 type Err = CpuMathError;
74 type FlowParameters = F::FlowParameters;
75 type ExpandedVector = ExpandedVectorWrapper<F>;
76
77 fn new_array(&mut self) -> Self::Vector {
78 Col::zeros(self.dim())
79 }
80
81 fn new_eig_vectors<'a>(
82 &'a mut self,
83 vals: impl ExactSizeIterator<Item = &'a [f64]>,
84 ) -> Self::EigVectors {
85 let ndim = self.dim();
86 let nvecs = vals.len();
87
88 let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
89 vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
90 col.try_as_col_major_mut()
91 .expect("Array is not contiguous")
92 .as_slice_mut()
93 .copy_from_slice(vals)
94 });
95
96 vectors
97 }
98
99 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
100 let mut values: Col<f64> = Col::zeros(vals.len());
101 values
102 .try_as_col_major_mut()
103 .expect("Array is not contiguous")
104 .as_slice_mut()
105 .copy_from_slice(vals);
106 values
107 }
108
109 fn logp_array(
110 &mut self,
111 position: &Self::Vector,
112 gradient: &mut Self::Vector,
113 ) -> Result<f64, Self::LogpErr> {
114 self.logp_func.logp(
115 position
116 .try_as_col_major()
117 .expect("Array is not contiguous")
118 .as_slice(),
119 gradient
120 .try_as_col_major_mut()
121 .expect("Array is not contiguous")
122 .as_slice_mut(),
123 )
124 }
125
126 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
127 self.logp_func.logp(position, gradient)
128 }
129
130 fn dim(&self) -> usize {
131 self.logp_func.dim()
132 }
133
134 fn expand_vector<R: rand::Rng + ?Sized>(
135 &mut self,
136 rng: &mut R,
137 array: &Self::Vector,
138 ) -> Result<Self::ExpandedVector, Self::Err> {
139 Ok(ExpandedVectorWrapper(
140 self.logp_func.expand_vector(
141 rng,
142 array
143 .try_as_col_major()
144 .ok_or_else(|| {
145 CpuMathError::ExpandError("Internal vector was not col major".into())
146 })?
147 .as_slice(),
148 )?,
149 ))
150 }
151
152 fn vector_coord(&self) -> Option<Value> {
153 self.logp_func.vector_coord()
154 }
155
156 fn init_position<R: rand::Rng + ?Sized>(
157 &mut self,
158 rng: &mut R,
159 position: &mut Self::Vector,
160 gradient: &mut Self::Vector,
161 ) -> Result<f64, Self::LogpErr> {
162 let pos = position
163 .try_as_col_major_mut()
164 .expect("Array is not contiguous")
165 .as_slice_mut();
166
167 pos.iter_mut().for_each(|x| {
168 let val: f64 = rng.random();
169 *x = val * 2f64 - 1f64
170 });
171
172 self.logp_func.logp(
173 position
174 .try_as_col_major()
175 .expect("Array is not contiguous")
176 .as_slice(),
177 gradient
178 .try_as_col_major_mut()
179 .expect("Array is not contiguous")
180 .as_slice_mut(),
181 )
182 }
183
184 fn scalar_prods3(
185 &mut self,
186 positive1: &Self::Vector,
187 negative1: &Self::Vector,
188 positive2: &Self::Vector,
189 x: &Self::Vector,
190 y: &Self::Vector,
191 ) -> (f64, f64) {
192 scalar_prods3(
193 self.arch,
194 positive1.try_as_col_major().unwrap().as_slice(),
195 negative1.try_as_col_major().unwrap().as_slice(),
196 positive2.try_as_col_major().unwrap().as_slice(),
197 x.try_as_col_major().unwrap().as_slice(),
198 y.try_as_col_major().unwrap().as_slice(),
199 )
200 }
201
202 fn scalar_prods2(
203 &mut self,
204 positive1: &Self::Vector,
205 positive2: &Self::Vector,
206 x: &Self::Vector,
207 y: &Self::Vector,
208 ) -> (f64, f64) {
209 scalar_prods2(
210 self.arch,
211 positive1.try_as_col_major().unwrap().as_slice(),
212 positive2.try_as_col_major().unwrap().as_slice(),
213 x.try_as_col_major().unwrap().as_slice(),
214 y.try_as_col_major().unwrap().as_slice(),
215 )
216 }
217
218 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
219 x.try_as_col_major()
220 .unwrap()
221 .as_slice()
222 .iter()
223 .zip(y.try_as_col_major().unwrap().as_slice())
224 .map(|(&x, &y)| (x + y) * (x + y))
225 .sum()
226 }
227
228 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
229 dest.try_as_col_major_mut()
230 .unwrap()
231 .as_slice_mut()
232 .copy_from_slice(source);
233 }
234
235 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
236 dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
237 }
238
239 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
240 dest.clone_from(array)
241 }
242
243 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
244 axpy_out(
245 self.arch,
246 x.try_as_col_major().unwrap().as_slice(),
247 y.try_as_col_major().unwrap().as_slice(),
248 a,
249 out.try_as_col_major_mut().unwrap().as_slice_mut(),
250 );
251 }
252
253 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
254 axpy(
255 self.arch,
256 x.try_as_col_major().unwrap().as_slice(),
257 y.try_as_col_major_mut().unwrap().as_slice_mut(),
258 a,
259 );
260 }
261
262 fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
263 faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
264 }
265
266 fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
267 let mut ok = true;
268 faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
269 ok
270 }
271
272 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
273 self.arch.dispatch(|| {
274 array
275 .try_as_col_major()
276 .unwrap()
277 .as_slice()
278 .iter()
279 .all(|&x| x.is_finite() & (x != 0f64))
280 })
281 }
282
283 fn array_mult(
284 &mut self,
285 array1: &Self::Vector,
286 array2: &Self::Vector,
287 dest: &mut Self::Vector,
288 ) {
289 multiply(
290 self.arch,
291 array1.try_as_col_major().unwrap().as_slice(),
292 array2.try_as_col_major().unwrap().as_slice(),
293 dest.try_as_col_major_mut().unwrap().as_slice_mut(),
294 )
295 }
296
297 fn array_mult_eigs(
298 &mut self,
299 stds: &Self::Vector,
300 rhs: &Self::Vector,
301 dest: &mut Self::Vector,
302 vecs: &Self::EigVectors,
303 vals: &Self::EigValues,
304 ) {
305 let rhs = stds.as_diagonal() * rhs;
306 let trafo = vecs.transpose() * (&rhs);
307 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
308 let scaled = stds.as_diagonal() * inner_prod;
309
310 let _ = replace(dest, scaled);
311 }
312
313 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
314 vector_dot(
315 self.arch,
316 array1.try_as_col_major().unwrap().as_slice(),
317 array2.try_as_col_major().unwrap().as_slice(),
318 )
319 }
320
321 fn array_gaussian<R: rand::Rng + ?Sized>(
322 &mut self,
323 rng: &mut R,
324 dest: &mut Self::Vector,
325 stds: &Self::Vector,
326 ) {
327 let dist = rand_distr::StandardNormal;
328 dest.try_as_col_major_mut()
329 .unwrap()
330 .as_slice_mut()
331 .iter_mut()
332 .zip(stds.try_as_col_major().unwrap().as_slice().iter())
333 .for_each(|(p, &s)| {
334 let norm: f64 = rng.sample(dist);
335 *p = s * norm;
336 });
337 }
338
339 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
340 &mut self,
341 rng: &mut R,
342 dest: &mut Self::Vector,
343 scale: &Self::Vector,
344 vals: &Self::EigValues,
345 vecs: &Self::EigVectors,
346 ) {
347 let mut draw: Col<f64> = Col::zeros(self.dim());
348 let dist = rand_distr::StandardNormal;
349 draw.try_as_col_major_mut()
350 .unwrap()
351 .as_slice_mut()
352 .iter_mut()
353 .for_each(|p| {
354 *p = rng.sample(dist);
355 });
356
357 let trafo = vecs.transpose() * (&draw);
358 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
359
360 let scaled = scale.as_diagonal() * inner_prod;
361
362 let _ = replace(dest, scaled);
363 }
364
365 fn array_update_variance(
366 &mut self,
367 mean: &mut Self::Vector,
368 variance: &mut Self::Vector,
369 value: &Self::Vector,
370 diff_scale: f64, ) {
372 self.arch.dispatch(|| {
373 izip!(
374 mean.try_as_col_major_mut()
375 .unwrap()
376 .as_slice_mut()
377 .iter_mut(),
378 variance
379 .try_as_col_major_mut()
380 .unwrap()
381 .as_slice_mut()
382 .iter_mut(),
383 value.try_as_col_major().unwrap().as_slice()
384 )
385 .for_each(|(mean, var, x)| {
386 let diff = x - *mean;
387 *mean += diff * diff_scale;
388 *var += diff * diff;
389 });
390 })
391 }
392
393 fn array_update_var_inv_std_draw(
394 &mut self,
395 variance_out: &mut Self::Vector,
396 inv_std: &mut Self::Vector,
397 draw_var: &Self::Vector,
398 scale: f64,
399 fill_invalid: Option<f64>,
400 clamp: (f64, f64),
401 ) {
402 self.arch.dispatch(|| {
403 izip!(
404 variance_out
405 .try_as_col_major_mut()
406 .unwrap()
407 .as_slice_mut()
408 .iter_mut(),
409 inv_std
410 .try_as_col_major_mut()
411 .unwrap()
412 .as_slice_mut()
413 .iter_mut(),
414 draw_var.try_as_col_major().unwrap().as_slice().iter(),
415 )
416 .for_each(|(var_out, inv_std_out, &draw_var)| {
417 let draw_var = draw_var * scale;
418 if (!draw_var.is_finite()) | (draw_var == 0f64) {
419 if let Some(fill_val) = fill_invalid {
420 *var_out = fill_val;
421 *inv_std_out = fill_val.recip().sqrt();
422 }
423 } else {
424 let val = draw_var.clamp(clamp.0, clamp.1);
425 *var_out = val;
426 *inv_std_out = val.recip().sqrt();
427 }
428 });
429 });
430 }
431
432 fn array_update_var_inv_std_draw_grad(
433 &mut self,
434 variance_out: &mut Self::Vector,
435 inv_std: &mut Self::Vector,
436 draw_var: &Self::Vector,
437 grad_var: &Self::Vector,
438 fill_invalid: Option<f64>,
439 clamp: (f64, f64),
440 ) {
441 self.arch.dispatch(|| {
442 izip!(
443 variance_out
444 .try_as_col_major_mut()
445 .unwrap()
446 .as_slice_mut()
447 .iter_mut(),
448 inv_std
449 .try_as_col_major_mut()
450 .unwrap()
451 .as_slice_mut()
452 .iter_mut(),
453 draw_var.try_as_col_major().unwrap().as_slice().iter(),
454 grad_var.try_as_col_major().unwrap().as_slice().iter(),
455 )
456 .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
457 let val = (draw_var / grad_var).sqrt();
458 if (!val.is_finite()) | (val == 0f64) {
459 if let Some(fill_val) = fill_invalid {
460 *var_out = fill_val;
461 *inv_std_out = fill_val.recip().sqrt();
462 }
463 } else {
464 let val = val.clamp(clamp.0, clamp.1);
465 *var_out = val;
466 *inv_std_out = val.recip().sqrt();
467 }
468 });
469 });
470 }
471
472 fn array_update_var_inv_std_grad(
473 &mut self,
474 variance_out: &mut Self::Vector,
475 inv_std: &mut Self::Vector,
476 gradient: &Self::Vector,
477 fill_invalid: f64,
478 clamp: (f64, f64),
479 ) {
480 self.arch.dispatch(|| {
481 izip!(
482 variance_out
483 .try_as_col_major_mut()
484 .unwrap()
485 .as_slice_mut()
486 .iter_mut(),
487 inv_std
488 .try_as_col_major_mut()
489 .unwrap()
490 .as_slice_mut()
491 .iter_mut(),
492 gradient.try_as_col_major().unwrap().as_slice().iter(),
493 )
494 .for_each(|(var_out, inv_std_out, &grad_var)| {
495 let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
496 let val = if val.is_finite() { val } else { fill_invalid };
497 *var_out = val;
498 *inv_std_out = val.recip().sqrt();
499 });
500 });
501 }
502
503 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
504 source
505 .try_as_col_major()
506 .unwrap()
507 .as_slice()
508 .to_vec()
509 .into()
510 }
511
512 fn inv_transform_normalize(
513 &mut self,
514 params: &Self::FlowParameters,
515 untransformed_position: &Self::Vector,
516 untransofrmed_gradient: &Self::Vector,
517 transformed_position: &mut Self::Vector,
518 transformed_gradient: &mut Self::Vector,
519 ) -> Result<f64, Self::LogpErr> {
520 self.logp_func.inv_transform_normalize(
521 params,
522 untransformed_position
523 .try_as_col_major()
524 .unwrap()
525 .as_slice(),
526 untransofrmed_gradient
527 .try_as_col_major()
528 .unwrap()
529 .as_slice(),
530 transformed_position
531 .try_as_col_major_mut()
532 .unwrap()
533 .as_slice_mut(),
534 transformed_gradient
535 .try_as_col_major_mut()
536 .unwrap()
537 .as_slice_mut(),
538 )
539 }
540
541 fn init_from_untransformed_position(
542 &mut self,
543 params: &Self::FlowParameters,
544 untransformed_position: &Self::Vector,
545 untransformed_gradient: &mut Self::Vector,
546 transformed_position: &mut Self::Vector,
547 transformed_gradient: &mut Self::Vector,
548 ) -> Result<(f64, f64), Self::LogpErr> {
549 self.logp_func.init_from_untransformed_position(
550 params,
551 untransformed_position
552 .try_as_col_major()
553 .unwrap()
554 .as_slice(),
555 untransformed_gradient
556 .try_as_col_major_mut()
557 .unwrap()
558 .as_slice_mut(),
559 transformed_position
560 .try_as_col_major_mut()
561 .unwrap()
562 .as_slice_mut(),
563 transformed_gradient
564 .try_as_col_major_mut()
565 .unwrap()
566 .as_slice_mut(),
567 )
568 }
569
570 fn init_from_transformed_position(
571 &mut self,
572 params: &Self::FlowParameters,
573 untransformed_position: &mut Self::Vector,
574 untransformed_gradient: &mut Self::Vector,
575 transformed_position: &Self::Vector,
576 transformed_gradient: &mut Self::Vector,
577 ) -> Result<(f64, f64), Self::LogpErr> {
578 self.logp_func.init_from_transformed_position(
579 params,
580 untransformed_position
581 .try_as_col_major_mut()
582 .unwrap()
583 .as_slice_mut(),
584 untransformed_gradient
585 .try_as_col_major_mut()
586 .unwrap()
587 .as_slice_mut(),
588 transformed_position.try_as_col_major().unwrap().as_slice(),
589 transformed_gradient
590 .try_as_col_major_mut()
591 .unwrap()
592 .as_slice_mut(),
593 )
594 }
595
596 fn update_transformation<'a, R: rand::Rng + ?Sized>(
597 &'a mut self,
598 rng: &mut R,
599 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
600 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
601 untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
602 params: &'a mut Self::FlowParameters,
603 ) -> Result<(), Self::LogpErr> {
604 self.logp_func.update_transformation(
605 rng,
606 untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
607 untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
608 untransformed_logp,
609 params,
610 )
611 }
612
613 fn new_transformation<R: rand::Rng + ?Sized>(
614 &mut self,
615 rng: &mut R,
616 untransformed_position: &Self::Vector,
617 untransfogmed_gradient: &Self::Vector,
618 chain: u64,
619 ) -> Result<Self::FlowParameters, Self::LogpErr> {
620 self.logp_func.new_transformation(
621 rng,
622 untransformed_position
623 .try_as_col_major()
624 .unwrap()
625 .as_slice(),
626 untransfogmed_gradient
627 .try_as_col_major()
628 .unwrap()
629 .as_slice(),
630 chain,
631 )
632 }
633
634 fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
635 self.logp_func.transformation_id(params)
636 }
637}
638
639pub trait CpuLogpFunc: HasDims {
640 type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
641 type FlowParameters;
642 type ExpandedVector: Storable<Self>;
643
644 fn dim(&self) -> usize;
645 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
646 fn expand_vector<R>(
647 &mut self,
648 rng: &mut R,
649 array: &[f64],
650 ) -> Result<Self::ExpandedVector, CpuMathError>
651 where
652 R: rand::Rng + ?Sized;
653
654 fn vector_coord(&self) -> Option<Value> {
655 None
656 }
657
658 fn inv_transform_normalize(
659 &mut self,
660 _params: &Self::FlowParameters,
661 _untransformed_position: &[f64],
662 _untransformed_gradient: &[f64],
663 _transformed_position: &mut [f64],
664 _transformed_gradient: &mut [f64],
665 ) -> Result<f64, Self::LogpError> {
666 unimplemented!()
667 }
668
669 fn init_from_untransformed_position(
670 &mut self,
671 _params: &Self::FlowParameters,
672 _untransformed_position: &[f64],
673 _untransformed_gradient: &mut [f64],
674 _transformed_position: &mut [f64],
675 _transformed_gradient: &mut [f64],
676 ) -> Result<(f64, f64), Self::LogpError> {
677 unimplemented!()
678 }
679
680 fn init_from_transformed_position(
681 &mut self,
682 _params: &Self::FlowParameters,
683 _untransformed_position: &mut [f64],
684 _untransformed_gradient: &mut [f64],
685 _transformed_position: &[f64],
686 _transformed_gradient: &mut [f64],
687 ) -> Result<(f64, f64), Self::LogpError> {
688 unimplemented!()
689 }
690
691 fn update_transformation<'a, R: rand::Rng + ?Sized>(
692 &'a mut self,
693 _rng: &mut R,
694 _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
695 _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
696 _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
697 _params: &'a mut Self::FlowParameters,
698 ) -> Result<(), Self::LogpError> {
699 unimplemented!()
700 }
701
702 fn new_transformation<R: rand::Rng + ?Sized>(
703 &mut self,
704 _rng: &mut R,
705 _untransformed_position: &[f64],
706 _untransformed_gradient: &[f64],
707 _chain: u64,
708 ) -> Result<Self::FlowParameters, Self::LogpError> {
709 unimplemented!()
710 }
711
712 fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
713 unimplemented!()
714 }
715}
716
717impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
718 fn clone(&self) -> Self {
719 Self {
720 logp_func: self.logp_func.clone(),
721 arch: self.arch,
722 }
723 }
724}