sassy 0.2.1

Approximate string matching using SIMD
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
use crate::pattern_tiling::backend::SimdBackend;
use crate::pattern_tiling::minima::{TracePostProcess, local_minima_indices};
use crate::pattern_tiling::search::{HitRange, Myers};
use crate::pattern_tiling::tqueries::TQueries;
use crate::profiles::Profile;
use crate::search::Match;
use crate::search::Strand;
use crate::search::get_overhang_steps;
use crate::trace::{CostLookup, get_trace};
use pa_types::Cost;

pub struct PatternHistory<S: Copy> {
    pub steps: Vec<SimdHistoryStep<S>>,
}

impl<S: Copy> Default for PatternHistory<S> {
    fn default() -> Self {
        Self { steps: Vec::new() }
    }
}

pub struct SimdHistoryStep<S: Copy> {
    pub vp: S,
    pub vn: S,
}

pub struct TraceBuffer {
    pub pattern_indices: Vec<usize>,
    pub approx_slices: Vec<(isize, isize)>,
    pub range_bounds: Vec<(isize, isize)>,
    pub per_range_alignments: Vec<Vec<Match>>,
    pub filtered_alignments: Vec<Match>,
    pub temp_pos_cost: Vec<(isize, isize)>,
    pub filled_till: usize,
    pub pos_cost_buffer: Vec<(isize, isize)>,
    pub minima_indices_buffer: Vec<usize>,
}

impl TraceBuffer {
    pub fn new(lanes: usize) -> Self {
        Self {
            pattern_indices: vec![0; lanes],
            approx_slices: vec![(0isize, 0isize); lanes],
            range_bounds: vec![(0isize, 0isize); lanes],
            per_range_alignments: vec![Vec::new(); lanes],
            filtered_alignments: Vec::with_capacity(10),
            temp_pos_cost: Vec::new(),
            filled_till: 0,
            pos_cost_buffer: Vec::new(),
            minima_indices_buffer: Vec::new(),
        }
    }

    #[inline(always)]
    pub fn clear_alns(&mut self) {
        for aln in self.per_range_alignments.iter_mut() {
            aln.clear();
        }
        self.filtered_alignments.clear();
        self.temp_pos_cost.clear();
        self.filled_till = 0;
    }

    #[inline(always)]
    pub fn populate(&mut self, ranges: &[HitRange], left_buffer: usize) {
        if self.pattern_indices.len() < ranges.len() {
            self.pattern_indices.resize(ranges.len(), 0);
            self.approx_slices.resize(ranges.len(), (0isize, 0isize));
            self.range_bounds.resize(ranges.len(), (0isize, 0isize));
            self.per_range_alignments.resize(ranges.len(), Vec::new());
        }
        for (i, r) in ranges.iter().enumerate() {
            self.pattern_indices[i] = r.pattern_idx;
            // Text slice where the match for sure is present (prevent out of bounds)
            self.approx_slices[i] = (r.start.saturating_sub(left_buffer as isize).max(0), r.end);
            // The original range bounds (can be -1 for r.start in case of prefix overhang)
            self.range_bounds[i] = (r.start, r.end);
            // The alignments we find per range
            self.per_range_alignments[i].clear();
        }
        self.filled_till = ranges.len();
    }
}

pub(crate) struct V2CostLookup<'a, B: SimdBackend, P: Profile> {
    searcher: &'a Myers<B, P>,
    lane_idx: usize,
}

impl<B: SimdBackend, P: Profile> CostLookup for V2CostLookup<'_, B, P> {
    #[inline(always)]
    fn get(&self, i: usize, j: usize) -> Cost {
        if j == 0 {
            return 0;
        }
        let step_idx = i as isize - 1;
        let pattern_pos = j as isize - 1;

        let mask = if pattern_pos >= 63 {
            !0u64
        } else {
            (1u64 << (pattern_pos + 1)) - 1
        };

        if step_idx < 0 {
            (self.searcher.alpha_pattern & mask).count_ones() as Cost
        } else {
            let step_data = &self.searcher.history[self.lane_idx].steps[step_idx as usize];
            let vp_bits = extract_simd_lane::<B>(step_data.vp, self.lane_idx);
            let vn_bits = extract_simd_lane::<B>(step_data.vn, self.lane_idx);
            let pos = (vp_bits & mask).count_ones() as Cost;
            let neg = (vn_bits & mask).count_ones() as Cost;
            pos - neg
        }
    }
}

#[inline(always)]
fn handle_suffix_overhangs<B: SimdBackend, P: Profile>(
    searcher: &mut Myers<B, P>,
    last_bit_shift: u32,
    last_bit_mask: B::Simd,
    batch_size: usize,
    overhang_steps: usize,
) {
    let blocks_ptr = searcher.blocks.as_mut_ptr();
    let all_ones = searcher.all_ones;
    for _i in 0..overhang_steps {
        unsafe {
            let block = &mut *blocks_ptr;
            let (vp_out, vn_out, _cost_out) = Myers::<B, P>::myers_step(
                block.vp,
                block.vn,
                block.cost,
                all_ones, // eq = 1111..
                all_ones,
                last_bit_shift,
                last_bit_mask,
            );
            for lane in 0..batch_size {
                searcher.history[lane].steps.push(SimdHistoryStep {
                    vp: vp_out,
                    vn: vn_out,
                });
            }
            block.vp = vp_out;
            block.vn = vn_out;
        }
    }
}

