fast_shard/
lib.rs

1// File: src/lib.rs
2use std::ops::RangeInclusive;
3#[cfg(all(target_arch = "x86_64", any(target_feature = "avx512f", target_feature = "avx2", target_feature = "aes")))]
4use std::arch::x86_64::*;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum ShardAlgorithm {
8    Avx512,
9    Avx2,
10    AesNi,
11    Fnv1a,
12    Xxh3,
13}
14
15#[derive(Debug, Clone)]
16pub struct ShardTier {
17    pub size_range: RangeInclusive<usize>,
18    pub algorithms: Vec<ShardAlgorithm>,
19}
20
21#[derive(Debug, Clone)]
22pub struct ShardConfig {
23    pub tiers: Vec<ShardTier>,
24    pub default_algorithms: Vec<ShardAlgorithm>,
25}
26
27impl Default for ShardConfig {
28    fn default() -> Self {
29        let small_key_algorithms = vec![
30            ShardAlgorithm::Avx512,
31            ShardAlgorithm::Avx2,
32            ShardAlgorithm::AesNi,
33            ShardAlgorithm::Fnv1a,
34            ShardAlgorithm::Xxh3,
35        ];
36
37        let large_key_algorithms = vec![
38            ShardAlgorithm::Avx512,
39            ShardAlgorithm::Avx2,
40            ShardAlgorithm::AesNi,
41            ShardAlgorithm::Xxh3,
42            ShardAlgorithm::Fnv1a,
43        ];
44
45        ShardConfig {
46            tiers: vec![
47                ShardTier {
48                    size_range: 0..=16,
49                    algorithms: small_key_algorithms,
50                },
51                ShardTier {
52                    size_range: 17..=usize::MAX,
53                    algorithms: large_key_algorithms,
54                },
55            ],
56            default_algorithms: vec![ShardAlgorithm::Xxh3],
57        }
58    }
59}
60
61#[derive(Debug)]
62pub struct FastShard {
63    shard_count: u32,
64    config: ShardConfig,
65}
66
67impl FastShard {
68    pub fn new(shard_count: u32) -> Self {
69        Self {
70            shard_count,
71            config: ShardConfig::default(),
72        }
73    }
74
75    pub fn with_config(shard_count: u32, config: ShardConfig) -> Self {
76        Self { shard_count, config }
77    }
78
79    fn get_available_algorithm(&self, algorithms: &[ShardAlgorithm]) -> ShardAlgorithm {
80        for algo in algorithms {
81            match algo {
82                ShardAlgorithm::Avx512 => {
83                    #[cfg(target_feature = "avx512f")]
84                    return ShardAlgorithm::Avx512;
85                }
86                ShardAlgorithm::Avx2 => {
87                    #[cfg(target_feature = "avx2")]
88                    return ShardAlgorithm::Avx2;
89                }
90                ShardAlgorithm::AesNi => {
91                    #[cfg(target_feature = "aes")]
92                    return ShardAlgorithm::AesNi;
93                }
94                ShardAlgorithm::Fnv1a => return ShardAlgorithm::Fnv1a,
95                ShardAlgorithm::Xxh3 => return ShardAlgorithm::Xxh3,
96            }
97        }
98        ShardAlgorithm::Xxh3 // Final fallback
99    }
100
101    fn get_algorithm_for_size(&self, size: usize) -> ShardAlgorithm {
102        for tier in &self.config.tiers {
103            if tier.size_range.contains(&size) {
104                return self.get_available_algorithm(&tier.algorithms);
105            }
106        }
107        self.get_available_algorithm(&self.config.default_algorithms)
108    }
109
110    pub fn shard(&self, key: &[u8]) -> u32 {
111        let algorithm = self.get_algorithm_for_size(key.len());
112        match algorithm {
113            ShardAlgorithm::Avx512 => self.shard_with_avx512(key),
114            ShardAlgorithm::Avx2 => self.shard_with_avx2(key),
115            ShardAlgorithm::AesNi => self.shard_with_aesni(key),
116            ShardAlgorithm::Fnv1a => self.shard_with_fnv1a(key),
117            ShardAlgorithm::Xxh3 => self.shard_with_xxh3(key),
118        }
119    }
120
121    #[cfg(target_feature = "avx512f")]
122    fn shard_with_avx512(&self, key: &[u8]) -> u32 {
123        unsafe {
124            if is_x86_feature_detected!("avx512f") {
125                let mut hash = 0u32;
126                for chunk in key.chunks(64) {
127                    let vec = if chunk.len() == 64 {
128                        _mm512_loadu_si512(chunk.as_ptr() as *const _)
129                    } else {
130                        let mut padded = [0u8; 64];
131                        padded[..chunk.len()].copy_from_slice(chunk);
132                        _mm512_loadu_si512(padded.as_ptr() as *const _)
133                    };
134                    
135                    let reduced = _mm512_reduce_add_epi32(vec);
136                    hash = hash.wrapping_add(reduced as u32);
137                }
138                hash % self.shard_count
139            } else {
140                self.shard_with_xxh3(key)
141            }
142        }
143    }
144
145    #[cfg(not(target_feature = "avx512f"))]
146    fn shard_with_avx512(&self, key: &[u8]) -> u32 {
147        self.shard_with_xxh3(key)
148    }
149
150    #[cfg(target_feature = "avx2")]
151    fn shard_with_avx2(&self, key: &[u8]) -> u32 {
152        unsafe {
153            if is_x86_feature_detected!("avx2") {
154                let mut hash = 0u32;
155                for chunk in key.chunks(32) {
156                    let vec = if chunk.len() == 32 {
157                        _mm256_loadu_si256(chunk.as_ptr() as *const _)
158                    } else {
159                        let mut padded = [0u8; 32];
160                        padded[..chunk.len()].copy_from_slice(chunk);
161                        _mm256_loadu_si256(padded.as_ptr() as *const _)
162                    };
163                    
164                    let reduced = _mm256_extract_epi32::<0>(vec) as u32;
165                    hash = hash.wrapping_add(reduced);
166                }
167                hash % self.shard_count
168            } else {
169                self.shard_with_xxh3(key)
170            }
171        }
172    }
173
174    #[cfg(not(target_feature = "avx2"))]
175    fn shard_with_avx2(&self, key: &[u8]) -> u32 {
176        self.shard_with_xxh3(key)
177    }
178
179    #[cfg(target_feature = "aes")]
180    fn shard_with_aesni(&self, key: &[u8]) -> u32 {
181        unsafe {
182            if is_x86_feature_detected!("aes") {
183                let mut hash = _mm_set1_epi32(0);
184                for chunk in key.chunks(16) {
185                    let data = if chunk.len() == 16 {
186                        _mm_loadu_si128(chunk.as_ptr() as *const _)
187                    } else {
188                        let mut padded = [0u8; 16];
189                        padded[..chunk.len()].copy_from_slice(chunk);
190                        _mm_loadu_si128(padded.as_ptr() as *const _)
191                    };
192                    
193                    hash = _mm_aesenc_si128(hash, data);
194                }
195                let result = _mm_extract_epi32::<0>(hash) as u32;
196                result % self.shard_count
197            } else {
198                self.shard_with_xxh3(key)
199            }
200        }
201    }
202
203    #[cfg(not(target_feature = "aes"))]
204    fn shard_with_aesni(&self, key: &[u8]) -> u32 {
205        self.shard_with_xxh3(key)
206    }
207
208    fn shard_with_fnv1a(&self, key: &[u8]) -> u32 {
209        let mut hasher = fnv::FnvHasher::default();
210        use std::hash::Hasher;
211        hasher.write(key);
212        (hasher.finish() % self.shard_count as u64) as u32
213    }
214
215    fn shard_with_xxh3(&self, key: &[u8]) -> u32 {
216        use xxhash_rust::xxh3::xxh3_64;
217        (xxh3_64(key) % self.shard_count as u64) as u32
218    }
219}
220
221
222// Add test module
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_custom_config() {
229        let config = ShardConfig {
230            tiers: vec![
231                ShardTier {
232                    size_range: 0..=16,
233                    algorithms: vec![ShardAlgorithm::Fnv1a],
234                },
235                ShardTier {
236                    size_range: 17..=1024,
237                    algorithms: vec![ShardAlgorithm::Xxh3],
238                },
239            ],
240            default_algorithms: vec![ShardAlgorithm::Xxh3],
241        };
242
243        let shard = FastShard::with_config(16, config);
244        
245        let small_key = b"small";
246        let large_key = vec![0u8; 100];
247        
248        // These should execute without panicking
249        let _ = shard.shard(small_key);
250        let _ = shard.shard(&large_key);
251    }
252
253    #[test]
254    fn test_default_config() {
255        let shard = FastShard::new(16);
256        
257        // Test various key sizes
258        let keys = vec![
259            vec![0u8; 8],    // Small
260            vec![0u8; 16],   // Border
261            vec![0u8; 32],   // Medium
262            vec![0u8; 1024], // Large
263        ];
264        
265        for key in keys {
266            let _ = shard.shard(&key);
267        }
268    }
269}
270