1use log::info;
2use std::arch::x86_64::*;
3use std::collections::HashSet;
4
5use crate::types::FileSketch;
6use rand::{RngCore, SeedableRng};
7use wyhash::WyRng;
8
9extern crate bitpacking;
10use bitpacking::{BitPacker, BitPacker8x};
11
12use rayon::prelude::*;
13
14#[target_feature(enable = "avx2")]
15pub unsafe fn encode_hash_hd_avx2(kmer_hash_set: &HashSet<u64>, sketch: &FileSketch) -> Vec<i16> {
16 let hv_d = sketch.hv_d;
17 let _mm256_const_one_epi16 = _mm256_set1_epi16(1);
18 let _mm256_const_zero = _mm256_setzero_si256();
19 let shuffle_mask = _mm256_set_epi8(
20 15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0, 15, 14, 7, 6, 13, 12, 5, 4, 11, 10,
21 3, 2, 9, 8, 1, 0,
22 );
23
24 let mut rng_vec = vec![WyRng::default(); 4];
25 let mut rnd_vec: Vec<u64> = vec![0; 4];
27
28 let num_seed = kmer_hash_set.len();
29 let mut hv = vec![-(num_seed as i16); hv_d];
30
31 let num_tail = num_seed % 4;
32 let num_seed_round_4 = num_seed + (if num_tail == 0 { 0 } else { 4 - num_tail });
33 let num_batch_round_4 = num_seed_round_4 / 4;
34 let num_chunk = hv_d / 64;
35
36 let mut seed_vec = Vec::from_iter(kmer_hash_set.clone());
37 seed_vec.resize(num_seed_round_4, 0);
39
40 for b_i in 0..num_batch_round_4 {
42 for j in 0..4 {
44 rng_vec[j] = WyRng::seed_from_u64(seed_vec[b_i * 4 + j]);
45 }
46
47 for i in 0..num_chunk {
49 for j in 0..4 {
51 rnd_vec[j] = rng_vec[j].next_u64();
52 }
53
54 if b_i == num_batch_round_4 - 1 && num_tail > 0 {
55 for j in num_tail..4 {
56 rnd_vec[j] = 0;
57 }
58 }
59
60 let simd_rnd_4_shuffle = _mm256_shuffle_epi8(
62 _mm256_set_epi64x(
63 rnd_vec[0] as i64,
64 rnd_vec[1] as i64,
65 rnd_vec[2] as i64,
66 rnd_vec[3] as i64,
67 ),
68 shuffle_mask,
69 );
70
71 for k in 0..16_usize {
72 let shift_and_256 = _mm256_and_si256(
73 _mm256_srl_epi16(simd_rnd_4_shuffle, _mm_set1_epi64x(k as i64)),
74 _mm256_const_one_epi16,
76 );
77
78 let mut hadd_ = _mm256_hadd_epi16(shift_and_256, _mm256_const_zero);
79 hadd_ = _mm256_permute4x64_epi64(hadd_, 0xD8);
80 hadd_ = _mm256_shuffle_epi8(hadd_, shuffle_mask);
81 hadd_ = _mm256_hadd_epi16(hadd_, _mm256_const_zero);
82 hadd_ = _mm256_slli_epi16(hadd_, 1);
83
84 hv[i * 64 + k * 4] += _mm256_extract_epi16::<0>(hadd_) as i16;
85 hv[i * 64 + k * 4 + 1] += _mm256_extract_epi16::<1>(hadd_) as i16;
86 hv[i * 64 + k * 4 + 2] += _mm256_extract_epi16::<2>(hadd_) as i16;
87 hv[i * 64 + k * 4 + 3] += _mm256_extract_epi16::<3>(hadd_) as i16;
88 }
89 }
90 }
91 hv
92}
93
94pub fn encode_hash_hd(kmer_hash_set: &HashSet<u64>, sketch: &FileSketch) -> Vec<i16> {
95 let hv_d = sketch.hv_d;
96 let seed_vec = Vec::from_iter(kmer_hash_set.clone());
97 let mut hv = vec![-(kmer_hash_set.len() as i16); hv_d];
98
99 for hash in seed_vec {
100 let mut rng = WyRng::seed_from_u64(hash);
101
102 for i in 0..(hv_d / 64) {
103 let rnd_btis = rng.next_u64();
104
105 for j in 0..64 {
106 hv[i * 64 + j] += (((rnd_btis >> j) & 1) << 1) as i16;
107 }
108 }
109 }
110
111 hv
112}
113
114#[cfg(target_arch = "x86_64")]
115#[target_feature(enable = "avx2")]
116pub unsafe fn compress_hd_sketch(sketch: &mut FileSketch, hv: &Vec<i16>) -> u8 {
117 let hv_d = sketch.hv_d;
118
119 let min_hv = hv.iter().min().unwrap().clone();
121 let max_hv = hv.iter().max().unwrap().clone();
122
123 let mut quant_bit: i16 = 6;
124 loop {
125 let quant_min: i16 = -(1 << (quant_bit - 1));
126 let quant_max: i16 = (1 << (quant_bit - 1)) - 1;
127
128 if quant_min <= min_hv && quant_max >= max_hv {
129 break;
130 }
131
132 if quant_bit == 16 {
133 break;
134 }
135 quant_bit += 1;
136 }
137
138 if is_x86_feature_detected!("avx2") {
140 let offset: i16 = 1 << (quant_bit - 1);
141 let hv_u32: Vec<u32> = hv.iter().map(|&i| (i + offset) as u32).collect();
142
143 let bitpacker = BitPacker8x::new();
144 let bits_per_block = quant_bit as usize * 32;
145
146 let mut hv_compress_bits = vec![0u8; (quant_bit as usize) * (hv_d >> 3)];
147 for i in 0..(hv_d / BitPacker8x::BLOCK_LEN) {
148 bitpacker.compress(
149 &hv_u32[i * BitPacker8x::BLOCK_LEN..(i + 1) * BitPacker8x::BLOCK_LEN],
150 &mut hv_compress_bits[(bits_per_block * i)..(bits_per_block * (i + 1))],
151 quant_bit as u8,
152 );
153 }
154
155 sketch
156 .hv
157 .clone_from(&hv_compress_bits[..].align_to::<i16>().1.to_vec());
158 } else {
159 let len_bit_vec_u16 = (quant_bit as usize * hv_d + 16) / 16;
160 let mut hv_compress_bits: Vec<i16> = vec![0; len_bit_vec_u16];
161 for i in 0..(quant_bit as usize * hv_d) {
162 hv_compress_bits[i / 16] |=
163 ((hv[i / quant_bit as usize] >> (i % quant_bit as usize)) & 1) << (i % 16);
164 }
165 sketch.hv.clone_from(&hv_compress_bits);
166 }
167
168 quant_bit as u8
169}
170
171pub fn decompress_file_sketch(file_sketch: &mut Vec<FileSketch>) {
172 let hv_dim = file_sketch[0].hv_d;
173
174 info!("Decompressing sketch with HV dim={}", hv_dim);
175
176 file_sketch.into_par_iter().for_each(|sketch| {
177 let hv_decompressed = unsafe { decompress_hd_sketch(sketch) };
178 sketch.hv.clone_from(&hv_decompressed);
179 });
180}
181
182#[cfg(target_arch = "x86_64")]
183#[target_feature(enable = "avx2")]
184pub unsafe fn decompress_hd_sketch(sketch: &mut FileSketch) -> Vec<i16> {
185 let hv_d = sketch.hv_d;
186 let quant_bit = sketch.hv_quant_bits;
187
188 let mut hv_decompressed: Vec<i16> = vec![0; hv_d];
189
190 if is_x86_feature_detected!("avx2") {
191 let bitpacker = BitPacker8x::new();
193 let bits_per_block = quant_bit as usize * 32;
194
195 let hv_u8 = sketch.hv.align_to::<u8>().1.to_vec();
196 let mut _hv_decompressed: Vec<u32> = vec![0; hv_d];
197
198 for i in 0..(hv_d / BitPacker8x::BLOCK_LEN) {
199 bitpacker.decompress(
200 &hv_u8[(bits_per_block * i)..(bits_per_block * (i + 1))],
201 &mut _hv_decompressed[i * BitPacker8x::BLOCK_LEN..(i + 1) * BitPacker8x::BLOCK_LEN],
202 quant_bit,
203 );
204 }
205
206 let offset: i16 = 1 << (quant_bit - 1);
207 hv_decompressed.clone_from(
208 &_hv_decompressed
209 .into_iter()
210 .map(|i| (i as i16 - offset))
211 .collect(),
212 );
213 } else {
214 for i in 0..(quant_bit as usize * hv_d) {
216 hv_decompressed[i / quant_bit as usize] |=
217 (((sketch.hv[i / 16] >> (i % 16)) & 1) << (i % quant_bit as usize)) as i16;
218
219 if (i + 1) % quant_bit as usize == 0 {
220 hv_decompressed[i / quant_bit as usize] = {
221 if hv_decompressed[i / quant_bit as usize] > (1 << (quant_bit - 1)) {
222 hv_decompressed[i / quant_bit as usize] - (1 << quant_bit)
223 } else {
224 hv_decompressed[i / quant_bit as usize]
225 }
226 };
227 }
228 }
229 }
230
231 hv_decompressed
232}