#[inline(always)]
fn traceback_positions<B: SimdBackend, P: Profile>(
    positions_and_costs: &[(isize, isize)],
    searcher: &Myers<B, P>,
    lane: usize,
    pattern_idx: usize,
    t_queries: &TQueries<B, P>,
    approx_start: isize,
    text: &[u8],
    out: &mut Vec<Match>,
) {
    for &(pos, _cost) in positions_and_costs {
        let aln = traceback_single(
            searcher,
            lane,
            pattern_idx,
            t_queries,
            (approx_start, pos),
            text,
        );
        out.push(aln);
    }
}

#[inline(always)]
fn trace_passing_alignments<B: SimdBackend, P: Profile>(
    batch_size: usize,
    buffer: &mut TraceBuffer,
    searcher: &Myers<B, P>,
    t_queries: &TQueries<B, P>,
    text: &[u8],
    k: u32,
    post: TracePostProcess,
) {
    let text_len = text.len();
    let pattern_length = t_queries.pattern_length;
    let k_isize = k as isize;
    let max_valid_text_pos = text_len.saturating_sub(1) as isize;

    for lane in 0..batch_size {
        // r_start can be -1 in case of full prefix overhang (before text)
        let (r_start, r_end) = buffer.range_bounds[lane];
        let approx_start = buffer.approx_slices[lane].0;
        let pattern_idx = buffer.pattern_indices[lane];

        let cost_lookup = V2CostLookup {
            searcher,
            lane_idx: lane,
        };

        // Collect all passing positions (shared for both branches)
        buffer.pos_cost_buffer.clear();
        for pos in r_start..=r_end {
            let step_idx = pos - approx_start;
            let i = (step_idx + 1).max(0) as usize;
            let j = pattern_length;
            let mut cost = cost_lookup.get(i, j) as isize;

            // Apply overhang correct, fixme: can we do this neater in cost lookup?
            if pos > max_valid_text_pos && searcher.alpha != 1.0 {
                let overshoot = (pos - max_valid_text_pos) as usize;
                cost += (overshoot as f32 * searcher.alpha).floor() as isize;
            }

            if cost <= k_isize {
                buffer.pos_cost_buffer.push((pos, cost));
            }
        }

        match post {
            TracePostProcess::All => {
                // Trace all passing positions
                traceback_positions::<B, P>(
                    &buffer.pos_cost_buffer,
                    searcher,
                    lane,
                    pattern_idx,
                    t_queries,
                    approx_start,
                    text,
                    &mut buffer.filtered_alignments,
                );
            }
            TracePostProcess::LocalMinima => {
                local_minima_indices(&buffer.pos_cost_buffer, &mut buffer.minima_indices_buffer);
                // Build subset of pos_cost pairs for local minima only
                buffer.temp_pos_cost.clear();
                buffer.temp_pos_cost.extend(
                    buffer
                        .minima_indices_buffer
                        .iter()
                        .map(|&idx| buffer.pos_cost_buffer[idx]),
                );
                traceback_positions::<B, P>(
                    &buffer.temp_pos_cost,
                    searcher,
                    lane,
                    pattern_idx,
                    t_queries,
                    approx_start,
                    text,
                    &mut buffer.filtered_alignments,
                );
            }
        }
    }
}

