simd_minimizers/
sliding_min.rs

1//! Sliding window minimum over windows of size `w`.
2//!
3//! For each window, the absolute position of the minimum is returned.
4//!
5//! Each method takes a `LEFT: bool` const generic. Set to `true` to break ties
6//! towards the leftmost minimum, and false for the rightmost minimum.
7//!
8//! All these methods take 32 bit input values, **but they only use the upper 16 bits!**
9//!
10//! Positions ar
11use crate::S;
12use core::array::from_fn;
13use std::hint::assert_unchecked;
14
15/// A custom RingBuf implementation that has a fixed size `w` and wraps around.
16#[derive(Default, Clone)]
17struct RingBuf<V> {
18    w: usize,
19    idx: usize,
20    data: Vec<V>,
21}
22
23impl<V: Clone> RingBuf<V> {
24    #[allow(unused)]
25    #[inline(always)]
26    fn new(w: usize, v: V) -> Self {
27        assert!(w > 0);
28        let data = vec![v; w];
29        Self { w, idx: 0, data }
30    }
31
32    #[inline(always)]
33    fn assign(&mut self, w: usize, v: V) -> &mut Self {
34        assert!(w > 0);
35        self.w = w;
36        self.idx = 0;
37        self.data.clear();
38        self.data.resize(w, v);
39        self
40    }
41
42    /// Returns the next index to be written.
43    #[inline(always)]
44    const fn idx(&self) -> usize {
45        self.idx
46    }
47
48    #[inline(always)]
49    fn push(&mut self, v: V) {
50        *unsafe { self.data.get_unchecked_mut(self.idx) } = v;
51        self.idx += 1;
52        if self.idx == self.w {
53            self.idx = 0;
54        }
55    }
56}
57
58/// A RingBuf can be used as a slice.
59impl<V> std::ops::Deref for RingBuf<V> {
60    type Target = [V];
61
62    #[inline(always)]
63    fn deref(&self) -> &[V] {
64        &self.data
65    }
66}
67
68/// A RingBuf can be used as a mutable slice.
69impl<V> std::ops::DerefMut for RingBuf<V> {
70    #[inline(always)]
71    fn deref_mut(&mut self) -> &mut [V] {
72        &mut self.data
73    }
74}
75
76#[derive(Default, Clone)]
77pub struct Cache {
78    scalar: RingBuf<u32>,
79    scalar_lr: RingBuf<(u32, u32)>,
80    simd: RingBuf<S>,
81    simd_lr: RingBuf<(S, S)>,
82}
83
84/// Scalar version. Takes an iterator over values and returns an iterator over positions.
85#[inline(always)]
86pub fn sliding_min_mapper_scalar<const LEFT: bool>(
87    w: usize,
88    len: usize,
89    cache: &mut Cache,
90) -> impl FnMut(u32) -> u32 {
91    assert!(w > 0);
92    assert!(
93        w < (1 << 15),
94        "sliding_min is not tested for windows of length > 2^15."
95    );
96    assert!(
97        len < (1 << 32),
98        "sliding_min returns 32bit indices. Try splitting the input into 4GB chunks first."
99    );
100    let mut prefix_min = u32::MAX;
101    let ring_buf = cache.scalar.assign(w, prefix_min);
102    // We only compare the upper 16 bits of each hash.
103    // Ties are broken automatically in favour of lower pos.
104    let val_mask = 0xffff_0000;
105    let pos_mask = 0x0000_ffff;
106    let mut pos = 0;
107    let max_pos = (1 << 16) - 1;
108    let mut pos_offset = 0;
109
110    fn min<const LEFT: bool>(a: u32, b: u32) -> u32 {
111        if LEFT { a.min(b) } else { a.max(b) }
112    }
113
114    #[inline(always)]
115    move |val| {
116        // Make sure the position does not interfere with the hash value.
117        if pos == max_pos {
118            let delta = ((1 << 16) - 2 - w) as u32;
119            pos -= delta;
120            prefix_min -= delta;
121            pos_offset += delta;
122            for x in &mut **ring_buf {
123                *x -= delta;
124            }
125        }
126        let elem = (if LEFT { val } else { !val } & val_mask) | pos;
127        pos += 1;
128        ring_buf.push(elem);
129        prefix_min = min::<LEFT>(prefix_min, elem);
130        // After a chunk has been filled, compute suffix minima.
131        if ring_buf.idx() == 0 {
132            let mut suffix_min = ring_buf[w - 1];
133            for i in (0..w - 1).rev() {
134                suffix_min = min::<LEFT>(suffix_min, ring_buf[i]);
135                ring_buf[i] = suffix_min;
136            }
137            prefix_min = elem; // slightly faster than assigning S::splat(u32::MAX)
138        }
139        let suffix_min = unsafe { *ring_buf.get_unchecked(ring_buf.idx()) };
140        (min::<LEFT>(prefix_min, suffix_min) & pos_mask) + pos_offset
141    }
142}
143
144#[inline(always)]
145pub fn sliding_lr_min_mapper_scalar(
146    w: usize,
147    len: usize,
148    cache: &mut Cache,
149) -> impl FnMut(u32) -> (u32, u32) {
150    assert!(w > 0);
151    assert!(
152        w < (1 << 15),
153        "sliding_min is not tested for windows of length > 2^15."
154    );
155    assert!(
156        len < (1 << 32),
157        "sliding_min returns 32bit indices. Try splitting the input into 4GB chunks first."
158    );
159    let mut prefix_lr_min = (u32::MAX, u32::MAX);
160    let ring_buf = cache.scalar_lr.assign(w, prefix_lr_min);
161    // We only compare the upper 16 bits of each hash.
162    // Ties are broken automatically in favour of lower pos.
163    let val_mask = 0xffff_0000;
164    let pos_mask = 0x0000_ffff;
165    let mut pos = 0;
166    let max_pos = (1 << 16) - 1;
167    let mut pos_offset = 0;
168
169    fn lr_min((al, ar): (u32, u32), (bl, br): (u32, u32)) -> (u32, u32) {
170        (
171            std::hint::select_unpredictable(al < bl, al, bl),
172            std::hint::select_unpredictable(ar > br, ar, br),
173        )
174    }
175
176    #[inline(always)]
177    move |val| {
178        // Make sure the position does not interfere with the hash value.
179        if pos == max_pos {
180            let delta = ((1 << 16) - 2 - w) as u32;
181            pos -= delta;
182            prefix_lr_min.0 -= delta;
183            prefix_lr_min.1 -= delta;
184            pos_offset += delta;
185            for x in &mut **ring_buf {
186                x.0 -= delta;
187                x.1 -= delta;
188            }
189        }
190        let lelem = (val & val_mask) | pos;
191        let relem = (!val & val_mask) | pos;
192        let elem = (lelem, relem);
193        pos += 1;
194        ring_buf.push(elem);
195        prefix_lr_min = lr_min(prefix_lr_min, elem);
196        // After a chunk has been filled, compute suffix minima.
197        if ring_buf.idx() == 0 {
198            let mut suffix_min = ring_buf[w - 1];
199            for i in (0..w - 1).rev() {
200                suffix_min = lr_min(suffix_min, ring_buf[i]);
201                ring_buf[i] = suffix_min;
202            }
203            prefix_lr_min = elem; // slightly faster than assigning S::splat(u32::MAX)
204        }
205        let suffix_min = unsafe { *ring_buf.get_unchecked(ring_buf.idx()) };
206        let (lmin, rmin) = lr_min(prefix_lr_min, suffix_min);
207        (
208            (lmin & pos_mask) + pos_offset,
209            (rmin & pos_mask) + pos_offset,
210        )
211    }
212}
213
214fn simd_min<const LEFT: bool>(a: S, b: S) -> S {
215    if LEFT { a.min(b) } else { a.max(b) }
216}
217
218/// Mapper version, that returns a function that can be called with new inputs as needed.
219/// Output values are offset by `-(k-1)`, so that the k'th returned value (the first kmer) is at position 0.
220/// `len` is the number of values in each chunk. The SIMD lanes will be offset by `len-(k+w-2)`.
221/// The first `k+w-2` returned values are bogus, since they correspond to incomplete windows.
222pub fn sliding_min_mapper_simd<const LEFT: bool>(
223    w: usize,
224    len: usize,
225    cache: &mut Cache,
226) -> impl FnMut(S) -> S {
227    assert!(w > 0);
228    assert!(w < (1 << 15), "This method is not tested for large w.");
229    assert!(len * 8 < (1 << 32));
230    let mut prefix_min = S::splat(u32::MAX);
231    let ring_buf = cache.simd.assign(w, prefix_min);
232    // We only compare the upper 16 bits of each hash.
233    // Ties are broken automatically in favour of lower pos.
234    let val_mask = S::splat(0xffff_0000);
235    let pos_mask = S::splat(0x0000_ffff);
236    let max_pos = S::splat((1 << 16) - 1);
237    let mut pos = S::splat(0);
238    // Sliding min is over w+k-1 characters, so chunks overlap w+k-2.
239    // Thus, the true length of each lane is len-(k+w-2).
240    //
241    // The k-mer starting at position 0 is done after processing the char at
242    // position k-1, so we compensate for that as well.
243    let mut pos_offset: S = from_fn(|l| (l * len.saturating_sub(w - 1)) as u32).into();
244    let delta = S::splat((1 << 16) - 2 - w as u32);
245
246    #[inline(always)]
247    move |val| {
248        // Make sure the position does not interfere with the hash value.
249        if pos == max_pos {
250            // Slow case extracted to a function to have better inlining here.
251            reset_positions_offsets(delta, &mut pos, &mut prefix_min, &mut pos_offset, ring_buf);
252        }
253        // slightly faster than assigning S::splat(u32::MAX)
254        let elem = (if LEFT { val } else { !val } & val_mask) | pos;
255        pos += S::ONE;
256        ring_buf.push(elem);
257        prefix_min = simd_min::<LEFT>(prefix_min, elem);
258        // After a chunk has been filled, compute suffix minima.
259        if ring_buf.idx() == 0 {
260            // Slow case extracted to a function to have better inlining here.
261            suffix_minima::<LEFT>(ring_buf, w, &mut prefix_min, elem);
262        }
263
264        let suffix_min = unsafe { *ring_buf.get_unchecked(ring_buf.idx()) };
265        (simd_min::<LEFT>(prefix_min, suffix_min) & pos_mask) + pos_offset
266    }
267}
268
269fn suffix_minima<const LEFT: bool>(
270    ring_buf: &mut RingBuf<S>,
271    w: usize,
272    prefix_min: &mut S,
273    elem: S,
274) {
275    // Avoid some bounds checks when this function is not inlined.
276    unsafe { assert_unchecked(ring_buf.len() == w) };
277    unsafe { assert_unchecked(w > 0) };
278    let mut suffix_min = ring_buf[w - 1];
279    for i in (0..w - 1).rev() {
280        suffix_min = simd_min::<LEFT>(suffix_min, ring_buf[i]);
281        ring_buf[i] = suffix_min;
282    }
283    *prefix_min = elem;
284}
285
286fn reset_positions_offsets(
287    delta: S,
288    pos: &mut S,
289    prefix_min: &mut S,
290    pos_offset: &mut S,
291    ring_buf: &mut RingBuf<S>,
292) {
293    *pos -= delta;
294    *prefix_min -= delta;
295    *pos_offset += delta;
296    for x in &mut **ring_buf {
297        *x -= delta;
298    }
299}
300
301/// Like `sliding_min_mapper`, but returns both the leftmost and the rightmost minimum.
302pub fn sliding_lr_min_mapper_simd(
303    w: usize,
304    len: usize,
305    cache: &mut Cache,
306) -> impl FnMut(S) -> (S, S) {
307    assert!(w > 0);
308    assert!(w < (1 << 15), "This method is not tested for large w.");
309    assert!(len * 8 < (1 << 32));
310    let mut prefix_lr_min = (S::splat(u32::MAX), S::splat(u32::MAX));
311    let ring_buf = cache.simd_lr.assign(w, prefix_lr_min);
312    // let mut ring_buf = RingBuf::new(w, prefix_lr_min);
313    // We only compare the upper 16 bits of each hash.
314    // Ties are broken automatically in favour of lower pos.
315    let val_mask = S::splat(0xffff_0000);
316    let pos_mask = S::splat(0x0000_ffff);
317    let max_pos = S::splat((1 << 16) - 1);
318    let mut pos = S::splat(0);
319    let mut pos_offset: S = from_fn(|l| (l * len.saturating_sub(w - 1)) as u32).into();
320    let delta = S::splat((1 << 16) - 2 - w as u32);
321
322    #[inline(always)]
323    move |val| {
324        // Make sure the position does not interfere with the hash value.
325        if pos == max_pos {
326            // Slow case extracted to a function to have better inlining here.
327            reset_positions_offsets_lr(
328                delta,
329                &mut pos,
330                &mut prefix_lr_min,
331                &mut pos_offset,
332                ring_buf,
333            );
334        }
335        // slightly faster than assigning S::splat(u32::MAX)
336        let lelem = (val & val_mask) | pos;
337        let relem = (!val & val_mask) | pos;
338        let elem = (lelem, relem);
339        pos += S::ONE;
340        ring_buf.push(elem);
341        prefix_lr_min = simd_lr_min(prefix_lr_min, elem);
342        // After a chunk has been filled, compute suffix minima.
343        if ring_buf.idx() == 0 {
344            // Slow case extracted to a function to have better inlining here.
345            suffix_lr_minima(ring_buf, w, &mut prefix_lr_min, elem);
346        }
347
348        let suffix_lr_min = unsafe { *ring_buf.get_unchecked(ring_buf.idx()) };
349        let (lmin, rmin) = simd_lr_min(prefix_lr_min, suffix_lr_min);
350        (
351            (lmin & pos_mask) + pos_offset,
352            (rmin & pos_mask) + pos_offset,
353        )
354    }
355}
356
357#[inline(always)]
358fn simd_lr_min((al, ar): (S, S), (bl, br): (S, S)) -> (S, S) {
359    (al.min(bl), ar.max(br))
360}
361
362#[inline(always)]
363fn suffix_lr_minima(
364    ring_buf: &mut RingBuf<(S, S)>,
365    w: usize,
366    prefix_min: &mut (S, S),
367    elem: (S, S),
368) {
369    // Avoid some bounds checks when this function is not inlined.
370    unsafe { assert_unchecked(ring_buf.len() == w) };
371    unsafe { assert_unchecked(w > 0) };
372    let mut suffix_min = ring_buf[w - 1];
373    for i in (0..w - 1).rev() {
374        suffix_min = simd_lr_min(suffix_min, ring_buf[i]);
375        ring_buf[i] = suffix_min;
376    }
377    *prefix_min = elem;
378}
379
380#[inline(always)]
381fn reset_positions_offsets_lr(
382    delta: S,
383    pos: &mut S,
384    prefix_min: &mut (S, S),
385    pos_offset: &mut S,
386    ring_buf: &mut RingBuf<(S, S)>,
387) {
388    *pos -= delta;
389    *pos_offset += delta;
390    prefix_min.0 -= delta;
391    prefix_min.1 -= delta;
392    for x in &mut **ring_buf {
393        x.0 -= delta;
394        x.1 -= delta;
395    }
396}