1use crate::bit_util::ceil;
21
22pub fn set_bits(
29 write_data: &mut [u8],
30 data: &[u8],
31 offset_write: usize,
32 offset_read: usize,
33 len: usize,
34) -> usize {
35 assert!(
36 offset_write
37 .checked_add(len)
38 .expect("operation will overflow write buffer")
39 <= write_data.len() * 8
40 );
41 assert!(
42 offset_read
43 .checked_add(len)
44 .expect("operation will overflow read buffer")
45 <= data.len() * 8
46 );
47 let mut null_count = 0;
48 let mut acc = 0;
49 while len > acc {
50 let (n, len_set) = unsafe {
54 set_upto_64bits(
55 write_data,
56 data,
57 offset_write + acc,
58 offset_read + acc,
59 len - acc,
60 )
61 };
62 null_count += n;
63 acc += len_set;
64 }
65
66 null_count
67}
68
69#[inline]
75unsafe fn set_upto_64bits(
76 write_data: &mut [u8],
77 data: &[u8],
78 offset_write: usize,
79 offset_read: usize,
80 len: usize,
81) -> (usize, usize) {
82 let read_byte = offset_read / 8;
83 let read_shift = offset_read % 8;
84 let write_byte = offset_write / 8;
85 let write_shift = offset_write % 8;
86
87 if len >= 64 {
88 let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
89 if read_shift == 0 {
90 if write_shift == 0 {
91 let len = 64;
93 let null_count = chunk.count_zeros() as usize;
94 unsafe { write_u64_bytes(write_data, write_byte, chunk) };
95 (null_count, len)
96 } else {
97 let len = 64 - write_shift;
99 let chunk = chunk << write_shift;
100 let null_count = len - chunk.count_ones() as usize;
101 unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
102 (null_count, len)
103 }
104 } else if write_shift == 0 {
105 let len = 64 - 8; let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; let null_count = len - chunk.count_ones() as usize;
109 unsafe { write_u64_bytes(write_data, write_byte, chunk) };
110 (null_count, len)
111 } else {
112 let len = 64 - std::cmp::max(read_shift, write_shift);
113 let chunk = (chunk >> read_shift) << write_shift;
114 let null_count = len - chunk.count_ones() as usize;
115 unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
116 (null_count, len)
117 }
118 } else if len == 1 {
119 let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
120 unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
121 ((byte_chunk ^ 1) as usize, 1)
122 } else {
123 let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
124 let bytes = ceil(len + read_shift, 8);
125 let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
127 let mask = u64::MAX >> (64 - len);
128 let chunk = (chunk >> read_shift) & mask; let chunk = chunk << write_shift; let null_count = len - chunk.count_ones() as usize;
131 let bytes = ceil(len + write_shift, 8);
132 for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
133 unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
134 }
135 (null_count, len)
136 }
137}
138
139#[inline]
142unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
143 debug_assert!(count <= 8);
144 let mut tmp: u64 = 0;
145 let src = unsafe { data.as_ptr().add(offset) };
146 unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) };
147 tmp
148}
149
150#[inline]
153unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
154 let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64;
155 unsafe { ptr.write_unaligned(chunk) };
156}
157
158#[inline]
164unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
165 let ptr = unsafe { data.as_mut_ptr().add(offset) };
166 let chunk = chunk | (unsafe { *ptr }) as u64;
167 unsafe { (ptr as *mut u64).write_unaligned(chunk) };
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::bit_util::{get_bit, set_bit, unset_bit};
174 use rand::prelude::StdRng;
175 use rand::{Rng, SeedableRng, TryRngCore};
176 use std::fmt::Display;
177
178 #[test]
179 fn test_set_bits_aligned() {
180 SetBitsTest {
181 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
182 data: vec![
183 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
184 0b10100101,
185 ],
186 offset_write: 8,
187 offset_read: 0,
188 len: 64,
189 expected_data: vec![
190 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
191 0b11100111, 0b10100101, 0,
192 ],
193 expected_null_count: 24,
194 }
195 .verify();
196 }
197
198 #[test]
199 fn test_set_bits_unaligned_destination_start() {
200 SetBitsTest {
201 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
202 data: vec![
203 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
204 0b10100101,
205 ],
206 offset_write: 3,
207 offset_read: 0,
208 len: 64,
209 expected_data: vec![
210 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110,
211 0b00101111, 0b00000101, 0b00000000,
212 ],
213 expected_null_count: 24,
214 }
215 .verify();
216 }
217
218 #[test]
219 fn test_set_bits_unaligned_destination_end() {
220 SetBitsTest {
221 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
222 data: vec![
223 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
224 0b10100101,
225 ],
226 offset_write: 8,
227 offset_read: 0,
228 len: 62,
229 expected_data: vec![
230 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
231 0b11100111, 0b00100101, 0,
232 ],
233 expected_null_count: 23,
234 }
235 .verify();
236 }
237
238 #[test]
239 fn test_set_bits_unaligned() {
240 SetBitsTest {
241 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
242 data: vec![
243 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
244 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101,
245 0b10011001, 0b11011011, 0b11101011, 0b11000011,
246 ],
247 offset_write: 3,
248 offset_read: 5,
249 len: 95,
250 expected_data: vec![
251 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001,
252 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001,
253 ],
254 expected_null_count: 35,
255 }
256 .verify();
257 }
258
259 #[test]
260 fn set_bits_fuzz() {
261 let mut rng = StdRng::seed_from_u64(42);
262 let mut data = SetBitsTest::new();
263 for _ in 0..100 {
264 data.regen(&mut rng);
265 data.verify();
266 }
267 }
268
269 #[derive(Debug, Default)]
270 struct SetBitsTest {
271 write_data: Vec<u8>,
273 data: Vec<u8>,
275 offset_write: usize,
276 offset_read: usize,
277 len: usize,
278 expected_data: Vec<u8>,
280 expected_null_count: usize,
282 }
283
284 struct BinaryFormatter<'a>(&'a [u8]);
286 impl Display for BinaryFormatter<'_> {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 for byte in self.0 {
289 write!(f, "{byte:08b} ")?;
290 }
291 write!(f, " ")?;
292 Ok(())
293 }
294 }
295
296 impl Display for SetBitsTest {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 writeln!(f, "SetBitsTest {{")?;
299 writeln!(f, " write_data: {}", BinaryFormatter(&self.write_data))?;
300 writeln!(f, " data: {}", BinaryFormatter(&self.data))?;
301 writeln!(
302 f,
303 " expected_data: {}",
304 BinaryFormatter(&self.expected_data)
305 )?;
306 writeln!(f, " offset_write: {}", self.offset_write)?;
307 writeln!(f, " offset_read: {}", self.offset_read)?;
308 writeln!(f, " len: {}", self.len)?;
309 writeln!(f, " expected_null_count: {}", self.expected_null_count)?;
310 writeln!(f, "}}")
311 }
312 }
313
314 impl SetBitsTest {
315 fn new() -> Self {
317 Self::default()
318 }
319
320 fn regen(&mut self, rng: &mut StdRng) {
322 let len = rng.random_range(0..=200);
334
335 let offset_write_bits = rng.random_range(0..=200);
337 let offset_write_bytes = if offset_write_bits % 8 == 0 {
338 offset_write_bits / 8
339 } else {
340 (offset_write_bits / 8) + 1
341 };
342 let extra_write_data_bytes = rng.random_range(0..=5); let extra_read_data_bytes = rng.random_range(0..=5); let offset_read_bits = rng.random_range(0..=200);
347 let offset_read_bytes = if offset_read_bits % 8 != 0 {
348 (offset_read_bits / 8) + 1
349 } else {
350 offset_read_bits / 8
351 };
352
353 self.write_data.clear();
355 self.write_data
356 .resize(offset_write_bytes + len + extra_write_data_bytes, 0);
357
358 self.offset_write = offset_write_bits;
362
363 self.data
365 .resize(offset_read_bytes + len + extra_read_data_bytes, 0);
366 rng.try_fill_bytes(self.data.as_mut_slice()).unwrap();
368 self.offset_read = offset_read_bits;
369
370 self.len = len;
371
372 self.expected_data.resize(self.write_data.len(), 0);
374 self.expected_data.copy_from_slice(&self.write_data);
375
376 self.expected_null_count = 0;
377 for i in 0..self.len {
378 let bit = get_bit(&self.data, self.offset_read + i);
379 if bit {
380 set_bit(&mut self.expected_data, self.offset_write + i);
381 } else {
382 unset_bit(&mut self.expected_data, self.offset_write + i);
383 self.expected_null_count += 1;
384 }
385 }
386 }
387
388 fn verify(&self) {
390 let mut actual = self.write_data.to_vec();
392 let null_count = set_bits(
393 &mut actual,
394 &self.data,
395 self.offset_write,
396 self.offset_read,
397 self.len,
398 );
399
400 assert_eq!(actual, self.expected_data, "self: {self}");
401 assert_eq!(null_count, self.expected_null_count, "self: {self}");
402 }
403 }
404
405 #[test]
406 fn test_set_upto_64bits() {
407 let write_data: &mut [u8] = &mut [0; 9];
409 let data: &[u8] = &[
410 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
411 0b00000001, 0b00000001,
412 ];
413 let offset_write = 1;
414 let offset_read = 0;
415 let len = 65;
416 let (n, len_set) =
417 unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
418 assert_eq!(n, 55);
419 assert_eq!(len_set, 63);
420 assert_eq!(
421 write_data,
422 &[
423 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
424 0b00000010, 0b00000000
425 ]
426 );
427
428 let write_data: &mut [u8] = &mut [0b00000000];
430 let data: &[u8] = &[0b00000001];
431 let offset_write = 1;
432 let offset_read = 0;
433 let len = 1;
434 let (n, len_set) =
435 unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
436 assert_eq!(n, 0);
437 assert_eq!(len_set, 1);
438 assert_eq!(write_data, &[0b00000010]);
439 }
440
441 #[test]
442 #[should_panic(expected = "operation will overflow read buffer")]
443 fn test_overflow_read_buffer_bounds() {
444 let data = [0u8; 1];
446 let mut write_data = [0u8; 1];
447
448 let offset_write: usize = 0;
452 let offset_read: usize = usize::MAX - 7;
453 let len: usize = 8;
454
455 let _nulls = set_bits(&mut write_data, &data, offset_write, offset_read, len);
457 }
458
459 #[test]
460 #[should_panic(expected = "operation will overflow write buffer")]
461 fn test_overflow_write_buffer_bounds() {
462 let data = [0u8; 1];
464 let mut write_data = [0u8; 1];
465
466 let offset_write: usize = usize::MAX - 7;
470 let offset_read: usize = 0;
471 let len: usize = 8;
472
473 let _nulls = set_bits(&mut write_data, &data, offset_write, offset_read, len);
475 }
476}