edit/simd/
lines_fwd.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::ptr;
5
6use crate::helpers::CoordType;
7
8/// Starting from the `offset` in `haystack` with a current line index of
9/// `line`, this seeks to the `line_stop`-nth line and returns the
10/// new offset and the line index at that point.
11///
12/// It returns an offset *past* the newline.
13/// If `line` is already at or past `line_stop`, it returns immediately.
14pub fn lines_fwd(
15    haystack: &[u8],
16    offset: usize,
17    line: CoordType,
18    line_stop: CoordType,
19) -> (usize, CoordType) {
20    unsafe {
21        let beg = haystack.as_ptr();
22        let end = beg.add(haystack.len());
23        let it = beg.add(offset.min(haystack.len()));
24        let (it, line) = lines_fwd_raw(it, end, line, line_stop);
25        (it.offset_from_unsigned(beg), line)
26    }
27}
28
29unsafe fn lines_fwd_raw(
30    beg: *const u8,
31    end: *const u8,
32    line: CoordType,
33    line_stop: CoordType,
34) -> (*const u8, CoordType) {
35    #[cfg(target_arch = "x86_64")]
36    return unsafe { LINES_FWD_DISPATCH(beg, end, line, line_stop) };
37
38    #[cfg(target_arch = "aarch64")]
39    return unsafe { lines_fwd_neon(beg, end, line, line_stop) };
40
41    #[allow(unreachable_code)]
42    return unsafe { lines_fwd_fallback(beg, end, line, line_stop) };
43}
44
45unsafe fn lines_fwd_fallback(
46    mut beg: *const u8,
47    end: *const u8,
48    mut line: CoordType,
49    line_stop: CoordType,
50) -> (*const u8, CoordType) {
51    unsafe {
52        if line < line_stop {
53            while !ptr::eq(beg, end) {
54                let c = *beg;
55                beg = beg.add(1);
56                if c == b'\n' {
57                    line += 1;
58                    if line == line_stop {
59                        break;
60                    }
61                }
62            }
63        }
64        (beg, line)
65    }
66}
67
68#[cfg(target_arch = "x86_64")]
69static mut LINES_FWD_DISPATCH: unsafe fn(
70    beg: *const u8,
71    end: *const u8,
72    line: CoordType,
73    line_stop: CoordType,
74) -> (*const u8, CoordType) = lines_fwd_dispatch;
75
76#[cfg(target_arch = "x86_64")]
77unsafe fn lines_fwd_dispatch(
78    beg: *const u8,
79    end: *const u8,
80    line: CoordType,
81    line_stop: CoordType,
82) -> (*const u8, CoordType) {
83    let func = if is_x86_feature_detected!("avx2") { lines_fwd_avx2 } else { lines_fwd_fallback };
84    unsafe { LINES_FWD_DISPATCH = func };
85    unsafe { func(beg, end, line, line_stop) }
86}
87
88#[cfg(target_arch = "x86_64")]
89#[target_feature(enable = "avx2")]
90unsafe fn lines_fwd_avx2(
91    mut beg: *const u8,
92    end: *const u8,
93    mut line: CoordType,
94    line_stop: CoordType,
95) -> (*const u8, CoordType) {
96    unsafe {
97        use std::arch::x86_64::*;
98
99        #[inline(always)]
100        unsafe fn horizontal_sum_i64(v: __m256i) -> i64 {
101            unsafe {
102                let hi = _mm256_extracti128_si256::<1>(v);
103                let lo = _mm256_castsi256_si128(v);
104                let sum = _mm_add_epi64(lo, hi);
105                let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum);
106                let sum = _mm_add_epi64(sum, shuf);
107                _mm_cvtsi128_si64(sum)
108            }
109        }
110
111        let lf = _mm256_set1_epi8(b'\n' as i8);
112        let mut remaining = end.offset_from_unsigned(beg);
113
114        if line < line_stop {
115            // Unrolling the loop by 4x speeds things up by >3x.
116            // It allows us to accumulate matches before doing a single `vpsadbw`.
117            while remaining >= 128 {
118                let v1 = _mm256_loadu_si256(beg.add(0) as *const _);
119                let v2 = _mm256_loadu_si256(beg.add(32) as *const _);
120                let v3 = _mm256_loadu_si256(beg.add(64) as *const _);
121                let v4 = _mm256_loadu_si256(beg.add(96) as *const _);
122
123                // `vpcmpeqb` leaves each comparison result byte as 0 or -1 (0xff).
124                // This allows us to accumulate the comparisons by subtracting them.
125                let mut sum = _mm256_setzero_si256();
126                sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf));
127                sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf));
128                sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf));
129                sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf));
130
131                // Calculate the total number of matches in this chunk.
132                let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256());
133                let sum = horizontal_sum_i64(sum);
134
135                let line_next = line + sum as CoordType;
136                if line_next >= line_stop {
137                    break;
138                }
139
140                beg = beg.add(128);
141                remaining -= 128;
142                line = line_next;
143            }
144
145            while remaining >= 32 {
146                let v = _mm256_loadu_si256(beg as *const _);
147                let c = _mm256_cmpeq_epi8(v, lf);
148
149                // If you ask an LLM, the best way to do this is
150                // to do a `vpmovmskb` followed by `popcnt`.
151                // One contemporary hardware that's a bad idea though.
152                let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01));
153                let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256());
154                let sum = horizontal_sum_i64(sum);
155
156                let line_next = line + sum as CoordType;
157                if line_next >= line_stop {
158                    break;
159                }
160
161                beg = beg.add(32);
162                remaining -= 32;
163                line = line_next;
164            }
165        }
166
167        lines_fwd_fallback(beg, end, line, line_stop)
168    }
169}
170
171#[cfg(target_arch = "aarch64")]
172unsafe fn lines_fwd_neon(
173    mut beg: *const u8,
174    end: *const u8,
175    mut line: CoordType,
176    line_stop: CoordType,
177) -> (*const u8, CoordType) {
178    unsafe {
179        use std::arch::aarch64::*;
180
181        let lf = vdupq_n_u8(b'\n');
182        let mut remaining = end.offset_from_unsigned(beg);
183
184        if line < line_stop {
185            while remaining >= 64 {
186                let v1 = vld1q_u8(beg.add(0));
187                let v2 = vld1q_u8(beg.add(16));
188                let v3 = vld1q_u8(beg.add(32));
189                let v4 = vld1q_u8(beg.add(48));
190
191                // `vceqq_u8` leaves each comparison result byte as 0 or -1 (0xff).
192                // This allows us to accumulate the comparisons by subtracting them.
193                let mut sum = vdupq_n_u8(0);
194                sum = vsubq_u8(sum, vceqq_u8(v1, lf));
195                sum = vsubq_u8(sum, vceqq_u8(v2, lf));
196                sum = vsubq_u8(sum, vceqq_u8(v3, lf));
197                sum = vsubq_u8(sum, vceqq_u8(v4, lf));
198
199                let sum = vaddvq_u8(sum);
200
201                let line_next = line + sum as CoordType;
202                if line_next >= line_stop {
203                    break;
204                }
205
206                beg = beg.add(64);
207                remaining -= 64;
208                line = line_next;
209            }
210
211            while remaining >= 16 {
212                let v = vld1q_u8(beg);
213                let c = vceqq_u8(v, lf);
214                let c = vandq_u8(c, vdupq_n_u8(0x01));
215                let sum = vaddvq_u8(c);
216
217                let line_next = line + sum as CoordType;
218                if line_next >= line_stop {
219                    break;
220                }
221
222                beg = beg.add(16);
223                remaining -= 16;
224                line = line_next;
225            }
226        }
227
228        lines_fwd_fallback(beg, end, line, line_stop)
229    }
230}
231
232#[cfg(test)]
233mod test {
234    use super::*;
235    use crate::helpers::CoordType;
236    use crate::simd::test::*;
237
238    #[test]
239    fn pseudo_fuzz() {
240        let text = generate_random_text(1024);
241        let lines = count_lines(&text);
242        let mut offset_rng = make_rng();
243        let mut line_rng = make_rng();
244        let mut line_distance_rng = make_rng();
245
246        for _ in 0..1000 {
247            let offset = offset_rng() % (text.len() + 1);
248            let line = line_rng() % 100;
249            let line_stop = line + line_distance_rng() % (lines + 1);
250
251            let line = line as CoordType;
252            let line_stop = line_stop as CoordType;
253
254            let expected = reference_lines_fwd(text.as_bytes(), offset, line, line_stop);
255            let actual = lines_fwd(text.as_bytes(), offset, line, line_stop);
256
257            assert_eq!(expected, actual);
258        }
259    }
260
261    fn reference_lines_fwd(
262        haystack: &[u8],
263        mut offset: usize,
264        mut line: CoordType,
265        line_stop: CoordType,
266    ) -> (usize, CoordType) {
267        if line < line_stop {
268            while offset < haystack.len() {
269                let c = haystack[offset];
270                offset += 1;
271                if c == b'\n' {
272                    line += 1;
273                    if line == line_stop {
274                        break;
275                    }
276                }
277            }
278        }
279        (offset, line)
280    }
281}