nuts_rs/
cpu_math.rs

1use std::{error::Error, fmt::Debug, mem::replace};
2
3use faer::{Col, Mat};
4use itertools::{izip, Itertools};
5use thiserror::Error;
6
7use crate::{
8    math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot},
9    math_base::{LogpError, Math},
10};
11
12#[derive(Debug)]
13pub struct CpuMath<F: CpuLogpFunc> {
14    logp_func: F,
15    arch: pulp::Arch,
16}
17
18impl<F: CpuLogpFunc> CpuMath<F> {
19    pub fn new(logp_func: F) -> Self {
20        let arch = pulp::Arch::new();
21        Self { logp_func, arch }
22    }
23}
24
25#[non_exhaustive]
26#[derive(Error, Debug)]
27pub enum CpuMathError {
28    #[error("Error during array operation")]
29    ArrayError(),
30}
31
32impl<F: CpuLogpFunc> Math for CpuMath<F> {
33    type Vector = Col<f64>;
34    type EigVectors = Mat<f64>;
35    type EigValues = Col<f64>;
36    type LogpErr = F::LogpError;
37    type Err = CpuMathError;
38    type TransformParams = F::TransformParams;
39
40    fn new_array(&mut self) -> Self::Vector {
41        Col::zeros(self.dim())
42    }
43
44    fn new_eig_vectors<'a>(
45        &'a mut self,
46        vals: impl ExactSizeIterator<Item = &'a [f64]>,
47    ) -> Self::EigVectors {
48        let ndim = self.dim();
49        let nvecs = vals.len();
50
51        let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
52        vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
53            col.try_as_col_major_mut()
54                .expect("Array is not contiguous")
55                .as_slice_mut()
56                .copy_from_slice(vals)
57        });
58
59        vectors
60    }
61
62    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
63        let mut values: Col<f64> = Col::zeros(vals.len());
64        values
65            .try_as_col_major_mut()
66            .expect("Array is not contiguous")
67            .as_slice_mut()
68            .copy_from_slice(vals);
69        values
70    }
71
72    fn logp_array(
73        &mut self,
74        position: &Self::Vector,
75        gradient: &mut Self::Vector,
76    ) -> Result<f64, Self::LogpErr> {
77        self.logp_func.logp(
78            position
79                .try_as_col_major()
80                .expect("Array is not contiguous")
81                .as_slice(),
82            gradient
83                .try_as_col_major_mut()
84                .expect("Array is not contiguous")
85                .as_slice_mut(),
86        )
87    }
88
89    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
90        self.logp_func.logp(position, gradient)
91    }
92
93    fn dim(&self) -> usize {
94        self.logp_func.dim()
95    }
96
97    fn scalar_prods3(
98        &mut self,
99        positive1: &Self::Vector,
100        negative1: &Self::Vector,
101        positive2: &Self::Vector,
102        x: &Self::Vector,
103        y: &Self::Vector,
104    ) -> (f64, f64) {
105        scalar_prods3(
106            positive1.try_as_col_major().unwrap().as_slice(),
107            negative1.try_as_col_major().unwrap().as_slice(),
108            positive2.try_as_col_major().unwrap().as_slice(),
109            x.try_as_col_major().unwrap().as_slice(),
110            y.try_as_col_major().unwrap().as_slice(),
111        )
112    }
113
114    fn scalar_prods2(
115        &mut self,
116        positive1: &Self::Vector,
117        positive2: &Self::Vector,
118        x: &Self::Vector,
119        y: &Self::Vector,
120    ) -> (f64, f64) {
121        scalar_prods2(
122            positive1.try_as_col_major().unwrap().as_slice(),
123            positive2.try_as_col_major().unwrap().as_slice(),
124            x.try_as_col_major().unwrap().as_slice(),
125            y.try_as_col_major().unwrap().as_slice(),
126        )
127    }
128
129    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
130        x.try_as_col_major()
131            .unwrap()
132            .as_slice()
133            .iter()
134            .zip(y.try_as_col_major().unwrap().as_slice())
135            .map(|(&x, &y)| (x + y) * (x + y))
136            .sum()
137    }
138
139    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
140        dest.try_as_col_major_mut()
141            .unwrap()
142            .as_slice_mut()
143            .copy_from_slice(source);
144    }
145
146    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
147        dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
148    }
149
150    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
151        dest.clone_from(array)
152    }
153
154    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
155        axpy_out(
156            x.try_as_col_major().unwrap().as_slice(),
157            y.try_as_col_major().unwrap().as_slice(),
158            a,
159            out.try_as_col_major_mut().unwrap().as_slice_mut(),
160        );
161    }
162
163    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
164        axpy(
165            x.try_as_col_major().unwrap().as_slice(),
166            y.try_as_col_major_mut().unwrap().as_slice_mut(),
167            a,
168        );
169    }
170
171    fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
172        faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
173    }
174
175    fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
176        let mut ok = true;
177        faer::zip!(array).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
178        ok
179    }
180
181    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
182        self.arch.dispatch(|| {
183            array
184                .try_as_col_major()
185                .unwrap()
186                .as_slice()
187                .iter()
188                .all(|&x| x.is_finite() & (x != 0f64))
189        })
190    }
191
192    fn array_mult(
193        &mut self,
194        array1: &Self::Vector,
195        array2: &Self::Vector,
196        dest: &mut Self::Vector,
197    ) {
198        multiply(
199            array1.try_as_col_major().unwrap().as_slice(),
200            array2.try_as_col_major().unwrap().as_slice(),
201            dest.try_as_col_major_mut().unwrap().as_slice_mut(),
202        )
203    }
204
205    fn array_mult_eigs(
206        &mut self,
207        stds: &Self::Vector,
208        rhs: &Self::Vector,
209        dest: &mut Self::Vector,
210        vecs: &Self::EigVectors,
211        vals: &Self::EigValues,
212    ) {
213        let rhs = stds.as_diagonal() * rhs;
214        let trafo = vecs.transpose() * (&rhs);
215        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
216        let scaled = stds.as_diagonal() * inner_prod;
217
218        let _ = replace(dest, scaled);
219    }
220
221    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
222        vector_dot(
223            array1.try_as_col_major().unwrap().as_slice(),
224            array2.try_as_col_major().unwrap().as_slice(),
225        )
226    }
227
228    fn array_gaussian<R: rand::Rng + ?Sized>(
229        &mut self,
230        rng: &mut R,
231        dest: &mut Self::Vector,
232        stds: &Self::Vector,
233    ) {
234        let dist = rand_distr::StandardNormal;
235        dest.try_as_col_major_mut()
236            .unwrap()
237            .as_slice_mut()
238            .iter_mut()
239            .zip(stds.try_as_col_major().unwrap().as_slice().iter())
240            .for_each(|(p, &s)| {
241                let norm: f64 = rng.sample(dist);
242                *p = s * norm;
243            });
244    }
245
246    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
247        &mut self,
248        rng: &mut R,
249        dest: &mut Self::Vector,
250        scale: &Self::Vector,
251        vals: &Self::EigValues,
252        vecs: &Self::EigVectors,
253    ) {
254        let mut draw: Col<f64> = Col::zeros(self.dim());
255        let dist = rand_distr::StandardNormal;
256        draw.try_as_col_major_mut()
257            .unwrap()
258            .as_slice_mut()
259            .iter_mut()
260            .for_each(|p| {
261                *p = rng.sample(dist);
262            });
263
264        let trafo = vecs.transpose() * (&draw);
265        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
266
267        let scaled = scale.as_diagonal() * inner_prod;
268
269        let _ = replace(dest, scaled);
270    }
271
272    fn array_update_variance(
273        &mut self,
274        mean: &mut Self::Vector,
275        variance: &mut Self::Vector,
276        value: &Self::Vector,
277        diff_scale: f64, // 1 / self.count
278    ) {
279        self.arch.dispatch(|| {
280            izip!(
281                mean.try_as_col_major_mut()
282                    .unwrap()
283                    .as_slice_mut()
284                    .iter_mut(),
285                variance
286                    .try_as_col_major_mut()
287                    .unwrap()
288                    .as_slice_mut()
289                    .iter_mut(),
290                value.try_as_col_major().unwrap().as_slice()
291            )
292            .for_each(|(mean, var, x)| {
293                let diff = x - *mean;
294                *mean += diff * diff_scale;
295                *var += diff * diff;
296            });
297        })
298    }
299
300    fn array_update_var_inv_std_draw(
301        &mut self,
302        variance_out: &mut Self::Vector,
303        inv_std: &mut Self::Vector,
304        draw_var: &Self::Vector,
305        scale: f64,
306        fill_invalid: Option<f64>,
307        clamp: (f64, f64),
308    ) {
309        self.arch.dispatch(|| {
310            izip!(
311                variance_out
312                    .try_as_col_major_mut()
313                    .unwrap()
314                    .as_slice_mut()
315                    .iter_mut(),
316                inv_std
317                    .try_as_col_major_mut()
318                    .unwrap()
319                    .as_slice_mut()
320                    .iter_mut(),
321                draw_var.try_as_col_major().unwrap().as_slice().iter(),
322            )
323            .for_each(|(var_out, inv_std_out, &draw_var)| {
324                let draw_var = draw_var * scale;
325                if (!draw_var.is_finite()) | (draw_var == 0f64) {
326                    if let Some(fill_val) = fill_invalid {
327                        *var_out = fill_val;
328                        *inv_std_out = fill_val.recip().sqrt();
329                    }
330                } else {
331                    let val = draw_var.clamp(clamp.0, clamp.1);
332                    *var_out = val;
333                    *inv_std_out = val.recip().sqrt();
334                }
335            });
336        });
337    }
338
339    fn array_update_var_inv_std_draw_grad(
340        &mut self,
341        variance_out: &mut Self::Vector,
342        inv_std: &mut Self::Vector,
343        draw_var: &Self::Vector,
344        grad_var: &Self::Vector,
345        fill_invalid: Option<f64>,
346        clamp: (f64, f64),
347    ) {
348        self.arch.dispatch(|| {
349            izip!(
350                variance_out
351                    .try_as_col_major_mut()
352                    .unwrap()
353                    .as_slice_mut()
354                    .iter_mut(),
355                inv_std
356                    .try_as_col_major_mut()
357                    .unwrap()
358                    .as_slice_mut()
359                    .iter_mut(),
360                draw_var.try_as_col_major().unwrap().as_slice().iter(),
361                grad_var.try_as_col_major().unwrap().as_slice().iter(),
362            )
363            .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
364                let val = (draw_var / grad_var).sqrt();
365                if (!val.is_finite()) | (val == 0f64) {
366                    if let Some(fill_val) = fill_invalid {
367                        *var_out = fill_val;
368                        *inv_std_out = fill_val.recip().sqrt();
369                    }
370                } else {
371                    let val = val.clamp(clamp.0, clamp.1);
372                    *var_out = val;
373                    *inv_std_out = val.recip().sqrt();
374                }
375            });
376        });
377    }
378
379    fn array_update_var_inv_std_grad(
380        &mut self,
381        variance_out: &mut Self::Vector,
382        inv_std: &mut Self::Vector,
383        gradient: &Self::Vector,
384        fill_invalid: f64,
385        clamp: (f64, f64),
386    ) {
387        self.arch.dispatch(|| {
388            izip!(
389                variance_out
390                    .try_as_col_major_mut()
391                    .unwrap()
392                    .as_slice_mut()
393                    .iter_mut(),
394                inv_std
395                    .try_as_col_major_mut()
396                    .unwrap()
397                    .as_slice_mut()
398                    .iter_mut(),
399                gradient.try_as_col_major().unwrap().as_slice().iter(),
400            )
401            .for_each(|(var_out, inv_std_out, &grad_var)| {
402                let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
403                let val = if val.is_finite() { val } else { fill_invalid };
404                *var_out = val;
405                *inv_std_out = val.recip().sqrt();
406            });
407        });
408    }
409
410    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
411        source
412            .try_as_col_major()
413            .unwrap()
414            .as_slice()
415            .to_vec()
416            .into()
417    }
418
419    fn inv_transform_normalize(
420        &mut self,
421        params: &Self::TransformParams,
422        untransformed_position: &Self::Vector,
423        untransofrmed_gradient: &Self::Vector,
424        transformed_position: &mut Self::Vector,
425        transformed_gradient: &mut Self::Vector,
426    ) -> Result<f64, Self::LogpErr> {
427        self.logp_func.inv_transform_normalize(
428            params,
429            untransformed_position
430                .try_as_col_major()
431                .unwrap()
432                .as_slice(),
433            untransofrmed_gradient
434                .try_as_col_major()
435                .unwrap()
436                .as_slice(),
437            transformed_position
438                .try_as_col_major_mut()
439                .unwrap()
440                .as_slice_mut(),
441            transformed_gradient
442                .try_as_col_major_mut()
443                .unwrap()
444                .as_slice_mut(),
445        )
446    }
447
448    fn init_from_untransformed_position(
449        &mut self,
450        params: &Self::TransformParams,
451        untransformed_position: &Self::Vector,
452        untransformed_gradient: &mut Self::Vector,
453        transformed_position: &mut Self::Vector,
454        transformed_gradient: &mut Self::Vector,
455    ) -> Result<(f64, f64), Self::LogpErr> {
456        self.logp_func.init_from_untransformed_position(
457            params,
458            untransformed_position
459                .try_as_col_major()
460                .unwrap()
461                .as_slice(),
462            untransformed_gradient
463                .try_as_col_major_mut()
464                .unwrap()
465                .as_slice_mut(),
466            transformed_position
467                .try_as_col_major_mut()
468                .unwrap()
469                .as_slice_mut(),
470            transformed_gradient
471                .try_as_col_major_mut()
472                .unwrap()
473                .as_slice_mut(),
474        )
475    }
476
477    fn init_from_transformed_position(
478        &mut self,
479        params: &Self::TransformParams,
480        untransformed_position: &mut Self::Vector,
481        untransformed_gradient: &mut Self::Vector,
482        transformed_position: &Self::Vector,
483        transformed_gradient: &mut Self::Vector,
484    ) -> Result<(f64, f64), Self::LogpErr> {
485        self.logp_func.init_from_transformed_position(
486            params,
487            untransformed_position
488                .try_as_col_major_mut()
489                .unwrap()
490                .as_slice_mut(),
491            untransformed_gradient
492                .try_as_col_major_mut()
493                .unwrap()
494                .as_slice_mut(),
495            transformed_position.try_as_col_major().unwrap().as_slice(),
496            transformed_gradient
497                .try_as_col_major_mut()
498                .unwrap()
499                .as_slice_mut(),
500        )
501    }
502
503    fn update_transformation<'a, R: rand::Rng + ?Sized>(
504        &'a mut self,
505        rng: &mut R,
506        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
507        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
508        untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
509        params: &'a mut Self::TransformParams,
510    ) -> Result<(), Self::LogpErr> {
511        self.logp_func.update_transformation(
512            rng,
513            untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
514            untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
515            untransformed_logp,
516            params,
517        )
518    }
519
520    fn new_transformation<R: rand::Rng + ?Sized>(
521        &mut self,
522        rng: &mut R,
523        untransformed_position: &Self::Vector,
524        untransfogmed_gradient: &Self::Vector,
525        chain: u64,
526    ) -> Result<Self::TransformParams, Self::LogpErr> {
527        self.logp_func.new_transformation(
528            rng,
529            untransformed_position
530                .try_as_col_major()
531                .unwrap()
532                .as_slice(),
533            untransfogmed_gradient
534                .try_as_col_major()
535                .unwrap()
536                .as_slice(),
537            chain,
538        )
539    }
540
541    fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
542        self.logp_func.transformation_id(params)
543    }
544}
545
546pub trait CpuLogpFunc {
547    type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
548    type TransformParams;
549
550    fn dim(&self) -> usize;
551    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
552
553    fn inv_transform_normalize(
554        &mut self,
555        _params: &Self::TransformParams,
556        _untransformed_position: &[f64],
557        _untransformed_gradient: &[f64],
558        _transformed_position: &mut [f64],
559        _transformed_gradient: &mut [f64],
560    ) -> Result<f64, Self::LogpError> {
561        unimplemented!()
562    }
563
564    fn init_from_untransformed_position(
565        &mut self,
566        _params: &Self::TransformParams,
567        _untransformed_position: &[f64],
568        _untransformed_gradient: &mut [f64],
569        _transformed_position: &mut [f64],
570        _transformed_gradient: &mut [f64],
571    ) -> Result<(f64, f64), Self::LogpError> {
572        unimplemented!()
573    }
574
575    fn init_from_transformed_position(
576        &mut self,
577        _params: &Self::TransformParams,
578        _untransformed_position: &mut [f64],
579        _untransformed_gradient: &mut [f64],
580        _transformed_position: &[f64],
581        _transformed_gradient: &mut [f64],
582    ) -> Result<(f64, f64), Self::LogpError> {
583        unimplemented!()
584    }
585
586    fn update_transformation<'a, R: rand::Rng + ?Sized>(
587        &'a mut self,
588        _rng: &mut R,
589        _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
590        _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
591        _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
592        _params: &'a mut Self::TransformParams,
593    ) -> Result<(), Self::LogpError> {
594        unimplemented!()
595    }
596
597    fn new_transformation<R: rand::Rng + ?Sized>(
598        &mut self,
599        _rng: &mut R,
600        _untransformed_position: &[f64],
601        _untransformed_gradient: &[f64],
602        _chain: u64,
603    ) -> Result<Self::TransformParams, Self::LogpError> {
604        unimplemented!()
605    }
606
607    fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
608        unimplemented!()
609    }
610}