nuts_rs/
cpu_math.rs

1use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{Itertools, izip};
5use nuts_storable::{HasDims, Storable, Value};
6use thiserror::Error;
7
8use crate::{
9    math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
10    math_base::{LogpError, Math},
11};
12
13#[derive(Debug)]
14pub struct CpuMath<F: CpuLogpFunc> {
15    logp_func: F,
16    arch: pulp::Arch,
17}
18
19impl<F: CpuLogpFunc> CpuMath<F> {
20    pub fn new(logp_func: F) -> Self {
21        let arch = pulp::Arch::new();
22        Self { logp_func, arch }
23    }
24}
25
26#[non_exhaustive]
27#[derive(Error, Debug)]
28pub enum CpuMathError {
29    #[error("Error during array operation")]
30    ArrayError(),
31    #[error("Error during point expansion: {0}")]
32    ExpandError(String),
33}
34
35impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
36    fn dim_sizes(&self) -> HashMap<String, u64> {
37        self.logp_func.dim_sizes()
38    }
39
40    fn coords(&self) -> HashMap<String, nuts_storable::Value> {
41        self.logp_func.coords()
42    }
43}
44
45pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
46
47impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
48    fn names(parent: &CpuMath<F>) -> Vec<&str> {
49        F::ExpandedVector::names(&parent.logp_func)
50    }
51
52    fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
53        F::ExpandedVector::item_type(&parent.logp_func, item)
54    }
55
56    fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
57        F::ExpandedVector::dims(&parent.logp_func, item)
58    }
59
60    fn get_all<'a>(
61        &'a mut self,
62        parent: &'a CpuMath<F>,
63    ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
64        self.0.get_all(&parent.logp_func)
65    }
66}
67
68impl<F: CpuLogpFunc> Math for CpuMath<F> {
69    type Vector = Col<f64>;
70    type EigVectors = Mat<f64>;
71    type EigValues = Col<f64>;
72    type LogpErr = F::LogpError;
73    type Err = CpuMathError;
74    type FlowParameters = F::FlowParameters;
75    type ExpandedVector = ExpandedVectorWrapper<F>;
76
77    fn new_array(&mut self) -> Self::Vector {
78        Col::zeros(self.dim())
79    }
80
81    fn new_eig_vectors<'a>(
82        &'a mut self,
83        vals: impl ExactSizeIterator<Item = &'a [f64]>,
84    ) -> Self::EigVectors {
85        let ndim = self.dim();
86        let nvecs = vals.len();
87
88        let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
89        vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
90            col.try_as_col_major_mut()
91                .expect("Array is not contiguous")
92                .as_slice_mut()
93                .copy_from_slice(vals)
94        });
95
96        vectors
97    }
98
99    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
100        let mut values: Col<f64> = Col::zeros(vals.len());
101        values
102            .try_as_col_major_mut()
103            .expect("Array is not contiguous")
104            .as_slice_mut()
105            .copy_from_slice(vals);
106        values
107    }
108
109    fn logp_array(
110        &mut self,
111        position: &Self::Vector,
112        gradient: &mut Self::Vector,
113    ) -> Result<f64, Self::LogpErr> {
114        self.logp_func.logp(
115            position
116                .try_as_col_major()
117                .expect("Array is not contiguous")
118                .as_slice(),
119            gradient
120                .try_as_col_major_mut()
121                .expect("Array is not contiguous")
122                .as_slice_mut(),
123        )
124    }
125
126    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
127        self.logp_func.logp(position, gradient)
128    }
129
130    fn dim(&self) -> usize {
131        self.logp_func.dim()
132    }
133
134    fn expand_vector<R: rand::Rng + ?Sized>(
135        &mut self,
136        rng: &mut R,
137        array: &Self::Vector,
138    ) -> Result<Self::ExpandedVector, Self::Err> {
139        Ok(ExpandedVectorWrapper(
140            self.logp_func.expand_vector(
141                rng,
142                array
143                    .try_as_col_major()
144                    .ok_or_else(|| {
145                        CpuMathError::ExpandError("Internal vector was not col major".into())
146                    })?
147                    .as_slice(),
148            )?,
149        ))
150    }
151
152    fn vector_coord(&self) -> Option<Value> {
153        self.logp_func.vector_coord()
154    }
155
156    fn init_position<R: rand::Rng + ?Sized>(
157        &mut self,
158        rng: &mut R,
159        position: &mut Self::Vector,
160        gradient: &mut Self::Vector,
161    ) -> Result<f64, Self::LogpErr> {
162        let pos = position
163            .try_as_col_major_mut()
164            .expect("Array is not contiguous")
165            .as_slice_mut();
166
167        pos.iter_mut().for_each(|x| {
168            let val: f64 = rng.random();
169            *x = val * 2f64 - 1f64
170        });
171
172        self.logp_func.logp(
173            position
174                .try_as_col_major()
175                .expect("Array is not contiguous")
176                .as_slice(),
177            gradient
178                .try_as_col_major_mut()
179                .expect("Array is not contiguous")
180                .as_slice_mut(),
181        )
182    }
183
184    fn scalar_prods3(
185        &mut self,
186        positive1: &Self::Vector,
187        negative1: &Self::Vector,
188        positive2: &Self::Vector,
189        x: &Self::Vector,
190        y: &Self::Vector,
191    ) -> (f64, f64) {
192        scalar_prods3(
193            self.arch,
194            positive1.try_as_col_major().unwrap().as_slice(),
195            negative1.try_as_col_major().unwrap().as_slice(),
196            positive2.try_as_col_major().unwrap().as_slice(),
197            x.try_as_col_major().unwrap().as_slice(),
198            y.try_as_col_major().unwrap().as_slice(),
199        )
200    }
201
202    fn scalar_prods2(
203        &mut self,
204        positive1: &Self::Vector,
205        positive2: &Self::Vector,
206        x: &Self::Vector,
207        y: &Self::Vector,
208    ) -> (f64, f64) {
209        scalar_prods2(
210            self.arch,
211            positive1.try_as_col_major().unwrap().as_slice(),
212            positive2.try_as_col_major().unwrap().as_slice(),
213            x.try_as_col_major().unwrap().as_slice(),
214            y.try_as_col_major().unwrap().as_slice(),
215        )
216    }
217
218    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
219        x.try_as_col_major()
220            .unwrap()
221            .as_slice()
222            .iter()
223            .zip(y.try_as_col_major().unwrap().as_slice())
224            .map(|(&x, &y)| (x + y) * (x + y))
225            .sum()
226    }
227
228    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
229        dest.try_as_col_major_mut()
230            .unwrap()
231            .as_slice_mut()
232            .copy_from_slice(source);
233    }
234
235    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
236        dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
237    }
238
239    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
240        dest.clone_from(array)
241    }
242
243    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
244        axpy_out(
245            self.arch,
246            x.try_as_col_major().unwrap().as_slice(),
247            y.try_as_col_major().unwrap().as_slice(),
248            a,
249            out.try_as_col_major_mut().unwrap().as_slice_mut(),
250        );
251    }
252
253    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
254        axpy(
255            self.arch,
256            x.try_as_col_major().unwrap().as_slice(),
257            y.try_as_col_major_mut().unwrap().as_slice_mut(),
258            a,
259        );
260    }
261
262    fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
263        faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
264    }
265
266    fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
267        let mut ok = true;
268        faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
269        ok
270    }
271
272    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
273        self.arch.dispatch(|| {
274            array
275                .try_as_col_major()
276                .unwrap()
277                .as_slice()
278                .iter()
279                .all(|&x| x.is_finite() & (x != 0f64))
280        })
281    }
282
283    fn array_mult(
284        &mut self,
285        array1: &Self::Vector,
286        array2: &Self::Vector,
287        dest: &mut Self::Vector,
288    ) {
289        multiply(
290            self.arch,
291            array1.try_as_col_major().unwrap().as_slice(),
292            array2.try_as_col_major().unwrap().as_slice(),
293            dest.try_as_col_major_mut().unwrap().as_slice_mut(),
294        )
295    }
296
297    fn array_mult_eigs(
298        &mut self,
299        stds: &Self::Vector,
300        rhs: &Self::Vector,
301        dest: &mut Self::Vector,
302        vecs: &Self::EigVectors,
303        vals: &Self::EigValues,
304    ) {
305        let rhs = stds.as_diagonal() * rhs;
306        let trafo = vecs.transpose() * (&rhs);
307        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
308        let scaled = stds.as_diagonal() * inner_prod;
309
310        let _ = replace(dest, scaled);
311    }
312
313    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
314        vector_dot(
315            self.arch,
316            array1.try_as_col_major().unwrap().as_slice(),
317            array2.try_as_col_major().unwrap().as_slice(),
318        )
319    }
320
321    fn array_gaussian<R: rand::Rng + ?Sized>(
322        &mut self,
323        rng: &mut R,
324        dest: &mut Self::Vector,
325        stds: &Self::Vector,
326    ) {
327        let dist = rand_distr::StandardNormal;
328        dest.try_as_col_major_mut()
329            .unwrap()
330            .as_slice_mut()
331            .iter_mut()
332            .zip(stds.try_as_col_major().unwrap().as_slice().iter())
333            .for_each(|(p, &s)| {
334                let norm: f64 = rng.sample(dist);
335                *p = s * norm;
336            });
337    }
338
339    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
340        &mut self,
341        rng: &mut R,
342        dest: &mut Self::Vector,
343        scale: &Self::Vector,
344        vals: &Self::EigValues,
345        vecs: &Self::EigVectors,
346    ) {
347        let mut draw: Col<f64> = Col::zeros(self.dim());
348        let dist = rand_distr::StandardNormal;
349        draw.try_as_col_major_mut()
350            .unwrap()
351            .as_slice_mut()
352            .iter_mut()
353            .for_each(|p| {
354                *p = rng.sample(dist);
355            });
356
357        let trafo = vecs.transpose() * (&draw);
358        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
359
360        let scaled = scale.as_diagonal() * inner_prod;
361
362        let _ = replace(dest, scaled);
363    }
364
365    fn array_update_variance(
366        &mut self,
367        mean: &mut Self::Vector,
368        variance: &mut Self::Vector,
369        value: &Self::Vector,
370        diff_scale: f64, // 1 / self.count
371    ) {
372        self.arch.dispatch(|| {
373            izip!(
374                mean.try_as_col_major_mut()
375                    .unwrap()
376                    .as_slice_mut()
377                    .iter_mut(),
378                variance
379                    .try_as_col_major_mut()
380                    .unwrap()
381                    .as_slice_mut()
382                    .iter_mut(),
383                value.try_as_col_major().unwrap().as_slice()
384            )
385            .for_each(|(mean, var, x)| {
386                let diff = x - *mean;
387                *mean += diff * diff_scale;
388                *var += diff * diff;
389            });
390        })
391    }
392
393    fn array_update_var_inv_std_draw(
394        &mut self,
395        variance_out: &mut Self::Vector,
396        inv_std: &mut Self::Vector,
397        draw_var: &Self::Vector,
398        scale: f64,
399        fill_invalid: Option<f64>,
400        clamp: (f64, f64),
401    ) {
402        self.arch.dispatch(|| {
403            izip!(
404                variance_out
405                    .try_as_col_major_mut()
406                    .unwrap()
407                    .as_slice_mut()
408                    .iter_mut(),
409                inv_std
410                    .try_as_col_major_mut()
411                    .unwrap()
412                    .as_slice_mut()
413                    .iter_mut(),
414                draw_var.try_as_col_major().unwrap().as_slice().iter(),
415            )
416            .for_each(|(var_out, inv_std_out, &draw_var)| {
417                let draw_var = draw_var * scale;
418                if (!draw_var.is_finite()) | (draw_var == 0f64) {
419                    if let Some(fill_val) = fill_invalid {
420                        *var_out = fill_val;
421                        *inv_std_out = fill_val.recip().sqrt();
422                    }
423                } else {
424                    let val = draw_var.clamp(clamp.0, clamp.1);
425                    *var_out = val;
426                    *inv_std_out = val.recip().sqrt();
427                }
428            });
429        });
430    }
431
432    fn array_update_var_inv_std_draw_grad(
433        &mut self,
434        variance_out: &mut Self::Vector,
435        inv_std: &mut Self::Vector,
436        draw_var: &Self::Vector,
437        grad_var: &Self::Vector,
438        fill_invalid: Option<f64>,
439        clamp: (f64, f64),
440    ) {
441        self.arch.dispatch(|| {
442            izip!(
443                variance_out
444                    .try_as_col_major_mut()
445                    .unwrap()
446                    .as_slice_mut()
447                    .iter_mut(),
448                inv_std
449                    .try_as_col_major_mut()
450                    .unwrap()
451                    .as_slice_mut()
452                    .iter_mut(),
453                draw_var.try_as_col_major().unwrap().as_slice().iter(),
454                grad_var.try_as_col_major().unwrap().as_slice().iter(),
455            )
456            .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
457                let val = (draw_var / grad_var).sqrt();
458                if (!val.is_finite()) | (val == 0f64) {
459                    if let Some(fill_val) = fill_invalid {
460                        *var_out = fill_val;
461                        *inv_std_out = fill_val.recip().sqrt();
462                    }
463                } else {
464                    let val = val.clamp(clamp.0, clamp.1);
465                    *var_out = val;
466                    *inv_std_out = val.recip().sqrt();
467                }
468            });
469        });
470    }
471
472    fn array_update_var_inv_std_grad(
473        &mut self,
474        variance_out: &mut Self::Vector,
475        inv_std: &mut Self::Vector,
476        gradient: &Self::Vector,
477        fill_invalid: f64,
478        clamp: (f64, f64),
479    ) {
480        self.arch.dispatch(|| {
481            izip!(
482                variance_out
483                    .try_as_col_major_mut()
484                    .unwrap()
485                    .as_slice_mut()
486                    .iter_mut(),
487                inv_std
488                    .try_as_col_major_mut()
489                    .unwrap()
490                    .as_slice_mut()
491                    .iter_mut(),
492                gradient.try_as_col_major().unwrap().as_slice().iter(),
493            )
494            .for_each(|(var_out, inv_std_out, &grad_var)| {
495                let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
496                let val = if val.is_finite() { val } else { fill_invalid };
497                *var_out = val;
498                *inv_std_out = val.recip().sqrt();
499            });
500        });
501    }
502
503    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
504        source
505            .try_as_col_major()
506            .unwrap()
507            .as_slice()
508            .to_vec()
509            .into()
510    }
511
512    fn inv_transform_normalize(
513        &mut self,
514        params: &Self::FlowParameters,
515        untransformed_position: &Self::Vector,
516        untransofrmed_gradient: &Self::Vector,
517        transformed_position: &mut Self::Vector,
518        transformed_gradient: &mut Self::Vector,
519    ) -> Result<f64, Self::LogpErr> {
520        self.logp_func.inv_transform_normalize(
521            params,
522            untransformed_position
523                .try_as_col_major()
524                .unwrap()
525                .as_slice(),
526            untransofrmed_gradient
527                .try_as_col_major()
528                .unwrap()
529                .as_slice(),
530            transformed_position
531                .try_as_col_major_mut()
532                .unwrap()
533                .as_slice_mut(),
534            transformed_gradient
535                .try_as_col_major_mut()
536                .unwrap()
537                .as_slice_mut(),
538        )
539    }
540
541    fn init_from_untransformed_position(
542        &mut self,
543        params: &Self::FlowParameters,
544        untransformed_position: &Self::Vector,
545        untransformed_gradient: &mut Self::Vector,
546        transformed_position: &mut Self::Vector,
547        transformed_gradient: &mut Self::Vector,
548    ) -> Result<(f64, f64), Self::LogpErr> {
549        self.logp_func.init_from_untransformed_position(
550            params,
551            untransformed_position
552                .try_as_col_major()
553                .unwrap()
554                .as_slice(),
555            untransformed_gradient
556                .try_as_col_major_mut()
557                .unwrap()
558                .as_slice_mut(),
559            transformed_position
560                .try_as_col_major_mut()
561                .unwrap()
562                .as_slice_mut(),
563            transformed_gradient
564                .try_as_col_major_mut()
565                .unwrap()
566                .as_slice_mut(),
567        )
568    }
569
570    fn init_from_transformed_position(
571        &mut self,
572        params: &Self::FlowParameters,
573        untransformed_position: &mut Self::Vector,
574        untransformed_gradient: &mut Self::Vector,
575        transformed_position: &Self::Vector,
576        transformed_gradient: &mut Self::Vector,
577    ) -> Result<(f64, f64), Self::LogpErr> {
578        self.logp_func.init_from_transformed_position(
579            params,
580            untransformed_position
581                .try_as_col_major_mut()
582                .unwrap()
583                .as_slice_mut(),
584            untransformed_gradient
585                .try_as_col_major_mut()
586                .unwrap()
587                .as_slice_mut(),
588            transformed_position.try_as_col_major().unwrap().as_slice(),
589            transformed_gradient
590                .try_as_col_major_mut()
591                .unwrap()
592                .as_slice_mut(),
593        )
594    }
595
596    fn update_transformation<'a, R: rand::Rng + ?Sized>(
597        &'a mut self,
598        rng: &mut R,
599        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
600        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
601        untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
602        params: &'a mut Self::FlowParameters,
603    ) -> Result<(), Self::LogpErr> {
604        self.logp_func.update_transformation(
605            rng,
606            untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
607            untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
608            untransformed_logp,
609            params,
610        )
611    }
612
613    fn new_transformation<R: rand::Rng + ?Sized>(
614        &mut self,
615        rng: &mut R,
616        untransformed_position: &Self::Vector,
617        untransfogmed_gradient: &Self::Vector,
618        chain: u64,
619    ) -> Result<Self::FlowParameters, Self::LogpErr> {
620        self.logp_func.new_transformation(
621            rng,
622            untransformed_position
623                .try_as_col_major()
624                .unwrap()
625                .as_slice(),
626            untransfogmed_gradient
627                .try_as_col_major()
628                .unwrap()
629                .as_slice(),
630            chain,
631        )
632    }
633
634    fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
635        self.logp_func.transformation_id(params)
636    }
637}
638
639pub trait CpuLogpFunc: HasDims {
640    type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
641    type FlowParameters;
642    type ExpandedVector: Storable<Self>;
643
644    fn dim(&self) -> usize;
645    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
646    fn expand_vector<R>(
647        &mut self,
648        rng: &mut R,
649        array: &[f64],
650    ) -> Result<Self::ExpandedVector, CpuMathError>
651    where
652        R: rand::Rng + ?Sized;
653
654    fn vector_coord(&self) -> Option<Value> {
655        None
656    }
657
658    fn inv_transform_normalize(
659        &mut self,
660        _params: &Self::FlowParameters,
661        _untransformed_position: &[f64],
662        _untransformed_gradient: &[f64],
663        _transformed_position: &mut [f64],
664        _transformed_gradient: &mut [f64],
665    ) -> Result<f64, Self::LogpError> {
666        unimplemented!()
667    }
668
669    fn init_from_untransformed_position(
670        &mut self,
671        _params: &Self::FlowParameters,
672        _untransformed_position: &[f64],
673        _untransformed_gradient: &mut [f64],
674        _transformed_position: &mut [f64],
675        _transformed_gradient: &mut [f64],
676    ) -> Result<(f64, f64), Self::LogpError> {
677        unimplemented!()
678    }
679
680    fn init_from_transformed_position(
681        &mut self,
682        _params: &Self::FlowParameters,
683        _untransformed_position: &mut [f64],
684        _untransformed_gradient: &mut [f64],
685        _transformed_position: &[f64],
686        _transformed_gradient: &mut [f64],
687    ) -> Result<(f64, f64), Self::LogpError> {
688        unimplemented!()
689    }
690
691    fn update_transformation<'a, R: rand::Rng + ?Sized>(
692        &'a mut self,
693        _rng: &mut R,
694        _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
695        _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
696        _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
697        _params: &'a mut Self::FlowParameters,
698    ) -> Result<(), Self::LogpError> {
699        unimplemented!()
700    }
701
702    fn new_transformation<R: rand::Rng + ?Sized>(
703        &mut self,
704        _rng: &mut R,
705        _untransformed_position: &[f64],
706        _untransformed_gradient: &[f64],
707        _chain: u64,
708    ) -> Result<Self::FlowParameters, Self::LogpError> {
709        unimplemented!()
710    }
711
712    fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
713        unimplemented!()
714    }
715}
716
717impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
718    fn clone(&self) -> Self {
719        Self {
720            logp_func: self.logp_func.clone(),
721            arch: self.arch,
722        }
723    }
724}