1use std::{
65 arch::x86_64::{_mm_loadu_si128, _mm_shuffle_epi8, _mm_storeu_si128},
66 debug_assert,
67 io::{self, BufRead, Write},
68 mem,
69};
70
71#[allow(non_camel_case_types)]
72type u32x4 = [u32; 4];
73
74const MASKS: [(u32x4, u8); 256] = u32_shuffle_masks();
80
81const SEGMENT_MAGIC: u16 = 0x0B0D;
83
84const SEGMENT_HEADER_LENGTH: usize = 14;
86
87pub trait Segments {
100 fn next(&mut self) -> io::Result<usize>;
102
103 fn data_stream(&self) -> &[u8];
105
106 fn control_stream(&self) -> &[u8];
108}
109
110pub struct BufReadSegments<R> {
112 source: R,
113 control_stream: Vec<u8>,
114 data_stream: Vec<u8>,
115}
116
117impl<R> BufReadSegments<R> {
118 pub fn new(source: R) -> Self {
119 Self {
120 source,
121 control_stream: vec![],
122 data_stream: vec![],
123 }
124 }
125}
126
127impl<R: BufRead> Segments for BufReadSegments<R> {
128 fn next(&mut self) -> io::Result<usize> {
129 let result = read_segment(
130 &mut self.source,
131 &mut self.control_stream,
132 &mut self.data_stream,
133 );
134 match result {
135 Ok(elements) => Ok(elements),
136 Err(e) => {
137 if e.kind() == io::ErrorKind::UnexpectedEof {
138 Ok(0)
139 } else {
140 Err(e)
141 }
142 }
143 }
144 }
145
146 fn data_stream(&self) -> &[u8] {
147 self.control_stream.as_ref()
148 }
149
150 fn control_stream(&self) -> &[u8] {
151 self.data_stream.as_ref()
152 }
153}
154
155pub struct MemorySegments<'a> {
157 data: &'a [u8],
158 control_stream: &'a [u8],
159 data_stream: &'a [u8],
160}
161
162impl<'a> MemorySegments<'a> {
163 pub fn new(data: &'a [u8]) -> Self {
164 Self {
165 data,
166 control_stream: &data[0..0],
167 data_stream: &data[0..0],
168 }
169 }
170}
171
172impl<'a> Segments for MemorySegments<'a> {
173 fn next(&mut self) -> io::Result<usize> {
174 if self.data.is_empty() {
175 return Ok(0);
176 }
177
178 let segment = SegmentHeader::parse(self.data);
179 self.control_stream =
180 &self.data[SEGMENT_HEADER_LENGTH..SEGMENT_HEADER_LENGTH + segment.cs_length];
181 self.data_stream = &self.data[SEGMENT_HEADER_LENGTH + segment.cs_length
182 ..SEGMENT_HEADER_LENGTH + segment.cs_length + segment.ds_length];
183 self.data = &self.data[SEGMENT_HEADER_LENGTH + segment.cs_length + segment.ds_length..];
184
185 Ok(segment.count)
186 }
187
188 fn data_stream(&self) -> &[u8] {
189 self.data_stream
190 }
191 fn control_stream(&self) -> &[u8] {
192 self.control_stream
193 }
194}
195
196pub struct DecodeCursor<S: Segments> {
201 elements_left: usize,
202 control_stream_offset: usize,
203 data_stream_offset: usize,
204 segments: S,
205}
206
207impl<S: Segments> DecodeCursor<S> {
208 pub fn new(segments: S) -> io::Result<Self> {
209 Ok(Self {
210 elements_left: 0,
211 control_stream_offset: 0,
212 data_stream_offset: 0,
213 segments,
214 })
215 }
216
217 #[inline(never)]
218 fn refill(&mut self) -> io::Result<usize> {
219 debug_assert!(
220 self.elements_left == 0,
221 "Should be 0, got: {}",
222 self.elements_left
223 );
224
225 let elements = self.segments.next()?;
226 if elements > 0 {
227 let cs = self.segments.control_stream();
228 let ds = self.segments.data_stream();
229 assert!(
230 cs.len() * 4 >= elements,
231 "Invalid control stream length. Expected: {}, got: {}",
232 (elements + 3) / 4,
233 cs.len()
234 );
235 assert!(
236 ds.len() >= elements,
237 "Invalid data stream length. Expected: >={}, got: {}",
238 elements,
239 ds.len()
240 );
241 self.data_stream_offset = 0;
242 self.control_stream_offset = 0;
243 self.elements_left = elements;
244 }
245 Ok(elements)
246 }
247}
248
249#[derive(Debug, PartialEq)]
253struct SegmentHeader {
254 count: usize,
255 cs_length: usize,
256 ds_length: usize,
257}
258
259impl SegmentHeader {
260 fn new(count: usize, cs_size: usize, ds_size: usize) -> Self {
261 Self {
262 count,
263 cs_length: cs_size,
264 ds_length: ds_size,
265 }
266 }
267
268 fn parse(input: &[u8]) -> Self {
269 assert!(
270 input.len() >= SEGMENT_HEADER_LENGTH,
271 "Expected slice of len >={}, got: {}",
272 SEGMENT_HEADER_LENGTH,
273 input.len()
274 );
275 let input = &input[..SEGMENT_HEADER_LENGTH];
276
277 let magic = u16::from_be_bytes(input[0..2].try_into().unwrap());
278 let count = u32::from_be_bytes(input[2..6].try_into().unwrap()) as usize;
279 let cs_length = u32::from_be_bytes(input[6..10].try_into().unwrap()) as usize;
280 let ds_length = u32::from_be_bytes(input[10..14].try_into().unwrap()) as usize;
281
282 assert!(
283 magic == SEGMENT_MAGIC,
284 "Expected magic: {}, got: {}",
285 SEGMENT_MAGIC,
286 magic,
287 );
288
289 Self {
290 count,
291 cs_length,
292 ds_length,
293 }
294 }
295
296 fn write(&self, out: &mut dyn Write) -> io::Result<()> {
297 out.write_all(&SEGMENT_MAGIC.to_be_bytes())?;
298
299 debug_assert!(self.count <= u32::MAX as usize);
300 let number_of_elements = (self.count as u32).to_be_bytes();
301 out.write_all(&number_of_elements)?;
302
303 debug_assert!(self.cs_length <= u32::MAX as usize);
304 let cs_len = (self.cs_length as u32).to_be_bytes();
305 out.write_all(&cs_len)?;
306
307 debug_assert!(self.ds_length <= u32::MAX as usize);
308 let ds_len = (self.ds_length as u32).to_be_bytes();
309 out.write_all(&ds_len)?;
310
311 Ok(())
312 }
313}
314
315fn read_segment(input: &mut impl BufRead, cs: &mut Vec<u8>, ds: &mut Vec<u8>) -> io::Result<usize> {
319 let mut buf = [0u8; SEGMENT_HEADER_LENGTH];
320 input.read_exact(&mut buf)?;
321 let header = SegmentHeader::parse(&buf);
322
323 cs.resize(header.cs_length, 0);
324 input.read_exact(&mut cs[..header.cs_length])?;
325
326 ds.resize(header.ds_length, 0);
327 input.read_exact(&mut ds[..header.ds_length])?;
328
329 Ok(header.count)
330}
331
332impl<S: Segments> Decoder<u32> for DecodeCursor<S> {
333 fn decode(&mut self, buffer: &mut [u32]) -> io::Result<usize> {
334 const DECODE_WIDTH: usize = 4;
336 assert!(
337 buffer.len() >= DECODE_WIDTH,
338 "Buffer should be at least {} elements long",
339 DECODE_WIDTH
340 );
341 if self.elements_left == 0 && self.refill()? == 0 {
342 return Ok(0);
343 }
344
345 let mut data_stream_offset = self.data_stream_offset;
346 let control_stream = &self.segments.control_stream()[self.control_stream_offset..];
347 let data_stream = &self.segments.data_stream()[data_stream_offset..];
348 let mut data_stream = data_stream.as_ptr();
349
350 let mut iterations = buffer.len() / 4;
363 iterations = iterations.min(control_stream.len());
364
365 self.control_stream_offset += iterations;
366 let decoded = iterations * DECODE_WIDTH;
367
368 let mut buffer: *mut u32x4 = buffer.as_mut_ptr().cast();
369 let mut control_words = control_stream.as_ptr();
370
371 const UNROLL_FACTOR: usize = 8;
373 while iterations >= UNROLL_FACTOR {
374 for _ in 0..UNROLL_FACTOR {
375 let encoded_len = unsafe {
376 debug_assert!(
377 self.segments.data_stream()[data_stream_offset..].len() >= 16,
378 "At least 16 bytes should be available in data stream"
379 );
380 let data_stream = mem::transmute(data_stream);
381 let output = mem::transmute(buffer);
382 simd_decode(data_stream, *control_words, output)
383 };
384
385 control_words = control_words.wrapping_add(1);
386 buffer = buffer.wrapping_add(1);
387
388 data_stream = data_stream.wrapping_add(encoded_len as usize);
389 data_stream_offset += encoded_len as usize;
390 }
391
392 iterations -= UNROLL_FACTOR;
393 }
394
395 while iterations > 0 {
397 let encoded_len = unsafe {
398 debug_assert!(
399 self.segments.data_stream()[data_stream_offset..].len() >= 16,
400 "At least 16 bytes should be available in data stream"
401 );
402 let data_stream = mem::transmute(data_stream);
403 let output = mem::transmute(buffer);
404 simd_decode(data_stream, *control_words, output)
405 };
406
407 control_words = control_words.wrapping_add(1);
408 buffer = buffer.wrapping_add(1);
409
410 data_stream = data_stream.wrapping_add(encoded_len as usize);
411 data_stream_offset += encoded_len as usize;
412
413 iterations -= 1;
414 }
415
416 self.data_stream_offset = data_stream_offset;
417 let decoded = decoded.min(self.elements_left);
418 self.elements_left -= decoded;
419 Ok(decoded)
420 }
421}
422
423#[inline]
440fn simd_decode(input: &[u8; 16], control_word: u8, output: &mut u32x4) -> u8 {
441 let (ref mask, encoded_len) = MASKS[control_word as usize];
442 unsafe {
443 let mask = _mm_loadu_si128(mask.as_ptr().cast());
444 let input = _mm_loadu_si128(input.as_ptr().cast());
445 let answer = _mm_shuffle_epi8(input, mask);
446 _mm_storeu_si128(output.as_mut_ptr().cast(), answer);
447 }
448
449 encoded_len
450}
451
452const fn u32_shuffle_mask(len: usize, offset: usize) -> u32 {
459 const PZ: u8 = 0b10000000;
460 assert!(offset < 16, "Offset should be <16");
461 let offset = offset as u8;
462 let p1 = offset;
463 let p2 = offset + 1;
464 let p3 = offset + 2;
465 let p4 = offset + 3;
466 match len {
467 1 => u32::from_be_bytes([PZ, PZ, PZ, p1]),
468 2 => u32::from_be_bytes([PZ, PZ, p1, p2]),
469 3 => u32::from_be_bytes([PZ, p1, p2, p3]),
470 4 => u32::from_be_bytes([p1, p2, p3, p4]),
471 _ => panic!("Length of u32 is 1..=4 bytes"),
472 }
473}
474
475const fn u32_shuffle_masks() -> [(u32x4, u8); 256] {
508 let mut masks = [([0u32; 4], 0u8); 256];
509
510 let mut a = 1;
511 while a <= 4 {
512 let mut b = 1;
513 while b <= 4 {
514 let mut c = 1;
515 while c <= 4 {
516 let mut d = 1;
517 while d <= 4 {
518 let mask = [
520 u32_shuffle_mask(a, 0),
521 u32_shuffle_mask(b, a),
522 u32_shuffle_mask(c, a + b),
523 u32_shuffle_mask(d, a + b + c),
524 ];
525
526 let idx = (a - 1) << 6 | (b - 1) << 4 | (c - 1) << 2 | (d - 1);
528 assert!(a + b + c + d <= 16);
529 masks[idx] = (mask, (a + b + c + d) as u8);
530 d += 1;
531 }
532 c += 1;
533 }
534 b += 1;
535 }
536 a += 1;
537 }
538 masks
539}
540
541pub struct EncodeCursor<W> {
564 data_stream: Vec<u8>,
565 control_stream: Vec<u8>,
566 output: Box<W>,
567 written: usize,
568}
569
570impl<W: Write> EncodeCursor<W> {
571 pub fn new(output: W) -> Self {
572 Self {
573 data_stream: vec![],
574 control_stream: vec![],
575 output: Box::new(output),
576 written: 0,
577 }
578 }
579 pub fn encode(&mut self, input: &[u32]) -> io::Result<()> {
581 for n in input {
582 let bytes: [u8; 4] = n.to_be_bytes();
583 let length = 4 - n.leading_zeros() as u8 / 8;
584 let length = length.max(1);
585 debug_assert!((1..=4).contains(&length));
586
587 let control_word = self.get_control_word();
588 *control_word <<= 2;
589 *control_word |= length - 1;
590 self.written += 1;
591
592 self.data_stream.write_all(&bytes[4 - length as usize..])?;
593 self.write_segment_if_needed()?;
594 }
595 Ok(())
596 }
597
598 fn get_control_word(&mut self) -> &mut u8 {
599 if self.written % 4 == 0 {
600 self.control_stream.push(0);
601 }
602 self.control_stream.last_mut().unwrap()
603 }
604
605 fn write_segment_if_needed(&mut self) -> io::Result<()> {
606 const MAX_SEGMENT_SIZE: usize = 8 * 1024;
607 let segment_size = 2 + 4 + 4 + 4 + self.data_stream.len() + self.control_stream.len();
612 if segment_size >= MAX_SEGMENT_SIZE {
613 self.write_segment()?;
614
615 self.written = 0;
616 self.data_stream.clear();
617 self.control_stream.clear();
618 }
619 Ok(())
620 }
621
622 fn write_segment(&mut self) -> io::Result<()> {
623 let tail = self.written % 4;
624 if tail > 0 {
627 let control_word = self.control_stream.last_mut().unwrap();
628 *control_word <<= 2 * (4 - tail);
629 }
630
631 let control_word = self.control_stream.last().unwrap();
635 let quadruple_length =
636 byte_to_4_length(*control_word).iter().sum::<u8>() as usize - (4 - tail);
637
638 for _ in quadruple_length..16 {
639 self.data_stream.write_all(&[0])?;
640 }
641
642 let header = SegmentHeader::new(
643 self.written,
644 self.control_stream.len(),
645 self.data_stream.len(),
646 );
647 header.write(&mut self.output)?;
648
649 self.output.write_all(&self.control_stream)?;
650 self.output.write_all(&self.data_stream)?;
651
652 Ok(())
653 }
654
655 pub fn finish(mut self) -> io::Result<W> {
660 self.write_segment()?;
661 Ok(*self.output)
662 }
663}
664
665pub trait Decoder<T: Copy + From<u8>> {
667 fn decode(&mut self, buffer: &mut [T]) -> io::Result<usize>;
673
674 fn to_vec(mut self) -> io::Result<Vec<T>>
676 where
677 Self: Sized,
678 {
679 let mut buffer = [0u8.into(); 128];
680 let mut result = vec![];
681 let mut len = self.decode(&mut buffer)?;
682 while len > 0 {
683 result.extend(&buffer[..len]);
684 len = self.decode(&mut buffer)?;
685 }
686 Ok(result)
687 }
688}
689
690fn byte_to_4_length(input: u8) -> [u8; 4] {
694 [
695 (input.rotate_left(2) & 0b11) + 1,
696 (input.rotate_left(4) & 0b11) + 1,
697 (input.rotate_left(6) & 0b11) + 1,
698 (input.rotate_left(8) & 0b11) + 1,
699 ]
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use rand::{rngs::ThreadRng, thread_rng, Rng, RngCore};
706 use std::io::{Cursor, Seek, SeekFrom};
707
708 #[test]
709 fn check_encode() {
710 let (control, data, _) = encode_values(&[0x01, 0x0100, 0x010000, 0x01000000, 0x010000]);
711
712 assert_eq!(
713 data,
714 [
715 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
722 ]
723 );
724
725 let len = byte_to_4_length(control[0]);
726 assert_eq!(len, [1, 2, 3, 4]);
727
728 let len = byte_to_4_length(control[1]);
729 assert_eq!(len, [3, 1, 1, 1]);
730 }
731
732 #[test]
733 fn check_small_functional_encode_decode() {
734 let mut rng = thread_rng();
735 for _ in 0..1000 {
736 let len = rng.gen_range(1..20);
737 check_encode_decode_cycle(&mut rng, len);
738 }
739 }
740
741 #[test]
742 fn check_large_functional_encode_decode() {
743 let mut rng = thread_rng();
744 for _ in 0..10 {
745 let len = rng.gen_range(10000..20000);
746 check_encode_decode_cycle(&mut rng, len);
747 }
748 }
749
750 fn check_encode_decode_cycle(rng: &mut ThreadRng, len: usize) {
751 let input: Vec<u32> = generate_random_data(rng, len);
752 let (_, _, encoded) = encode_values(&input);
753 let output = DecodeCursor::new(MemorySegments::new(&encoded.into_inner()))
754 .unwrap()
755 .to_vec()
756 .unwrap();
757 assert_eq!(input.len(), output.len());
758 let chunk_size = 4;
759 for (i, (input, output)) in input
760 .chunks(chunk_size)
761 .zip(output.chunks(chunk_size))
762 .enumerate()
763 {
764 assert_eq!(input, output, "Arrays differs position {}", i * chunk_size);
765 }
766 }
767
768 #[test]
769 fn check_decode() {
770 let input = [1, 255, 1024, 2048, 0xFF000000];
771 let (_, _, encoded) = encode_values(&input);
772 let output = DecodeCursor::new(MemorySegments::new(&encoded.into_inner()))
773 .unwrap()
774 .to_vec()
775 .unwrap();
776 assert_eq!(output.len(), output.len());
777 assert_eq!(output, input);
778 }
779
780 #[allow(clippy::unusual_byte_groupings)]
781 #[test]
782 fn check_create_mask() {
783 assert_eq!(u32_shuffle_mask(1, 0), 0x808080_00);
784 assert_eq!(u32_shuffle_mask(2, 0), 0x8080_0001);
785
786 assert_eq!(u32_shuffle_mask(1, 3), 0x808080_03);
787 assert_eq!(u32_shuffle_mask(2, 3), 0x8080_0304);
788 }
789
790 #[allow(clippy::unusual_byte_groupings)]
791 #[test]
792 fn check_shuffle_masks() {
793 let masks = u32_shuffle_masks();
794 assert_eq!(
795 masks[0b_00_00_00_00],
797 ([0x808080_00, 0x808080_01, 0x808080_02, 0x808080_03], 4)
798 );
799 assert_eq!(
800 masks[0b_11_11_11_11],
802 ([0x00010203, 0x04050607, 0x08090a0b, 0x0c0d0e0f], 16)
803 );
804 assert_eq!(
805 masks[0b_11_00_11_00],
807 ([0x00010203, 0x808080_04, 0x05060708, 0x808080_09], 10)
808 );
809 assert_eq!(
810 masks[0b_11_10_01_00],
812 ([0x00010203, 0x80_040506, 0x8080_0708, 0x808080_09], 10)
813 );
814 }
815
816 #[test]
817 fn check_header_format() {
818 let expected = SegmentHeader::new(3, 1, 2);
819 let mut out = vec![];
820
821 expected.write(&mut out).unwrap();
822 let header = SegmentHeader::parse(&out);
823 assert_eq!(header, expected);
824 }
825
826 pub fn encode_values(input: &[u32]) -> (Vec<u8>, Vec<u8>, Cursor<Vec<u8>>) {
828 let mut encoder = EncodeCursor::new(Cursor::new(vec![]));
829 encoder.encode(input).unwrap();
830 let mut source = encoder.finish().unwrap();
831 let mut cs = vec![];
832 let mut ds = vec![];
833 source.seek(SeekFrom::Start(0)).unwrap();
834 read_segment(&mut source, &mut cs, &mut ds).unwrap();
835 source.seek(SeekFrom::Start(0)).unwrap();
836 (cs, ds, source)
837 }
838
839 fn generate_random_data(rng: &mut ThreadRng, size: usize) -> Vec<u32> {
844 let mut input = vec![];
845 input.resize_with(size, || match rng.gen_range(1..=4) {
846 1 => rng.next_u32() % (0xFF + 1),
847 2 => rng.next_u32() % (0xFFFF + 1),
848 3 => rng.next_u32() % (0xFFFFFF + 1),
849 _ => rng.next_u32(),
850 });
851 input
852 }
853}