/// Trace alignments for hit ranges using a single SIMD forward pass per range.
#[inline(always)]
pub fn trace_batch_ranges<B: SimdBackend, P: Profile>(
    searcher: &mut Myers<B, P>,
    t_queries: &TQueries<B, P>,
    text: &[u8],
    ranges: &[HitRange],
    k: u32,
    post: TracePostProcess,
    alpha: Option<f32>,
    max_overhang: Option<usize>,
    buffer: &mut TraceBuffer,
) {
    assert!(ranges.len() <= B::LANES, "Batch size must be <= LANES");

    if ranges.is_empty() {
        return;
    }

    // How far (at most) we have to move to the left to capture the full alignment
    let left_buffer = t_queries.pattern_length + k as usize;
    buffer.clear_alns();
    buffer.populate(ranges, left_buffer);
    let batch_size = buffer.filled_till;

    // We only have a single block with B::LANES items
    searcher.ensure_capacity(1, buffer.filled_till);

    // Prep allocs for single block search
    let length_mask = (!0u64) >> (64usize.saturating_sub(t_queries.pattern_length));
    searcher.search_prep(
        1,
        t_queries.n_queries,
        t_queries.pattern_length,
        searcher.alpha_pattern & length_mask,
    );

    // Prep history for single block search
    // fixme: lets move from searcher to traceBuffer
    for i in 0..buffer.filled_till {
        searcher.history[i].steps.clear();
        searcher.history[i].steps.reserve(left_buffer);
    }

    let last_bit_shift = (t_queries.pattern_length - 1) as u32;
    let last_bit_mask = B::splat_one() << last_bit_shift;
    let all_ones = B::splat_all_ones();
    let zero_scalar = B::scalar_from_i64(0);
    let one_mask = <B as SimdBackend>::mask_word_to_scalar(!0u64);

    let blocks_ptr = searcher.blocks.as_mut_ptr();
    let text_ptr = text.as_ptr();
    let text_len = text.len();

    // Calculate max_len based on range ends
    let mut max_len = 0;
    for slice in buffer.approx_slices.iter().take(batch_size) {
        let len = (slice.1 - slice.0 + 1) as usize;
        if len > max_len {
            max_len = len;
        }
    }

    let overhang_steps = get_overhang_steps(
        t_queries.pattern_length,
        k as usize,
        alpha.unwrap_or(1.0),
        max_overhang,
    );

    let mut eq_arr = B::LaneArray::default();
    let mut keep_mask_arr = B::LaneArray::default();

    for i in 0..max_len {
        unsafe {
            let block = &mut *blocks_ptr;

            let eq_slice = eq_arr.as_mut();
            let keep_slice = keep_mask_arr.as_mut();

            for lane in 0..batch_size {
                let q_idx = buffer.pattern_indices[lane];
                let start = buffer.approx_slices[lane].0;
                let abs_pos = (i as isize) + start;
                if abs_pos >= 0 && (abs_pos as usize) < text_len {
                    let cur_char = *text_ptr.add(abs_pos as usize);
                    let enc = P::encode_char(cur_char) as usize;
                    eq_slice[lane] = B::mask_word_to_scalar(t_queries.peq_masks[enc][q_idx]);
                    keep_slice[lane] = one_mask;
                } else {
                    eq_slice[lane] = zero_scalar;
                    keep_slice[lane] = zero_scalar;
                }
            }

            let eq = B::from_array(eq_arr);
            let keep_mask = B::from_array(keep_mask_arr);
            let freeze_mask = all_ones ^ keep_mask;

            let (vp_new, vn_new, cost_new) = Myers::<B, P>::myers_step(
                block.vp,
                block.vn,
                block.cost,
                eq,
                all_ones,
                last_bit_shift,
                last_bit_mask,
            );

            let vp_masked = (vp_new & keep_mask) | (block.vp & freeze_mask);
            let vn_masked = (vn_new & keep_mask) | (block.vn & freeze_mask);
            let cost_masked = (cost_new & keep_mask) | (block.cost & freeze_mask);
            let freeze_arr = B::to_array(freeze_mask);
            let freeze_slice = freeze_arr.as_ref();
            for lane in 0..batch_size {
                let is_frozen = B::scalar_to_u64(freeze_slice[lane]) != 0;
                if !is_frozen {
                    searcher.history[lane].steps.push(SimdHistoryStep {
                        vp: vp_masked,
                        vn: vn_masked,
                    });
                }
            }

            block.vp = vp_masked;
            block.vn = vn_masked;
            block.cost = cost_masked;
        }
    }

    if alpha.is_some() {
        // If we have alpha, we should artifically extend beyond the text
        handle_suffix_overhangs(
            searcher,
            last_bit_shift,
            last_bit_mask,
            batch_size,
            overhang_steps,
        );
    }

    trace_passing_alignments(batch_size, buffer, searcher, t_queries, text, k, post);
}

#[inline(always)]
fn extract_simd_lane<B: SimdBackend>(simd_val: B::Simd, lane: usize) -> u64 {
    let arr = B::to_array(simd_val);
    B::scalar_to_u64(arr.as_ref()[lane])
}

#[inline(always)]
fn traceback_single<B: SimdBackend, P: Profile>(
    searcher: &Myers<B, P>,
    lane_idx: usize,
    original_pattern_idx: usize,
    t_queries: &TQueries<B, P>,
    slice: (isize, isize),
    text: &[u8],
) -> Match {
    let pattern = &t_queries.queries[original_pattern_idx];
    let approx_start = slice.0 as usize;
    let end_pos_exclusive = (slice.1 + 1).max(0) as usize;
    let text_slice_end = end_pos_exclusive.min(text.len());
    let text_slice = &text[approx_start..text_slice_end];

    let alpha = if searcher.alpha != 1.0 {
        Some(searcher.alpha)
    } else {
        None
    };

    // Passes ref to searcher so re-creating struct shouldn't have to much overhead
    let cost_lookup = V2CostLookup { searcher, lane_idx };

    let mut m = get_trace::<P>(
        pattern,
        approx_start,
        end_pos_exclusive,
        text_slice,
        &cost_lookup,
        alpha,
        searcher.max_overhang,
    );

    m.pattern_idx = original_pattern_idx % t_queries.n_original_queries;
    m.strand = if original_pattern_idx >= t_queries.n_original_queries {
        Strand::Rc
    } else {
        Strand::Fwd
    };

    m
}