Skip to main content

nuts_rs/
cpu_math.rs

1use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{Itertools, izip};
5use nuts_storable::{HasDims, Storable, Value};
6use rand::RngExt;
7use thiserror::Error;
8
9use crate::{
10    math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
11    math_base::{LogpError, Math},
12};
13
14#[derive(Debug)]
15pub struct CpuMath<F: CpuLogpFunc> {
16    logp_func: F,
17    arch: pulp::Arch,
18}
19
20impl<F: CpuLogpFunc> CpuMath<F> {
21    pub fn new(logp_func: F) -> Self {
22        let arch = pulp::Arch::new();
23        Self { logp_func, arch }
24    }
25}
26
27#[non_exhaustive]
28#[derive(Error, Debug)]
29pub enum CpuMathError {
30    #[error("Error during array operation")]
31    ArrayError(),
32    #[error("Error during point expansion: {0}")]
33    ExpandError(String),
34}
35
36impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
37    fn dim_sizes(&self) -> HashMap<String, u64> {
38        self.logp_func.dim_sizes()
39    }
40
41    fn coords(&self) -> HashMap<String, nuts_storable::Value> {
42        self.logp_func.coords()
43    }
44}
45
46pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
47
48impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
49    fn names(parent: &CpuMath<F>) -> Vec<&str> {
50        F::ExpandedVector::names(&parent.logp_func)
51    }
52
53    fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
54        F::ExpandedVector::item_type(&parent.logp_func, item)
55    }
56
57    fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
58        F::ExpandedVector::dims(&parent.logp_func, item)
59    }
60
61    fn get_all<'a>(
62        &'a mut self,
63        parent: &'a CpuMath<F>,
64    ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
65        self.0.get_all(&parent.logp_func)
66    }
67}
68
69impl<F: CpuLogpFunc> Math for CpuMath<F> {
70    type Vector = Col<f64>;
71    type EigVectors = Mat<f64>;
72    type EigValues = Col<f64>;
73    type LogpErr = F::LogpError;
74    type Err = CpuMathError;
75    type FlowParameters = F::FlowParameters;
76    type ExpandedVector = ExpandedVectorWrapper<F>;
77
78    fn new_array(&mut self) -> Self::Vector {
79        Col::zeros(self.dim())
80    }
81
82    fn new_eig_vectors<'a>(
83        &'a mut self,
84        vals: impl ExactSizeIterator<Item = &'a [f64]>,
85    ) -> Self::EigVectors {
86        let ndim = self.dim();
87        let nvecs = vals.len();
88
89        let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
90        vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
91            col.try_as_col_major_mut()
92                .expect("Array is not contiguous")
93                .as_slice_mut()
94                .copy_from_slice(vals)
95        });
96
97        vectors
98    }
99
100    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
101        let mut values: Col<f64> = Col::zeros(vals.len());
102        values
103            .try_as_col_major_mut()
104            .expect("Array is not contiguous")
105            .as_slice_mut()
106            .copy_from_slice(vals);
107        values
108    }
109
110    fn logp_array(
111        &mut self,
112        position: &Self::Vector,
113        gradient: &mut Self::Vector,
114    ) -> Result<f64, Self::LogpErr> {
115        self.logp_func.logp(
116            position
117                .try_as_col_major()
118                .expect("Array is not contiguous")
119                .as_slice(),
120            gradient
121                .try_as_col_major_mut()
122                .expect("Array is not contiguous")
123                .as_slice_mut(),
124        )
125    }
126
127    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
128        self.logp_func.logp(position, gradient)
129    }
130
131    fn dim(&self) -> usize {
132        self.logp_func.dim()
133    }
134
135    fn expand_vector<R: rand::Rng + ?Sized>(
136        &mut self,
137        rng: &mut R,
138        array: &Self::Vector,
139    ) -> Result<Self::ExpandedVector, Self::Err> {
140        Ok(ExpandedVectorWrapper(
141            self.logp_func.expand_vector(
142                rng,
143                array
144                    .try_as_col_major()
145                    .ok_or_else(|| {
146                        CpuMathError::ExpandError("Internal vector was not col major".into())
147                    })?
148                    .as_slice(),
149            )?,
150        ))
151    }
152
153    fn vector_coord(&self) -> Option<Value> {
154        self.logp_func.vector_coord()
155    }
156
157    fn init_position<R: rand::Rng + ?Sized>(
158        &mut self,
159        rng: &mut R,
160        position: &mut Self::Vector,
161        gradient: &mut Self::Vector,
162    ) -> Result<f64, Self::LogpErr> {
163        let pos = position
164            .try_as_col_major_mut()
165            .expect("Array is not contiguous")
166            .as_slice_mut();
167
168        pos.iter_mut().for_each(|x| {
169            let val: f64 = rng.random();
170            *x = val * 2f64 - 1f64
171        });
172
173        self.logp_func.logp(
174            position
175                .try_as_col_major()
176                .expect("Array is not contiguous")
177                .as_slice(),
178            gradient
179                .try_as_col_major_mut()
180                .expect("Array is not contiguous")
181                .as_slice_mut(),
182        )
183    }
184
185    fn scalar_prods3(
186        &mut self,
187        positive1: &Self::Vector,
188        negative1: &Self::Vector,
189        positive2: &Self::Vector,
190        x: &Self::Vector,
191        y: &Self::Vector,
192    ) -> (f64, f64) {
193        scalar_prods3(
194            self.arch,
195            positive1.try_as_col_major().unwrap().as_slice(),
196            negative1.try_as_col_major().unwrap().as_slice(),
197            positive2.try_as_col_major().unwrap().as_slice(),
198            x.try_as_col_major().unwrap().as_slice(),
199            y.try_as_col_major().unwrap().as_slice(),
200        )
201    }
202
203    fn scalar_prods2(
204        &mut self,
205        positive1: &Self::Vector,
206        positive2: &Self::Vector,
207        x: &Self::Vector,
208        y: &Self::Vector,
209    ) -> (f64, f64) {
210        scalar_prods2(
211            self.arch,
212            positive1.try_as_col_major().unwrap().as_slice(),
213            positive2.try_as_col_major().unwrap().as_slice(),
214            x.try_as_col_major().unwrap().as_slice(),
215            y.try_as_col_major().unwrap().as_slice(),
216        )
217    }
218
219    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
220        x.try_as_col_major()
221            .unwrap()
222            .as_slice()
223            .iter()
224            .zip(y.try_as_col_major().unwrap().as_slice())
225            .map(|(&x, &y)| (x + y) * (x + y))
226            .sum()
227    }
228
229    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
230        dest.try_as_col_major_mut()
231            .unwrap()
232            .as_slice_mut()
233            .copy_from_slice(source);
234    }
235
236    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
237        dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
238    }
239
240    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
241        dest.clone_from(array)
242    }
243
244    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
245        axpy_out(
246            self.arch,
247            x.try_as_col_major().unwrap().as_slice(),
248            y.try_as_col_major().unwrap().as_slice(),
249            a,
250            out.try_as_col_major_mut().unwrap().as_slice_mut(),
251        );
252    }
253
254    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
255        axpy(
256            self.arch,
257            x.try_as_col_major().unwrap().as_slice(),
258            y.try_as_col_major_mut().unwrap().as_slice_mut(),
259            a,
260        );
261    }
262
263    fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
264        faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
265    }
266
267    fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
268        let mut ok = true;
269        faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
270        ok
271    }
272
273    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
274        self.arch.dispatch(|| {
275            array
276                .try_as_col_major()
277                .unwrap()
278                .as_slice()
279                .iter()
280                .all(|&x| x.is_finite() & (x != 0f64))
281        })
282    }
283
284    fn array_mult(
285        &mut self,
286        array1: &Self::Vector,
287        array2: &Self::Vector,
288        dest: &mut Self::Vector,
289    ) {
290        multiply(
291            self.arch,
292            array1.try_as_col_major().unwrap().as_slice(),
293            array2.try_as_col_major().unwrap().as_slice(),
294            dest.try_as_col_major_mut().unwrap().as_slice_mut(),
295        )
296    }
297
298    fn array_mult_eigs(
299        &mut self,
300        stds: &Self::Vector,
301        rhs: &Self::Vector,
302        dest: &mut Self::Vector,
303        vecs: &Self::EigVectors,
304        vals: &Self::EigValues,
305    ) {
306        let rhs = stds.as_diagonal() * rhs;
307        let trafo = vecs.transpose() * (&rhs);
308        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
309        let scaled = stds.as_diagonal() * inner_prod;
310
311        let _ = replace(dest, scaled);
312    }
313
314    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
315        vector_dot(
316            self.arch,
317            array1.try_as_col_major().unwrap().as_slice(),
318            array2.try_as_col_major().unwrap().as_slice(),
319        )
320    }
321
322    fn array_gaussian<R: rand::Rng + ?Sized>(
323        &mut self,
324        rng: &mut R,
325        dest: &mut Self::Vector,
326        stds: &Self::Vector,
327    ) {
328        let dist = rand_distr::StandardNormal;
329        dest.try_as_col_major_mut()
330            .unwrap()
331            .as_slice_mut()
332            .iter_mut()
333            .zip(stds.try_as_col_major().unwrap().as_slice().iter())
334            .for_each(|(p, &s)| {
335                let norm: f64 = rng.sample(dist);
336                *p = s * norm;
337            });
338    }
339
340    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
341        &mut self,
342        rng: &mut R,
343        dest: &mut Self::Vector,
344        scale: &Self::Vector,
345        vals: &Self::EigValues,
346        vecs: &Self::EigVectors,
347    ) {
348        let mut draw: Col<f64> = Col::zeros(self.dim());
349        let dist = rand_distr::StandardNormal;
350        draw.try_as_col_major_mut()
351            .unwrap()
352            .as_slice_mut()
353            .iter_mut()
354            .for_each(|p| {
355                *p = rng.sample(dist);
356            });
357
358        let trafo = vecs.transpose() * (&draw);
359        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
360
361        let scaled = scale.as_diagonal() * inner_prod;
362
363        let _ = replace(dest, scaled);
364    }
365
366    fn array_update_variance(
367        &mut self,
368        mean: &mut Self::Vector,
369        variance: &mut Self::Vector,
370        value: &Self::Vector,
371        diff_scale: f64, // 1 / self.count
372    ) {
373        self.arch.dispatch(|| {
374            izip!(
375                mean.try_as_col_major_mut()
376                    .unwrap()
377                    .as_slice_mut()
378                    .iter_mut(),
379                variance
380                    .try_as_col_major_mut()
381                    .unwrap()
382                    .as_slice_mut()
383                    .iter_mut(),
384                value.try_as_col_major().unwrap().as_slice()
385            )
386            .for_each(|(mean, var, x)| {
387                let diff = x - *mean;
388                *mean += diff * diff_scale;
389                *var += diff * diff;
390            });
391        })
392    }
393
394    fn array_update_var_inv_std_draw(
395        &mut self,
396        variance_out: &mut Self::Vector,
397        inv_std: &mut Self::Vector,
398        draw_var: &Self::Vector,
399        scale: f64,
400        fill_invalid: Option<f64>,
401        clamp: (f64, f64),
402    ) {
403        self.arch.dispatch(|| {
404            izip!(
405                variance_out
406                    .try_as_col_major_mut()
407                    .unwrap()
408                    .as_slice_mut()
409                    .iter_mut(),
410                inv_std
411                    .try_as_col_major_mut()
412                    .unwrap()
413                    .as_slice_mut()
414                    .iter_mut(),
415                draw_var.try_as_col_major().unwrap().as_slice().iter(),
416            )
417            .for_each(|(var_out, inv_std_out, &draw_var)| {
418                let draw_var = draw_var * scale;
419                if (!draw_var.is_finite()) | (draw_var == 0f64) {
420                    if let Some(fill_val) = fill_invalid {
421                        *var_out = fill_val;
422                        *inv_std_out = fill_val.recip().sqrt();
423                    }
424                } else {
425                    let val = draw_var.clamp(clamp.0, clamp.1);
426                    *var_out = val;
427                    *inv_std_out = val.recip().sqrt();
428                }
429            });
430        });
431    }
432
433    fn array_update_var_inv_std_draw_grad(
434        &mut self,
435        variance_out: &mut Self::Vector,
436        inv_std: &mut Self::Vector,
437        draw_var: &Self::Vector,
438        grad_var: &Self::Vector,
439        fill_invalid: Option<f64>,
440        clamp: (f64, f64),
441    ) {
442        self.arch.dispatch(|| {
443            izip!(
444                variance_out
445                    .try_as_col_major_mut()
446                    .unwrap()
447                    .as_slice_mut()
448                    .iter_mut(),
449                inv_std
450                    .try_as_col_major_mut()
451                    .unwrap()
452                    .as_slice_mut()
453                    .iter_mut(),
454                draw_var.try_as_col_major().unwrap().as_slice().iter(),
455                grad_var.try_as_col_major().unwrap().as_slice().iter(),
456            )
457            .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
458                let val = (draw_var / grad_var).sqrt();
459                if (!val.is_finite()) | (val == 0f64) {
460                    if let Some(fill_val) = fill_invalid {
461                        *var_out = fill_val;
462                        *inv_std_out = fill_val.recip().sqrt();
463                    }
464                } else {
465                    let val = val.clamp(clamp.0, clamp.1);
466                    *var_out = val;
467                    *inv_std_out = val.recip().sqrt();
468                }
469            });
470        });
471    }
472
473    fn array_update_var_inv_std_grad(
474        &mut self,
475        variance_out: &mut Self::Vector,
476        inv_std: &mut Self::Vector,
477        gradient: &Self::Vector,
478        fill_invalid: f64,
479        clamp: (f64, f64),
480    ) {
481        self.arch.dispatch(|| {
482            izip!(
483                variance_out
484                    .try_as_col_major_mut()
485                    .unwrap()
486                    .as_slice_mut()
487                    .iter_mut(),
488                inv_std
489                    .try_as_col_major_mut()
490                    .unwrap()
491                    .as_slice_mut()
492                    .iter_mut(),
493                gradient.try_as_col_major().unwrap().as_slice().iter(),
494            )
495            .for_each(|(var_out, inv_std_out, &grad_var)| {
496                let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
497                let val = if val.is_finite() { val } else { fill_invalid };
498                *var_out = val;
499                *inv_std_out = val.recip().sqrt();
500            });
501        });
502    }
503
504    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
505        source
506            .try_as_col_major()
507            .unwrap()
508            .as_slice()
509            .to_vec()
510            .into()
511    }
512
513    fn inv_transform_normalize(
514        &mut self,
515        params: &Self::FlowParameters,
516        untransformed_position: &Self::Vector,
517        untransofrmed_gradient: &Self::Vector,
518        transformed_position: &mut Self::Vector,
519        transformed_gradient: &mut Self::Vector,
520    ) -> Result<f64, Self::LogpErr> {
521        self.logp_func.inv_transform_normalize(
522            params,
523            untransformed_position
524                .try_as_col_major()
525                .unwrap()
526                .as_slice(),
527            untransofrmed_gradient
528                .try_as_col_major()
529                .unwrap()
530                .as_slice(),
531            transformed_position
532                .try_as_col_major_mut()
533                .unwrap()
534                .as_slice_mut(),
535            transformed_gradient
536                .try_as_col_major_mut()
537                .unwrap()
538                .as_slice_mut(),
539        )
540    }
541
542    fn init_from_untransformed_position(
543        &mut self,
544        params: &Self::FlowParameters,
545        untransformed_position: &Self::Vector,
546        untransformed_gradient: &mut Self::Vector,
547        transformed_position: &mut Self::Vector,
548        transformed_gradient: &mut Self::Vector,
549    ) -> Result<(f64, f64), Self::LogpErr> {
550        self.logp_func.init_from_untransformed_position(
551            params,
552            untransformed_position
553                .try_as_col_major()
554                .unwrap()
555                .as_slice(),
556            untransformed_gradient
557                .try_as_col_major_mut()
558                .unwrap()
559                .as_slice_mut(),
560            transformed_position
561                .try_as_col_major_mut()
562                .unwrap()
563                .as_slice_mut(),
564            transformed_gradient
565                .try_as_col_major_mut()
566                .unwrap()
567                .as_slice_mut(),
568        )
569    }
570
571    fn init_from_transformed_position(
572        &mut self,
573        params: &Self::FlowParameters,
574        untransformed_position: &mut Self::Vector,
575        untransformed_gradient: &mut Self::Vector,
576        transformed_position: &Self::Vector,
577        transformed_gradient: &mut Self::Vector,
578    ) -> Result<(f64, f64), Self::LogpErr> {
579        self.logp_func.init_from_transformed_position(
580            params,
581            untransformed_position
582                .try_as_col_major_mut()
583                .unwrap()
584                .as_slice_mut(),
585            untransformed_gradient
586                .try_as_col_major_mut()
587                .unwrap()
588                .as_slice_mut(),
589            transformed_position.try_as_col_major().unwrap().as_slice(),
590            transformed_gradient
591                .try_as_col_major_mut()
592                .unwrap()
593                .as_slice_mut(),
594        )
595    }
596
597    fn update_transformation<'a, R: rand::Rng + ?Sized>(
598        &'a mut self,
599        rng: &mut R,
600        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
601        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
602        untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
603        params: &'a mut Self::FlowParameters,
604    ) -> Result<(), Self::LogpErr> {
605        self.logp_func.update_transformation(
606            rng,
607            untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
608            untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
609            untransformed_logp,
610            params,
611        )
612    }
613
614    fn new_transformation<R: rand::Rng + ?Sized>(
615        &mut self,
616        rng: &mut R,
617        untransformed_position: &Self::Vector,
618        untransfogmed_gradient: &Self::Vector,
619        chain: u64,
620    ) -> Result<Self::FlowParameters, Self::LogpErr> {
621        self.logp_func.new_transformation(
622            rng,
623            untransformed_position
624                .try_as_col_major()
625                .unwrap()
626                .as_slice(),
627            untransfogmed_gradient
628                .try_as_col_major()
629                .unwrap()
630                .as_slice(),
631            chain,
632        )
633    }
634
635    fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
636        self.logp_func.transformation_id(params)
637    }
638}
639
640pub trait CpuLogpFunc: HasDims {
641    type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
642    type FlowParameters;
643    type ExpandedVector: Storable<Self>;
644
645    fn dim(&self) -> usize;
646    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
647    fn expand_vector<R>(
648        &mut self,
649        rng: &mut R,
650        array: &[f64],
651    ) -> Result<Self::ExpandedVector, CpuMathError>
652    where
653        R: rand::Rng + ?Sized;
654
655    fn vector_coord(&self) -> Option<Value> {
656        None
657    }
658
659    fn inv_transform_normalize(
660        &mut self,
661        _params: &Self::FlowParameters,
662        _untransformed_position: &[f64],
663        _untransformed_gradient: &[f64],
664        _transformed_position: &mut [f64],
665        _transformed_gradient: &mut [f64],
666    ) -> Result<f64, Self::LogpError> {
667        unimplemented!()
668    }
669
670    fn init_from_untransformed_position(
671        &mut self,
672        _params: &Self::FlowParameters,
673        _untransformed_position: &[f64],
674        _untransformed_gradient: &mut [f64],
675        _transformed_position: &mut [f64],
676        _transformed_gradient: &mut [f64],
677    ) -> Result<(f64, f64), Self::LogpError> {
678        unimplemented!()
679    }
680
681    fn init_from_transformed_position(
682        &mut self,
683        _params: &Self::FlowParameters,
684        _untransformed_position: &mut [f64],
685        _untransformed_gradient: &mut [f64],
686        _transformed_position: &[f64],
687        _transformed_gradient: &mut [f64],
688    ) -> Result<(f64, f64), Self::LogpError> {
689        unimplemented!()
690    }
691
692    fn update_transformation<'a, R: rand::Rng + ?Sized>(
693        &'a mut self,
694        _rng: &mut R,
695        _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
696        _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
697        _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
698        _params: &'a mut Self::FlowParameters,
699    ) -> Result<(), Self::LogpError> {
700        unimplemented!()
701    }
702
703    fn new_transformation<R: rand::Rng + ?Sized>(
704        &mut self,
705        _rng: &mut R,
706        _untransformed_position: &[f64],
707        _untransformed_gradient: &[f64],
708        _chain: u64,
709    ) -> Result<Self::FlowParameters, Self::LogpError> {
710        unimplemented!()
711    }
712
713    fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
714        unimplemented!()
715    }
716}
717
718impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
719    fn clone(&self) -> Self {
720        Self {
721            logp_func: self.logp_func.clone(),
722            arch: self.arch,
723        }
724    }
725}