edit/simd/
memset.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! `memchr` for arbitrary sizes (1/2/4/8 bytes).
5//!
6//! Clang calls the C `memset` function only for byte-sized types (or 0 fills).
7//! We however need to fill other types as well. For that, clang generates
8//! SIMD loops under higher optimization levels. With `-Os` however, it only
9//! generates a trivial loop which is too slow for our needs.
10//!
11//! This implementation uses SWAR to only have a single implementation for all
12//! 4 sizes: By duplicating smaller types into a larger `u64` register we can
13//! treat all sizes as if they were `u64`. The only thing we need to take care
14//! of is the tail end of the array, which needs to write 0-7 additional bytes.
15
16use std::mem;
17
18/// A marker trait for types that are safe to `memset`.
19///
20/// # Safety
21///
22/// Just like with C's `memset`, bad things happen
23/// if you use this with non-trivial types.
24pub unsafe trait MemsetSafe: Copy {}
25
26unsafe impl MemsetSafe for u8 {}
27unsafe impl MemsetSafe for u16 {}
28unsafe impl MemsetSafe for u32 {}
29unsafe impl MemsetSafe for u64 {}
30unsafe impl MemsetSafe for usize {}
31
32unsafe impl MemsetSafe for i8 {}
33unsafe impl MemsetSafe for i16 {}
34unsafe impl MemsetSafe for i32 {}
35unsafe impl MemsetSafe for i64 {}
36unsafe impl MemsetSafe for isize {}
37
38/// Fills a slice with the given value.
39#[inline]
40pub fn memset<T: MemsetSafe>(dst: &mut [T], val: T) {
41    unsafe {
42        match mem::size_of::<T>() {
43            1 => {
44                // LLVM will compile this to a call to `memset`,
45                // which hopefully should be better optimized than my code.
46                let beg = dst.as_mut_ptr();
47                let val = mem::transmute_copy::<_, u8>(&val);
48                beg.write_bytes(val, dst.len());
49            }
50            2 => {
51                let beg = dst.as_mut_ptr();
52                let end = beg.add(dst.len());
53                let val = mem::transmute_copy::<_, u16>(&val);
54                memset_raw(beg as *mut u8, end as *mut u8, val as u64 * 0x0001000100010001);
55            }
56            4 => {
57                let beg = dst.as_mut_ptr();
58                let end = beg.add(dst.len());
59                let val = mem::transmute_copy::<_, u32>(&val);
60                memset_raw(beg as *mut u8, end as *mut u8, val as u64 * 0x0000000100000001);
61            }
62            8 => {
63                let beg = dst.as_mut_ptr();
64                let end = beg.add(dst.len());
65                let val = mem::transmute_copy::<_, u64>(&val);
66                memset_raw(beg as *mut u8, end as *mut u8, val);
67            }
68            _ => unreachable!(),
69        }
70    }
71}
72
73#[inline]
74fn memset_raw(beg: *mut u8, end: *mut u8, val: u64) {
75    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76    return unsafe { MEMSET_DISPATCH(beg, end, val) };
77
78    #[cfg(target_arch = "aarch64")]
79    return unsafe { memset_neon(beg, end, val) };
80
81    #[allow(unreachable_code)]
82    return unsafe { memset_fallback(beg, end, val) };
83}
84
85#[inline(never)]
86unsafe fn memset_fallback(mut beg: *mut u8, end: *mut u8, val: u64) {
87    unsafe {
88        let mut remaining = end.offset_from_unsigned(beg);
89
90        while remaining >= 8 {
91            (beg as *mut u64).write_unaligned(val);
92            beg = beg.add(8);
93            remaining -= 8;
94        }
95
96        if remaining >= 4 {
97            // 4-7 bytes remaining
98            (beg as *mut u32).write_unaligned(val as u32);
99            (end.sub(4) as *mut u32).write_unaligned(val as u32);
100        } else if remaining >= 2 {
101            // 2-3 bytes remaining
102            (beg as *mut u16).write_unaligned(val as u16);
103            (end.sub(2) as *mut u16).write_unaligned(val as u16);
104        } else if remaining >= 1 {
105            // 1 byte remaining
106            beg.write(val as u8);
107        }
108    }
109}
110
111#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112static mut MEMSET_DISPATCH: unsafe fn(beg: *mut u8, end: *mut u8, val: u64) = memset_dispatch;
113
114#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
115fn memset_dispatch(beg: *mut u8, end: *mut u8, val: u64) {
116    let func = if is_x86_feature_detected!("avx2") { memset_avx2 } else { memset_sse2 };
117    unsafe { MEMSET_DISPATCH = func };
118    unsafe { func(beg, end, val) }
119}
120
121#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
122#[target_feature(enable = "sse2")]
123unsafe fn memset_sse2(mut beg: *mut u8, end: *mut u8, val: u64) {
124    unsafe {
125        #[cfg(target_arch = "x86")]
126        use std::arch::x86::*;
127        #[cfg(target_arch = "x86_64")]
128        use std::arch::x86_64::*;
129
130        let mut remaining = end.offset_from_unsigned(beg);
131
132        if remaining >= 16 {
133            let fill = _mm_set1_epi64x(val as i64);
134
135            while remaining >= 32 {
136                _mm_storeu_si128(beg as *mut _, fill);
137                _mm_storeu_si128(beg.add(16) as *mut _, fill);
138
139                beg = beg.add(32);
140                remaining -= 32;
141            }
142
143            if remaining >= 16 {
144                // 16-31 bytes remaining
145                _mm_storeu_si128(beg as *mut _, fill);
146                _mm_storeu_si128(end.sub(16) as *mut _, fill);
147                return;
148            }
149        }
150
151        if remaining >= 8 {
152            // 8-15 bytes remaining
153            (beg as *mut u64).write_unaligned(val);
154            (end.sub(8) as *mut u64).write_unaligned(val);
155        } else if remaining >= 4 {
156            // 4-7 bytes remaining
157            (beg as *mut u32).write_unaligned(val as u32);
158            (end.sub(4) as *mut u32).write_unaligned(val as u32);
159        } else if remaining >= 2 {
160            // 2-3 bytes remaining
161            (beg as *mut u16).write_unaligned(val as u16);
162            (end.sub(2) as *mut u16).write_unaligned(val as u16);
163        } else if remaining >= 1 {
164            // 1 byte remaining
165            beg.write(val as u8);
166        }
167    }
168}
169
170#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
171#[target_feature(enable = "avx2")]
172fn memset_avx2(mut beg: *mut u8, end: *mut u8, val: u64) {
173    unsafe {
174        #[cfg(target_arch = "x86")]
175        use std::arch::x86::*;
176        #[cfg(target_arch = "x86_64")]
177        use std::arch::x86_64::*;
178        use std::hint::black_box;
179
180        let mut remaining = end.offset_from_unsigned(beg);
181
182        if remaining >= 128 {
183            let fill = _mm256_set1_epi64x(val as i64);
184
185            loop {
186                _mm256_storeu_si256(beg as *mut _, fill);
187                _mm256_storeu_si256(beg.add(32) as *mut _, fill);
188                _mm256_storeu_si256(beg.add(64) as *mut _, fill);
189                _mm256_storeu_si256(beg.add(96) as *mut _, fill);
190
191                beg = beg.add(128);
192                remaining -= 128;
193                if remaining < 128 {
194                    break;
195                }
196            }
197        }
198
199        if remaining >= 16 {
200            let fill = _mm_set1_epi64x(val as i64);
201
202            loop {
203                // LLVM is _very_ eager to unroll loops. In the absence of an unroll attribute, black_box does the job.
204                // Note that this must not be applied to the intrinsic parameters, as they're otherwise misoptimized.
205                #[allow(clippy::unit_arg)]
206                black_box(_mm_storeu_si128(beg as *mut _, fill));
207
208                beg = beg.add(16);
209                remaining -= 16;
210                if remaining < 16 {
211                    break;
212                }
213            }
214        }
215
216        // `remaining` is between 0 and 15 at this point.
217        // By overlapping the stores we can write all of them in at most 2 stores. This approach
218        // can be seen in various libraries, such as wyhash which uses it for loading data in `wyr3`.
219        if remaining >= 8 {
220            // 8-15 bytes
221            (beg as *mut u64).write_unaligned(val);
222            (end.sub(8) as *mut u64).write_unaligned(val);
223        } else if remaining >= 4 {
224            // 4-7 bytes
225            (beg as *mut u32).write_unaligned(val as u32);
226            (end.sub(4) as *mut u32).write_unaligned(val as u32);
227        } else if remaining >= 2 {
228            // 2-3 bytes
229            (beg as *mut u16).write_unaligned(val as u16);
230            (end.sub(2) as *mut u16).write_unaligned(val as u16);
231        } else if remaining >= 1 {
232            // 1 byte
233            beg.write(val as u8);
234        }
235    }
236}
237
238#[cfg(target_arch = "aarch64")]
239unsafe fn memset_neon(mut beg: *mut u8, end: *mut u8, val: u64) {
240    unsafe {
241        use std::arch::aarch64::*;
242        let mut remaining = end.offset_from_unsigned(beg);
243
244        if remaining >= 32 {
245            let fill = vdupq_n_u64(val);
246
247            loop {
248                // Compiles to a single `stp` instruction.
249                vst1q_u64(beg as *mut _, fill);
250                vst1q_u64(beg.add(16) as *mut _, fill);
251
252                beg = beg.add(32);
253                remaining -= 32;
254                if remaining < 32 {
255                    break;
256                }
257            }
258        }
259
260        if remaining >= 16 {
261            // 16-31 bytes remaining
262            let fill = vdupq_n_u64(val);
263            vst1q_u64(beg as *mut _, fill);
264            vst1q_u64(end.sub(16) as *mut _, fill);
265        } else if remaining >= 8 {
266            // 8-15 bytes remaining
267            (beg as *mut u64).write_unaligned(val);
268            (end.sub(8) as *mut u64).write_unaligned(val);
269        } else if remaining >= 4 {
270            // 4-7 bytes remaining
271            (beg as *mut u32).write_unaligned(val as u32);
272            (end.sub(4) as *mut u32).write_unaligned(val as u32);
273        } else if remaining >= 2 {
274            // 2-3 bytes remaining
275            (beg as *mut u16).write_unaligned(val as u16);
276            (end.sub(2) as *mut u16).write_unaligned(val as u16);
277        } else if remaining >= 1 {
278            // 1 byte remaining
279            beg.write(val as u8);
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use std::fmt;
287    use std::ops::Not;
288
289    use super::*;
290
291    fn check_memset<T>(val: T, len: usize)
292    where
293        T: MemsetSafe + Not<Output = T> + PartialEq + fmt::Debug,
294    {
295        let mut buf = vec![!val; len];
296        memset(&mut buf, val);
297        assert!(buf.iter().all(|&x| x == val));
298    }
299
300    #[test]
301    fn test_memset_empty() {
302        check_memset(0u8, 0);
303        check_memset(0u16, 0);
304        check_memset(0u32, 0);
305        check_memset(0u64, 0);
306    }
307
308    #[test]
309    fn test_memset_single() {
310        check_memset(0u8, 1);
311        check_memset(0xFFu8, 1);
312        check_memset(0xABu16, 1);
313        check_memset(0x12345678u32, 1);
314        check_memset(0xDEADBEEFu64, 1);
315    }
316
317    #[test]
318    fn test_memset_small() {
319        for &len in &[2, 3, 4, 5, 7, 8, 9] {
320            check_memset(0xAAu8, len);
321            check_memset(0xBEEFu16, len);
322            check_memset(0xCAFEBABEu32, len);
323            check_memset(0x1234567890ABCDEFu64, len);
324        }
325    }
326
327    #[test]
328    fn test_memset_large() {
329        check_memset(0u8, 1000);
330        check_memset(0xFFu8, 1024);
331        check_memset(0xBEEFu16, 512);
332        check_memset(0xCAFEBABEu32, 256);
333        check_memset(0x1234567890ABCDEFu64, 128);
334    }
335
336    #[test]
337    fn test_memset_various_values() {
338        check_memset(0u8, 17);
339        check_memset(0x7Fu8, 17);
340        check_memset(0x8001u16, 17);
341        check_memset(0xFFFFFFFFu32, 17);
342        check_memset(0x8000000000000001u64, 17);
343    }
344
345    #[test]
346    fn test_memset_signed_types() {
347        check_memset(-1i8, 8);
348        check_memset(-2i16, 8);
349        check_memset(-3i32, 8);
350        check_memset(-4i64, 8);
351        check_memset(-5isize, 8);
352    }
353
354    #[test]
355    fn test_memset_usize_isize() {
356        check_memset(0usize, 4);
357        check_memset(usize::MAX, 4);
358        check_memset(0isize, 4);
359        check_memset(isize::MIN, 4);
360    }
361
362    #[test]
363    fn test_memset_alignment() {
364        // Check that memset works for slices not aligned to 8 bytes
365        let mut buf = [0u8; 15];
366        for offset in 0..8 {
367            let slice = &mut buf[offset..(offset + 7)];
368            memset(slice, 0x5A);
369            assert!(slice.iter().all(|&x| x == 0x5A));
370        }
371    }
372}