1#[cfg(not(any(feature = "cuda", feature = "rayon")))]
49compile_error!(
50 "flash-map: enable at least one of 'cuda' or 'rayon' features"
51);
52
53mod error;
54mod hash;
55
56#[cfg(feature = "cuda")]
57mod gpu;
58
59#[cfg(feature = "rayon")]
60mod rayon_cpu;
61
62#[cfg(feature = "tokio")]
63mod async_map;
64
65pub use bytemuck::Pod;
66pub use error::FlashMapError;
67pub use hash::HashStrategy;
68
69#[cfg(feature = "cuda")]
70pub use cudarc::driver::CudaSlice;
71
72#[cfg(feature = "cuda")]
73pub use cudarc::driver::CudaDevice;
74
75#[cfg(feature = "tokio")]
76pub use async_map::AsyncFlashMap;
77
78use bytemuck::Pod as PodBound;
79
80pub struct FlashMap<K: PodBound, V: PodBound> {
95 inner: FlashMapBackend<K, V>,
96}
97
98enum FlashMapBackend<K: PodBound, V: PodBound> {
99 #[cfg(feature = "cuda")]
100 Gpu(gpu::GpuFlashMap<K, V>),
101 #[cfg(feature = "rayon")]
102 Rayon(rayon_cpu::RayonFlashMap<K, V>),
103}
104
105impl<K: PodBound + Send + Sync, V: PodBound + Send + Sync> FlashMap<K, V> {
106 pub fn with_capacity(capacity: usize) -> Result<Self, FlashMapError> {
111 FlashMapBuilder::new(capacity).build()
112 }
113
114 pub fn builder(capacity: usize) -> FlashMapBuilder {
116 FlashMapBuilder::new(capacity)
117 }
118
119 pub fn bulk_get(&self, keys: &[K]) -> Result<Vec<Option<V>>, FlashMapError> {
125 match &self.inner {
126 #[cfg(feature = "cuda")]
127 FlashMapBackend::Gpu(m) => m.bulk_get(keys),
128 #[cfg(feature = "rayon")]
129 FlashMapBackend::Rayon(m) => m.bulk_get(keys),
130 }
131 }
132
133 pub fn bulk_insert(&mut self, pairs: &[(K, V)]) -> Result<usize, FlashMapError> {
144 match &mut self.inner {
145 #[cfg(feature = "cuda")]
146 FlashMapBackend::Gpu(m) => m.bulk_insert(pairs),
147 #[cfg(feature = "rayon")]
148 FlashMapBackend::Rayon(m) => m.bulk_insert(pairs),
149 }
150 }
151
152 pub fn bulk_remove(&mut self, keys: &[K]) -> Result<usize, FlashMapError> {
156 match &mut self.inner {
157 #[cfg(feature = "cuda")]
158 FlashMapBackend::Gpu(m) => m.bulk_remove(keys),
159 #[cfg(feature = "rayon")]
160 FlashMapBackend::Rayon(m) => m.bulk_remove(keys),
161 }
162 }
163
164 pub fn len(&self) -> usize {
166 match &self.inner {
167 #[cfg(feature = "cuda")]
168 FlashMapBackend::Gpu(m) => m.len(),
169 #[cfg(feature = "rayon")]
170 FlashMapBackend::Rayon(m) => m.len(),
171 }
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.len() == 0
177 }
178
179 pub fn capacity(&self) -> usize {
181 match &self.inner {
182 #[cfg(feature = "cuda")]
183 FlashMapBackend::Gpu(m) => m.capacity(),
184 #[cfg(feature = "rayon")]
185 FlashMapBackend::Rayon(m) => m.capacity(),
186 }
187 }
188
189 pub fn load_factor(&self) -> f64 {
191 match &self.inner {
192 #[cfg(feature = "cuda")]
193 FlashMapBackend::Gpu(m) => m.load_factor(),
194 #[cfg(feature = "rayon")]
195 FlashMapBackend::Rayon(m) => m.load_factor(),
196 }
197 }
198
199 pub fn clear(&mut self) -> Result<(), FlashMapError> {
201 match &mut self.inner {
202 #[cfg(feature = "cuda")]
203 FlashMapBackend::Gpu(m) => m.clear(),
204 #[cfg(feature = "rayon")]
205 FlashMapBackend::Rayon(m) => m.clear(),
206 }
207 }
208
209 #[cfg(feature = "cuda")]
218 pub fn device(&self) -> Option<&std::sync::Arc<CudaDevice>> {
219 match &self.inner {
220 FlashMapBackend::Gpu(m) => Some(m.device()),
221 #[cfg(feature = "rayon")]
222 FlashMapBackend::Rayon(_) => None,
223 }
224 }
225
226 #[cfg(feature = "cuda")]
231 pub fn upload_keys(&self, keys: &[K]) -> Result<CudaSlice<u8>, FlashMapError> {
232 match &self.inner {
233 FlashMapBackend::Gpu(m) => m.upload_keys(keys),
234 #[cfg(feature = "rayon")]
235 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
236 }
237 }
238
239 #[cfg(feature = "cuda")]
241 pub fn upload_values(&self, values: &[V]) -> Result<CudaSlice<u8>, FlashMapError> {
242 match &self.inner {
243 FlashMapBackend::Gpu(m) => m.upload_values(values),
244 #[cfg(feature = "rayon")]
245 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
246 }
247 }
248
249 #[cfg(feature = "cuda")]
251 pub fn alloc_device(&self, n: usize) -> Result<CudaSlice<u8>, FlashMapError> {
252 match &self.inner {
253 FlashMapBackend::Gpu(m) => m.alloc_device(n),
254 #[cfg(feature = "rayon")]
255 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
256 }
257 }
258
259 #[cfg(feature = "cuda")]
261 pub fn download(&self, d_buf: &CudaSlice<u8>) -> Result<Vec<u8>, FlashMapError> {
262 match &self.inner {
263 FlashMapBackend::Gpu(m) => m.download(d_buf),
264 #[cfg(feature = "rayon")]
265 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
266 }
267 }
268
269 #[cfg(feature = "cuda")]
274 pub fn bulk_get_device(
275 &self,
276 d_query_keys: &CudaSlice<u8>,
277 count: usize,
278 ) -> Result<(CudaSlice<u8>, CudaSlice<u8>), FlashMapError> {
279 match &self.inner {
280 FlashMapBackend::Gpu(m) => m.bulk_get_device(d_query_keys, count),
281 #[cfg(feature = "rayon")]
282 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
283 }
284 }
285
286 #[cfg(feature = "cuda")]
291 pub fn bulk_get_values_device(
292 &self,
293 d_query_keys: &CudaSlice<u8>,
294 count: usize,
295 ) -> Result<CudaSlice<u8>, FlashMapError> {
296 match &self.inner {
297 FlashMapBackend::Gpu(m) => m.bulk_get_values_device(d_query_keys, count),
298 #[cfg(feature = "rayon")]
299 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
300 }
301 }
302
303 #[cfg(feature = "cuda")]
307 pub fn bulk_insert_device(
308 &mut self,
309 d_keys: &CudaSlice<u8>,
310 d_values: &CudaSlice<u8>,
311 count: usize,
312 ) -> Result<usize, FlashMapError> {
313 match &mut self.inner {
314 FlashMapBackend::Gpu(m) => m.bulk_insert_device(d_keys, d_values, count),
315 #[cfg(feature = "rayon")]
316 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
317 }
318 }
319
320 #[cfg(feature = "cuda")]
325 pub fn bulk_insert_device_uncounted(
326 &mut self,
327 d_keys: &CudaSlice<u8>,
328 d_values: &CudaSlice<u8>,
329 count: usize,
330 ) -> Result<(), FlashMapError> {
331 match &mut self.inner {
332 FlashMapBackend::Gpu(m) => {
333 m.bulk_insert_device_uncounted(d_keys, d_values, count)
334 }
335 #[cfg(feature = "rayon")]
336 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
337 }
338 }
339
340 #[cfg(feature = "cuda")]
344 pub fn recount(&self) -> Result<usize, FlashMapError> {
345 match &self.inner {
346 FlashMapBackend::Gpu(m) => m.recount(),
347 #[cfg(feature = "rayon")]
348 FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
349 }
350 }
351}
352
353impl<K: PodBound + Send + Sync, V: PodBound + Send + Sync> std::fmt::Debug
354 for FlashMap<K, V>
355{
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 let backend = match &self.inner {
358 #[cfg(feature = "cuda")]
359 FlashMapBackend::Gpu(_) => "GPU",
360 #[cfg(feature = "rayon")]
361 FlashMapBackend::Rayon(_) => "Rayon",
362 };
363 f.debug_struct("FlashMap")
364 .field("backend", &backend)
365 .field("len", &self.len())
366 .field("capacity", &self.capacity())
367 .field("load_factor", &format!("{:.1}%", self.load_factor() * 100.0))
368 .finish()
369 }
370}
371
372pub struct FlashMapBuilder {
378 capacity: usize,
379 hash_strategy: HashStrategy,
380 device_id: usize,
381 force_rayon: bool,
382}
383
384impl FlashMapBuilder {
385 pub fn new(capacity: usize) -> Self {
387 Self {
388 capacity,
389 hash_strategy: HashStrategy::Identity,
390 device_id: 0,
391 force_rayon: false,
392 }
393 }
394
395 pub fn hash_strategy(mut self, strategy: HashStrategy) -> Self {
397 self.hash_strategy = strategy;
398 self
399 }
400
401 pub fn device_id(mut self, id: usize) -> Self {
403 self.device_id = id;
404 self
405 }
406
407 pub fn force_cpu(mut self) -> Self {
409 self.force_rayon = true;
410 self
411 }
412
413 pub fn build<K: PodBound + Send + Sync, V: PodBound + Send + Sync>(
415 self,
416 ) -> Result<FlashMap<K, V>, FlashMapError> {
417 let mut _gpu_err: Option<FlashMapError> = None;
418
419 #[cfg(feature = "cuda")]
420 if !self.force_rayon {
421 match gpu::GpuFlashMap::<K, V>::new(
422 self.capacity,
423 self.hash_strategy,
424 self.device_id,
425 ) {
426 Ok(m) => return Ok(FlashMap { inner: FlashMapBackend::Gpu(m) }),
427 Err(e) => _gpu_err = Some(e),
428 }
429 }
430
431 #[cfg(feature = "rayon")]
432 {
433 if let Some(ref e) = _gpu_err {
434 eprintln!("[flash-map] GPU unavailable ({e}), using Rayon backend");
435 }
436 return Ok(FlashMap {
437 inner: FlashMapBackend::Rayon(rayon_cpu::RayonFlashMap::new(
438 self.capacity,
439 self.hash_strategy,
440 )),
441 });
442 }
443
444 #[allow(unreachable_code)]
445 match _gpu_err {
446 Some(e) => Err(e),
447 None => Err(FlashMapError::NoBackend),
448 }
449 }
450}