edit/simd/
lines_bwd.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::ptr;
5
6use crate::helpers::CoordType;
7
8/// Starting from the `offset` in `haystack` with a current line index of
9/// `line`, this seeks backwards to the `line_stop`-nth line and returns the
10/// new offset and the line index at that point.
11///
12/// Note that this function differs from `lines_fwd` in that it
13/// seeks backwards even if the `line` is already at `line_stop`.
14/// This allows you to ensure (or test) whether `offset` is at a line start.
15///
16/// It returns an offset *past* a newline and thus at the start of a line.
17pub fn lines_bwd(
18    haystack: &[u8],
19    offset: usize,
20    line: CoordType,
21    line_stop: CoordType,
22) -> (usize, CoordType) {
23    unsafe {
24        let beg = haystack.as_ptr();
25        let it = beg.add(offset.min(haystack.len()));
26        let (it, line) = lines_bwd_raw(beg, it, line, line_stop);
27        (it.offset_from_unsigned(beg), line)
28    }
29}
30
31unsafe fn lines_bwd_raw(
32    beg: *const u8,
33    end: *const u8,
34    line: CoordType,
35    line_stop: CoordType,
36) -> (*const u8, CoordType) {
37    #[cfg(target_arch = "x86_64")]
38    return unsafe { LINES_BWD_DISPATCH(beg, end, line, line_stop) };
39
40    #[cfg(target_arch = "aarch64")]
41    return unsafe { lines_bwd_neon(beg, end, line, line_stop) };
42
43    #[allow(unreachable_code)]
44    return unsafe { lines_bwd_fallback(beg, end, line, line_stop) };
45}
46
47unsafe fn lines_bwd_fallback(
48    beg: *const u8,
49    mut end: *const u8,
50    mut line: CoordType,
51    line_stop: CoordType,
52) -> (*const u8, CoordType) {
53    unsafe {
54        while !ptr::eq(end, beg) {
55            let n = end.sub(1);
56            if *n == b'\n' {
57                if line <= line_stop {
58                    break;
59                }
60                line -= 1;
61            }
62            end = n;
63        }
64        (end, line)
65    }
66}
67
68#[cfg(target_arch = "x86_64")]
69static mut LINES_BWD_DISPATCH: unsafe fn(
70    beg: *const u8,
71    end: *const u8,
72    line: CoordType,
73    line_stop: CoordType,
74) -> (*const u8, CoordType) = lines_bwd_dispatch;
75
76#[cfg(target_arch = "x86_64")]
77unsafe fn lines_bwd_dispatch(
78    beg: *const u8,
79    end: *const u8,
80    line: CoordType,
81    line_stop: CoordType,
82) -> (*const u8, CoordType) {
83    let func = if is_x86_feature_detected!("avx2") { lines_bwd_avx2 } else { lines_bwd_fallback };
84    unsafe { LINES_BWD_DISPATCH = func };
85    unsafe { func(beg, end, line, line_stop) }
86}
87
88#[cfg(target_arch = "x86_64")]
89#[target_feature(enable = "avx2")]
90unsafe fn lines_bwd_avx2(
91    beg: *const u8,
92    mut end: *const u8,
93    mut line: CoordType,
94    line_stop: CoordType,
95) -> (*const u8, CoordType) {
96    unsafe {
97        use std::arch::x86_64::*;
98
99        #[inline(always)]
100        unsafe fn horizontal_sum_i64(v: __m256i) -> i64 {
101            unsafe {
102                let hi = _mm256_extracti128_si256::<1>(v);
103                let lo = _mm256_castsi256_si128(v);
104                let sum = _mm_add_epi64(lo, hi);
105                let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum);
106                let sum = _mm_add_epi64(sum, shuf);
107                _mm_cvtsi128_si64(sum)
108            }
109        }
110
111        let lf = _mm256_set1_epi8(b'\n' as i8);
112        let line_stop = line_stop.min(line);
113        let mut remaining = end.offset_from_unsigned(beg);
114
115        while remaining >= 128 {
116            let chunk_start = end.sub(128);
117
118            let v1 = _mm256_loadu_si256(chunk_start.add(0) as *const _);
119            let v2 = _mm256_loadu_si256(chunk_start.add(32) as *const _);
120            let v3 = _mm256_loadu_si256(chunk_start.add(64) as *const _);
121            let v4 = _mm256_loadu_si256(chunk_start.add(96) as *const _);
122
123            let mut sum = _mm256_setzero_si256();
124            sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf));
125            sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf));
126            sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf));
127            sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf));
128
129            let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256());
130            let sum = horizontal_sum_i64(sum);
131
132            let line_next = line - sum as CoordType;
133            if line_next <= line_stop {
134                break;
135            }
136
137            end = chunk_start;
138            remaining -= 128;
139            line = line_next;
140        }
141
142        while remaining >= 32 {
143            let chunk_start = end.sub(32);
144            let v = _mm256_loadu_si256(chunk_start as *const _);
145            let c = _mm256_cmpeq_epi8(v, lf);
146
147            let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01));
148            let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256());
149            let sum = horizontal_sum_i64(sum);
150
151            let line_next = line - sum as CoordType;
152            if line_next <= line_stop {
153                break;
154            }
155
156            end = chunk_start;
157            remaining -= 32;
158            line = line_next;
159        }
160
161        lines_bwd_fallback(beg, end, line, line_stop)
162    }
163}
164
165#[cfg(target_arch = "aarch64")]
166unsafe fn lines_bwd_neon(
167    beg: *const u8,
168    mut end: *const u8,
169    mut line: CoordType,
170    line_stop: CoordType,
171) -> (*const u8, CoordType) {
172    unsafe {
173        use std::arch::aarch64::*;
174
175        let lf = vdupq_n_u8(b'\n');
176        let line_stop = line_stop.min(line);
177        let mut remaining = end.offset_from_unsigned(beg);
178
179        while remaining >= 64 {
180            let chunk_start = end.sub(64);
181
182            let v1 = vld1q_u8(chunk_start.add(0));
183            let v2 = vld1q_u8(chunk_start.add(16));
184            let v3 = vld1q_u8(chunk_start.add(32));
185            let v4 = vld1q_u8(chunk_start.add(48));
186
187            let mut sum = vdupq_n_u8(0);
188            sum = vsubq_u8(sum, vceqq_u8(v1, lf));
189            sum = vsubq_u8(sum, vceqq_u8(v2, lf));
190            sum = vsubq_u8(sum, vceqq_u8(v3, lf));
191            sum = vsubq_u8(sum, vceqq_u8(v4, lf));
192
193            let sum = vaddvq_u8(sum);
194
195            let line_next = line - sum as CoordType;
196            if line_next <= line_stop {
197                break;
198            }
199
200            end = chunk_start;
201            remaining -= 64;
202            line = line_next;
203        }
204
205        while remaining >= 16 {
206            let chunk_start = end.sub(16);
207            let v = vld1q_u8(chunk_start);
208            let c = vceqq_u8(v, lf);
209            let c = vandq_u8(c, vdupq_n_u8(0x01));
210            let sum = vaddvq_u8(c);
211
212            let line_next = line - sum as CoordType;
213            if line_next <= line_stop {
214                break;
215            }
216
217            end = chunk_start;
218            remaining -= 16;
219            line = line_next;
220        }
221
222        lines_bwd_fallback(beg, end, line, line_stop)
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229    use crate::helpers::CoordType;
230    use crate::simd::test::*;
231
232    #[test]
233    fn pseudo_fuzz() {
234        let text = generate_random_text(1024);
235        let lines = count_lines(&text);
236        let mut offset_rng = make_rng();
237        let mut line_rng = make_rng();
238        let mut line_distance_rng = make_rng();
239
240        for _ in 0..1000 {
241            let offset = offset_rng() % (text.len() + 1);
242            let line_stop = line_distance_rng() % (lines + 1);
243            let line = line_stop + line_rng() % 100;
244
245            let line = line as CoordType;
246            let line_stop = line_stop as CoordType;
247
248            let expected = reference_lines_bwd(text.as_bytes(), offset, line, line_stop);
249            let actual = lines_bwd(text.as_bytes(), offset, line, line_stop);
250
251            assert_eq!(expected, actual);
252        }
253    }
254
255    fn reference_lines_bwd(
256        haystack: &[u8],
257        mut offset: usize,
258        mut line: CoordType,
259        line_stop: CoordType,
260    ) -> (usize, CoordType) {
261        if line >= line_stop {
262            while offset > 0 {
263                let c = haystack[offset - 1];
264                if c == b'\n' {
265                    if line == line_stop {
266                        break;
267                    }
268                    line -= 1;
269                }
270                offset -= 1;
271            }
272        }
273        (offset, line)
274    }
275    #[test]
276    fn seeks_to_start() {
277        for i in 6..=11 {
278            let (off, line) = lines_bwd(b"Hello\nWorld\n", i, 123, 456);
279            assert_eq!(off, 6); // After "Hello\n"
280            assert_eq!(line, 123); // Still on the same line
281        }
282    }
283}