hyper_gen/
hd.rs

1use log::info;
2use std::arch::x86_64::*;
3use std::collections::HashSet;
4
5use crate::types::FileSketch;
6use rand::{RngCore, SeedableRng};
7use wyhash::WyRng;
8
9extern crate bitpacking;
10use bitpacking::{BitPacker, BitPacker8x};
11
12use rayon::prelude::*;
13
14#[target_feature(enable = "avx2")]
15pub unsafe fn encode_hash_hd_avx2(kmer_hash_set: &HashSet<u64>, sketch: &FileSketch) -> Vec<i16> {
16    let hv_d = sketch.hv_d;
17    let _mm256_const_one_epi16 = _mm256_set1_epi16(1);
18    let _mm256_const_zero = _mm256_setzero_si256();
19    let shuffle_mask = _mm256_set_epi8(
20        15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0, 15, 14, 7, 6, 13, 12, 5, 4, 11, 10,
21        3, 2, 9, 8, 1, 0,
22    );
23
24    let mut rng_vec = vec![WyRng::default(); 4];
25    // let mut rng_vec = vec![SeedableRng::seed_from_u64(0); 4];
26    let mut rnd_vec: Vec<u64> = vec![0; 4];
27
28    let num_seed = kmer_hash_set.len();
29    let mut hv = vec![-(num_seed as i16); hv_d];
30
31    let num_tail = num_seed % 4;
32    let num_seed_round_4 = num_seed + (if num_tail == 0 { 0 } else { 4 - num_tail });
33    let num_batch_round_4 = num_seed_round_4 / 4;
34    let num_chunk = hv_d / 64;
35
36    let mut seed_vec = Vec::from_iter(kmer_hash_set.clone());
37    // padding seed_vec with size rounded by 4
38    seed_vec.resize(num_seed_round_4, 0);
39
40    // loop through all batches with seeds<=4
41    for b_i in 0..num_batch_round_4 {
42        // fetch seeds and load into RNG
43        for j in 0..4 {
44            rng_vec[j] = WyRng::seed_from_u64(seed_vec[b_i * 4 + j]);
45        }
46
47        // SIMD-based HV encoding
48        for i in 0..num_chunk {
49            // load rnd into 256b buffer
50            for j in 0..4 {
51                rnd_vec[j] = rng_vec[j].next_u64();
52            }
53
54            if b_i == num_batch_round_4 - 1 && num_tail > 0 {
55                for j in num_tail..4 {
56                    rnd_vec[j] = 0;
57                }
58            }
59
60            // HD aggregation encoding
61            let simd_rnd_4_shuffle = _mm256_shuffle_epi8(
62                _mm256_set_epi64x(
63                    rnd_vec[0] as i64,
64                    rnd_vec[1] as i64,
65                    rnd_vec[2] as i64,
66                    rnd_vec[3] as i64,
67                ),
68                shuffle_mask,
69            );
70
71            for k in 0..16_usize {
72                let shift_and_256 = _mm256_and_si256(
73                    _mm256_srl_epi16(simd_rnd_4_shuffle, _mm_set1_epi64x(k as i64)),
74                    // _mm256_slli_epi16(simd_rnd_4_shuffle, k),
75                    _mm256_const_one_epi16,
76                );
77
78                let mut hadd_ = _mm256_hadd_epi16(shift_and_256, _mm256_const_zero);
79                hadd_ = _mm256_permute4x64_epi64(hadd_, 0xD8);
80                hadd_ = _mm256_shuffle_epi8(hadd_, shuffle_mask);
81                hadd_ = _mm256_hadd_epi16(hadd_, _mm256_const_zero);
82                hadd_ = _mm256_slli_epi16(hadd_, 1);
83
84                hv[i * 64 + k * 4] += _mm256_extract_epi16::<0>(hadd_) as i16;
85                hv[i * 64 + k * 4 + 1] += _mm256_extract_epi16::<1>(hadd_) as i16;
86                hv[i * 64 + k * 4 + 2] += _mm256_extract_epi16::<2>(hadd_) as i16;
87                hv[i * 64 + k * 4 + 3] += _mm256_extract_epi16::<3>(hadd_) as i16;
88            }
89        }
90    }
91    hv
92}
93
94pub fn encode_hash_hd(kmer_hash_set: &HashSet<u64>, sketch: &FileSketch) -> Vec<i16> {
95    let hv_d = sketch.hv_d;
96    let seed_vec = Vec::from_iter(kmer_hash_set.clone());
97    let mut hv = vec![-(kmer_hash_set.len() as i16); hv_d];
98
99    for hash in seed_vec {
100        let mut rng = WyRng::seed_from_u64(hash);
101
102        for i in 0..(hv_d / 64) {
103            let rnd_btis = rng.next_u64();
104
105            for j in 0..64 {
106                hv[i * 64 + j] += (((rnd_btis >> j) & 1) << 1) as i16;
107            }
108        }
109    }
110
111    hv
112}
113
114#[cfg(target_arch = "x86_64")]
115#[target_feature(enable = "avx2")]
116pub unsafe fn compress_hd_sketch(sketch: &mut FileSketch, hv: &Vec<i16>) -> u8 {
117    let hv_d = sketch.hv_d;
118
119    // find the lossless quantization bit width
120    let min_hv = hv.iter().min().unwrap().clone();
121    let max_hv = hv.iter().max().unwrap().clone();
122
123    let mut quant_bit: i16 = 6;
124    loop {
125        let quant_min: i16 = -(1 << (quant_bit - 1));
126        let quant_max: i16 = (1 << (quant_bit - 1)) - 1;
127
128        if quant_min <= min_hv && quant_max >= max_hv {
129            break;
130        }
131
132        if quant_bit == 16 {
133            break;
134        }
135        quant_bit += 1;
136    }
137
138    // bit packing
139    if is_x86_feature_detected!("avx2") {
140        let offset: i16 = 1 << (quant_bit - 1);
141        let hv_u32: Vec<u32> = hv.iter().map(|&i| (i + offset) as u32).collect();
142
143        let bitpacker = BitPacker8x::new();
144        let bits_per_block = quant_bit as usize * 32;
145
146        let mut hv_compress_bits = vec![0u8; (quant_bit as usize) * (hv_d >> 3)];
147        for i in 0..(hv_d / BitPacker8x::BLOCK_LEN) {
148            bitpacker.compress(
149                &hv_u32[i * BitPacker8x::BLOCK_LEN..(i + 1) * BitPacker8x::BLOCK_LEN],
150                &mut hv_compress_bits[(bits_per_block * i)..(bits_per_block * (i + 1))],
151                quant_bit as u8,
152            );
153        }
154
155        sketch
156            .hv
157            .clone_from(&hv_compress_bits[..].align_to::<i16>().1.to_vec());
158    } else {
159        let len_bit_vec_u16 = (quant_bit as usize * hv_d + 16) / 16;
160        let mut hv_compress_bits: Vec<i16> = vec![0; len_bit_vec_u16];
161        for i in 0..(quant_bit as usize * hv_d) {
162            hv_compress_bits[i / 16] |=
163                ((hv[i / quant_bit as usize] >> (i % quant_bit as usize)) & 1) << (i % 16);
164        }
165        sketch.hv.clone_from(&hv_compress_bits);
166    }
167
168    quant_bit as u8
169}
170
171pub fn decompress_file_sketch(file_sketch: &mut Vec<FileSketch>) {
172    let hv_dim = file_sketch[0].hv_d;
173
174    info!("Decompressing sketch with HV dim={}", hv_dim);
175
176    file_sketch.into_par_iter().for_each(|sketch| {
177        let hv_decompressed = unsafe { decompress_hd_sketch(sketch) };
178        sketch.hv.clone_from(&hv_decompressed);
179    });
180}
181
182#[cfg(target_arch = "x86_64")]
183#[target_feature(enable = "avx2")]
184pub unsafe fn decompress_hd_sketch(sketch: &mut FileSketch) -> Vec<i16> {
185    let hv_d = sketch.hv_d;
186    let quant_bit = sketch.hv_quant_bits;
187
188    let mut hv_decompressed: Vec<i16> = vec![0; hv_d];
189
190    if is_x86_feature_detected!("avx2") {
191        // SIMD-based Bit Unpacking
192        let bitpacker = BitPacker8x::new();
193        let bits_per_block = quant_bit as usize * 32;
194
195        let hv_u8 = sketch.hv.align_to::<u8>().1.to_vec();
196        let mut _hv_decompressed: Vec<u32> = vec![0; hv_d];
197
198        for i in 0..(hv_d / BitPacker8x::BLOCK_LEN) {
199            bitpacker.decompress(
200                &hv_u8[(bits_per_block * i)..(bits_per_block * (i + 1))],
201                &mut _hv_decompressed[i * BitPacker8x::BLOCK_LEN..(i + 1) * BitPacker8x::BLOCK_LEN],
202                quant_bit,
203            );
204        }
205
206        let offset: i16 = 1 << (quant_bit - 1);
207        hv_decompressed.clone_from(
208            &_hv_decompressed
209                .into_iter()
210                .map(|i| (i as i16 - offset))
211                .collect(),
212        );
213    } else {
214        // Scalar Bit Unpacking
215        for i in 0..(quant_bit as usize * hv_d) {
216            hv_decompressed[i / quant_bit as usize] |=
217                (((sketch.hv[i / 16] >> (i % 16)) & 1) << (i % quant_bit as usize)) as i16;
218
219            if (i + 1) % quant_bit as usize == 0 {
220                hv_decompressed[i / quant_bit as usize] = {
221                    if hv_decompressed[i / quant_bit as usize] > (1 << (quant_bit - 1)) {
222                        hv_decompressed[i / quant_bit as usize] - (1 << quant_bit)
223                    } else {
224                        hv_decompressed[i / quant_bit as usize]
225                    }
226                };
227            }
228        }
229    }
230
231    hv_decompressed
232}