1use std::ops::RangeInclusive;
3#[cfg(all(target_arch = "x86_64", any(target_feature = "avx512f", target_feature = "avx2", target_feature = "aes")))]
4use std::arch::x86_64::*;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum ShardAlgorithm {
8 Avx512,
9 Avx2,
10 AesNi,
11 Fnv1a,
12 Xxh3,
13}
14
15#[derive(Debug, Clone)]
16pub struct ShardTier {
17 pub size_range: RangeInclusive<usize>,
18 pub algorithms: Vec<ShardAlgorithm>,
19}
20
21#[derive(Debug, Clone)]
22pub struct ShardConfig {
23 pub tiers: Vec<ShardTier>,
24 pub default_algorithms: Vec<ShardAlgorithm>,
25}
26
27impl Default for ShardConfig {
28 fn default() -> Self {
29 let small_key_algorithms = vec![
30 ShardAlgorithm::Avx512,
31 ShardAlgorithm::Avx2,
32 ShardAlgorithm::AesNi,
33 ShardAlgorithm::Fnv1a,
34 ShardAlgorithm::Xxh3,
35 ];
36
37 let large_key_algorithms = vec![
38 ShardAlgorithm::Avx512,
39 ShardAlgorithm::Avx2,
40 ShardAlgorithm::AesNi,
41 ShardAlgorithm::Xxh3,
42 ShardAlgorithm::Fnv1a,
43 ];
44
45 ShardConfig {
46 tiers: vec![
47 ShardTier {
48 size_range: 0..=16,
49 algorithms: small_key_algorithms,
50 },
51 ShardTier {
52 size_range: 17..=usize::MAX,
53 algorithms: large_key_algorithms,
54 },
55 ],
56 default_algorithms: vec![ShardAlgorithm::Xxh3],
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct FastShard {
63 shard_count: u32,
64 config: ShardConfig,
65}
66
67impl FastShard {
68 pub fn new(shard_count: u32) -> Self {
69 Self {
70 shard_count,
71 config: ShardConfig::default(),
72 }
73 }
74
75 pub fn with_config(shard_count: u32, config: ShardConfig) -> Self {
76 Self { shard_count, config }
77 }
78
79 fn get_available_algorithm(&self, algorithms: &[ShardAlgorithm]) -> ShardAlgorithm {
80 for algo in algorithms {
81 match algo {
82 ShardAlgorithm::Avx512 => {
83 #[cfg(target_feature = "avx512f")]
84 return ShardAlgorithm::Avx512;
85 }
86 ShardAlgorithm::Avx2 => {
87 #[cfg(target_feature = "avx2")]
88 return ShardAlgorithm::Avx2;
89 }
90 ShardAlgorithm::AesNi => {
91 #[cfg(target_feature = "aes")]
92 return ShardAlgorithm::AesNi;
93 }
94 ShardAlgorithm::Fnv1a => return ShardAlgorithm::Fnv1a,
95 ShardAlgorithm::Xxh3 => return ShardAlgorithm::Xxh3,
96 }
97 }
98 ShardAlgorithm::Xxh3 }
100
101 fn get_algorithm_for_size(&self, size: usize) -> ShardAlgorithm {
102 for tier in &self.config.tiers {
103 if tier.size_range.contains(&size) {
104 return self.get_available_algorithm(&tier.algorithms);
105 }
106 }
107 self.get_available_algorithm(&self.config.default_algorithms)
108 }
109
110 pub fn shard(&self, key: &[u8]) -> u32 {
111 let algorithm = self.get_algorithm_for_size(key.len());
112 match algorithm {
113 ShardAlgorithm::Avx512 => self.shard_with_avx512(key),
114 ShardAlgorithm::Avx2 => self.shard_with_avx2(key),
115 ShardAlgorithm::AesNi => self.shard_with_aesni(key),
116 ShardAlgorithm::Fnv1a => self.shard_with_fnv1a(key),
117 ShardAlgorithm::Xxh3 => self.shard_with_xxh3(key),
118 }
119 }
120
121 #[cfg(target_feature = "avx512f")]
122 fn shard_with_avx512(&self, key: &[u8]) -> u32 {
123 unsafe {
124 if is_x86_feature_detected!("avx512f") {
125 let mut hash = 0u32;
126 for chunk in key.chunks(64) {
127 let vec = if chunk.len() == 64 {
128 _mm512_loadu_si512(chunk.as_ptr() as *const _)
129 } else {
130 let mut padded = [0u8; 64];
131 padded[..chunk.len()].copy_from_slice(chunk);
132 _mm512_loadu_si512(padded.as_ptr() as *const _)
133 };
134
135 let reduced = _mm512_reduce_add_epi32(vec);
136 hash = hash.wrapping_add(reduced as u32);
137 }
138 hash % self.shard_count
139 } else {
140 self.shard_with_xxh3(key)
141 }
142 }
143 }
144
145 #[cfg(not(target_feature = "avx512f"))]
146 fn shard_with_avx512(&self, key: &[u8]) -> u32 {
147 self.shard_with_xxh3(key)
148 }
149
150 #[cfg(target_feature = "avx2")]
151 fn shard_with_avx2(&self, key: &[u8]) -> u32 {
152 unsafe {
153 if is_x86_feature_detected!("avx2") {
154 let mut hash = 0u32;
155 for chunk in key.chunks(32) {
156 let vec = if chunk.len() == 32 {
157 _mm256_loadu_si256(chunk.as_ptr() as *const _)
158 } else {
159 let mut padded = [0u8; 32];
160 padded[..chunk.len()].copy_from_slice(chunk);
161 _mm256_loadu_si256(padded.as_ptr() as *const _)
162 };
163
164 let reduced = _mm256_extract_epi32::<0>(vec) as u32;
165 hash = hash.wrapping_add(reduced);
166 }
167 hash % self.shard_count
168 } else {
169 self.shard_with_xxh3(key)
170 }
171 }
172 }
173
174 #[cfg(not(target_feature = "avx2"))]
175 fn shard_with_avx2(&self, key: &[u8]) -> u32 {
176 self.shard_with_xxh3(key)
177 }
178
179 #[cfg(target_feature = "aes")]
180 fn shard_with_aesni(&self, key: &[u8]) -> u32 {
181 unsafe {
182 if is_x86_feature_detected!("aes") {
183 let mut hash = _mm_set1_epi32(0);
184 for chunk in key.chunks(16) {
185 let data = if chunk.len() == 16 {
186 _mm_loadu_si128(chunk.as_ptr() as *const _)
187 } else {
188 let mut padded = [0u8; 16];
189 padded[..chunk.len()].copy_from_slice(chunk);
190 _mm_loadu_si128(padded.as_ptr() as *const _)
191 };
192
193 hash = _mm_aesenc_si128(hash, data);
194 }
195 let result = _mm_extract_epi32::<0>(hash) as u32;
196 result % self.shard_count
197 } else {
198 self.shard_with_xxh3(key)
199 }
200 }
201 }
202
203 #[cfg(not(target_feature = "aes"))]
204 fn shard_with_aesni(&self, key: &[u8]) -> u32 {
205 self.shard_with_xxh3(key)
206 }
207
208 fn shard_with_fnv1a(&self, key: &[u8]) -> u32 {
209 let mut hasher = fnv::FnvHasher::default();
210 use std::hash::Hasher;
211 hasher.write(key);
212 (hasher.finish() % self.shard_count as u64) as u32
213 }
214
215 fn shard_with_xxh3(&self, key: &[u8]) -> u32 {
216 use xxhash_rust::xxh3::xxh3_64;
217 (xxh3_64(key) % self.shard_count as u64) as u32
218 }
219}
220
221
222#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_custom_config() {
229 let config = ShardConfig {
230 tiers: vec![
231 ShardTier {
232 size_range: 0..=16,
233 algorithms: vec![ShardAlgorithm::Fnv1a],
234 },
235 ShardTier {
236 size_range: 17..=1024,
237 algorithms: vec![ShardAlgorithm::Xxh3],
238 },
239 ],
240 default_algorithms: vec![ShardAlgorithm::Xxh3],
241 };
242
243 let shard = FastShard::with_config(16, config);
244
245 let small_key = b"small";
246 let large_key = vec![0u8; 100];
247
248 let _ = shard.shard(small_key);
250 let _ = shard.shard(&large_key);
251 }
252
253 #[test]
254 fn test_default_config() {
255 let shard = FastShard::new(16);
256
257 let keys = vec![
259 vec![0u8; 8], vec![0u8; 16], vec![0u8; 32], vec![0u8; 1024], ];
264
265 for key in keys {
266 let _ = shard.shard(&key);
267 }
268 }
269}
270