1use std::mem;
17
18pub unsafe trait MemsetSafe: Copy {}
25
26unsafe impl MemsetSafe for u8 {}
27unsafe impl MemsetSafe for u16 {}
28unsafe impl MemsetSafe for u32 {}
29unsafe impl MemsetSafe for u64 {}
30unsafe impl MemsetSafe for usize {}
31
32unsafe impl MemsetSafe for i8 {}
33unsafe impl MemsetSafe for i16 {}
34unsafe impl MemsetSafe for i32 {}
35unsafe impl MemsetSafe for i64 {}
36unsafe impl MemsetSafe for isize {}
37
38#[inline]
40pub fn memset<T: MemsetSafe>(dst: &mut [T], val: T) {
41 unsafe {
42 match mem::size_of::<T>() {
43 1 => {
44 let beg = dst.as_mut_ptr();
47 let val = mem::transmute_copy::<_, u8>(&val);
48 beg.write_bytes(val, dst.len());
49 }
50 2 => {
51 let beg = dst.as_mut_ptr();
52 let end = beg.add(dst.len());
53 let val = mem::transmute_copy::<_, u16>(&val);
54 memset_raw(beg as *mut u8, end as *mut u8, val as u64 * 0x0001000100010001);
55 }
56 4 => {
57 let beg = dst.as_mut_ptr();
58 let end = beg.add(dst.len());
59 let val = mem::transmute_copy::<_, u32>(&val);
60 memset_raw(beg as *mut u8, end as *mut u8, val as u64 * 0x0000000100000001);
61 }
62 8 => {
63 let beg = dst.as_mut_ptr();
64 let end = beg.add(dst.len());
65 let val = mem::transmute_copy::<_, u64>(&val);
66 memset_raw(beg as *mut u8, end as *mut u8, val);
67 }
68 _ => unreachable!(),
69 }
70 }
71}
72
73#[inline]
74fn memset_raw(beg: *mut u8, end: *mut u8, val: u64) {
75 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76 return unsafe { MEMSET_DISPATCH(beg, end, val) };
77
78 #[cfg(target_arch = "aarch64")]
79 return unsafe { memset_neon(beg, end, val) };
80
81 #[allow(unreachable_code)]
82 return unsafe { memset_fallback(beg, end, val) };
83}
84
85#[inline(never)]
86unsafe fn memset_fallback(mut beg: *mut u8, end: *mut u8, val: u64) {
87 unsafe {
88 let mut remaining = end.offset_from_unsigned(beg);
89
90 while remaining >= 8 {
91 (beg as *mut u64).write_unaligned(val);
92 beg = beg.add(8);
93 remaining -= 8;
94 }
95
96 if remaining >= 4 {
97 (beg as *mut u32).write_unaligned(val as u32);
99 (end.sub(4) as *mut u32).write_unaligned(val as u32);
100 } else if remaining >= 2 {
101 (beg as *mut u16).write_unaligned(val as u16);
103 (end.sub(2) as *mut u16).write_unaligned(val as u16);
104 } else if remaining >= 1 {
105 beg.write(val as u8);
107 }
108 }
109}
110
111#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112static mut MEMSET_DISPATCH: unsafe fn(beg: *mut u8, end: *mut u8, val: u64) = memset_dispatch;
113
114#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
115fn memset_dispatch(beg: *mut u8, end: *mut u8, val: u64) {
116 let func = if is_x86_feature_detected!("avx2") { memset_avx2 } else { memset_sse2 };
117 unsafe { MEMSET_DISPATCH = func };
118 unsafe { func(beg, end, val) }
119}
120
121#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
122#[target_feature(enable = "sse2")]
123unsafe fn memset_sse2(mut beg: *mut u8, end: *mut u8, val: u64) {
124 unsafe {
125 #[cfg(target_arch = "x86")]
126 use std::arch::x86::*;
127 #[cfg(target_arch = "x86_64")]
128 use std::arch::x86_64::*;
129
130 let mut remaining = end.offset_from_unsigned(beg);
131
132 if remaining >= 16 {
133 let fill = _mm_set1_epi64x(val as i64);
134
135 while remaining >= 32 {
136 _mm_storeu_si128(beg as *mut _, fill);
137 _mm_storeu_si128(beg.add(16) as *mut _, fill);
138
139 beg = beg.add(32);
140 remaining -= 32;
141 }
142
143 if remaining >= 16 {
144 _mm_storeu_si128(beg as *mut _, fill);
146 _mm_storeu_si128(end.sub(16) as *mut _, fill);
147 return;
148 }
149 }
150
151 if remaining >= 8 {
152 (beg as *mut u64).write_unaligned(val);
154 (end.sub(8) as *mut u64).write_unaligned(val);
155 } else if remaining >= 4 {
156 (beg as *mut u32).write_unaligned(val as u32);
158 (end.sub(4) as *mut u32).write_unaligned(val as u32);
159 } else if remaining >= 2 {
160 (beg as *mut u16).write_unaligned(val as u16);
162 (end.sub(2) as *mut u16).write_unaligned(val as u16);
163 } else if remaining >= 1 {
164 beg.write(val as u8);
166 }
167 }
168}
169
170#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
171#[target_feature(enable = "avx2")]
172fn memset_avx2(mut beg: *mut u8, end: *mut u8, val: u64) {
173 unsafe {
174 #[cfg(target_arch = "x86")]
175 use std::arch::x86::*;
176 #[cfg(target_arch = "x86_64")]
177 use std::arch::x86_64::*;
178 use std::hint::black_box;
179
180 let mut remaining = end.offset_from_unsigned(beg);
181
182 if remaining >= 128 {
183 let fill = _mm256_set1_epi64x(val as i64);
184
185 loop {
186 _mm256_storeu_si256(beg as *mut _, fill);
187 _mm256_storeu_si256(beg.add(32) as *mut _, fill);
188 _mm256_storeu_si256(beg.add(64) as *mut _, fill);
189 _mm256_storeu_si256(beg.add(96) as *mut _, fill);
190
191 beg = beg.add(128);
192 remaining -= 128;
193 if remaining < 128 {
194 break;
195 }
196 }
197 }
198
199 if remaining >= 16 {
200 let fill = _mm_set1_epi64x(val as i64);
201
202 loop {
203 #[allow(clippy::unit_arg)]
206 black_box(_mm_storeu_si128(beg as *mut _, fill));
207
208 beg = beg.add(16);
209 remaining -= 16;
210 if remaining < 16 {
211 break;
212 }
213 }
214 }
215
216 if remaining >= 8 {
220 (beg as *mut u64).write_unaligned(val);
222 (end.sub(8) as *mut u64).write_unaligned(val);
223 } else if remaining >= 4 {
224 (beg as *mut u32).write_unaligned(val as u32);
226 (end.sub(4) as *mut u32).write_unaligned(val as u32);
227 } else if remaining >= 2 {
228 (beg as *mut u16).write_unaligned(val as u16);
230 (end.sub(2) as *mut u16).write_unaligned(val as u16);
231 } else if remaining >= 1 {
232 beg.write(val as u8);
234 }
235 }
236}
237
238#[cfg(target_arch = "aarch64")]
239unsafe fn memset_neon(mut beg: *mut u8, end: *mut u8, val: u64) {
240 unsafe {
241 use std::arch::aarch64::*;
242 let mut remaining = end.offset_from_unsigned(beg);
243
244 if remaining >= 32 {
245 let fill = vdupq_n_u64(val);
246
247 loop {
248 vst1q_u64(beg as *mut _, fill);
250 vst1q_u64(beg.add(16) as *mut _, fill);
251
252 beg = beg.add(32);
253 remaining -= 32;
254 if remaining < 32 {
255 break;
256 }
257 }
258 }
259
260 if remaining >= 16 {
261 let fill = vdupq_n_u64(val);
263 vst1q_u64(beg as *mut _, fill);
264 vst1q_u64(end.sub(16) as *mut _, fill);
265 } else if remaining >= 8 {
266 (beg as *mut u64).write_unaligned(val);
268 (end.sub(8) as *mut u64).write_unaligned(val);
269 } else if remaining >= 4 {
270 (beg as *mut u32).write_unaligned(val as u32);
272 (end.sub(4) as *mut u32).write_unaligned(val as u32);
273 } else if remaining >= 2 {
274 (beg as *mut u16).write_unaligned(val as u16);
276 (end.sub(2) as *mut u16).write_unaligned(val as u16);
277 } else if remaining >= 1 {
278 beg.write(val as u8);
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use std::fmt;
287 use std::ops::Not;
288
289 use super::*;
290
291 fn check_memset<T>(val: T, len: usize)
292 where
293 T: MemsetSafe + Not<Output = T> + PartialEq + fmt::Debug,
294 {
295 let mut buf = vec![!val; len];
296 memset(&mut buf, val);
297 assert!(buf.iter().all(|&x| x == val));
298 }
299
300 #[test]
301 fn test_memset_empty() {
302 check_memset(0u8, 0);
303 check_memset(0u16, 0);
304 check_memset(0u32, 0);
305 check_memset(0u64, 0);
306 }
307
308 #[test]
309 fn test_memset_single() {
310 check_memset(0u8, 1);
311 check_memset(0xFFu8, 1);
312 check_memset(0xABu16, 1);
313 check_memset(0x12345678u32, 1);
314 check_memset(0xDEADBEEFu64, 1);
315 }
316
317 #[test]
318 fn test_memset_small() {
319 for &len in &[2, 3, 4, 5, 7, 8, 9] {
320 check_memset(0xAAu8, len);
321 check_memset(0xBEEFu16, len);
322 check_memset(0xCAFEBABEu32, len);
323 check_memset(0x1234567890ABCDEFu64, len);
324 }
325 }
326
327 #[test]
328 fn test_memset_large() {
329 check_memset(0u8, 1000);
330 check_memset(0xFFu8, 1024);
331 check_memset(0xBEEFu16, 512);
332 check_memset(0xCAFEBABEu32, 256);
333 check_memset(0x1234567890ABCDEFu64, 128);
334 }
335
336 #[test]
337 fn test_memset_various_values() {
338 check_memset(0u8, 17);
339 check_memset(0x7Fu8, 17);
340 check_memset(0x8001u16, 17);
341 check_memset(0xFFFFFFFFu32, 17);
342 check_memset(0x8000000000000001u64, 17);
343 }
344
345 #[test]
346 fn test_memset_signed_types() {
347 check_memset(-1i8, 8);
348 check_memset(-2i16, 8);
349 check_memset(-3i32, 8);
350 check_memset(-4i64, 8);
351 check_memset(-5isize, 8);
352 }
353
354 #[test]
355 fn test_memset_usize_isize() {
356 check_memset(0usize, 4);
357 check_memset(usize::MAX, 4);
358 check_memset(0isize, 4);
359 check_memset(isize::MIN, 4);
360 }
361
362 #[test]
363 fn test_memset_alignment() {
364 let mut buf = [0u8; 15];
366 for offset in 0..8 {
367 let slice = &mut buf[offset..(offset + 7)];
368 memset(slice, 0x5A);
369 assert!(slice.iter().all(|&x| x == 0x5A));
370 }
371 }
372}