Skip to main content

nuts_rs/math/
cpu_math.rs

1//! CPU backend that calls the user-supplied logp function and provides the required vector operations.
2
3use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
4
5use faer::linalg::matmul::matmul;
6
7use faer::{Accum, Col, Mat, Par};
8use itertools::{Itertools, izip};
9use nuts_storable::{HasDims, Storable, Value};
10use rand::RngExt;
11use thiserror::Error;
12
13use crate::math::util::multiply_inplace;
14
15use super::{
16    math::{LogpError, Math},
17    util::{
18        axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, std_norm_flow, std_norm_grad_flow,
19        std_norm_grad_flow_inplace, vector_dot,
20    },
21};
22
23#[derive(Debug)]
24pub struct CpuMath<F: CpuLogpFunc> {
25    logp_func: F,
26    arch: pulp::Arch,
27    /// Preallocated scratch buffer for the low-rank transform intermediate vector
28    /// (U^T * rhs), sized to the current rank (vecs.ncols()). Resized as needed.
29    lowrank_scratch: Col<f64>,
30}
31
32impl<F: CpuLogpFunc> CpuMath<F> {
33    pub fn new(logp_func: F) -> Self {
34        let arch = pulp::Arch::new();
35        Self {
36            logp_func,
37            arch,
38            lowrank_scratch: Col::zeros(0),
39        }
40    }
41}
42
43#[non_exhaustive]
44#[derive(Error, Debug)]
45pub enum CpuMathError {
46    #[error("Error during array operation")]
47    ArrayError(),
48    #[error("Error during point expansion: {0}")]
49    ExpandError(String),
50}
51
52impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
53    fn dim_sizes(&self) -> HashMap<String, u64> {
54        self.logp_func.dim_sizes()
55    }
56
57    fn coords(&self) -> HashMap<String, nuts_storable::Value> {
58        self.logp_func.coords()
59    }
60}
61
62pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
63
64impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
65    fn names(parent: &CpuMath<F>) -> Vec<&str> {
66        F::ExpandedVector::names(&parent.logp_func)
67    }
68
69    fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
70        F::ExpandedVector::item_type(&parent.logp_func, item)
71    }
72
73    fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
74        F::ExpandedVector::dims(&parent.logp_func, item)
75    }
76
77    fn get_all<'a>(
78        &'a mut self,
79        parent: &'a CpuMath<F>,
80    ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
81        self.0.get_all(&parent.logp_func)
82    }
83}
84
85impl<F: CpuLogpFunc> Math for CpuMath<F> {
86    type Vector = Col<f64>;
87    type EigVectors = Mat<f64>;
88    type EigValues = Col<f64>;
89    type LogpErr = F::LogpError;
90    type Err = CpuMathError;
91    type FlowParameters = F::FlowParameters;
92    type ExpandedVector = ExpandedVectorWrapper<F>;
93
94    fn new_array(&mut self) -> Self::Vector {
95        Col::zeros(self.dim())
96    }
97
98    fn new_eig_vectors<'a>(
99        &'a mut self,
100        vals: impl ExactSizeIterator<Item = &'a [f64]>,
101    ) -> Self::EigVectors {
102        let ndim = self.dim();
103        let nvecs = vals.len();
104
105        let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
106        vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
107            col.try_as_col_major_mut()
108                .expect("Array is not contiguous")
109                .as_slice_mut()
110                .copy_from_slice(vals)
111        });
112
113        vectors
114    }
115
116    fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
117        let mut values: Col<f64> = Col::zeros(vals.len());
118        values
119            .try_as_col_major_mut()
120            .expect("Array is not contiguous")
121            .as_slice_mut()
122            .copy_from_slice(vals);
123        values
124    }
125
126    fn logp_array(
127        &mut self,
128        position: &Self::Vector,
129        gradient: &mut Self::Vector,
130    ) -> Result<f64, Self::LogpErr> {
131        self.logp_func.logp(
132            position
133                .try_as_col_major()
134                .expect("Array is not contiguous")
135                .as_slice(),
136            gradient
137                .try_as_col_major_mut()
138                .expect("Array is not contiguous")
139                .as_slice_mut(),
140        )
141    }
142
143    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
144        self.logp_func.logp(position, gradient)
145    }
146
147    fn dim(&self) -> usize {
148        self.logp_func.dim()
149    }
150
151    fn expand_vector<R: rand::Rng + ?Sized>(
152        &mut self,
153        rng: &mut R,
154        array: &Self::Vector,
155    ) -> Result<Self::ExpandedVector, Self::Err> {
156        Ok(ExpandedVectorWrapper(
157            self.logp_func.expand_vector(
158                rng,
159                array
160                    .try_as_col_major()
161                    .ok_or_else(|| {
162                        CpuMathError::ExpandError("Internal vector was not col major".into())
163                    })?
164                    .as_slice(),
165            )?,
166        ))
167    }
168
169    fn vector_coord(&self) -> Option<Value> {
170        self.logp_func.vector_coord()
171    }
172
173    fn init_position<R: rand::Rng + ?Sized>(
174        &mut self,
175        rng: &mut R,
176        position: &mut Self::Vector,
177        gradient: &mut Self::Vector,
178    ) -> Result<f64, Self::LogpErr> {
179        let pos = position
180            .try_as_col_major_mut()
181            .expect("Array is not contiguous")
182            .as_slice_mut();
183
184        pos.iter_mut().for_each(|x| {
185            let val: f64 = rng.random();
186            *x = val * 2f64 - 1f64
187        });
188
189        self.logp_func.logp(
190            position
191                .try_as_col_major()
192                .expect("Array is not contiguous")
193                .as_slice(),
194            gradient
195                .try_as_col_major_mut()
196                .expect("Array is not contiguous")
197                .as_slice_mut(),
198        )
199    }
200
201    fn scalar_prods3(
202        &mut self,
203        positive1: &Self::Vector,
204        negative1: &Self::Vector,
205        positive2: &Self::Vector,
206        x: &Self::Vector,
207        y: &Self::Vector,
208    ) -> (f64, f64) {
209        scalar_prods3(
210            self.arch,
211            positive1.try_as_col_major().unwrap().as_slice(),
212            negative1.try_as_col_major().unwrap().as_slice(),
213            positive2.try_as_col_major().unwrap().as_slice(),
214            x.try_as_col_major().unwrap().as_slice(),
215            y.try_as_col_major().unwrap().as_slice(),
216        )
217    }
218
219    fn scalar_prods2(
220        &mut self,
221        positive1: &Self::Vector,
222        positive2: &Self::Vector,
223        x: &Self::Vector,
224        y: &Self::Vector,
225    ) -> (f64, f64) {
226        scalar_prods2(
227            self.arch,
228            positive1.try_as_col_major().unwrap().as_slice(),
229            positive2.try_as_col_major().unwrap().as_slice(),
230            x.try_as_col_major().unwrap().as_slice(),
231            y.try_as_col_major().unwrap().as_slice(),
232        )
233    }
234
235    fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
236        x.try_as_col_major()
237            .unwrap()
238            .as_slice()
239            .iter()
240            .zip(y.try_as_col_major().unwrap().as_slice())
241            .map(|(&x, &y)| (x + y) * (x + y))
242            .sum()
243    }
244
245    fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
246        dest.try_as_col_major_mut()
247            .unwrap()
248            .as_slice_mut()
249            .copy_from_slice(source);
250    }
251
252    fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
253        dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
254    }
255
256    fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
257        dest.clone_from(array)
258    }
259
260    fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
261        axpy_out(
262            self.arch,
263            x.try_as_col_major().unwrap().as_slice(),
264            y.try_as_col_major().unwrap().as_slice(),
265            a,
266            out.try_as_col_major_mut().unwrap().as_slice_mut(),
267        );
268    }
269
270    fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
271        axpy(
272            self.arch,
273            x.try_as_col_major().unwrap().as_slice(),
274            y.try_as_col_major_mut().unwrap().as_slice_mut(),
275            a,
276        );
277    }
278
279    fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
280        faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
281    }
282
283    fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
284        let mut ok = true;
285        faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
286        ok
287    }
288
289    fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
290        self.arch.dispatch(|| {
291            array
292                .try_as_col_major()
293                .unwrap()
294                .as_slice()
295                .iter()
296                .all(|&x| x.is_finite() & (x != 0f64))
297        })
298    }
299
300    fn array_sum_ln(&mut self, array: &Self::Vector) -> f64 {
301        let mut sum = 0f64;
302        faer::zip!(array).for_each(|faer::unzip!(val)| sum += val.ln());
303        sum
304    }
305
306    fn array_mult(
307        &mut self,
308        array1: &Self::Vector,
309        array2: &Self::Vector,
310        dest: &mut Self::Vector,
311    ) {
312        multiply(
313            self.arch,
314            array1.try_as_col_major().unwrap().as_slice(),
315            array2.try_as_col_major().unwrap().as_slice(),
316            dest.try_as_col_major_mut().unwrap().as_slice_mut(),
317        )
318    }
319
320    fn array_mult_inplace(&mut self, array1: &mut Self::Vector, array2: &Self::Vector) {
321        multiply_inplace(
322            self.arch,
323            array1.try_as_col_major_mut().unwrap().as_slice_mut(),
324            array2.try_as_col_major().unwrap().as_slice(),
325        )
326    }
327
328    fn array_recip(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
329        faer::zip!(array, dest).for_each(|faer::unzip!(val, dest)| *dest = val.recip())
330    }
331
332    fn apply_lowrank_transform(
333        &mut self,
334        vecs: &Self::EigVectors,
335        vals: &Self::EigValues,
336        rhs: &Self::Vector,
337        dest: &mut Self::Vector,
338    ) {
339        if vecs.ncols() == 0 {
340            self.copy_into(rhs, dest);
341            return;
342        }
343        // dest = (I + U * (diag(vals) - I) * U^T) * rhs
344        //      = rhs + U * (diag(vals) - I) * (U^T * rhs)
345
346        let rank = vecs.ncols();
347
348        // Resize scratch if needed (rank can change across calls)
349        if self.lowrank_scratch.nrows() != rank {
350            self.lowrank_scratch.resize_with(rank, |_| 0.0);
351        }
352
353        // scratch = U^T * rhs
354        matmul(
355            self.lowrank_scratch.as_mut(),
356            Accum::Replace,
357            vecs.transpose(),
358            rhs.as_ref(),
359            1.0,
360            Par::Seq,
361        );
362
363        // scratch = (diag(vals) - I) * scratch  (element-wise: scratch[i] *= vals[i] - 1)
364        self.lowrank_scratch
365            .iter_mut()
366            .zip(vals.iter())
367            .for_each(|(s, &v)| *s *= v - 1.0);
368
369        // dest = rhs + U * scratch
370        dest.copy_from(rhs);
371        matmul(
372            dest.as_mut(),
373            Accum::Add,
374            vecs.as_ref(),
375            self.lowrank_scratch.as_ref(),
376            1.0,
377            Par::Seq,
378        );
379    }
380
381    fn apply_lowrank_transform_inplace(
382        &mut self,
383        vecs: &Self::EigVectors,
384        vals: &Self::EigValues,
385        rhs_and_dest: &mut Self::Vector,
386    ) {
387        if vecs.ncols() == 0 {
388            return;
389        }
390        // rhs_and_dest = (I + U * (diag(vals) - I) * U^T) * rhs_and_dest
391        //              = rhs_and_dest + U * (diag(vals) - I) * (U^T * rhs_and_dest)
392
393        let rank = vecs.ncols();
394
395        // Resize scratch if needed
396        if self.lowrank_scratch.nrows() != rank {
397            self.lowrank_scratch.resize_with(rank, |_| 0.0);
398        }
399
400        // scratch = U^T * rhs_and_dest
401        matmul(
402            self.lowrank_scratch.as_mut(),
403            Accum::Replace,
404            vecs.transpose(),
405            rhs_and_dest.as_ref(),
406            1.0,
407            Par::Seq,
408        );
409
410        // scratch = (diag(vals) - I) * scratch  (element-wise: scratch[i] *= vals[i] - 1)
411        self.lowrank_scratch
412            .iter_mut()
413            .zip(vals.iter())
414            .for_each(|(s, &v)| *s *= v - 1.0);
415
416        // rhs_and_dest += U * scratch
417        matmul(
418            rhs_and_dest.as_mut(),
419            Accum::Add,
420            vecs.as_ref(),
421            self.lowrank_scratch.as_ref(),
422            1.0,
423            Par::Seq,
424        );
425    }
426
427    fn array_mult_eigs(
428        &mut self,
429        stds: &Self::Vector,
430        rhs: &Self::Vector,
431        dest: &mut Self::Vector,
432        vecs: &Self::EigVectors,
433        vals: &Self::EigValues,
434    ) {
435        let rhs = stds.as_diagonal() * rhs;
436        let trafo = vecs.transpose() * (&rhs);
437        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
438        let scaled = stds.as_diagonal() * inner_prod;
439
440        let _ = replace(dest, scaled);
441    }
442
443    /// The exponential map of the Hamiltonian flow for the standard normal distribution.
444    ///
445    /// This is the harmonic oscillator with unit mass and unit frequency.
446    fn std_norm_flow(
447        &mut self,
448        pos: &Self::Vector,
449        pos_out: &mut Self::Vector,
450        vel: &mut Self::Vector,
451        epsilon: f64,
452    ) {
453        std_norm_flow(
454            self.arch,
455            pos.try_as_col_major().unwrap().as_slice(),
456            pos_out.try_as_col_major_mut().unwrap().as_slice_mut(),
457            vel.try_as_col_major_mut().unwrap().as_slice_mut(),
458            epsilon,
459        );
460    }
461
462    fn std_norm_grad_flow(
463        &mut self,
464        pos: &Self::Vector,
465        grad: &Self::Vector,
466        vel: &Self::Vector,
467        vel_out: &mut Self::Vector,
468        epsilon: f64,
469    ) {
470        std_norm_grad_flow(
471            self.arch,
472            pos.try_as_col_major().unwrap().as_slice(),
473            grad.try_as_col_major().unwrap().as_slice(),
474            vel.try_as_col_major().unwrap().as_slice(),
475            vel_out.try_as_col_major_mut().unwrap().as_slice_mut(),
476            epsilon,
477        );
478    }
479
480    fn std_norm_grad_flow_inplace(
481        &mut self,
482        pos: &Self::Vector,
483        grad: &Self::Vector,
484        vel: &mut Self::Vector,
485        epsilon: f64,
486    ) {
487        std_norm_grad_flow_inplace(
488            self.arch,
489            pos.try_as_col_major().unwrap().as_slice(),
490            grad.try_as_col_major().unwrap().as_slice(),
491            vel.try_as_col_major_mut().unwrap().as_slice_mut(),
492            epsilon,
493        );
494    }
495
496    fn array_normalize(&mut self, v: &mut Self::Vector) {
497        let v = v.try_as_col_major_mut().unwrap().as_slice_mut();
498        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
499        let inv = 1.0 / norm;
500        for x in v.iter_mut() {
501            *x *= inv;
502        }
503    }
504
505    fn esh_momentum_update(
506        &mut self,
507        gradient: &Self::Vector,
508        momentum: &mut Self::Vector,
509        step_size: f64,
510    ) -> f64 {
511        let gradient = gradient.try_as_col_major().unwrap().as_slice();
512        let momentum = momentum.try_as_col_major_mut().unwrap().as_slice_mut();
513        let n = gradient.len();
514        assert!(n >= 2, "ESH dynamics requires at least 2 dimensions");
515
516        // ‖g‖
517        let grad_norm: f64 = gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
518
519        let inv_grad_norm = 1.0 / grad_norm;
520
521        // α = p · ĝ
522        let momentum_proj: f64 = momentum
523            .iter()
524            .zip(gradient.iter())
525            .map(|(p, g)| p * g * inv_grad_norm)
526            .sum();
527
528        let dims_m1 = (n - 1) as f64;
529        let delta = step_size * grad_norm / dims_m1;
530        let zeta = (-delta).exp();
531
532        // p_raw = ĝ · (1 − ζ)(1 + ζ + α(1 − ζ))  +  2ζ p
533        let coeff_g = (1.0 - zeta) * (1.0 + zeta + momentum_proj * (1.0 - zeta));
534        let coeff_p = 2.0 * zeta;
535
536        for (p, g) in momentum.iter_mut().zip(gradient.iter()) {
537            *p = coeff_g * (g * inv_grad_norm) + coeff_p * *p;
538        }
539
540        // Renormalise to unit sphere.
541        let raw_norm: f64 = momentum.iter().map(|p| p * p).sum::<f64>().sqrt();
542        let inv = 1.0 / raw_norm;
543        for p in momentum.iter_mut() {
544            *p *= inv;
545        }
546
547        let arg = momentum_proj + (1.0 - momentum_proj) * zeta * zeta;
548        let kinetic_energy_change = (delta - std::f64::consts::LN_2 + arg.ln_1p()) * dims_m1;
549
550        kinetic_energy_change
551    }
552
553    fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
554        vector_dot(
555            self.arch,
556            array1.try_as_col_major().unwrap().as_slice(),
557            array2.try_as_col_major().unwrap().as_slice(),
558        )
559    }
560
561    fn array_gaussian<R: rand::Rng + ?Sized>(
562        &mut self,
563        rng: &mut R,
564        dest: &mut Self::Vector,
565        stds: &Self::Vector,
566    ) {
567        let dist = rand_distr::StandardNormal;
568        dest.try_as_col_major_mut()
569            .unwrap()
570            .as_slice_mut()
571            .iter_mut()
572            .zip(stds.try_as_col_major().unwrap().as_slice().iter())
573            .for_each(|(p, &s)| {
574                let norm: f64 = rng.sample(dist);
575                *p = s * norm;
576            });
577    }
578
579    fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
580        &mut self,
581        rng: &mut R,
582        dest: &mut Self::Vector,
583        scale: &Self::Vector,
584        vals: &Self::EigValues,
585        vecs: &Self::EigVectors,
586    ) {
587        let mut draw: Col<f64> = Col::zeros(self.dim());
588        let dist = rand_distr::StandardNormal;
589        draw.try_as_col_major_mut()
590            .unwrap()
591            .as_slice_mut()
592            .iter_mut()
593            .for_each(|p| {
594                *p = rng.sample(dist);
595            });
596
597        let trafo = vecs.transpose() * (&draw);
598        let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
599
600        let scaled = scale.as_diagonal() * inner_prod;
601
602        let _ = replace(dest, scaled);
603    }
604
605    fn array_update_variance(
606        &mut self,
607        mean: &mut Self::Vector,
608        variance: &mut Self::Vector,
609        value: &Self::Vector,
610        diff_scale: f64, // 1 / self.count
611    ) {
612        self.arch.dispatch(|| {
613            izip!(
614                mean.try_as_col_major_mut()
615                    .unwrap()
616                    .as_slice_mut()
617                    .iter_mut(),
618                variance
619                    .try_as_col_major_mut()
620                    .unwrap()
621                    .as_slice_mut()
622                    .iter_mut(),
623                value.try_as_col_major().unwrap().as_slice()
624            )
625            .for_each(|(mean, var, x)| {
626                let diff = x - *mean;
627                *mean += diff * diff_scale;
628                *var += diff * diff;
629            });
630        })
631    }
632
633    fn array_update_var_inv_std_draw(
634        &mut self,
635        inv_std: &mut Self::Vector,
636        std: &mut Self::Vector,
637        draw_var: &Self::Vector,
638        scale: f64,
639        fill_invalid: Option<f64>,
640        clamp: (f64, f64),
641    ) {
642        self.arch.dispatch(|| {
643            izip!(
644                std.try_as_col_major_mut()
645                    .unwrap()
646                    .as_slice_mut()
647                    .iter_mut(),
648                inv_std
649                    .try_as_col_major_mut()
650                    .unwrap()
651                    .as_slice_mut()
652                    .iter_mut(),
653                draw_var.try_as_col_major().unwrap().as_slice().iter(),
654            )
655            .for_each(|(std_out, inv_std_out, &draw_var)| {
656                let draw_var = draw_var * scale;
657                if (!draw_var.is_finite()) | (draw_var == 0f64) {
658                    if let Some(fill_val) = fill_invalid {
659                        *std_out = fill_val.sqrt();
660                        *inv_std_out = fill_val.recip().sqrt();
661                    }
662                } else {
663                    let val = draw_var.clamp(clamp.0, clamp.1);
664                    *std_out = val.sqrt();
665                    *inv_std_out = val.recip().sqrt();
666                }
667            });
668        });
669    }
670
671    fn array_update_var_inv_std_draw_grad(
672        &mut self,
673        inv_std: &mut Self::Vector,
674        std: &mut Self::Vector,
675        draw_var: &Self::Vector,
676        grad_var: &Self::Vector,
677        fill_invalid: Option<f64>,
678        clamp: (f64, f64),
679    ) {
680        self.arch.dispatch(|| {
681            izip!(
682                std.try_as_col_major_mut()
683                    .unwrap()
684                    .as_slice_mut()
685                    .iter_mut(),
686                inv_std
687                    .try_as_col_major_mut()
688                    .unwrap()
689                    .as_slice_mut()
690                    .iter_mut(),
691                draw_var.try_as_col_major().unwrap().as_slice().iter(),
692                grad_var.try_as_col_major().unwrap().as_slice().iter(),
693            )
694            .for_each(|(std_out, inv_std_out, &draw_var, &grad_var)| {
695                let val = (draw_var / grad_var).sqrt();
696                if (!val.is_finite()) | (val == 0f64) {
697                    if let Some(fill_val) = fill_invalid {
698                        *std_out = fill_val.sqrt();
699                        *inv_std_out = fill_val.recip().sqrt();
700                    }
701                } else {
702                    let val = val.clamp(clamp.0, clamp.1);
703                    *std_out = val.sqrt();
704                    *inv_std_out = val.recip().sqrt();
705                }
706            });
707        });
708    }
709
710    fn array_update_var_inv_std_grad(
711        &mut self,
712        inv_std: &mut Self::Vector,
713        std: &mut Self::Vector,
714        gradient: &Self::Vector,
715        fill_invalid: f64,
716        clamp: (f64, f64),
717    ) {
718        self.arch.dispatch(|| {
719            izip!(
720                std.try_as_col_major_mut()
721                    .unwrap()
722                    .as_slice_mut()
723                    .iter_mut(),
724                inv_std
725                    .try_as_col_major_mut()
726                    .unwrap()
727                    .as_slice_mut()
728                    .iter_mut(),
729                gradient.try_as_col_major().unwrap().as_slice().iter(),
730            )
731            .for_each(|(std_out, inv_std_out, &grad_var)| {
732                let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
733                let val = if val.is_finite() { val } else { fill_invalid };
734                *std_out = val.sqrt();
735                *inv_std_out = val.recip().sqrt();
736            });
737        });
738    }
739
740    fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
741        source
742            .try_as_col_major()
743            .unwrap()
744            .as_slice()
745            .to_vec()
746            .into()
747    }
748
749    fn inv_transform_normalize(
750        &mut self,
751        params: &Self::FlowParameters,
752        untransformed_position: &Self::Vector,
753        untransofrmed_gradient: &Self::Vector,
754        transformed_position: &mut Self::Vector,
755        transformed_gradient: &mut Self::Vector,
756    ) -> Result<f64, Self::LogpErr> {
757        self.logp_func.inv_transform_normalize(
758            params,
759            untransformed_position
760                .try_as_col_major()
761                .unwrap()
762                .as_slice(),
763            untransofrmed_gradient
764                .try_as_col_major()
765                .unwrap()
766                .as_slice(),
767            transformed_position
768                .try_as_col_major_mut()
769                .unwrap()
770                .as_slice_mut(),
771            transformed_gradient
772                .try_as_col_major_mut()
773                .unwrap()
774                .as_slice_mut(),
775        )
776    }
777
778    fn init_from_untransformed_position(
779        &mut self,
780        params: &Self::FlowParameters,
781        untransformed_position: &Self::Vector,
782        untransformed_gradient: &mut Self::Vector,
783        transformed_position: &mut Self::Vector,
784        transformed_gradient: &mut Self::Vector,
785    ) -> Result<(f64, f64), Self::LogpErr> {
786        self.logp_func.init_from_untransformed_position(
787            params,
788            untransformed_position
789                .try_as_col_major()
790                .unwrap()
791                .as_slice(),
792            untransformed_gradient
793                .try_as_col_major_mut()
794                .unwrap()
795                .as_slice_mut(),
796            transformed_position
797                .try_as_col_major_mut()
798                .unwrap()
799                .as_slice_mut(),
800            transformed_gradient
801                .try_as_col_major_mut()
802                .unwrap()
803                .as_slice_mut(),
804        )
805    }
806
807    fn init_from_transformed_position(
808        &mut self,
809        params: &Self::FlowParameters,
810        untransformed_position: &mut Self::Vector,
811        untransformed_gradient: &mut Self::Vector,
812        transformed_position: &Self::Vector,
813        transformed_gradient: &mut Self::Vector,
814    ) -> Result<(f64, f64), Self::LogpErr> {
815        self.logp_func.init_from_transformed_position(
816            params,
817            untransformed_position
818                .try_as_col_major_mut()
819                .unwrap()
820                .as_slice_mut(),
821            untransformed_gradient
822                .try_as_col_major_mut()
823                .unwrap()
824                .as_slice_mut(),
825            transformed_position.try_as_col_major().unwrap().as_slice(),
826            transformed_gradient
827                .try_as_col_major_mut()
828                .unwrap()
829                .as_slice_mut(),
830        )
831    }
832
833    fn update_transformation<'a, R: rand::Rng + ?Sized>(
834        &'a mut self,
835        rng: &mut R,
836        untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
837        untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
838        untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
839        params: &'a mut Self::FlowParameters,
840    ) -> Result<(), Self::LogpErr> {
841        self.logp_func.update_transformation(
842            rng,
843            untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
844            untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
845            untransformed_logp,
846            params,
847        )
848    }
849
850    fn init_transformation<R: rand::Rng + ?Sized>(
851        &mut self,
852        rng: &mut R,
853        untransformed_position: &Self::Vector,
854        untransfogmed_gradient: &Self::Vector,
855        chain: u64,
856    ) -> Result<Self::FlowParameters, Self::LogpErr> {
857        self.logp_func.init_transformation(
858            rng,
859            untransformed_position
860                .try_as_col_major()
861                .unwrap()
862                .as_slice(),
863            untransfogmed_gradient
864                .try_as_col_major()
865                .unwrap()
866                .as_slice(),
867            chain,
868        )
869    }
870
871    fn new_transformation<R: rand::Rng + ?Sized>(
872        &mut self,
873        rng: &mut R,
874        dim: usize,
875        chain: u64,
876    ) -> Result<Self::FlowParameters, Self::LogpErr> {
877        self.logp_func.new_transformation(rng, dim, chain)
878    }
879
880    fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
881        self.logp_func.transformation_id(params)
882    }
883}
884
885pub trait CpuLogpFunc: HasDims {
886    type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
887    type FlowParameters;
888    type ExpandedVector: Storable<Self>;
889
890    fn dim(&self) -> usize;
891    fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
892    fn expand_vector<R>(
893        &mut self,
894        rng: &mut R,
895        array: &[f64],
896    ) -> Result<Self::ExpandedVector, CpuMathError>
897    where
898        R: rand::Rng + ?Sized;
899
900    fn vector_coord(&self) -> Option<Value> {
901        None
902    }
903
904    fn inv_transform_normalize(
905        &mut self,
906        _params: &Self::FlowParameters,
907        _untransformed_position: &[f64],
908        _untransformed_gradient: &[f64],
909        _transformed_position: &mut [f64],
910        _transformed_gradient: &mut [f64],
911    ) -> Result<f64, Self::LogpError> {
912        unimplemented!()
913    }
914
915    fn init_from_untransformed_position(
916        &mut self,
917        _params: &Self::FlowParameters,
918        _untransformed_position: &[f64],
919        _untransformed_gradient: &mut [f64],
920        _transformed_position: &mut [f64],
921        _transformed_gradient: &mut [f64],
922    ) -> Result<(f64, f64), Self::LogpError> {
923        unimplemented!()
924    }
925
926    fn init_from_transformed_position(
927        &mut self,
928        _params: &Self::FlowParameters,
929        _untransformed_position: &mut [f64],
930        _untransformed_gradient: &mut [f64],
931        _transformed_position: &[f64],
932        _transformed_gradient: &mut [f64],
933    ) -> Result<(f64, f64), Self::LogpError> {
934        unimplemented!()
935    }
936
937    fn update_transformation<'a, R: rand::Rng + ?Sized>(
938        &'a mut self,
939        _rng: &mut R,
940        _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
941        _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
942        _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
943        _params: &'a mut Self::FlowParameters,
944    ) -> Result<(), Self::LogpError> {
945        unimplemented!()
946    }
947
948    fn init_transformation<R: rand::Rng + ?Sized>(
949        &mut self,
950        _rng: &mut R,
951        _untransformed_position: &[f64],
952        _untransformed_gradient: &[f64],
953        _chain: u64,
954    ) -> Result<Self::FlowParameters, Self::LogpError> {
955        unimplemented!()
956    }
957
958    fn new_transformation<R: rand::Rng + ?Sized>(
959        &mut self,
960        _rng: &mut R,
961        _dim: usize,
962        _chain: u64,
963    ) -> Result<Self::FlowParameters, Self::LogpError> {
964        unimplemented!()
965    }
966
967    fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
968        unimplemented!()
969    }
970}
971
972impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
973    fn clone(&self) -> Self {
974        Self {
975            logp_func: self.logp_func.clone(),
976            arch: self.arch,
977            lowrank_scratch: Col::zeros(self.lowrank_scratch.nrows()),
978        }
979    }
980}