mzsignal/
average.rs

1//! Re-bin a single spectrum, or average together multiple spectra using
2//! interpolation.
3//!
4use std::borrow::Cow;
5use std::cmp;
6use std::collections::VecDeque;
7
8#[cfg(target_arch = "x86")]
9use std::arch::x86::__m256d;
10#[cfg(target_arch = "x86_64")]
11use std::arch::x86_64::__m256d;
12#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
13struct __m256d();
14
15#[cfg(feature = "parallelism")]
16use rayon::prelude::*;
17#[cfg(feature = "parallelism")]
18use std::sync::Mutex;
19
20use cfg_if;
21
22use mzpeaks::coordinate::{CoordinateLike, Time};
23
24use crate::arrayops::{gridspace, ArrayPair, ArrayPairIter, ArrayPairLike, ArrayPairSplit, MZGrid};
25use num_traits::Float;
26
27trait MZInterpolator {
28    /// Linear interpolation between two control points to find the intensity
29    /// at a third point between them.
30    ///
31    /// # Arguments
32    /// - `mz_j` - The first control point's m/z
33    /// - `mz_x` - The interpolated m/z
34    /// - `mz_j1` - The second control point's m/z
35    /// - `inten_j` - The first control point's intensity
36    /// - `inten_j1` - The second control point's intensity
37    #[inline]
38    fn interpolate_point(
39        &self,
40        mz_j: f64,
41        mz_x: f64,
42        mz_j1: f64,
43        inten_j: f64,
44        inten_j1: f64,
45    ) -> f64 {
46        // ((inten_j * (mz_j1 - mz_x)) + (inten_j1 * (mz_x - mz_j))) / (mz_j1 - mz_j)
47        let step_a = mz_j1 - mz_x;
48        let step_b = mz_x - mz_j;
49        let step_ab = mz_j1 - mz_j;
50        let vb = inten_j1 * step_b;
51        let vab = inten_j.mul_add(step_a, vb);
52        vab / step_ab
53    }
54
55    // A version of [`MZInterpolator::interpolate_point`] that uses AVX 256-bit register operations
56    #[cfg(feature = "avx")]
57    #[cfg(target_arch = "x86_64")]
58    fn interpolate_avx(
59        &self,
60        mz_j: __m256d,
61        mz_x: __m256d,
62        mz_j1: __m256d,
63        inten_j: __m256d,
64        inten_j1: __m256d,
65    ) -> __m256d {
66        unsafe {
67            use std::arch::x86_64::*;
68            let step_a = _mm256_sub_pd(mz_j1, mz_x);
69            let step_b = _mm256_sub_pd(mz_x, mz_j);
70            let step_ab = _mm256_sub_pd(mz_j1, mz_j);
71            let vb = _mm256_mul_pd(inten_j1, step_b);
72            let vab = _mm256_fmadd_pd(inten_j, step_a, vb);
73            _mm256_div_pd(vab, step_ab)
74        }
75    }
76}
77
78
79struct Interpolator {}
80impl MZInterpolator for Interpolator {}
81
82pub fn interpolate(xj: f64, x: f64, xj1: f64, yj: f64, yj1: f64) -> f64 {
83    Interpolator{}.interpolate_point(xj, x, xj1, yj, yj1)
84}
85
86struct MonotonicBlockSearcher<'a> {
87    data: &'a ArrayPair<'a>,
88    next_value: Option<f64>,
89    last_index: usize,
90}
91
92impl<'a> MonotonicBlockSearcher<'a> {
93    fn new(data: &'a ArrayPair<'a>) -> Self {
94        Self {
95            data,
96            next_value: None,
97            last_index: 0,
98        }
99    }
100
101    fn find_update(&mut self, mz: f64) -> usize {
102        let i = self.data.find(mz);
103        self.last_index = i;
104        self.next_value = self.data.mz_array.get(i).copied();
105        i
106    }
107
108    /// This assumes that the next value will be suitable, but this is not actually
109    /// true. The algorithm this component is used in though does not make the distinction
110    #[allow(unused)]
111    fn peek(&self, mz: f64) -> usize {
112        if let Some(next_value) = self.next_value {
113            if mz < next_value {
114                self.last_index
115            } else {
116                (self.last_index + 1).min(self.data.len().saturating_sub(1))
117            }
118        } else {
119            self.last_index
120        }
121    }
122
123    fn find(&mut self, mz: f64) -> usize {
124        if let Some(next_value) = self.next_value {
125            if mz < next_value {
126                self.last_index
127            } else {
128                self.find_update(mz)
129            }
130        } else {
131            self.find_update(mz)
132        }
133    }
134}
135
136#[allow(unused)]
137struct MonotonicBlockedIterator<'a, 'b: 'a, T: Iterator<Item = (f64, &'b mut f32)>> {
138    block: std::iter::Enumerate<ArrayPairIter<'a>>,
139    last_value: (usize, (f64, f64)),
140    current_value: (usize, (f64, f64)),
141    next_value: Option<(usize, (f64, f64))>,
142    block_n: usize,
143    it: T,
144}
145
146impl<'b, T: Iterator<Item = (f64, &'b mut f32)>> MZInterpolator
147    for MonotonicBlockedIterator<'_, 'b, T>
148{
149}
150
151type BlockIteratorPoint = (usize, (f64, f64));
152
153impl<'a, 'b: 'a, T: Iterator<Item = (f64, &'b mut f32)>> MonotonicBlockedIterator<'a, 'b, T> {
154    fn new(block: &'a ArrayPair<'a>, it: T) -> Self {
155        let mut source = block.iter().enumerate();
156        let current_value = source.next().map(|(i, (x, y))| (i, (x, y as f64))).unwrap();
157        let next_value = source.next().map(|(i, (x, y))| (i, (x, y as f64)));
158        let block_n = block.len();
159        Self {
160            block: source,
161            last_value: current_value,
162            current_value,
163            next_value,
164            block_n,
165            it,
166        }
167    }
168
169    fn next_value_from_source(&mut self) -> Option<BlockIteratorPoint> {
170        self.block.next().map(|(i, (x, y))| (i, (x, y as f64)))
171    }
172
173    fn step(&mut self) -> Option<(f64, &'b mut f32, BlockIteratorPoint)> {
174        if let Some((x, o)) = self.it.next() {
175            if let Some((vi, (vmz, vint))) = self.next_value.as_ref() {
176                if x >= *vmz {
177                    self.last_value = self.current_value;
178                    self.current_value = (*vi, (*vmz, *vint));
179                    self.next_value = self.next_value_from_source();
180                }
181                Some((x, o, self.current_value))
182            } else {
183                Some((x, o, self.current_value))
184            }
185        } else {
186            None
187        }
188    }
189
190    fn interpolant_step(&mut self) -> Option<(f64, &'b mut f32)> {
191        if let Some((mz, o, (_, (mz_j, inten_j)))) = self.step() {
192            if mz_j <= mz {
193                if let Some((_, (mz_j1, inten_j1))) = self.next_value {
194                    let inten = self.interpolate_point(mz_j, mz, mz_j1, inten_j, inten_j1);
195                    *o += inten as f32;
196                    Some((mz, o))
197                } else {
198                    let (mz_j1, inten_j1) = (mz_j, inten_j);
199                    let (_, (mz_j, inten_j)) = self.last_value;
200
201                    let inten = self.interpolate_point(mz_j, mz, mz_j1, inten_j, inten_j1);
202                    *o += inten as f32;
203                    Some((mz, o))
204                }
205            } else {
206                let (mz_j1, inten_j1) = (mz_j, inten_j);
207                let (_, (mz_j, inten_j)) = self.last_value;
208
209                let inten = self.interpolate_point(mz_j, mz, mz_j1, inten_j, inten_j1);
210                *o += inten as f32;
211                Some((mz, o))
212            }
213        } else {
214            None
215        }
216    }
217}
218
219impl<'a, 'b: 'a, T: Iterator<Item = (f64, &'b mut f32)>> Iterator
220    for MonotonicBlockedIterator<'a, 'b, T>
221{
222    type Item = (f64, &'b mut f32);
223
224    fn next(&mut self) -> Option<Self::Item> {
225        self.interpolant_step()
226    }
227}
228
229/// A linear interpolation spectrum intensity averager over a shared m/z axis.
230#[derive(Debug, Default, Clone)]
231pub struct SignalAverager<'lifespan> {
232    /// The evenly spaced m/z axis over which spectra are averaged.
233    pub mz_grid: Vec<f64>,
234    /// The lowest m/z in the spectrum. If an input spectrum has lower m/z values, they will be ignored.
235    pub mz_start: f64,
236    /// The highest m/z in the spectrum. If an input spectrum has higher m/z values, they will be ignored.
237    pub mz_end: f64,
238    /// The spacing between m/z values in `mz_grid`. This value should be chosen relative to the sharpness
239    /// of the peak shape of the mass analyzer used, but the smaller it is, the more computationally intensive
240    /// the averaging process is, and the more memory it consumes.
241    pub dx: f64,
242    /// The current set of spectra to be averaged together. This uses a deque because the usecase of
243    /// pushing spectra into an averaging window while removing them from the other side fits with the
244    /// way one might average spectra over time.
245    pub array_pairs: VecDeque<ArrayPair<'lifespan>>,
246}
247
248impl<'a, 'b: 'a> SignalAverager<'a> {
249    pub fn new(mz_start: f64, mz_end: f64, dx: f64) -> SignalAverager<'a> {
250        SignalAverager {
251            mz_grid: gridspace(mz_start, mz_end, dx),
252            mz_start,
253            mz_end,
254            dx,
255            array_pairs: VecDeque::new(),
256        }
257    }
258
259    /// Put `pair` into the queue of arrays being averaged together.
260    pub fn push(&mut self, pair: ArrayPair<'b>) {
261        self.array_pairs.push_back(pair)
262    }
263
264    /// Remove the least recently added array pair from the queue.
265    pub fn pop(&mut self) -> Option<ArrayPair<'a>> {
266        self.array_pairs.pop_front()
267    }
268
269    pub fn len(&self) -> usize {
270        self.array_pairs.len()
271    }
272
273    pub fn is_empty(&self) -> bool {
274        self.array_pairs.is_empty()
275    }
276
277    /// A linear interpolation across all spectra between `start_mz` and `end_mz`, with
278    /// their intensities written into `out`.
279    pub(crate) fn interpolate_into_iter(
280        &self,
281        out: &mut [f32],
282        start_mz: f64,
283        end_mz: f64,
284    ) -> usize {
285        let offset = self.find_offset(start_mz);
286        let stop_index = self.find_offset(end_mz);
287
288        let grid_size = self.mz_grid.len();
289        assert!(offset < grid_size || grid_size == 0);
290        assert!(stop_index <= grid_size);
291        assert!((stop_index - offset) == out.len());
292
293        let grid_slice = &self.mz_grid[offset..stop_index];
294        for block in self.array_pairs.iter() {
295            if block.is_empty() {
296                continue;
297            }
298
299            let start_idx = block.find(start_mz).saturating_sub(1);
300            let block_slice = block.slice(start_idx, block.len());
301
302            let it = MonotonicBlockedIterator::new(
303                &block_slice,
304                grid_slice.iter().copied().zip(out.iter_mut()),
305            );
306            let _traveled = it.count();
307        }
308        if self.array_pairs.len() > 1 {
309            let normalizer = self.array_pairs.len() as f32;
310            out.iter_mut().for_each(|y| *y /= normalizer);
311        }
312        stop_index - offset
313    }
314
315    #[inline(always)]
316    /// Get the first and second control points' m/z and intensity values,
317    /// (mz, inten, mz1, inten1), in ascendng m/z order around `x`
318    fn get_interpolation_values(
319        &self,
320        x: f64,
321        j: usize,
322        mz_j: f64,
323        block_n: usize,
324        block_mz_array: &[f64],
325        block_intensity_array: &[f32],
326    ) -> Option<(f64, f64, f64, f64)> {
327        let js1 = j + 1;
328        if (mz_j <= x) && (js1 < block_n) {
329            Some((
330                mz_j,
331                block_intensity_array[j] as f64,
332                block_mz_array[js1],
333                block_intensity_array[js1] as f64,
334            ))
335        } else if mz_j > x && j > 0 {
336            let js1 = j - 1;
337            Some((
338                block_mz_array[js1],
339                block_intensity_array[js1] as f64,
340                block_mz_array[j],
341                block_intensity_array[j] as f64,
342            ))
343        } else {
344            None
345        }
346    }
347
348    #[inline(always)]
349    fn interpolate_into_idx_seq(
350        &self,
351        grid_mzs: &[f64],
352        out: &mut [f32],
353        block_mz_array: &[f64],
354        block_intensity_array: &[f32],
355        block_n: usize,
356        block_searcher: &mut MonotonicBlockSearcher,
357    ) {
358        for (x, o) in grid_mzs.iter().copied().zip(out.iter_mut()) {
359            let j = block_searcher.find(x);
360            let mz_j = block_mz_array[j];
361
362            if let Some((mz_j, inten_j, mz_j1, inten_j1)) = self.get_interpolation_values(
363                x,
364                j,
365                mz_j,
366                block_n,
367                block_mz_array,
368                block_intensity_array,
369            ) {
370                let interp = self.interpolate_point(mz_j, x, mz_j1, inten_j, inten_j1);
371                *o += interp as f32;
372            }
373        }
374    }
375
376    #[inline(always)]
377    fn interpolate_into_idx_lanes_fallback<const LANES: usize>(
378        &self,
379        grid_mz_block: &[f64],
380        output_intensity_block: &mut [f32],
381        block_mz_array: &[f64],
382        block_intensity_array: &[f32],
383        block_n: usize,
384        block_searcher: &mut MonotonicBlockSearcher,
385    ) {
386        assert_eq!(grid_mz_block.len(), LANES);
387        assert_eq!(output_intensity_block.len(), LANES);
388        for lane_i in 0..LANES {
389            let grid_mz = grid_mz_block[lane_i];
390            let output_intensity = &mut output_intensity_block[lane_i];
391            let mz_index_of_x = block_searcher.find(grid_mz);
392            let mz_j = block_mz_array[mz_index_of_x];
393
394            if let Some((mz_j, inten_j, mz_j1, inten_j1)) = self.get_interpolation_values(
395                grid_mz,
396                mz_index_of_x,
397                mz_j,
398                block_n,
399                block_mz_array,
400                block_intensity_array,
401            ) {
402                let interp = self.interpolate_point(mz_j, grid_mz, mz_j1, inten_j, inten_j1);
403                *output_intensity += interp as f32;
404            }
405        }
406    }
407
408    #[cfg(feature = "avx")]
409    fn normalize_intensity_by_scan_count_avx(&self, out: &mut [f32]) {
410        #[cfg(target_arch = "x86_64")]
411        if std::arch::is_x86_feature_detected!("avx") {
412            // Use AVX SIMD instructions available on x86_64 CPUs to process up to eight steps at a time.
413            unsafe {
414                use std::arch::x86_64::*;
415                const LANES: usize = 8;
416                let normalizer = self.array_pairs.len() as f32;
417                let normalizer_v8 = _mm256_broadcast_ss(&normalizer);
418                let mut chunks_it = out.chunks_exact_mut(LANES);
419                for chunk in chunks_it.by_ref() {
420                    let o_v8: __m256 = _mm256_loadu_ps(chunk.as_ptr());
421                    let o_normalized_v8 = _mm256_div_ps(o_v8, normalizer_v8);
422                    _mm256_storeu_ps(chunk.as_mut_ptr(), o_normalized_v8);
423                }
424                for o in chunks_it.into_remainder() {
425                    *o /= normalizer;
426                }
427            }
428        } else {
429            self.normalize_intensity_by_scan_count_fallback(out)
430        }
431        #[cfg(not(target_arch = "x86_64"))]
432        self.normalize_intensity_by_scan_count_fallback(out);
433    }
434
435    fn normalize_intensity_by_scan_count_fallback(&self, out: &mut [f32]) {
436        let normalizer = self.array_pairs.len() as f32;
437
438        const LANES: usize = 8;
439        let mut it = out.chunks_exact_mut(LANES);
440
441        for chunk in it.by_ref() {
442            // Make it obvious to the compiler to vectorize
443            #[allow(clippy::needless_range_loop)]
444            for i in 0..LANES {
445                chunk[i] /= normalizer;
446            }
447        }
448        it.into_remainder()
449            .iter_mut()
450            .for_each(|y| *y /= normalizer);
451    }
452
453    fn normalize_intensity_by_scan_count(&self, out: &mut [f32]) {
454        #[cfg(target_arch = "x86_64")]
455        if std::arch::is_x86_feature_detected!("avx") {
456            #[cfg(feature = "avx")]
457            self.normalize_intensity_by_scan_count_avx(out);
458            #[cfg(not(feature = "avx"))]
459            self.normalize_intensity_by_scan_count_fallback(out);
460        } else {
461            self.normalize_intensity_by_scan_count_fallback(out);
462        }
463        #[cfg(not(target_arch = "x86_64"))]
464        self.normalize_intensity_by_scan_count_fallback(out);
465    }
466
467    pub(crate) fn interpolate_into_idx(
468        &self,
469        out: &mut [f32],
470        start_mz: f64,
471        end_mz: f64,
472    ) -> usize {
473        let offset = self.find_offset(start_mz);
474        let stop_index = self.find_offset(end_mz);
475
476        let grid_size = self.mz_grid.len();
477        {
478            assert!(offset < grid_size || grid_size == 0);
479            assert!(stop_index <= grid_size);
480            assert!((stop_index - offset) == out.len());
481        }
482
483        let grid_slice = &self.mz_grid[offset..stop_index];
484
485        for block in self.array_pairs.iter() {
486            if block.is_empty() {
487                continue;
488            }
489            let mut block_searcher = MonotonicBlockSearcher::new(block);
490            let block_n = block.len();
491            let block_mz_array = block.mz_array.as_ref();
492            let block_intensity_array = block.intensity_array.as_ref();
493            assert_eq!(block_mz_array.len(), block_n);
494            assert_eq!(block_intensity_array.len(), block_n);
495
496            const LANES: usize = 4;
497            let mut grid_chunks = grid_slice.chunks_exact(LANES);
498            let mut out_chunks = out.chunks_exact_mut(LANES);
499
500            while let (Some(grid_mz_block), Some(output_intensity_block)) =
501                (grid_chunks.next(), out_chunks.next())
502            {
503                #[cfg(not(target_arch = "x86_64"))]
504                let did_vector = false;
505                #[cfg(target_arch = "x86_64")]
506                let did_vector = if std::arch::is_x86_feature_detected!("avx") {
507                    #[cfg(not(feature = "avx"))]
508                    {
509                        false
510                    }
511                    #[cfg(feature = "avx")]
512                    // Use AVX SIMD instructions available on x86_64 CPUs to process up to four steps at a time.
513                    unsafe {
514                        use std::arch::x86_64::*;
515                        let grid_mz_first = *grid_mz_block.get_unchecked(0);
516                        let grid_mz_last = *grid_mz_block.get_unchecked(3);
517                        let j_first = block_searcher.find(grid_mz_first);
518                        let j_last = block_searcher.peek(grid_mz_last);
519                        let mz_j_first = *block_mz_array.get_unchecked(j_first);
520
521                        // If the solution uses the same two control points for every comparison, as given by both
522                        // using the same first point in the block, then we can take this fast path that performs
523                        // the interpolation operation using AVX and 256-bit vector instructions.
524                        //
525                        // This could also be done with the AVX 512-bit vectors but they are not available on most
526                        // machines yet.
527                        if j_first == j_last {
528                            if let Some((mz_j, inten_j, mz_j1, inten_j1)) = self
529                                .get_interpolation_values(
530                                    grid_mz_first,
531                                    j_first,
532                                    mz_j_first,
533                                    block_n,
534                                    block_mz_array,
535                                    block_intensity_array,
536                                )
537                            {
538                                // Populate the vectors going into `interpolate_avx`
539                                let mz_x_v4: __m256d = _mm256_loadu_pd(grid_mz_block.as_ptr());
540                                let mz_j_v4: __m256d = _mm256_broadcast_sd(&mz_j);
541                                let mz_j1_v4: __m256d = _mm256_broadcast_sd(&mz_j1);
542                                let inten_j_v4: __m256d = _mm256_broadcast_sd(&inten_j);
543                                let inten_j1_v4: __m256d = _mm256_broadcast_sd(&inten_j1);
544
545                                // Perform the interpolation on the vector registers
546                                let result_v4 = self.interpolate_avx(
547                                    mz_j_v4,
548                                    mz_x_v4,
549                                    mz_j1_v4,
550                                    inten_j_v4,
551                                    inten_j1_v4,
552                                );
553
554                                // Cast down from f64 to f32 registers
555                                let result_v4_f32 = _mm256_cvtpd_ps(result_v4);
556                                // Load the accumulator from the output array of f32
557                                let acc_v4 = _mm_loadu_ps(output_intensity_block.as_ptr());
558                                // Add the result to the accumulator
559                                let total_v4 = _mm_add_ps(result_v4_f32, acc_v4);
560                                // Store the accumulator back to the array of f32
561                                _mm_storeu_ps(output_intensity_block.as_mut_ptr(), total_v4);
562                            }
563                            true
564                        } else {
565                            false
566                        }
567                    }
568                } else {
569                    false
570                };
571                if !did_vector {
572                    #[allow(clippy::if_same_then_else)] // hint to the compiler that a hardware feature will be available
573                    #[cfg(target_arch = "x86_64")]
574                    if std::arch::is_x86_feature_detected!("avx") {
575                        self.interpolate_into_idx_lanes_fallback::<LANES>(
576                            grid_mz_block,
577                            output_intensity_block,
578                            block_mz_array,
579                            block_intensity_array,
580                            block_n,
581                            &mut block_searcher,
582                        );
583                    } else {
584                        self.interpolate_into_idx_lanes_fallback::<LANES>(
585                            grid_mz_block,
586                            output_intensity_block,
587                            block_mz_array,
588                            block_intensity_array,
589                            block_n,
590                            &mut block_searcher,
591                        );
592                    }
593                    #[cfg(not(target_arch = "x86_64"))]
594                    self.interpolate_into_idx_lanes_fallback::<LANES>(
595                        grid_mz_block,
596                        output_intensity_block,
597                        block_mz_array,
598                        block_intensity_array,
599                        block_n,
600                        &mut block_searcher,
601                    );
602                }
603            }
604
605            // Clean up remainder
606            self.interpolate_into_idx_seq(
607                grid_chunks.remainder(),
608                out_chunks.into_remainder(),
609                block_mz_array,
610                block_intensity_array,
611                block_n,
612                &mut block_searcher,
613            );
614        }
615        if self.array_pairs.len() > 1 {
616            self.normalize_intensity_by_scan_count(out);
617        }
618        stop_index - offset
619    }
620
621    pub fn interpolate_chunks(&self, n_chunks: usize) -> Vec<f32> {
622        let mut result = self.create_intensity_array();
623        if self.array_pairs.is_empty() {
624            return result;
625        }
626        let n_points = self.points_between(self.mz_start, self.mz_end);
627
628        let points_per_chunk = n_points / n_chunks;
629        for i in 0..n_chunks {
630            let offset = i * points_per_chunk;
631            let (size, start_mz, end_mz) = if i == n_chunks - 1 {
632                (n_points - offset, self.mz_grid[offset], self.mz_end)
633            } else {
634                (
635                    points_per_chunk,
636                    self.mz_grid[offset],
637                    self.mz_grid[offset + points_per_chunk],
638                )
639            };
640            let mut sub = self.create_intensity_array_of_size(size);
641            self.interpolate_into_iter(&mut sub, start_mz, end_mz);
642            (result[offset..offset + size]).copy_from_slice(&sub);
643        }
644        result
645    }
646
647    #[cfg(feature = "parallelism")]
648    #[allow(unused)]
649    pub(crate) fn interpolate_chunks_parallel_locked(&'a self, n_chunks: usize) -> Vec<f32> {
650        let result = self.create_intensity_array();
651        if self.array_pairs.is_empty() {
652            return result;
653        }
654        let n_points = self.points_between(self.mz_start, self.mz_end);
655        let locked_result = Mutex::new(result);
656        let points_per_chunk = n_points / n_chunks;
657        (0..n_chunks).into_par_iter().for_each(|i| {
658            let offset = i * points_per_chunk;
659            let (size, start_mz, end_mz) = if i == n_chunks - 1 {
660                (n_points - offset, self.mz_grid[offset], self.mz_end)
661            } else {
662                (
663                    points_per_chunk,
664                    self.mz_grid[offset],
665                    self.mz_grid[offset + points_per_chunk],
666                )
667            };
668            let mut sub = self.create_intensity_array_of_size(size);
669            self.interpolate_into_iter(&mut sub, start_mz, end_mz);
670
671            let mut out = locked_result.lock().unwrap();
672            (out[offset..offset + size]).copy_from_slice(&sub);
673        });
674        locked_result.into_inner().unwrap()
675    }
676
677    #[cfg(feature = "parallelism")]
678    #[allow(unused)]
679    pub(crate) fn interpolate_chunks_parallel(&'a self, n_chunks: usize) -> Vec<f32> {
680        let mut result = self.create_intensity_array();
681        if self.array_pairs.is_empty() {
682            return result;
683        }
684        let n_points = self.points_between(self.mz_start, self.mz_end);
685        let points_per_chunk = n_points / n_chunks;
686        let mz_chunks: Vec<&[f64]> = self.mz_grid.chunks(points_per_chunk).collect();
687        let mut intensity_chunks: Vec<&mut [f32]> = result.chunks_mut(points_per_chunk).collect();
688
689        intensity_chunks[..]
690            .par_iter_mut()
691            .zip(mz_chunks[..].par_iter())
692            .for_each(|(mut intensity_chunk, mz_chunk)| {
693                let start_mz = mz_chunk.first().unwrap();
694                // The + 1e-6 is just a gentle push to get interpolate_into to roll over to the last position in the chunk
695                let end_mz = mz_chunk.last().unwrap() + 1e-6;
696                self.interpolate_into_iter(intensity_chunk, *start_mz, end_mz);
697            });
698        result
699    }
700
701    pub fn interpolate_between(&'a self, mz_start: f64, mz_end: f64) -> (Vec<f32>, (usize, usize)) {
702        let (n_points, (start, end)) = self.points_between_with_indices(mz_start, mz_end);
703        let mut result = self.create_intensity_array_of_size(n_points);
704        self.interpolate_into_iter(&mut result, mz_start, mz_end);
705        (result, (start, end))
706    }
707
708    /// Allocate a new intensity array and interpolate the averaged representation of the collected spectra
709    /// and return it.
710    ///
711    /// ```math
712    /// y_z = \frac{y_{j} \times (x_{j} - x_{i}) + y_{i} \times (x_z - x_j)}{x_j - x_i}
713    /// ```
714    pub fn interpolate(&'a self) -> Vec<f32> {
715        let mut result = self.create_intensity_array();
716        self.interpolate_into_idx(&mut result, self.mz_start, self.mz_end);
717        result
718    }
719
720    #[allow(unused)]
721    pub fn interpolate_iter(&'a self) -> Vec<f32> {
722        let mut result = self.create_intensity_array();
723        self.interpolate_into_iter(&mut result, self.mz_start, self.mz_end);
724        result
725    }
726}
727
728impl MZGrid for SignalAverager<'_> {
729    fn mz_grid(&self) -> &[f64] {
730        &self.mz_grid
731    }
732}
733impl MZInterpolator for SignalAverager<'_> {}
734
735impl<'lifespan> Extend<ArrayPair<'lifespan>> for SignalAverager<'lifespan> {
736    fn extend<T: IntoIterator<Item = ArrayPair<'lifespan>>>(&mut self, iter: T) {
737        self.array_pairs.extend(iter)
738    }
739}
740
741// Can't inline cfg-if
742cfg_if::cfg_if! {
743    if #[cfg(feature = "parallelism")] {
744        fn average_signal_inner(averager: &SignalAverager, n: usize) -> Vec<f32> {
745            averager.interpolate_chunks_parallel(3 + n)
746        }
747    } else {
748        fn average_signal_inner(averager: &SignalAverager, _n: usize) -> Vec<f32> {
749            averager.interpolate()
750        }
751    }
752}
753
754/// Average together signal from the slice of `ArrayPair`s with spacing `dx` and create
755/// a new `ArrayPair` from it
756pub fn average_signal<'lifespan, 'owned: 'lifespan>(
757    signal: &[ArrayPair<'lifespan>],
758    dx: f64,
759) -> ArrayPair<'owned> {
760    let (mz_min, mz_max) = signal.iter().fold((f64::infinity(), 0.0), |acc, x| {
761        (
762            if acc.0 < x.min_mz { acc.0 } else { x.min_mz },
763            if acc.1 > x.max_mz { acc.1 } else { x.max_mz },
764        )
765    });
766    let mut averager = SignalAverager::new(mz_min, mz_max, dx);
767    averager
768        .array_pairs
769        .extend(signal.iter().map(|a| a.borrow()));
770    let signal = average_signal_inner(&averager, signal.len());
771    ArrayPair::new(Cow::Owned(averager.copy_mz_array()), Cow::Owned(signal))
772}
773
774#[inline(never)]
775pub fn rebin<'transient, 'lifespan: 'transient>(
776    mz_array: &'lifespan [f64],
777    intensity_array: &'lifespan [f32],
778    dx: f64,
779) -> ArrayPair<'transient> {
780    let pair = [ArrayPair::from((mz_array, intensity_array))];
781    average_signal(&pair, dx)
782}
783
784/// A segment over a signal array pair
785#[derive(Debug, Default, Clone, Copy)]
786pub struct Segment {
787    start: usize,
788    end: usize,
789}
790
791/// An [`ArrayPair`] with an associated list of [`Segment`], an associated cached interpolated intensity array
792/// and a time index
793#[derive(Debug, Default, Clone)]
794pub struct ArrayPairWithSegments<'a> {
795    /// The original signal
796    pub array_pair: ArrayPair<'a>,
797    /// The segments over the interpolated coordinate system that there was signal for
798    pub segments: Vec<Segment>,
799    /// The interpolated signal
800    pub intensity_array: Vec<f32>,
801    /// The time point associated with this signal
802    pub time: f64,
803}
804
805impl PartialEq for ArrayPairWithSegments<'_> {
806    fn eq(&self, other: &Self) -> bool {
807        self.time == other.time && self.intensity_array == other.intensity_array
808    }
809}
810
811impl PartialOrd for ArrayPairWithSegments<'_> {
812    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
813        Some(self.cmp(other))
814    }
815}
816
817impl Eq for ArrayPairWithSegments<'_> {}
818
819impl Ord for ArrayPairWithSegments<'_> {
820    fn cmp(&self, other: &Self) -> cmp::Ordering {
821        self.time.total_cmp(&other.time)
822    }
823}
824
825impl CoordinateLike<Time> for ArrayPairWithSegments<'_> {
826    fn coordinate(&self) -> f64 {
827        self.time
828    }
829}
830
831/// A linear interpolation spectrum intensity averager over a shared m/z axis that pre-computes
832/// a segmented intensity grid.
833#[derive(Debug, Default, Clone)]
834pub struct SegmentGridSignalAverager<'lifespan> {
835    /// The evenly spaced m/z axis over which spectra are averaged.
836    pub mz_grid: Vec<f64>,
837    /// The lowest m/z in the spectrum. If an input spectrum has lower m/z values, they will be ignored.
838    pub mz_start: f64,
839    /// The highest m/z in the spectrum. If an input spectrum has higher m/z values, they will be ignored.
840    pub mz_end: f64,
841    /// The spacing between m/z values in `mz_grid`. This value should be chosen relative to the sharpness
842    /// of the peak shape of the mass analyzer used, but the smaller it is, the more computationally intensive
843    /// the averaging process is, and the more memory it consumes.
844    pub dx: f64,
845    /// The current set of spectra to be averaged over
846    pub array_pairs: Vec<ArrayPairWithSegments<'lifespan>>,
847}
848
849impl<'lifespan> Extend<(f64, ArrayPair<'lifespan>)> for SegmentGridSignalAverager<'lifespan> {
850    fn extend<T: IntoIterator<Item = (f64, ArrayPair<'lifespan>)>>(&mut self, iter: T) {
851        for (time, block) in iter {
852            self.push(time, block)
853        }
854    }
855}
856
857impl<'a, 'lifespan: 'a> SegmentGridSignalAverager<'lifespan> {
858    pub fn new(mz_start: f64, mz_end: f64, dx: f64) -> Self {
859        Self {
860            mz_grid: gridspace(mz_start, mz_end, dx),
861            mz_start,
862            mz_end,
863            dx,
864            array_pairs: Vec::new(),
865        }
866    }
867
868    pub fn from_iter<I: Iterator<Item = (f64, ArrayPair<'lifespan>)>>(
869        mz_start: f64,
870        mz_end: f64,
871        dx: f64,
872        iter: I,
873    ) -> Self {
874        let mut inst = Self::new(mz_start, mz_end, dx);
875        inst.extend(iter);
876        inst
877    }
878
879    pub fn find_time(&self, time: f64) -> Option<usize> {
880        let i = self
881            .array_pairs
882            .binary_search_by(|block| block.time.total_cmp(&time));
883        match i {
884            Ok(i) => Some(i),
885            Err(i) => (i.saturating_sub(2)..(i + 2).min(self.array_pairs.len()))
886                .min_by(|i, j| {
887                    let err_i = self
888                        .array_pairs
889                        .get(*i)
890                        .map(|block| (block.time - time).abs())
891                        .unwrap_or(f64::INFINITY);
892                    let err_j = self
893                        .array_pairs
894                        .get(*j)
895                        .map(|block| (block.time - time).abs())
896                        .unwrap_or(f64::INFINITY);
897                    err_i.total_cmp(&err_j)
898                }),
899        }
900    }
901
902    pub fn time_at(&self, index: usize) -> Option<f64> {
903        self.array_pairs.get(index).map(|block| block.time)
904    }
905
906    pub fn push(&mut self, time: f64, block: ArrayPair<'lifespan>) {
907        let block = self.populate_intensity_axis(time, block);
908        self.push_block(block);
909    }
910
911    fn push_block(&mut self, block: ArrayPairWithSegments<'lifespan>) {
912        if let Some(time) = self.array_pairs.last().map(|block| block.time) {
913            if time < block.time {
914                self.array_pairs.push(block)
915            } else {
916                self.array_pairs.push(block);
917                self.array_pairs.sort();
918            }
919        } else {
920            self.array_pairs.push(block)
921        }
922    }
923
924    pub fn len(&self) -> usize {
925        self.array_pairs.len()
926    }
927
928    pub fn is_empty(&self) -> bool {
929        self.array_pairs.is_empty()
930    }
931
932    fn populate_intensity_axis(
933        &self,
934        time: f64,
935        block: ArrayPair<'lifespan>,
936    ) -> ArrayPairWithSegments<'lifespan> {
937        let mut segments = Vec::default();
938        if block.is_empty() {
939            return ArrayPairWithSegments {
940                array_pair: block,
941                segments,
942                intensity_array: Vec::new(),
943                time,
944            };
945        }
946        let mut segment = Segment::default();
947        let mut opened = false;
948
949        let n = self.mz_grid.len();
950
951        let mut intensity_axis_ = self.create_intensity_array();
952        let intensity_axis = &mut intensity_axis_[0..n];
953
954        let mut block_searcher = MonotonicBlockSearcher::new(&block);
955        for (i, x) in self.mz_grid.iter().copied().enumerate() {
956            let j = block_searcher.find(x);
957            let mz_j = block.mz_array[j];
958
959            let (mz_j, inten_j, mz_j1, inten_j1) = if (mz_j <= x) && ((j + 1) < block.len()) {
960                (
961                    mz_j,
962                    block.intensity_array[j],
963                    block.mz_array[j + 1],
964                    block.intensity_array[j + 1],
965                )
966            } else if mz_j > x && j > 0 {
967                (
968                    block.mz_array[j - 1],
969                    block.intensity_array[j - 1],
970                    mz_j,
971                    block.intensity_array[j],
972                )
973            } else {
974                continue;
975            };
976            let interp = self.interpolate_point(mz_j, x, mz_j1, inten_j as f64, inten_j1 as f64);
977            intensity_axis[i] = interp as f32;
978            if interp > 0.0 {
979                if opened {
980                    segment.end = i;
981                } else {
982                    segment.start = i;
983                    opened = true;
984                }
985            } else if opened {
986                segment.end = i;
987                opened = false;
988                segments.push(segment);
989                segment = Segment::default();
990            }
991        }
992        if opened {
993            segment.end = n;
994            segments.push(segment);
995        }
996        ArrayPairWithSegments {
997            array_pair: block,
998            segments,
999            intensity_array: intensity_axis_,
1000            time,
1001        }
1002    }
1003
1004    pub fn iter(&'a self, width: usize) -> SegmentGridSignalAveragerIter<'a> {
1005        SegmentGridSignalAveragerIter {
1006            averager: self,
1007            index: 0,
1008            width,
1009        }
1010    }
1011
1012    pub fn average_over(&'a self, time: f64, width: usize) -> ArrayPairSplit<'a, 'static> {
1013        if let Some(i) = self.find_time(time) {
1014            self.average_over_index(i, width)
1015        } else {
1016            ArrayPairSplit::default()
1017        }
1018    }
1019
1020    pub fn average_over_index(&'a self, index: usize, width: usize) -> ArrayPairSplit<'a, 'static> {
1021        let blocks = &self.array_pairs
1022            [index.saturating_sub(width)..(index + width).min(self.array_pairs.len())];
1023        self.average_segments(blocks)
1024    }
1025
1026    pub fn average_segments(
1027        &'a self,
1028        segments: &[ArrayPairWithSegments],
1029    ) -> ArrayPairSplit<'a, 'static> {
1030        let (offset, end) = segments.iter().fold((usize::MAX, 0), |(start, end), seg| {
1031            let start = seg
1032                .segments
1033                .first()
1034                .map(|s| start.min(s.start))
1035                .unwrap_or(start);
1036            let end = seg.segments.last().map(|s| end.max(s.end)).unwrap_or(end);
1037            (start, end)
1038        });
1039
1040        if offset >= end {
1041            return (Vec::new(), Vec::new()).into();
1042        }
1043
1044        let mut intensity_array = self.create_intensity_array_of_size(end - offset);
1045
1046        for seg in segments.iter() {
1047            for sg in seg.segments.iter() {
1048                for i in sg.start..sg.end {
1049                    intensity_array[i.saturating_sub(offset)] += seg.intensity_array[i];
1050                }
1051            }
1052        }
1053
1054        let mz_array = Cow::Borrowed(&self.mz_grid[offset..end]);
1055
1056        ArrayPairSplit::new(mz_array, Cow::Owned(intensity_array))
1057    }
1058}
1059
1060impl MZGrid for SegmentGridSignalAverager<'_> {
1061    fn mz_grid(&self) -> &[f64] {
1062        &self.mz_grid
1063    }
1064}
1065impl MZInterpolator for SegmentGridSignalAverager<'_> {}
1066
1067#[derive(Debug)]
1068pub struct SegmentGridSignalAveragerIter<'lifespan> {
1069    averager: &'lifespan SegmentGridSignalAverager<'lifespan>,
1070    width: usize,
1071    index: usize,
1072}
1073
1074impl<'lifespan> Iterator for SegmentGridSignalAveragerIter<'lifespan> {
1075    type Item = (f64, ArrayPairSplit<'lifespan, 'static>);
1076
1077    fn next(&mut self) -> Option<Self::Item> {
1078        if self.index < self.averager.len() {
1079            let block = self.averager.average_over_index(self.index, self.width);
1080            let time = self.averager.time_at(self.index).unwrap();
1081            self.index += 1;
1082            Some((time, block))
1083        } else {
1084            None
1085        }
1086    }
1087}
1088
1089#[cfg(test)]
1090mod test {
1091    use std::io;
1092
1093    use mzpeaks::MZPeakSetType;
1094
1095    use super::*;
1096    use crate::peak_picker::PeakPicker;
1097    use crate::test_data::{X, Y};
1098    #[allow(unused)]
1099    use crate::text;
1100    use crate::FittedPeak;
1101
1102    #[test]
1103    fn test_rebin_one() {
1104        let mut averager = SignalAverager::new(X[0], X[X.len() - 1], 0.001);
1105        averager.push(ArrayPair::wrap(&X, &Y));
1106        let yhat = averager.interpolate();
1107        // text::arrays_to_file(ArrayPair::wrap(&averager.mz_grid, &yhat), "interpolate_avx.txt").unwrap();
1108        let picker = PeakPicker::default();
1109        let mut acc = Vec::new();
1110        picker
1111            .discover_peaks(&averager.mz_grid, &yhat, &mut acc)
1112            .expect("Signal can be picked");
1113        let mzs = [180.0633881, 181.06387399204235, 182.06404644991485];
1114        for (i, (peak, mz)) in acc.iter().zip(mzs.iter()).enumerate() {
1115            let diff = peak.mz - mz;
1116            assert!((peak.mz - mz).abs() < 1e-4, "Diff {} on peak {i}", diff);
1117            assert!(peak.intensity > 0.0);
1118        }
1119    }
1120
1121    #[test]
1122    fn test_averaging() -> io::Result<()> {
1123        let scans = text::arrays_over_time_from_file("./test/data/profiles.txt")?;
1124        let scans: Vec<_> = scans
1125            .into_iter()
1126            .skip(3)
1127            .take(3)
1128            .map(|(_, arrays)| arrays)
1129            .collect();
1130
1131        let low_mz = scans
1132            .iter()
1133            .map(|s| s.min_mz)
1134            .min_by(|a, b| a.total_cmp(b))
1135            .unwrap();
1136        let high_mz = scans
1137            .iter()
1138            .map(|s| s.max_mz)
1139            .max_by(|a, b| a.total_cmp(b))
1140            .unwrap();
1141
1142        let mut averager = SignalAverager::new(low_mz, high_mz, 0.001);
1143        averager.extend(scans.clone());
1144
1145        let _yhat = averager.interpolate();
1146        Ok(())
1147    }
1148
1149    #[test]
1150    fn test_rebin_chunked() {
1151        let mut averager = SignalAverager::new(X[0], X[X.len() - 1], 0.00001);
1152        averager.push(ArrayPair::wrap(&X, &Y));
1153        let yhat = averager.interpolate_chunks(3);
1154        // text::arrays_to_file(ArrayPair::wrap(&averager.mz_grid, &yhat), "chunked_iter.txt").unwrap();
1155        let picker = PeakPicker::default();
1156        let mut acc = Vec::new();
1157        picker
1158            .discover_peaks(&averager.mz_grid, &yhat, &mut acc)
1159            .expect("Signal can be picked");
1160        let mzs = [180.0633881, 181.06387399204235, 182.06404644991485];
1161        for (i, (peak, mz)) in acc.iter().zip(mzs.iter()).enumerate() {
1162            let diff = peak.mz - mz;
1163            assert!((peak.mz - mz).abs() < 1e-4, "Diff {} on peak {i}", diff);
1164            assert!(peak.intensity > 0.0);
1165        }
1166    }
1167
1168    #[test]
1169    #[cfg(feature = "parallelism")]
1170    fn test_rebin_parallel_locked() {
1171        let mut averager = SignalAverager::new(X[0], X[X.len() - 1], 0.00001);
1172        averager.push(ArrayPair::wrap(&X, &Y));
1173        let yhat = averager.interpolate_chunks_parallel_locked(6);
1174        let picker = PeakPicker::default();
1175        let mut acc = Vec::new();
1176        picker
1177            .discover_peaks(&averager.mz_grid, &yhat, &mut acc)
1178            .expect("Signal can be picked");
1179        let mzs = [180.0633881, 181.06387399204235, 182.06404644991485];
1180        for (i, (peak, mz)) in acc.iter().zip(mzs.iter()).enumerate() {
1181            let diff = peak.mz - mz;
1182            assert!((peak.mz - mz).abs() < 1e-4, "Diff {} on peak {i}", diff);
1183            assert!(peak.intensity > 0.0);
1184        }
1185    }
1186
1187    #[test]
1188    #[cfg(feature = "parallelism")]
1189    fn test_rebin_parallel() {
1190        let mut averager = SignalAverager::new(X[0], X[X.len() - 1], 0.001);
1191        averager.push(ArrayPair::wrap(&X, &Y));
1192        let yhat = averager.interpolate_chunks_parallel(6);
1193        let picker = PeakPicker::new(0.0, 0.0, 1.0, Default::default());
1194        let mut acc = Vec::new();
1195        picker
1196            .discover_peaks(&averager.mz_grid, &yhat, &mut acc)
1197            .expect("Signal can be picked");
1198        let mzs = [180.0633881, 181.06387399204235, 182.06404644991485];
1199        for (i, (peak, mz)) in acc.iter().zip(mzs.iter()).enumerate() {
1200            let diff = peak.mz - mz;
1201            assert!((peak.mz - mz).abs() < 1e-4, "Diff {} on peak {i}", diff);
1202            assert!(peak.intensity > 0.0);
1203        }
1204    }
1205
1206    #[test]
1207    fn test_rebin() {
1208        let pair = rebin(&X, &Y, 0.001);
1209        let (acc, _, n) = pair.mz_array().iter().copied().fold((0.0, pair.min_mz, 0), |(acc, last, n), mz| {
1210            (acc + (mz - last), mz, n + 1)
1211        });
1212        let avg = acc / (n as f64);
1213        assert!((avg - 0.0009998319327731112).abs() < 1e-6);
1214    }
1215
1216    #[test_log::test]
1217    fn test_segment_grid() -> io::Result<()> {
1218        use crate::text::arrays_over_time_from_file;
1219        let time_arrays = arrays_over_time_from_file("./test/data/peaks_over_time.txt")?;
1220
1221        let reprofiler = crate::reprofile::PeakSetReprofiler::new(200.0, 2000.0, 0.001);
1222
1223        let prepare_block = |t: f64, row: ArrayPair| {
1224            // log::info!("{i}: {t} with {} peaks", row.len());
1225            let peaks: MZPeakSetType<FittedPeak> = row
1226                .mz_array
1227                .iter()
1228                .zip(row.intensity_array.iter())
1229                .map(|(mz, i)| FittedPeak::new(*mz, *i, 0, *i, 0.005))
1230                .collect();
1231
1232            // log::info!("Reprofiling");
1233            let peak_models = reprofiler
1234                .build_peak_shape_models(peaks.as_slice(), crate::reprofile::PeakShape::Gaussian);
1235            let block = reprofiler.reprofile_from_models(&peak_models);
1236            (t, block)
1237        };
1238
1239        let mut t_blocks: Vec<(f64, ArrayPair<'_>)> = time_arrays
1240            .into_iter()
1241            .take(5)
1242            .map(|(t, row)| prepare_block(t, row))
1243            .collect();
1244
1245        t_blocks.sort_by(|a, b| a.0.total_cmp(&b.0));
1246        let mut averager = SegmentGridSignalAverager::from_iter(200.0, 2000.0, 0.001, t_blocks.into_iter());
1247        averager.array_pairs.sort();
1248
1249        // log::info!("Start averaging");
1250        let views: Vec<_> = averager.iter(1).collect();
1251        assert_eq!(views.len(), 5);
1252
1253        views.iter().for_each(|(_, block)| {
1254            assert!(block.intensity_array.iter().all(|i| *i >= 0.0));
1255        });
1256        Ok(())
1257    }
1258}