1use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
4
5use faer::linalg::matmul::matmul;
6
7use faer::{Accum, Col, Mat, Par};
8use itertools::{Itertools, izip};
9use nuts_storable::{HasDims, Storable, Value};
10use rand::RngExt;
11use thiserror::Error;
12
13use crate::math::util::multiply_inplace;
14
15use super::{
16 math::{LogpError, Math},
17 util::{
18 axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, std_norm_flow, std_norm_grad_flow,
19 std_norm_grad_flow_inplace, vector_dot,
20 },
21};
22
23#[derive(Debug)]
24pub struct CpuMath<F: CpuLogpFunc> {
25 logp_func: F,
26 arch: pulp::Arch,
27 lowrank_scratch: Col<f64>,
30}
31
32impl<F: CpuLogpFunc> CpuMath<F> {
33 pub fn new(logp_func: F) -> Self {
34 let arch = pulp::Arch::new();
35 Self {
36 logp_func,
37 arch,
38 lowrank_scratch: Col::zeros(0),
39 }
40 }
41}
42
43#[non_exhaustive]
44#[derive(Error, Debug)]
45pub enum CpuMathError {
46 #[error("Error during array operation")]
47 ArrayError(),
48 #[error("Error during point expansion: {0}")]
49 ExpandError(String),
50}
51
52impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
53 fn dim_sizes(&self) -> HashMap<String, u64> {
54 self.logp_func.dim_sizes()
55 }
56
57 fn coords(&self) -> HashMap<String, nuts_storable::Value> {
58 self.logp_func.coords()
59 }
60}
61
62pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
63
64impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
65 fn names(parent: &CpuMath<F>) -> Vec<&str> {
66 F::ExpandedVector::names(&parent.logp_func)
67 }
68
69 fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
70 F::ExpandedVector::item_type(&parent.logp_func, item)
71 }
72
73 fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
74 F::ExpandedVector::dims(&parent.logp_func, item)
75 }
76
77 fn get_all<'a>(
78 &'a mut self,
79 parent: &'a CpuMath<F>,
80 ) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
81 self.0.get_all(&parent.logp_func)
82 }
83}
84
85impl<F: CpuLogpFunc> Math for CpuMath<F> {
86 type Vector = Col<f64>;
87 type EigVectors = Mat<f64>;
88 type EigValues = Col<f64>;
89 type LogpErr = F::LogpError;
90 type Err = CpuMathError;
91 type FlowParameters = F::FlowParameters;
92 type ExpandedVector = ExpandedVectorWrapper<F>;
93
94 fn new_array(&mut self) -> Self::Vector {
95 Col::zeros(self.dim())
96 }
97
98 fn new_eig_vectors<'a>(
99 &'a mut self,
100 vals: impl ExactSizeIterator<Item = &'a [f64]>,
101 ) -> Self::EigVectors {
102 let ndim = self.dim();
103 let nvecs = vals.len();
104
105 let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
106 vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
107 col.try_as_col_major_mut()
108 .expect("Array is not contiguous")
109 .as_slice_mut()
110 .copy_from_slice(vals)
111 });
112
113 vectors
114 }
115
116 fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
117 let mut values: Col<f64> = Col::zeros(vals.len());
118 values
119 .try_as_col_major_mut()
120 .expect("Array is not contiguous")
121 .as_slice_mut()
122 .copy_from_slice(vals);
123 values
124 }
125
126 fn logp_array(
127 &mut self,
128 position: &Self::Vector,
129 gradient: &mut Self::Vector,
130 ) -> Result<f64, Self::LogpErr> {
131 self.logp_func.logp(
132 position
133 .try_as_col_major()
134 .expect("Array is not contiguous")
135 .as_slice(),
136 gradient
137 .try_as_col_major_mut()
138 .expect("Array is not contiguous")
139 .as_slice_mut(),
140 )
141 }
142
143 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
144 self.logp_func.logp(position, gradient)
145 }
146
147 fn dim(&self) -> usize {
148 self.logp_func.dim()
149 }
150
151 fn expand_vector<R: rand::Rng + ?Sized>(
152 &mut self,
153 rng: &mut R,
154 array: &Self::Vector,
155 ) -> Result<Self::ExpandedVector, Self::Err> {
156 Ok(ExpandedVectorWrapper(
157 self.logp_func.expand_vector(
158 rng,
159 array
160 .try_as_col_major()
161 .ok_or_else(|| {
162 CpuMathError::ExpandError("Internal vector was not col major".into())
163 })?
164 .as_slice(),
165 )?,
166 ))
167 }
168
169 fn vector_coord(&self) -> Option<Value> {
170 self.logp_func.vector_coord()
171 }
172
173 fn init_position<R: rand::Rng + ?Sized>(
174 &mut self,
175 rng: &mut R,
176 position: &mut Self::Vector,
177 gradient: &mut Self::Vector,
178 ) -> Result<f64, Self::LogpErr> {
179 let pos = position
180 .try_as_col_major_mut()
181 .expect("Array is not contiguous")
182 .as_slice_mut();
183
184 pos.iter_mut().for_each(|x| {
185 let val: f64 = rng.random();
186 *x = val * 2f64 - 1f64
187 });
188
189 self.logp_func.logp(
190 position
191 .try_as_col_major()
192 .expect("Array is not contiguous")
193 .as_slice(),
194 gradient
195 .try_as_col_major_mut()
196 .expect("Array is not contiguous")
197 .as_slice_mut(),
198 )
199 }
200
201 fn scalar_prods3(
202 &mut self,
203 positive1: &Self::Vector,
204 negative1: &Self::Vector,
205 positive2: &Self::Vector,
206 x: &Self::Vector,
207 y: &Self::Vector,
208 ) -> (f64, f64) {
209 scalar_prods3(
210 self.arch,
211 positive1.try_as_col_major().unwrap().as_slice(),
212 negative1.try_as_col_major().unwrap().as_slice(),
213 positive2.try_as_col_major().unwrap().as_slice(),
214 x.try_as_col_major().unwrap().as_slice(),
215 y.try_as_col_major().unwrap().as_slice(),
216 )
217 }
218
219 fn scalar_prods2(
220 &mut self,
221 positive1: &Self::Vector,
222 positive2: &Self::Vector,
223 x: &Self::Vector,
224 y: &Self::Vector,
225 ) -> (f64, f64) {
226 scalar_prods2(
227 self.arch,
228 positive1.try_as_col_major().unwrap().as_slice(),
229 positive2.try_as_col_major().unwrap().as_slice(),
230 x.try_as_col_major().unwrap().as_slice(),
231 y.try_as_col_major().unwrap().as_slice(),
232 )
233 }
234
235 fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
236 x.try_as_col_major()
237 .unwrap()
238 .as_slice()
239 .iter()
240 .zip(y.try_as_col_major().unwrap().as_slice())
241 .map(|(&x, &y)| (x + y) * (x + y))
242 .sum()
243 }
244
245 fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
246 dest.try_as_col_major_mut()
247 .unwrap()
248 .as_slice_mut()
249 .copy_from_slice(source);
250 }
251
252 fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
253 dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
254 }
255
256 fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
257 dest.clone_from(array)
258 }
259
260 fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
261 axpy_out(
262 self.arch,
263 x.try_as_col_major().unwrap().as_slice(),
264 y.try_as_col_major().unwrap().as_slice(),
265 a,
266 out.try_as_col_major_mut().unwrap().as_slice_mut(),
267 );
268 }
269
270 fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
271 axpy(
272 self.arch,
273 x.try_as_col_major().unwrap().as_slice(),
274 y.try_as_col_major_mut().unwrap().as_slice_mut(),
275 a,
276 );
277 }
278
279 fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
280 faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
281 }
282
283 fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
284 let mut ok = true;
285 faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
286 ok
287 }
288
289 fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
290 self.arch.dispatch(|| {
291 array
292 .try_as_col_major()
293 .unwrap()
294 .as_slice()
295 .iter()
296 .all(|&x| x.is_finite() & (x != 0f64))
297 })
298 }
299
300 fn array_sum_ln(&mut self, array: &Self::Vector) -> f64 {
301 let mut sum = 0f64;
302 faer::zip!(array).for_each(|faer::unzip!(val)| sum += val.ln());
303 sum
304 }
305
306 fn array_mult(
307 &mut self,
308 array1: &Self::Vector,
309 array2: &Self::Vector,
310 dest: &mut Self::Vector,
311 ) {
312 multiply(
313 self.arch,
314 array1.try_as_col_major().unwrap().as_slice(),
315 array2.try_as_col_major().unwrap().as_slice(),
316 dest.try_as_col_major_mut().unwrap().as_slice_mut(),
317 )
318 }
319
320 fn array_mult_inplace(&mut self, array1: &mut Self::Vector, array2: &Self::Vector) {
321 multiply_inplace(
322 self.arch,
323 array1.try_as_col_major_mut().unwrap().as_slice_mut(),
324 array2.try_as_col_major().unwrap().as_slice(),
325 )
326 }
327
328 fn array_recip(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
329 faer::zip!(array, dest).for_each(|faer::unzip!(val, dest)| *dest = val.recip())
330 }
331
332 fn apply_lowrank_transform(
333 &mut self,
334 vecs: &Self::EigVectors,
335 vals: &Self::EigValues,
336 rhs: &Self::Vector,
337 dest: &mut Self::Vector,
338 ) {
339 if vecs.ncols() == 0 {
340 self.copy_into(rhs, dest);
341 return;
342 }
343 let rank = vecs.ncols();
347
348 if self.lowrank_scratch.nrows() != rank {
350 self.lowrank_scratch.resize_with(rank, |_| 0.0);
351 }
352
353 matmul(
355 self.lowrank_scratch.as_mut(),
356 Accum::Replace,
357 vecs.transpose(),
358 rhs.as_ref(),
359 1.0,
360 Par::Seq,
361 );
362
363 self.lowrank_scratch
365 .iter_mut()
366 .zip(vals.iter())
367 .for_each(|(s, &v)| *s *= v - 1.0);
368
369 dest.copy_from(rhs);
371 matmul(
372 dest.as_mut(),
373 Accum::Add,
374 vecs.as_ref(),
375 self.lowrank_scratch.as_ref(),
376 1.0,
377 Par::Seq,
378 );
379 }
380
381 fn apply_lowrank_transform_inplace(
382 &mut self,
383 vecs: &Self::EigVectors,
384 vals: &Self::EigValues,
385 rhs_and_dest: &mut Self::Vector,
386 ) {
387 if vecs.ncols() == 0 {
388 return;
389 }
390 let rank = vecs.ncols();
394
395 if self.lowrank_scratch.nrows() != rank {
397 self.lowrank_scratch.resize_with(rank, |_| 0.0);
398 }
399
400 matmul(
402 self.lowrank_scratch.as_mut(),
403 Accum::Replace,
404 vecs.transpose(),
405 rhs_and_dest.as_ref(),
406 1.0,
407 Par::Seq,
408 );
409
410 self.lowrank_scratch
412 .iter_mut()
413 .zip(vals.iter())
414 .for_each(|(s, &v)| *s *= v - 1.0);
415
416 matmul(
418 rhs_and_dest.as_mut(),
419 Accum::Add,
420 vecs.as_ref(),
421 self.lowrank_scratch.as_ref(),
422 1.0,
423 Par::Seq,
424 );
425 }
426
427 fn array_mult_eigs(
428 &mut self,
429 stds: &Self::Vector,
430 rhs: &Self::Vector,
431 dest: &mut Self::Vector,
432 vecs: &Self::EigVectors,
433 vals: &Self::EigValues,
434 ) {
435 let rhs = stds.as_diagonal() * rhs;
436 let trafo = vecs.transpose() * (&rhs);
437 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
438 let scaled = stds.as_diagonal() * inner_prod;
439
440 let _ = replace(dest, scaled);
441 }
442
443 fn std_norm_flow(
447 &mut self,
448 pos: &Self::Vector,
449 pos_out: &mut Self::Vector,
450 vel: &mut Self::Vector,
451 epsilon: f64,
452 ) {
453 std_norm_flow(
454 self.arch,
455 pos.try_as_col_major().unwrap().as_slice(),
456 pos_out.try_as_col_major_mut().unwrap().as_slice_mut(),
457 vel.try_as_col_major_mut().unwrap().as_slice_mut(),
458 epsilon,
459 );
460 }
461
462 fn std_norm_grad_flow(
463 &mut self,
464 pos: &Self::Vector,
465 grad: &Self::Vector,
466 vel: &Self::Vector,
467 vel_out: &mut Self::Vector,
468 epsilon: f64,
469 ) {
470 std_norm_grad_flow(
471 self.arch,
472 pos.try_as_col_major().unwrap().as_slice(),
473 grad.try_as_col_major().unwrap().as_slice(),
474 vel.try_as_col_major().unwrap().as_slice(),
475 vel_out.try_as_col_major_mut().unwrap().as_slice_mut(),
476 epsilon,
477 );
478 }
479
480 fn std_norm_grad_flow_inplace(
481 &mut self,
482 pos: &Self::Vector,
483 grad: &Self::Vector,
484 vel: &mut Self::Vector,
485 epsilon: f64,
486 ) {
487 std_norm_grad_flow_inplace(
488 self.arch,
489 pos.try_as_col_major().unwrap().as_slice(),
490 grad.try_as_col_major().unwrap().as_slice(),
491 vel.try_as_col_major_mut().unwrap().as_slice_mut(),
492 epsilon,
493 );
494 }
495
496 fn array_normalize(&mut self, v: &mut Self::Vector) {
497 let v = v.try_as_col_major_mut().unwrap().as_slice_mut();
498 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
499 let inv = 1.0 / norm;
500 for x in v.iter_mut() {
501 *x *= inv;
502 }
503 }
504
505 fn esh_momentum_update(
506 &mut self,
507 gradient: &Self::Vector,
508 momentum: &mut Self::Vector,
509 step_size: f64,
510 ) -> f64 {
511 let gradient = gradient.try_as_col_major().unwrap().as_slice();
512 let momentum = momentum.try_as_col_major_mut().unwrap().as_slice_mut();
513 let n = gradient.len();
514 assert!(n >= 2, "ESH dynamics requires at least 2 dimensions");
515
516 let grad_norm: f64 = gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
518
519 let inv_grad_norm = 1.0 / grad_norm;
520
521 let momentum_proj: f64 = momentum
523 .iter()
524 .zip(gradient.iter())
525 .map(|(p, g)| p * g * inv_grad_norm)
526 .sum();
527
528 let dims_m1 = (n - 1) as f64;
529 let delta = step_size * grad_norm / dims_m1;
530 let zeta = (-delta).exp();
531
532 let coeff_g = (1.0 - zeta) * (1.0 + zeta + momentum_proj * (1.0 - zeta));
534 let coeff_p = 2.0 * zeta;
535
536 for (p, g) in momentum.iter_mut().zip(gradient.iter()) {
537 *p = coeff_g * (g * inv_grad_norm) + coeff_p * *p;
538 }
539
540 let raw_norm: f64 = momentum.iter().map(|p| p * p).sum::<f64>().sqrt();
542 let inv = 1.0 / raw_norm;
543 for p in momentum.iter_mut() {
544 *p *= inv;
545 }
546
547 let arg = momentum_proj + (1.0 - momentum_proj) * zeta * zeta;
548 let kinetic_energy_change = (delta - std::f64::consts::LN_2 + arg.ln_1p()) * dims_m1;
549
550 kinetic_energy_change
551 }
552
553 fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
554 vector_dot(
555 self.arch,
556 array1.try_as_col_major().unwrap().as_slice(),
557 array2.try_as_col_major().unwrap().as_slice(),
558 )
559 }
560
561 fn array_gaussian<R: rand::Rng + ?Sized>(
562 &mut self,
563 rng: &mut R,
564 dest: &mut Self::Vector,
565 stds: &Self::Vector,
566 ) {
567 let dist = rand_distr::StandardNormal;
568 dest.try_as_col_major_mut()
569 .unwrap()
570 .as_slice_mut()
571 .iter_mut()
572 .zip(stds.try_as_col_major().unwrap().as_slice().iter())
573 .for_each(|(p, &s)| {
574 let norm: f64 = rng.sample(dist);
575 *p = s * norm;
576 });
577 }
578
579 fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
580 &mut self,
581 rng: &mut R,
582 dest: &mut Self::Vector,
583 scale: &Self::Vector,
584 vals: &Self::EigValues,
585 vecs: &Self::EigVectors,
586 ) {
587 let mut draw: Col<f64> = Col::zeros(self.dim());
588 let dist = rand_distr::StandardNormal;
589 draw.try_as_col_major_mut()
590 .unwrap()
591 .as_slice_mut()
592 .iter_mut()
593 .for_each(|p| {
594 *p = rng.sample(dist);
595 });
596
597 let trafo = vecs.transpose() * (&draw);
598 let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
599
600 let scaled = scale.as_diagonal() * inner_prod;
601
602 let _ = replace(dest, scaled);
603 }
604
605 fn array_update_variance(
606 &mut self,
607 mean: &mut Self::Vector,
608 variance: &mut Self::Vector,
609 value: &Self::Vector,
610 diff_scale: f64, ) {
612 self.arch.dispatch(|| {
613 izip!(
614 mean.try_as_col_major_mut()
615 .unwrap()
616 .as_slice_mut()
617 .iter_mut(),
618 variance
619 .try_as_col_major_mut()
620 .unwrap()
621 .as_slice_mut()
622 .iter_mut(),
623 value.try_as_col_major().unwrap().as_slice()
624 )
625 .for_each(|(mean, var, x)| {
626 let diff = x - *mean;
627 *mean += diff * diff_scale;
628 *var += diff * diff;
629 });
630 })
631 }
632
633 fn array_update_var_inv_std_draw(
634 &mut self,
635 inv_std: &mut Self::Vector,
636 std: &mut Self::Vector,
637 draw_var: &Self::Vector,
638 scale: f64,
639 fill_invalid: Option<f64>,
640 clamp: (f64, f64),
641 ) {
642 self.arch.dispatch(|| {
643 izip!(
644 std.try_as_col_major_mut()
645 .unwrap()
646 .as_slice_mut()
647 .iter_mut(),
648 inv_std
649 .try_as_col_major_mut()
650 .unwrap()
651 .as_slice_mut()
652 .iter_mut(),
653 draw_var.try_as_col_major().unwrap().as_slice().iter(),
654 )
655 .for_each(|(std_out, inv_std_out, &draw_var)| {
656 let draw_var = draw_var * scale;
657 if (!draw_var.is_finite()) | (draw_var == 0f64) {
658 if let Some(fill_val) = fill_invalid {
659 *std_out = fill_val.sqrt();
660 *inv_std_out = fill_val.recip().sqrt();
661 }
662 } else {
663 let val = draw_var.clamp(clamp.0, clamp.1);
664 *std_out = val.sqrt();
665 *inv_std_out = val.recip().sqrt();
666 }
667 });
668 });
669 }
670
671 fn array_update_var_inv_std_draw_grad(
672 &mut self,
673 inv_std: &mut Self::Vector,
674 std: &mut Self::Vector,
675 draw_var: &Self::Vector,
676 grad_var: &Self::Vector,
677 fill_invalid: Option<f64>,
678 clamp: (f64, f64),
679 ) {
680 self.arch.dispatch(|| {
681 izip!(
682 std.try_as_col_major_mut()
683 .unwrap()
684 .as_slice_mut()
685 .iter_mut(),
686 inv_std
687 .try_as_col_major_mut()
688 .unwrap()
689 .as_slice_mut()
690 .iter_mut(),
691 draw_var.try_as_col_major().unwrap().as_slice().iter(),
692 grad_var.try_as_col_major().unwrap().as_slice().iter(),
693 )
694 .for_each(|(std_out, inv_std_out, &draw_var, &grad_var)| {
695 let val = (draw_var / grad_var).sqrt();
696 if (!val.is_finite()) | (val == 0f64) {
697 if let Some(fill_val) = fill_invalid {
698 *std_out = fill_val.sqrt();
699 *inv_std_out = fill_val.recip().sqrt();
700 }
701 } else {
702 let val = val.clamp(clamp.0, clamp.1);
703 *std_out = val.sqrt();
704 *inv_std_out = val.recip().sqrt();
705 }
706 });
707 });
708 }
709
710 fn array_update_var_inv_std_grad(
711 &mut self,
712 inv_std: &mut Self::Vector,
713 std: &mut Self::Vector,
714 gradient: &Self::Vector,
715 fill_invalid: f64,
716 clamp: (f64, f64),
717 ) {
718 self.arch.dispatch(|| {
719 izip!(
720 std.try_as_col_major_mut()
721 .unwrap()
722 .as_slice_mut()
723 .iter_mut(),
724 inv_std
725 .try_as_col_major_mut()
726 .unwrap()
727 .as_slice_mut()
728 .iter_mut(),
729 gradient.try_as_col_major().unwrap().as_slice().iter(),
730 )
731 .for_each(|(std_out, inv_std_out, &grad_var)| {
732 let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
733 let val = if val.is_finite() { val } else { fill_invalid };
734 *std_out = val.sqrt();
735 *inv_std_out = val.recip().sqrt();
736 });
737 });
738 }
739
740 fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
741 source
742 .try_as_col_major()
743 .unwrap()
744 .as_slice()
745 .to_vec()
746 .into()
747 }
748
749 fn inv_transform_normalize(
750 &mut self,
751 params: &Self::FlowParameters,
752 untransformed_position: &Self::Vector,
753 untransofrmed_gradient: &Self::Vector,
754 transformed_position: &mut Self::Vector,
755 transformed_gradient: &mut Self::Vector,
756 ) -> Result<f64, Self::LogpErr> {
757 self.logp_func.inv_transform_normalize(
758 params,
759 untransformed_position
760 .try_as_col_major()
761 .unwrap()
762 .as_slice(),
763 untransofrmed_gradient
764 .try_as_col_major()
765 .unwrap()
766 .as_slice(),
767 transformed_position
768 .try_as_col_major_mut()
769 .unwrap()
770 .as_slice_mut(),
771 transformed_gradient
772 .try_as_col_major_mut()
773 .unwrap()
774 .as_slice_mut(),
775 )
776 }
777
778 fn init_from_untransformed_position(
779 &mut self,
780 params: &Self::FlowParameters,
781 untransformed_position: &Self::Vector,
782 untransformed_gradient: &mut Self::Vector,
783 transformed_position: &mut Self::Vector,
784 transformed_gradient: &mut Self::Vector,
785 ) -> Result<(f64, f64), Self::LogpErr> {
786 self.logp_func.init_from_untransformed_position(
787 params,
788 untransformed_position
789 .try_as_col_major()
790 .unwrap()
791 .as_slice(),
792 untransformed_gradient
793 .try_as_col_major_mut()
794 .unwrap()
795 .as_slice_mut(),
796 transformed_position
797 .try_as_col_major_mut()
798 .unwrap()
799 .as_slice_mut(),
800 transformed_gradient
801 .try_as_col_major_mut()
802 .unwrap()
803 .as_slice_mut(),
804 )
805 }
806
807 fn init_from_transformed_position(
808 &mut self,
809 params: &Self::FlowParameters,
810 untransformed_position: &mut Self::Vector,
811 untransformed_gradient: &mut Self::Vector,
812 transformed_position: &Self::Vector,
813 transformed_gradient: &mut Self::Vector,
814 ) -> Result<(f64, f64), Self::LogpErr> {
815 self.logp_func.init_from_transformed_position(
816 params,
817 untransformed_position
818 .try_as_col_major_mut()
819 .unwrap()
820 .as_slice_mut(),
821 untransformed_gradient
822 .try_as_col_major_mut()
823 .unwrap()
824 .as_slice_mut(),
825 transformed_position.try_as_col_major().unwrap().as_slice(),
826 transformed_gradient
827 .try_as_col_major_mut()
828 .unwrap()
829 .as_slice_mut(),
830 )
831 }
832
833 fn update_transformation<'a, R: rand::Rng + ?Sized>(
834 &'a mut self,
835 rng: &mut R,
836 untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
837 untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
838 untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
839 params: &'a mut Self::FlowParameters,
840 ) -> Result<(), Self::LogpErr> {
841 self.logp_func.update_transformation(
842 rng,
843 untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
844 untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
845 untransformed_logp,
846 params,
847 )
848 }
849
850 fn init_transformation<R: rand::Rng + ?Sized>(
851 &mut self,
852 rng: &mut R,
853 untransformed_position: &Self::Vector,
854 untransfogmed_gradient: &Self::Vector,
855 chain: u64,
856 ) -> Result<Self::FlowParameters, Self::LogpErr> {
857 self.logp_func.init_transformation(
858 rng,
859 untransformed_position
860 .try_as_col_major()
861 .unwrap()
862 .as_slice(),
863 untransfogmed_gradient
864 .try_as_col_major()
865 .unwrap()
866 .as_slice(),
867 chain,
868 )
869 }
870
871 fn new_transformation<R: rand::Rng + ?Sized>(
872 &mut self,
873 rng: &mut R,
874 dim: usize,
875 chain: u64,
876 ) -> Result<Self::FlowParameters, Self::LogpErr> {
877 self.logp_func.new_transformation(rng, dim, chain)
878 }
879
880 fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
881 self.logp_func.transformation_id(params)
882 }
883}
884
885pub trait CpuLogpFunc: HasDims {
886 type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
887 type FlowParameters;
888 type ExpandedVector: Storable<Self>;
889
890 fn dim(&self) -> usize;
891 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
892 fn expand_vector<R>(
893 &mut self,
894 rng: &mut R,
895 array: &[f64],
896 ) -> Result<Self::ExpandedVector, CpuMathError>
897 where
898 R: rand::Rng + ?Sized;
899
900 fn vector_coord(&self) -> Option<Value> {
901 None
902 }
903
904 fn inv_transform_normalize(
905 &mut self,
906 _params: &Self::FlowParameters,
907 _untransformed_position: &[f64],
908 _untransformed_gradient: &[f64],
909 _transformed_position: &mut [f64],
910 _transformed_gradient: &mut [f64],
911 ) -> Result<f64, Self::LogpError> {
912 unimplemented!()
913 }
914
915 fn init_from_untransformed_position(
916 &mut self,
917 _params: &Self::FlowParameters,
918 _untransformed_position: &[f64],
919 _untransformed_gradient: &mut [f64],
920 _transformed_position: &mut [f64],
921 _transformed_gradient: &mut [f64],
922 ) -> Result<(f64, f64), Self::LogpError> {
923 unimplemented!()
924 }
925
926 fn init_from_transformed_position(
927 &mut self,
928 _params: &Self::FlowParameters,
929 _untransformed_position: &mut [f64],
930 _untransformed_gradient: &mut [f64],
931 _transformed_position: &[f64],
932 _transformed_gradient: &mut [f64],
933 ) -> Result<(f64, f64), Self::LogpError> {
934 unimplemented!()
935 }
936
937 fn update_transformation<'a, R: rand::Rng + ?Sized>(
938 &'a mut self,
939 _rng: &mut R,
940 _untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
941 _untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
942 _untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
943 _params: &'a mut Self::FlowParameters,
944 ) -> Result<(), Self::LogpError> {
945 unimplemented!()
946 }
947
948 fn init_transformation<R: rand::Rng + ?Sized>(
949 &mut self,
950 _rng: &mut R,
951 _untransformed_position: &[f64],
952 _untransformed_gradient: &[f64],
953 _chain: u64,
954 ) -> Result<Self::FlowParameters, Self::LogpError> {
955 unimplemented!()
956 }
957
958 fn new_transformation<R: rand::Rng + ?Sized>(
959 &mut self,
960 _rng: &mut R,
961 _dim: usize,
962 _chain: u64,
963 ) -> Result<Self::FlowParameters, Self::LogpError> {
964 unimplemented!()
965 }
966
967 fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
968 unimplemented!()
969 }
970}
971
972impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
973 fn clone(&self) -> Self {
974 Self {
975 logp_func: self.logp_func.clone(),
976 arch: self.arch,
977 lowrank_scratch: Col::zeros(self.lowrank_scratch.nrows()),
978 }
979 }
980}