1use burn::prelude::{Backend, Tensor};
5use core::fmt;
6use ndarray::{concatenate, prelude::*, stack};
7use ndarray_stats::QuantileExt;
8use num_traits::{Num, ToPrimitive};
9use rayon::prelude::*;
10use rustfft::{num_complex::Complex, FftPlanner};
11use std::{cmp::Ordering, error::Error};
12
13const ALPHA: f32 = 0.01;
14
15#[derive(Debug, Clone, PartialEq)]
26pub struct ChainTracker {
27 n_params: usize,
28 n: u64,
29 p_accept: f32,
30 last_state: Array1<f32>,
31 mean: Array1<f32>, mean_sq: Array1<f32>, }
34
35#[derive(Debug, Clone, PartialEq)]
43pub struct ChainStats {
44 pub n: u64,
45 pub p_accept: f32,
46 pub mean: Array1<f32>, pub sm2: Array1<f32>, }
49
50impl ChainTracker {
51 pub fn new<T>(n_params: usize, initial_state: &[T]) -> Self
60 where
61 T: num_traits::ToPrimitive + Clone,
62 {
63 let mean_sq = Array1::<f32>::zeros(n_params);
64 let mean = Array1::<f32>::zeros(n_params);
65 let last_state = ArrayView1::from_shape(n_params, initial_state)
66 .expect("Expected being able to convert initial state to a NdArray")
67 .mapv(|x| {
68 x.to_f32()
69 .expect("Expected conversion of elements to f32's to succeed")
70 });
71
72 Self {
73 n_params,
74 n: 0,
75 p_accept: -1.0,
76 last_state,
77 mean,
78 mean_sq,
79 }
80 }
81
82 pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
90 where
91 T: num_traits::ToPrimitive + Clone,
92 {
93 self.n += 1;
94
95 let n = self.n as f32;
96 let x_arr =
97 ndarray::ArrayView1::<T>::from_shape(self.n_params, x)?.mapv(|x| x.to_f32().unwrap());
98
99 self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
100 if self.n == 1 {
101 self.mean_sq = x_arr.pow2();
102 } else {
103 self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
104 };
105
106 let p_start = if self.p_accept >= 0.0 {
111 self.p_accept
112 } else {
113 x_arr
114 .index_axis(Axis(0), 0)
115 .ne(&self.last_state.index_axis(Axis(0), 0)) as i32 as f32
116 };
117 self.p_accept = ndarray::Zip::from(x_arr.rows())
118 .and(self.last_state.rows())
119 .fold(p_start, |p_accept, a, b| {
120 let accepted = (a.ne(&b) as i32) as f32;
121 (1.0 - ALPHA) * p_accept + ALPHA * accepted
122 });
123 self.last_state = x_arr;
124
125 Ok(())
126 }
127
128 pub fn stats(&self) -> ChainStats {
133 let n = self.n as f32;
134 ChainStats {
135 n: self.n,
136 p_accept: self.p_accept,
137 mean: self.mean.clone(),
138 sm2: (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0),
139 }
140 }
141}
142
143pub fn collect_rhat(chain_stats: &[&ChainStats]) -> Array1<f32> {
151 let (within, var) = withinvar_from_cs(chain_stats);
152 (var / within).sqrt()
153}
154
155fn withinvar_from_cs(chain_stats: &[&ChainStats]) -> (Array1<f32>, Array1<f32>) {
156 let means: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.mean.view()).collect();
157 let means = ndarray::stack(Axis(0), &means).expect("Expected stacking means to succeed");
158 let sm2s: Vec<ArrayView1<f32>> = chain_stats.iter().map(|x| x.sm2.view()).collect();
159 let sm2s = ndarray::stack(Axis(0), &sm2s).expect("Expected stacking sm2 arrays to succeed");
160
161 let within = sm2s
162 .mean_axis(Axis(0))
163 .expect("Expected computing within-chain variances to succeed");
164 let global_means = means
165 .mean_axis(Axis(0))
166 .expect("Expected computing global means to succeed");
167 let diffs: Array2<f32> = (means.clone()
168 - global_means
169 .broadcast(means.shape())
170 .expect("Expected broadcasting to succeed"))
171 .into_dimensionality()
172 .expect("Expected casting dimensionality to Array1 to succeed");
173 let between = diffs.pow2().sum_axis(Axis(0)) / (diffs.len() - 1) as f32;
174
175 let n: f32 = chain_stats.iter().map(|x| x.n as f32).sum::<f32>() / chain_stats.len() as f32;
176 let var = between + within.clone() * ((n - 1.0) / n);
177 (within, var)
178}
179
180#[derive(Debug, Clone, PartialEq)]
189pub struct MultiChainTracker {
190 n: usize,
191 pub p_accept: f32,
192 last_state: Array2<f32>,
193 mean: Array2<f32>, mean_sq: Array2<f32>, n_chains: usize,
196 n_params: usize,
197}
198
199impl MultiChainTracker {
200 pub fn new(n_chains: usize, n_params: usize) -> Self {
209 let mean_sq = Array2::<f32>::zeros((n_chains, n_params));
210 Self {
211 n: 0,
212 p_accept: 0.0,
213 last_state: Array2::<f32>::zeros((n_chains, n_params)),
214 mean: Array2::<f32>::zeros((n_chains, n_params)),
215 mean_sq,
216 n_chains,
217 n_params,
218 }
219 }
220
221 pub fn step<T>(&mut self, x: &[T]) -> Result<(), Box<dyn Error>>
229 where
230 T: Num
231 + num_traits::ToPrimitive
232 + num_traits::FromPrimitive
233 + std::clone::Clone
234 + std::cmp::PartialOrd,
235 {
236 self.n += 1;
237
238 let n = self.n as f32;
239 let x_arr = ndarray::ArrayView2::<T>::from_shape((self.n_chains, self.n_params), x)?
240 .mapv(|x| x.to_f32().unwrap());
241
242 self.mean = (self.mean.clone() * (n - 1.0) + x_arr.clone()) / n;
243 if self.n == 1 {
244 self.mean_sq = x_arr.pow2();
245 } else {
246 self.mean_sq = (self.mean_sq.clone() * (n - 1.0) + (x_arr.pow2())) / n;
247 };
248
249 self.p_accept = ndarray::Zip::from(x_arr.rows())
251 .and(self.last_state.rows())
252 .fold(self.p_accept, |p_accept, a, b| {
253 let accepted = (a.ne(&b) as i32) as f32;
254 (1.0 - ALPHA) * p_accept + ALPHA * accepted
255 });
256 self.last_state = x_arr;
257
258 Ok(())
259 }
260
261 pub fn stats<B: Backend>(&self, sample: Tensor<B, 3>) -> Result<RunStats, Box<dyn Error>> {
262 let sample_data = sample.to_data();
263 let sample_ndarray =
264 ArrayView3::from_shape(sample.dims(), sample_data.as_slice().unwrap())?;
265 Ok(RunStats::from_f32_view(sample_ndarray))
266 }
267
268 pub fn max_rhat(&self) -> Result<f32, Box<dyn Error>> {
273 let all: Array1<f32> = self.rhat()?;
274 let max = *all.max()?;
275 Ok(max)
276 }
277
278 pub fn rhat(&self) -> Result<Array1<f32>, Box<dyn Error>> {
283 let (within, var) = self.within_and_var()?;
284 let rhat = (var / within).sqrt();
285 Ok(rhat)
286 }
287
288 fn within_and_var(&self) -> Result<(Array1<f32>, Array1<f32>), Box<dyn Error>> {
289 let mean_chain = self
290 .mean
291 .mean_axis(Axis(0))
292 .ok_or("Mean reduction across chains for mean failed.")?;
293 let n_chains = self.mean.shape()[0] as f32;
294 let n = self.n as f32;
295 let fac = n / (n_chains - 1.0);
296 let between = (self.mean.clone() - mean_chain.insert_axis(Axis(0)))
297 .pow2()
298 .sum_axis(Axis(0))
299 * fac;
300 let sm2 = (self.mean_sq.clone() - self.mean.pow2()) * n / (n - 1.0);
301 let within = sm2
302 .mean_axis(Axis(0))
303 .ok_or("Mean reduction across chains for mean of squares failed.")?;
304 let var = within.clone() * ((n - 1.0) / n) + between * (1.0 / n);
305 Ok((within, var))
306 }
307}
308
309pub fn basic_stats(name: &str, mut data: Array1<f32>) -> BasicStats {
311 data.as_slice_mut()
312 .unwrap()
313 .sort_by(|a, b| match b.partial_cmp(a) {
314 Some(x) => x,
315 None => Ordering::Equal,
316 });
317 let (min, median, max) = (
318 *data
319 .last()
320 .expect("Expected getting first element from ess array succeed"),
321 data[data.len() / 2],
322 *data
323 .first()
324 .expect("Expected getting last element from ess array succeed"),
325 );
326 let mean = data.mean().expect("Expected computing mean ess to succeed");
327 let std = data.std(1.0);
328 BasicStats {
329 name: name.to_string(),
330 min,
331 median,
332 max,
333 mean,
334 std,
335 }
336}
337
338#[derive(Clone, Debug, PartialEq, PartialOrd)]
339pub struct RunStats {
340 pub ess: BasicStats,
341 pub rhat: BasicStats,
342}
343
344impl RunStats {
345 fn from_f32_view(sample: ArrayView3<f32>) -> Self {
346 let (rhat, ess) = split_rhat_mean_ess(sample);
347 let ess = basic_stats("ESS", ess);
348 let rhat = basic_stats("Split R-hat", rhat);
349 RunStats { ess, rhat }
350 }
351}
352
353impl fmt::Display for RunStats {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 write!(f, "{}\n{}", self.ess, self.rhat)
357 }
358}
359
360impl<T> From<ArrayView3<'_, T>> for RunStats
361where
362 T: ToPrimitive + std::clone::Clone,
363{
364 fn from(sample: ArrayView3<T>) -> Self {
365 let f32_sample = sample.mapv(|x| x.to_f32().unwrap());
366 let (rhat, ess) = split_rhat_mean_ess(f32_sample.view());
367 let ess = basic_stats("ESS", ess);
368 let rhat = basic_stats("Split R-hat", rhat);
369 RunStats { ess, rhat }
370 }
371}
372
373#[derive(Clone, Debug, PartialEq, PartialOrd)]
374pub struct BasicStats {
375 pub name: String,
376 pub min: f32,
377 pub median: f32,
378 pub max: f32,
379 pub mean: f32,
380 pub std: f32,
381}
382
383impl fmt::Display for BasicStats {
384 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385 write!(
387 f,
388 "{} in [{:.2}, {:.2}], median: {:.2}, mean: {:.2} ± {:.2}",
389 self.name, self.min, self.max, self.median, self.mean, self.std
390 )
391 }
392}
393
394fn splitcat(sample: ArrayView3<f32>) -> Array3<f32> {
397 let n = sample.shape()[1];
398 let half = (n / 2) as i32;
399 let half_1 = sample.slice(s![.., ..half, ..]);
400 let half_2 = sample.slice(s![.., -half.., ..]);
401 concatenate(Axis(0), &[half_1, half_2]).expect("Expected stacking two halves to succeed")
402}
403
404pub fn split_rhat_mean_ess(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
417 let splitted = splitcat(sample); let (within, var) = withinvar(splitted.view());
419 (
420 rhat(within.view(), var.view()),
421 ess(splitted.view(), within.view(), var.view()),
422 )
423}
424
425fn rhat(within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
426 (within.to_owned() / var).sqrt()
427}
428
429fn withinvar(sample: ArrayView3<f32>) -> (Array1<f32>, Array1<f32>) {
430 let c = sample.shape()[0];
431 let n = sample.shape()[1];
432 let p = sample.shape()[2];
433
434 let (within, var): (Vec<f32>, Vec<f32>) = (0..p)
437 .into_par_iter()
438 .map(|param_idx| {
439 let data_p = sample.slice(s![.., .., param_idx]);
440
441 let chain_means = data_p.mean_axis(Axis(1)).unwrap();
443 let overall_mean = chain_means.mean().unwrap();
444
445 let diff = &chain_means - overall_mean;
449 let b = diff.pow2().sum() * ((n as f32) / ((c - 1) as f32));
450
451 let mut squares = Vec::with_capacity(c);
456 for chain_i in 0..c {
457 let row = data_p.slice(s![chain_i, ..]);
458 let cm = chain_means[chain_i];
459 let sq = row.iter().map(|v| (v - cm) * (v - cm)).sum::<f32>() / (n as f32);
460 squares.push(sq);
461 }
462 let squares = Array1::from(squares); let w = squares.mean().unwrap(); let v = ((n as f32 - 1.0) / (n as f32)) * w + b / (n as f32);
466
467 (w, v)
468 })
469 .collect::<Vec<(f32, f32)>>()
470 .into_iter()
471 .fold((vec![], vec![]), |(mut within, mut var), (w, v)| {
472 within.push(w);
473 var.push(v);
474 (within, var)
475 });
476 (Array1::from_vec(within), Array1::from_vec(var))
477}
478
479fn ess(sample: ArrayView3<f32>, within: ArrayView1<f32>, var: ArrayView1<f32>) -> Array1<f32> {
497 let shape = sample.shape();
498 let (n_chains, n_steps, n_params) = (shape[0], shape[1], shape[2]);
499 let chain_rho: Vec<Array2<f32>> = (0..n_chains)
500 .map(|c| {
501 let chain_sample = sample.index_axis(Axis(0), c);
502 autocov(chain_sample)
503 })
504 .collect();
505 let chain_rho: Vec<ArrayView2<f32>> = chain_rho.iter().map(|x| x.view()).collect();
506 let chain_rho = stack(Axis(0), &chain_rho)
507 .expect("Expected stacking chain-specific autocovariance matrices to succeed");
508 let avg_rho = chain_rho.mean_axis(Axis(0)).unwrap();
509 let diff = -avg_rho
510 + within
511 .broadcast((n_steps, n_params))
512 .expect("Expected broadcasting to succeed");
513 let rho = -(diff
514 / var
515 .broadcast((n_steps, n_params))
516 .expect("Expected broadcasting to succeed"))
517 + 1.0;
518 let tau: Vec<f32> = (0..n_params)
519 .into_par_iter()
520 .map(|d| {
521 let rho_d = rho.index_axis(Axis(1), d).to_owned();
522
523 let mut min = if rho_d.len() >= 2 {
524 rho_d[[0]] + rho_d[[1]]
525 } else {
526 0.0
527 };
528
529 let mut out = 0.0;
530 for rho_t in rho_d.windows_with_stride(2, 2) {
531 let mut p_t = rho_t[0] + rho_t[1];
532 if p_t <= 0.0 {
533 break;
534 }
535 if p_t > min {
536 p_t = min;
537 }
538 min = p_t;
539 out += p_t;
540 }
541 -1.0 + 2.0 * out
542 })
543 .collect();
544 let tau = Array1::from_vec(tau);
545 tau.recip() * n_chains as f32 * n_steps as f32
546}
547
548fn autocov(sample: ArrayView2<f32>) -> Array2<f32> {
549 if sample.nrows() <= 100 {
550 autocov_bf(sample)
551 } else {
552 autocov_fft(sample)
553 }
554}
555
556fn autocov_fft(sample: ArrayView2<f32>) -> Array2<f32> {
577 let (n, d) = (sample.shape()[0], sample.shape()[1]);
578 let mut planner = FftPlanner::new();
579
580 let mut n_padded = 1;
582 while n_padded < 2 * n - 1 {
583 n_padded <<= 1;
584 }
585 let fft = planner.plan_fft_forward(n_padded);
586 let ffti = planner.plan_fft_inverse(n_padded);
587 let out: Vec<f32> = sample
588 .axis_iter(Axis(1))
589 .into_par_iter()
590 .map(|traj| {
591 let traj_mean = traj.sum() / traj.len() as f32;
592 let mut x: Vec<Complex<f32>> = traj
593 .iter()
594 .map(|xi| Complex {
595 re: (*xi - traj_mean),
596 im: 0.0f32,
597 })
598 .chain(
599 [Complex {
600 re: 0.0f32,
601 im: 0.0f32,
602 }]
603 .repeat(n_padded - n),
604 )
605 .collect();
606 fft.process(x.as_mut_slice());
607 x.iter_mut().for_each(|xi| {
608 *xi *= xi.conj();
609 });
610 ffti.process(x.as_mut_slice());
611 x.iter_mut()
612 .take(n)
613 .map(|xi| xi.re / n_padded as f32 / n as f32) .collect::<Vec<f32>>()
615 })
616 .flatten_iter()
617 .collect();
618 let out = Array2::from_shape_vec((d, n), out).expect("Expected creating dxn array to succeed");
619 out.t().to_owned()
620}
621
622fn autocov_bf(data: ArrayView2<f32>) -> Array2<f32> {
633 let (n, d) = data.dim();
634 let mut out = Array2::<f32>::zeros((n, d));
635
636 out.axis_iter_mut(Axis(1)) .into_par_iter() .enumerate() .for_each(|(col_idx, mut out_col)| {
640 let col_data = data.column(col_idx);
641 let col_data = col_data.to_owned() - col_data.mean().unwrap();
642
643 for lag in 0..n {
645 let mut sum_lag = 0.0;
646 for t in 0..(n - lag) {
647 sum_lag += col_data[t] * col_data[t + lag];
648 }
649 out_col[lag] = sum_lag / n as f32;
651 }
652 });
653 out
654}
655
656pub fn ess_from_chainstats(sample: ArrayView3<f32>, chain_stats: &[&ChainStats]) -> Array1<f32> {
669 let (within, var) = withinvar_from_cs(chain_stats);
670 ess(sample, within.view(), var.view())
671}
672
673#[cfg(test)]
674mod tests {
675 use std::io::Write;
676 use std::{f32, fs::File, time::Instant};
677
678 use approx::assert_abs_diff_eq;
679 use rand::rngs::SmallRng;
680 use rand::{Rng, SeedableRng};
681
682 use super::*;
683
684 fn run_rhat_test_generic<T>(data0: Array2<T>, data1: Array2<T>, expected: Array1<f32>, tol: f32)
686 where
687 T: ndarray::NdFloat + num_traits::FromPrimitive,
688 {
689 let mut psr = MultiChainTracker::new(3, 4);
690 psr.step(data0.as_slice().unwrap()).unwrap();
691 psr.step(data1.as_slice().unwrap()).unwrap();
692 let rhat = psr.rhat().unwrap();
693 let diff = *(rhat.clone() - expected.clone()).abs().max().unwrap();
694 assert!(
695 diff < tol,
696 "Mismatch in Rhat. Got {:?}, expected {:?}, diff = {:?}",
697 rhat,
698 expected,
699 diff
700 );
701 }
702
703 #[test]
704 fn test_rhat_f32_1() {
705 let data_step_0: Array2<f32> = arr2(&[
707 [0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
711
712 let data_step_1: Array2<f32> = arr2(&[
714 [1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
718 let expected = array![f32::consts::SQRT_2, 1.080_123_4, 0.894_427_3, 0.8660254];
719 run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
720 }
721
722 #[test]
723 fn test_rhat_f64_1() {
724 let data_step_0: Array2<f64> = arr2(&[
725 [0.0, 1.0, 0.0, 1.0], [1.0, 2.0, 0.0, 2.0], [0.0, 0.0, 0.0, 2.0], ]);
729 let data_step_1: Array2<f64> = arr2(&[
730 [1.0, 2.0, 2.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0], ]);
734 let expected = array![f32::consts::SQRT_2, 1.0801234, 0.8944271, 0.8660254];
735 run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
736 }
737
738 #[test]
739 fn test_rhat_f64_2() {
740 let data_step_0 = arr2(&[
741 [1.0, 0.0, 0.0, 1.0],
742 [1.0, 0.0, 0.0, 1.0],
743 [0.0, 1.0, 0.0, 2.0],
744 ]);
745 let data_step_1 = arr2(&[
746 [1.0, 2.0, 0.0, 2.0],
747 [1.0, 2.0, 0.0, 0.0],
748 [2.0, 0.0, 1.0, 2.0],
749 ]);
750 let expected = array![f32::consts::FRAC_1_SQRT_2, 0.74535599, 1.0, 1.5];
751 run_rhat_test_generic(data_step_0, data_step_1, expected, f32::EPSILON * 10.0);
752 }
753
754 fn run_test_case(
755 autocov_func: &dyn Fn(ArrayView2<f32>) -> Array2<f32>,
756 data: &Array2<f32>,
757 expected: &Array2<f32>,
758 test_name: &str,
759 ) {
760 let result = autocov_func(data.view());
761 assert_eq!(
762 result.dim(),
763 expected.dim(),
764 "{}: shape mismatch; got {:?}, expected {:?}",
765 test_name,
766 result.dim(),
767 expected.dim()
768 );
769
770 assert_abs_diff_eq!(result, *expected, epsilon = 1e-6);
771 println!("Test: {test_name} succeeded");
772 }
773
774 #[test]
778 fn test_single_param() {
779 let data = array![[1.0], [2.0], [3.0], [4.0],];
780 let expected = array![[1.25], [0.3125], [-0.375], [-0.5625]];
781
782 run_test_case(&autocov_bf, &data, &expected, "BF: single_param_small");
784
785 println!("Doing FFT test");
786 run_test_case(&autocov_fft, &data, &expected, "FFT: single_param_small");
788 println!("FFT test succeeded");
789 }
790
791 #[test]
795 fn test_two_params_1() {
796 let data = array![[1.0, 0.3], [2.0, 2.0], [3.0, -2.0], [4.0, 5.0],];
797 let expected = array![
798 [1.25, 6.516875],
799 [0.3125, -3.7889063],
800 [-0.375, 1.4721875],
801 [-0.5625, -0.94171875],
802 ];
803
804 run_test_case(&autocov_bf, &data, &expected, "BF: two_params_small");
806 run_test_case(&autocov_fft, &data, &expected, "FFT: two_params_small");
808 }
809
810 #[test]
811 fn ess_1() {
812 let m = 4;
813 let n = 1000;
814
815 let mut data = Array2::<f32>::zeros((m, n));
817
818 let mut rng = SmallRng::seed_from_u64(42);
820 for mut row in data.rows_mut() {
821 for elem in row.iter_mut() {
822 *elem = rng.random::<f32>(); }
824 }
825 let data = data
826 .to_shape((data.shape()[0], data.shape()[1], 1))
827 .unwrap();
828 let run_stats = RunStats::from(data.view());
829
830 println!("Samples: {}\n{run_stats}", m * n);
831
832 assert!(run_stats.ess.min > 3800.0);
833 assert!(run_stats.rhat.max < 1.01);
834 }
835
836 #[test]
837 #[ignore = "Benchmark test: run only when explicitly requested"]
838 fn test_autocov_perf_comp() {
839 let mut file =
841 File::create("runtime_results.csv").expect("Unable to create runtime_results.csv");
842 writeln!(file, "length,rep,time,algorithm").expect("Unable to write CSV header");
844
845 let mut rng = rand::rng();
846
847 for exp in 0..10 {
848 let n = 1 << exp; for rep in 1..=10 {
850 let sample_data: Vec<f32> = (0..n * 1000).map(|_| rng.random()).collect();
852 let sample = Array2::from_shape_vec((n, 1000), sample_data)
853 .expect("Failed to create Array2");
854
855 let start_fft = Instant::now();
857 autocov_fft(sample.view());
858 let fft_time = start_fft.elapsed().as_nanos();
859
860 let start_brute = Instant::now();
862 autocov_bf(sample.view());
863 let brute_time = start_brute.elapsed().as_nanos();
864
865 writeln!(file, "{},{},{},fft", n, rep, fft_time)
867 .expect("Unable to write test results to CSV");
868 writeln!(file, "{},{},{},brute force", n, rep, brute_time)
869 .expect("Unable to write test results to CSV");
870
871 println!(
873 "Length: {} | Rep: {} | FFT: {} ns | Brute: {} ns",
874 n, rep, fft_time, brute_time
875 );
876 }
877 }
878 }
879}