nuts_rs/
cpu_math.rs

1use std::{error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{izip, Itertools};
5use thiserror::Error;
6
7use crate::{
8    math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
9    math_base::{LogpError, Math},
10};
11
12#[derive(Debug)]
13pub struct CpuMath<F: CpuLogpFunc> {
14    logp_func: F,
15    arch: pulp::Arch,
16}
17
18impl<F: CpuLogpFunc> CpuMath<F> {
19    pub fn new(logp_func: F) -> Self {
20        let arch = pulp::Arch::new();
21        Self { logp_func, arch }
22    }
23}
24
25#[non_exhaustive]
26#[derive(Error, Debug)]
27pub enum CpuMathError {
28    #[error("Error during array operation")]
29    ArrayError(),
30}
31
32impl<F: CpuLogpFunc> Math for CpuMath<F> {
33    type Vector = Col<f64>;
34    type EigVectors = Mat<f64>;
35    type EigValues = Col<f64>;
36    type LogpErr = F::LogpError;
37    type Err = CpuMathError;
38    type TransformParams = F::TransformParams;
39
40    fn new_array(&mut self) -> Self::Vector {
41        Col::zeros(self.dim())
42    }
43
44    fn new_eig_vectors<'a>(
45        &'a mut self,
46        vals: impl ExactSizeIterator<Item = &'a [f64]>,
47    ) -> Self::EigVectors {
48        let ndim = self.dim();
49        let nvecs = vals.len();
50
51        let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
52        vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
53            col.try_as_col_major_mut()
54                .expect("Array is not contiguous")
55                .as_slice_mut()
56                .copy_from_slice(vals)
57        });
58
59        vectors
60    }
61
62    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
63        let mut values: Col<f64> = Col::zeros(vals.len());
64        values
65            .try_as_col_major_mut()
66            .expect("Array is not contiguous")
67            .as_slice_mut()
68            .copy_from_slice(vals);
69        values
70    }
71
72    fn logp_array(
73        &mut self,
74        position: &Self::Vector,
75        gradient: &mut Self::Vector,
76    ) -> Result<f64, Self::LogpErr> {
77        self.logp_func.logp(
78            position
79                .try_as_col_major()
80                .expect("Array is not contiguous")
81                .as_slice(),
82            gradient
83                .try_as_col_major_mut()
84                .expect("Array is not contiguous")
85                .as_slice_mut(),
86        )
87    }
88
89    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
90        self.logp_func.logp(position, gradient)
91    }
92
93    fn dim(&self) -> usize {
94        self.logp_func.dim()
95    }
96
97    fn scalar_prods3(
98        &mut self,
99        positive1: &Self::Vector,
100        negative1: &Self::Vector,
101        positive2: &Self::Vector,
102        x: &Self::Vector,
103        y: &Self::Vector,
104    ) -> (f64, f64) {
105        scalar_prods3(
106            self.arch,
107            positive1.try_as_col_major().unwrap().as_slice(),
108            negative1.try_as_col_major().unwrap().as_slice(),
109            positive2.try_as_col_major().unwrap().as_slice(),
110            x.try_as_col_major().unwrap().as_slice(),
111            y.try_as_col_major().unwrap().as_slice(),
112        )
113    }
114
115    fn scalar_prods2(
116        &mut self,
117        positive1: &Self::Vector,
118        positive2: &Self::Vector,
119        x: &Self::Vector,
120        y: &Self::Vector,
121    ) -> (f64, f64) {
122        scalar_prods2(
123            self.arch,
124            positive1.try_as_col_major().unwrap().as_slice(),
125            positive2.try_as_col_major().unwrap().as_slice(),
126            x.try_as_col_major().unwrap().as_slice(),
127            y.try_as_col_major().unwrap().as_slice(),
128        )
129    }
130
131    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
132        x.try_as_col_major()
133            .unwrap()
134            .as_slice()
135            .iter()
136            .zip(y.try_as_col_major().unwrap().as_slice())
137            .map(|(&x, &y)| (x + y) * (x + y))
138            .sum()
139    }
140
141    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
142        dest.try_as_col_major_mut()
143            .unwrap()
144            .as_slice_mut()
145            .copy_from_slice(source);
146    }
147
148    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
149        dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
150    }
151
152    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
153        dest.clone_from(array)
154    }
155
156    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
157        axpy_out(
158            self.arch,
159            x.try_as_col_major().unwrap().as_slice(),
160            y.try_as_col_major().unwrap().as_slice(),
161            a,
162            out.try_as_col_major_mut().unwrap().as_slice_mut(),
163        );
164    }
165
166    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
167        axpy(
168            self.arch,
169            x.try_as_col_major().unwrap().as_slice(),
170            y.try_as_col_major_mut().unwrap().as_slice_mut(),
171            a,
172        );
173    }
174
175    fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
176        faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
177    }
178
179    fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
180        let mut ok = true;
181        faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
182        ok
183    }
184
185    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
186        self.arch.dispatch(|| {
187            array
188                .try_as_col_major()
189                .unwrap()
190                .as_slice()
191                .iter()
192                .all(|&x| x.is_finite() & (x != 0f64))
193        })
194    }
195
196    fn array_mult(
197        &mut self,
198        array1: &Self::Vector,
199        array2: &Self::Vector,
200        dest: &mut Self::Vector,
201    ) {
202        multiply(
203            self.arch,
204            array1.try_as_col_major().unwrap().as_slice(),
205            array2.try_as_col_major().unwrap().as_slice(),
206            dest.try_as_col_major_mut().unwrap().as_slice_mut(),
207        )
208    }
209
210    fn array_mult_eigs(
211        &mut self,
212        stds: &Self::Vector,
213        rhs: &Self::Vector,
214        dest: &mut Self::Vector,
215        vecs: &Self::EigVectors,
216        vals: &Self::EigValues,
217    ) {
218        let rhs = stds.as_diagonal() * rhs;
219        let trafo = vecs.transpose() * (&rhs);
220        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
221        let scaled = stds.as_diagonal() * inner_prod;
222
223        let _ = replace(dest, scaled);
224    }
225
226    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
227        vector_dot(
228            self.arch,
229            array1.try_as_col_major().unwrap().as_slice(),
230            array2.try_as_col_major().unwrap().as_slice(),
231        )
232    }
233
234    fn array_gaussian<R: rand::Rng + ?Sized>(
235        &mut self,
236        rng: &mut R,
237        dest: &mut Self::Vector,
238        stds: &Self::Vector,
239    ) {
240        let dist = rand_distr::StandardNormal;
241        dest.try_as_col_major_mut()
242            .unwrap()
243            .as_slice_mut()
244            .iter_mut()
245            .zip(stds.try_as_col_major().unwrap().as_slice().iter())
246            .for_each(|(p, &s)| {
247                let norm: f64 = rng.sample(dist);
248                *p = s * norm;
249            });
250    }
251
252    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
253        &mut self,
254        rng: &mut R,
255        dest: &mut Self::Vector,
256        scale: &Self::Vector,
257        vals: &Self::EigValues,
258        vecs: &Self::EigVectors,
259    ) {
260        let mut draw: Col<f64> = Col::zeros(self.dim());
261        let dist = rand_distr::StandardNormal;
262        draw.try_as_col_major_mut()
263            .unwrap()
264            .as_slice_mut()
265            .iter_mut()
266            .for_each(|p| {
267                *p = rng.sample(dist);
268            });
269
270        let trafo = vecs.transpose() * (&draw);
271        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
272
273        let scaled = scale.as_diagonal() * inner_prod;
274
275        let _ = replace(dest, scaled);
276    }
277
278    fn array_update_variance(
279        &mut self,
280        mean: &mut Self::Vector,
281        variance: &mut Self::Vector,
282        value: &Self::Vector,
283        diff_scale: f64, // 1 / self.count
284    ) {
285        self.arch.dispatch(|| {
286            izip!(
287                mean.try_as_col_major_mut()
288                    .unwrap()
289                    .as_slice_mut()
290                    .iter_mut(),
291                variance
292                    .try_as_col_major_mut()
293                    .unwrap()
294                    .as_slice_mut()
295                    .iter_mut(),
296                value.try_as_col_major().unwrap().as_slice()
297            )
298            .for_each(|(mean, var, x)| {
299                let diff = x - *mean;
300                *mean += diff * diff_scale;
301                *var += diff * diff;
302            });
303        })
304    }
305
306    fn array_update_var_inv_std_draw(
307        &mut self,
308        variance_out: &mut Self::Vector,
309        inv_std: &mut Self::Vector,
310        draw_var: &Self::Vector,
311        scale: f64,
312        fill_invalid: Option<f64>,
313        clamp: (f64, f64),
314    ) {
315        self.arch.dispatch(|| {
316            izip!(
317                variance_out
318                    .try_as_col_major_mut()
319                    .unwrap()
320                    .as_slice_mut()
321                    .iter_mut(),
322                inv_std
323                    .try_as_col_major_mut()
324                    .unwrap()
325                    .as_slice_mut()
326                    .iter_mut(),
327                draw_var.try_as_col_major().unwrap().as_slice().iter(),
328            )
329            .for_each(|(var_out, inv_std_out, &draw_var)| {
330                let draw_var = draw_var * scale;
331                if (!draw_var.is_finite()) | (draw_var == 0f64) {
332                    if let Some(fill_val) = fill_invalid {
333                        *var_out = fill_val;
334                        *inv_std_out = fill_val.recip().sqrt();
335                    }
336                } else {
337                    let val = draw_var.clamp(clamp.0, clamp.1);
338                    *var_out = val;
339                    *inv_std_out = val.recip().sqrt();
340                }
341            });
342        });
343    }
344
345    fn array_update_var_inv_std_draw_grad(
346        &mut self,
347        variance_out: &mut Self::Vector,
348        inv_std: &mut Self::Vector,
349        draw_var: &Self::Vector,
350        grad_var: &Self::Vector,
351        fill_invalid: Option<f64>,
352        clamp: (f64, f64),
353    ) {
354        self.arch.dispatch(|| {
355            izip!(
356                variance_out
357                    .try_as_col_major_mut()
358                    .unwrap()
359                    .as_slice_mut()
360                    .iter_mut(),
361                inv_std
362                    .try_as_col_major_mut()
363                    .unwrap()
364                    .as_slice_mut()
365                    .iter_mut(),
366                draw_var.try_as_col_major().unwrap().as_slice().iter(),
367                grad_var.try_as_col_major().unwrap().as_slice().iter(),
368            )
369            .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
370                let val = (draw_var / grad_var).sqrt();
371                if (!val.is_finite()) | (val == 0f64) {
372                    if let Some(fill_val) = fill_invalid {
373                        *var_out = fill_val;
374                        *inv_std_out = fill_val.recip().sqrt();
375                    }
376                } else {
377                    let val = val.clamp(clamp.0, clamp.1);
378                    *var_out = val;
379                    *inv_std_out = val.recip().sqrt();
380                }
381            });
382        });
383    }
384
385    fn array_update_var_inv_std_grad(
386        &mut self,
387        variance_out: &mut Self::Vector,
388        inv_std: &mut Self::Vector,
389        gradient: &Self::Vector,
390        fill_invalid: f64,
391        clamp: (f64, f64),
392    ) {
393        self.arch.dispatch(|| {
394            izip!(
395                variance_out
396                    .try_as_col_major_mut()
397                    .unwrap()
398                    .as_slice_mut()
399                    .iter_mut(),
400                inv_std
401                    .try_as_col_major_mut()
402                    .unwrap()
403                    .as_slice_mut()
404                    .iter_mut(),
405                gradient.try_as_col_major().unwrap().as_slice().iter(),
406            )
407            .for_each(|(var_out, inv_std_out, &grad_var)| {
408                let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
409                let val = if val.is_finite() { val } else { fill_invalid };
410                *var_out = val;
411                *inv_std_out = val.recip().sqrt();
412            });
413        });
414    }
415
416    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
417        source
418            .try_as_col_major()
419            .unwrap()
420            .as_slice()
421            .to_vec()
422            .into()
423    }
424
425    fn inv_transform_normalize(
426        &mut self,
427        params: &Self::TransformParams,
428        untransformed_position: &Self::Vector,
429        untransofrmed_gradient: &Self::Vector,
430        transformed_position: &mut Self::Vector,
431        transformed_gradient: &mut Self::Vector,
432    ) -> Result<f64, Self::LogpErr> {
433        self.logp_func.inv_transform_normalize(
434            params,
435            untransformed_position
436                .try_as_col_major()
437                .unwrap()
438                .as_slice(),
439            untransofrmed_gradient
440                .try_as_col_major()
441                .unwrap()
442                .as_slice(),
443            transformed_position
444                .try_as_col_major_mut()
445                .unwrap()
446                .as_slice_mut(),
447            transformed_gradient
448                .try_as_col_major_mut()
449                .unwrap()
450                .as_slice_mut(),
451        )
452    }
453
454    fn init_from_untransformed_position(
455        &mut self,
456        params: &Self::TransformParams,
457        untransformed_position: &Self::Vector,
458        untransformed_gradient: &mut Self::Vector,
459        transformed_position: &mut Self::Vector,
460        transformed_gradient: &mut Self::Vector,
461    ) -> Result<(f64, f64), Self::LogpErr> {
462        self.logp_func.init_from_untransformed_position(
463            params,
464            untransformed_position
465                .try_as_col_major()
466                .unwrap()
467                .as_slice(),
468            untransformed_gradient
469                .try_as_col_major_mut()
470                .unwrap()
471                .as_slice_mut(),
472            transformed_position
473                .try_as_col_major_mut()
474                .unwrap()
475                .as_slice_mut(),
476            transformed_gradient
477                .try_as_col_major_mut()
478                .unwrap()
479                .as_slice_mut(),
480        )
481    }
482
483    fn init_from_transformed_position(
484        &mut self,
485        params: &Self::TransformParams,
486        untransformed_position: &mut Self::Vector,
487        untransformed_gradient: &mut Self::Vector,
488        transformed_position: &Self::Vector,
489        transformed_gradient: &mut Self::Vector,
490    ) -> Result<(f64, f64), Self::LogpErr> {
491        self.logp_func.init_from_transformed_position(
492            params,
493            untransformed_position
494                .try_as_col_major_mut()
495                .unwrap()
496                .as_slice_mut(),
497            untransformed_gradient
498                .try_as_col_major_mut()
499                .unwrap()
500                .as_slice_mut(),
501            transformed_position.try_as_col_major().unwrap().as_slice(),
502            transformed_gradient
503                .try_as_col_major_mut()
504                .unwrap()
505                .as_slice_mut(),
506        )
507    }
508
509    fn update_transformation<'a, R: rand::Rng + ?Sized>(
510        &'a mut self,
511        rng: &mut R,
512        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
513        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
514        untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
515        params: &'a mut Self::TransformParams,
516    ) -> Result<(), Self::LogpErr> {
517        self.logp_func.update_transformation(
518            rng,
519            untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
520            untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
521            untransformed_logp,
522            params,
523        )
524    }
525
526    fn new_transformation<R: rand::Rng + ?Sized>(
527        &mut self,
528        rng: &mut R,
529        untransformed_position: &Self::Vector,
530        untransfogmed_gradient: &Self::Vector,
531        chain: u64,
532    ) -> Result<Self::TransformParams, Self::LogpErr> {
533        self.logp_func.new_transformation(
534            rng,
535            untransformed_position
536                .try_as_col_major()
537                .unwrap()
538                .as_slice(),
539            untransfogmed_gradient
540                .try_as_col_major()
541                .unwrap()
542                .as_slice(),
543            chain,
544        )
545    }
546
547    fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
548        self.logp_func.transformation_id(params)
549    }
550}
551
552pub trait CpuLogpFunc {
553    type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
554    type TransformParams;
555
556    fn dim(&self) -> usize;
557    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
558
559    fn inv_transform_normalize(
560        &mut self,
561        _params: &Self::TransformParams,
562        _untransformed_position: &[f64],
563        _untransformed_gradient: &[f64],
564        _transformed_position: &mut [f64],
565        _transformed_gradient: &mut [f64],
566    ) -> Result<f64, Self::LogpError> {
567        unimplemented!()
568    }
569
570    fn init_from_untransformed_position(
571        &mut self,
572        _params: &Self::TransformParams,
573        _untransformed_position: &[f64],
574        _untransformed_gradient: &mut [f64],
575        _transformed_position: &mut [f64],
576        _transformed_gradient: &mut [f64],
577    ) -> Result<(f64, f64), Self::LogpError> {
578        unimplemented!()
579    }
580
581    fn init_from_transformed_position(
582        &mut self,
583        _params: &Self::TransformParams,
584        _untransformed_position: &mut [f64],
585        _untransformed_gradient: &mut [f64],
586        _transformed_position: &[f64],
587        _transformed_gradient: &mut [f64],
588    ) -> Result<(f64, f64), Self::LogpError> {
589        unimplemented!()
590    }
591
592    fn update_transformation<'a, R: rand::Rng + ?Sized>(
593        &'a mut self,
594        _rng: &mut R,
595        _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
596        _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
597        _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
598        _params: &'a mut Self::TransformParams,
599    ) -> Result<(), Self::LogpError> {
600        unimplemented!()
601    }
602
603    fn new_transformation<R: rand::Rng + ?Sized>(
604        &mut self,
605        _rng: &mut R,
606        _untransformed_position: &[f64],
607        _untransformed_gradient: &[f64],
608        _chain: u64,
609    ) -> Result<Self::TransformParams, Self::LogpError> {
610        unimplemented!()
611    }
612
613    fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
614        unimplemented!()
615    }
616}