Skip to main content

flash_map/
lib.rs

1//! FlashMap — GPU-native concurrent hash map.
2//!
3//! Bulk-only API designed for maximum GPU throughput:
4//! - `bulk_get` / `bulk_insert` / `bulk_remove` — host-facing (H↔D transfers)
5//! - `bulk_get_device` / `bulk_insert_device` — device-resident (zero-copy)
6//! - `bulk_get_values_device` — values-only lookup (no found mask)
7//! - `bulk_insert_device_uncounted` — fire-and-forget insert (no readback)
8//!
9//! SoA (Struct of Arrays) memory layout on GPU for coalesced access.
10//! Robin Hood hashing with probe distance early exit.
11//! Warp-cooperative probing (32 slots per iteration).
12//!
13//! # Features
14//!
15//! - `cuda` — GPU backend via CUDA (requires NVIDIA GPU + CUDA toolkit)
16//! - `rayon` — multi-threaded CPU backend (default)
17//!
18//! # Host API Example
19//!
20//! ```rust,no_run
21//! use flash_map::{FlashMap, HashStrategy};
22//!
23//! let mut map: FlashMap<[u8; 32], [u8; 128]> =
24//!     FlashMap::with_capacity(1_000_000).unwrap();
25//!
26//! let pairs: Vec<([u8; 32], [u8; 128])> = generate_pairs();
27//! map.bulk_insert(&pairs).unwrap();
28//!
29//! let keys: Vec<[u8; 32]> = pairs.iter().map(|(k, _)| *k).collect();
30//! let results: Vec<Option<[u8; 128]>> = map.bulk_get(&keys).unwrap();
31//! # fn generate_pairs() -> Vec<([u8; 32], [u8; 128])> { vec![] }
32//! ```
33//!
34//! # Device-Resident Pipeline Example
35//!
36//! ```rust,no_run,ignore
37//! use flash_map::FlashMap;
38//!
39//! let mut map = FlashMap::<u64, [u8; 32]>::with_capacity(1_000_000).unwrap();
40//!
41//! // Upload keys once (H→D), then all operations stay on GPU
42//! let d_keys = map.upload_keys(&[42u64]).unwrap();
43//! let d_vals = map.bulk_get_values_device(&d_keys, 1).unwrap();
44//! // d_vals is on GPU — pass to your CUDA kernel, then insert results back:
45//! // map.bulk_insert_device_uncounted(&d_new_keys, &d_new_vals, n).unwrap();
46//! ```
47
48#[cfg(not(any(feature = "cuda", feature = "rayon")))]
49compile_error!(
50    "flash-map: enable at least one of 'cuda' or 'rayon' features"
51);
52
53mod error;
54mod hash;
55
56#[cfg(feature = "cuda")]
57mod gpu;
58
59#[cfg(feature = "rayon")]
60mod rayon_cpu;
61
62#[cfg(feature = "tokio")]
63mod async_map;
64
65pub use bytemuck::Pod;
66pub use error::FlashMapError;
67pub use hash::HashStrategy;
68
69#[cfg(feature = "cuda")]
70pub use cudarc::driver::CudaSlice;
71
72#[cfg(feature = "cuda")]
73pub use cudarc::driver::CudaDevice;
74
75#[cfg(feature = "tokio")]
76pub use async_map::AsyncFlashMap;
77
78use bytemuck::Pod as PodBound;
79
80// ---------------------------------------------------------------------------
81// FlashMap — public API
82// ---------------------------------------------------------------------------
83
84/// GPU-native concurrent hash map with bulk-only operations.
85///
86/// Generic over fixed-size key `K` and value `V` types that implement
87/// [`bytemuck::Pod`] (plain old data — `Copy`, fixed layout, any bit
88/// pattern valid).
89///
90/// Common type combinations:
91/// - `FlashMap<[u8; 32], [u8; 128]>` — blockchain state (pubkey → account)
92/// - `FlashMap<u64, u64>` — numeric keys and values
93/// - `FlashMap<[u8; 32], [u8; 32]>` — hash → hash mappings
94pub struct FlashMap<K: PodBound, V: PodBound> {
95    inner: FlashMapBackend<K, V>,
96}
97
98enum FlashMapBackend<K: PodBound, V: PodBound> {
99    #[cfg(feature = "cuda")]
100    Gpu(gpu::GpuFlashMap<K, V>),
101    #[cfg(feature = "rayon")]
102    Rayon(rayon_cpu::RayonFlashMap<K, V>),
103}
104
105impl<K: PodBound + Send + Sync, V: PodBound + Send + Sync> FlashMap<K, V> {
106    /// Create a FlashMap with the given capacity using default settings.
107    ///
108    /// Tries GPU first (if `cuda` feature enabled), falls back to Rayon.
109    /// Capacity is rounded up to the next power of 2.
110    pub fn with_capacity(capacity: usize) -> Result<Self, FlashMapError> {
111        FlashMapBuilder::new(capacity).build()
112    }
113
114    /// Create a builder for fine-grained configuration.
115    pub fn builder(capacity: usize) -> FlashMapBuilder {
116        FlashMapBuilder::new(capacity)
117    }
118
119    // =================================================================
120    // Host-facing API (H↔D transfers per call)
121    // =================================================================
122
123    /// Look up multiple keys in parallel. Returns `None` for missing keys.
124    pub fn bulk_get(&self, keys: &[K]) -> Result<Vec<Option<V>>, FlashMapError> {
125        match &self.inner {
126            #[cfg(feature = "cuda")]
127            FlashMapBackend::Gpu(m) => m.bulk_get(keys),
128            #[cfg(feature = "rayon")]
129            FlashMapBackend::Rayon(m) => m.bulk_get(keys),
130        }
131    }
132
133    /// Insert multiple key-value pairs in parallel.
134    ///
135    /// Returns the number of **new** insertions (updates don't count).
136    /// If a key already exists, its value is updated in place.
137    ///
138    /// # Invariant
139    ///
140    /// No duplicate keys within a single batch. If the same key appears
141    /// multiple times, behavior is undefined (one will win, but which
142    /// one is non-deterministic on GPU).
143    pub fn bulk_insert(&mut self, pairs: &[(K, V)]) -> Result<usize, FlashMapError> {
144        match &mut self.inner {
145            #[cfg(feature = "cuda")]
146            FlashMapBackend::Gpu(m) => m.bulk_insert(pairs),
147            #[cfg(feature = "rayon")]
148            FlashMapBackend::Rayon(m) => m.bulk_insert(pairs),
149        }
150    }
151
152    /// Remove multiple keys in parallel (tombstone-based).
153    ///
154    /// Returns the number of keys actually removed.
155    pub fn bulk_remove(&mut self, keys: &[K]) -> Result<usize, FlashMapError> {
156        match &mut self.inner {
157            #[cfg(feature = "cuda")]
158            FlashMapBackend::Gpu(m) => m.bulk_remove(keys),
159            #[cfg(feature = "rayon")]
160            FlashMapBackend::Rayon(m) => m.bulk_remove(keys),
161        }
162    }
163
164    /// Number of occupied entries.
165    pub fn len(&self) -> usize {
166        match &self.inner {
167            #[cfg(feature = "cuda")]
168            FlashMapBackend::Gpu(m) => m.len(),
169            #[cfg(feature = "rayon")]
170            FlashMapBackend::Rayon(m) => m.len(),
171        }
172    }
173
174    /// Whether the map is empty.
175    pub fn is_empty(&self) -> bool {
176        self.len() == 0
177    }
178
179    /// Total slot capacity (always a power of 2).
180    pub fn capacity(&self) -> usize {
181        match &self.inner {
182            #[cfg(feature = "cuda")]
183            FlashMapBackend::Gpu(m) => m.capacity(),
184            #[cfg(feature = "rayon")]
185            FlashMapBackend::Rayon(m) => m.capacity(),
186        }
187    }
188
189    /// Current load factor (0.0 to 1.0).
190    pub fn load_factor(&self) -> f64 {
191        match &self.inner {
192            #[cfg(feature = "cuda")]
193            FlashMapBackend::Gpu(m) => m.load_factor(),
194            #[cfg(feature = "rayon")]
195            FlashMapBackend::Rayon(m) => m.load_factor(),
196        }
197    }
198
199    /// Remove all entries, resetting to empty.
200    pub fn clear(&mut self) -> Result<(), FlashMapError> {
201        match &mut self.inner {
202            #[cfg(feature = "cuda")]
203            FlashMapBackend::Gpu(m) => m.clear(),
204            #[cfg(feature = "rayon")]
205            FlashMapBackend::Rayon(m) => m.clear(),
206        }
207    }
208
209    // =================================================================
210    // Device-resident API (zero-copy, GPU only)
211    // =================================================================
212
213    /// Reference to the CUDA device. Allows sharing device context
214    /// with external CUDA kernels (e.g., SHA256 hashers).
215    ///
216    /// Returns `None` if using the Rayon backend.
217    #[cfg(feature = "cuda")]
218    pub fn device(&self) -> Option<&std::sync::Arc<CudaDevice>> {
219        match &self.inner {
220            FlashMapBackend::Gpu(m) => Some(m.device()),
221            #[cfg(feature = "rayon")]
222            FlashMapBackend::Rayon(_) => None,
223        }
224    }
225
226    /// Transfer host keys to a device buffer (H→D).
227    ///
228    /// Returns a `CudaSlice<u8>` of `keys.len() * size_of::<K>()` bytes
229    /// for use with `bulk_get_device` / `bulk_insert_device`.
230    #[cfg(feature = "cuda")]
231    pub fn upload_keys(&self, keys: &[K]) -> Result<CudaSlice<u8>, FlashMapError> {
232        match &self.inner {
233            FlashMapBackend::Gpu(m) => m.upload_keys(keys),
234            #[cfg(feature = "rayon")]
235            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
236        }
237    }
238
239    /// Transfer host values to a device buffer (H→D).
240    #[cfg(feature = "cuda")]
241    pub fn upload_values(&self, values: &[V]) -> Result<CudaSlice<u8>, FlashMapError> {
242        match &self.inner {
243            FlashMapBackend::Gpu(m) => m.upload_values(values),
244            #[cfg(feature = "rayon")]
245            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
246        }
247    }
248
249    /// Allocate a zeroed device buffer of `n` bytes.
250    #[cfg(feature = "cuda")]
251    pub fn alloc_device(&self, n: usize) -> Result<CudaSlice<u8>, FlashMapError> {
252        match &self.inner {
253            FlashMapBackend::Gpu(m) => m.alloc_device(n),
254            #[cfg(feature = "rayon")]
255            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
256        }
257    }
258
259    /// Download a device buffer to host memory (D→H).
260    #[cfg(feature = "cuda")]
261    pub fn download(&self, d_buf: &CudaSlice<u8>) -> Result<Vec<u8>, FlashMapError> {
262        match &self.inner {
263            FlashMapBackend::Gpu(m) => m.download(d_buf),
264            #[cfg(feature = "rayon")]
265            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
266        }
267    }
268
269    /// Device-to-device bulk lookup. No host memory touched.
270    ///
271    /// Returns `(d_values, d_found)` — both `CudaSlice<u8>` on GPU.
272    /// `d_found` has 1 byte per query (1 = found, 0 = miss).
273    #[cfg(feature = "cuda")]
274    pub fn bulk_get_device(
275        &self,
276        d_query_keys: &CudaSlice<u8>,
277        count: usize,
278    ) -> Result<(CudaSlice<u8>, CudaSlice<u8>), FlashMapError> {
279        match &self.inner {
280            FlashMapBackend::Gpu(m) => m.bulk_get_device(d_query_keys, count),
281            #[cfg(feature = "rayon")]
282            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
283        }
284    }
285
286    /// Device-to-device values-only lookup. No found mask allocated.
287    ///
288    /// For pipelines where all keys are guaranteed to exist.
289    /// Missing keys get zeroed values (from alloc_zeros).
290    #[cfg(feature = "cuda")]
291    pub fn bulk_get_values_device(
292        &self,
293        d_query_keys: &CudaSlice<u8>,
294        count: usize,
295    ) -> Result<CudaSlice<u8>, FlashMapError> {
296        match &self.inner {
297            FlashMapBackend::Gpu(m) => m.bulk_get_values_device(d_query_keys, count),
298            #[cfg(feature = "rayon")]
299            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
300        }
301    }
302
303    /// Device-to-device bulk insert. Keys and values already on GPU.
304    ///
305    /// Returns the number of new insertions (4-byte D→H readback).
306    #[cfg(feature = "cuda")]
307    pub fn bulk_insert_device(
308        &mut self,
309        d_keys: &CudaSlice<u8>,
310        d_values: &CudaSlice<u8>,
311        count: usize,
312    ) -> Result<usize, FlashMapError> {
313        match &mut self.inner {
314            FlashMapBackend::Gpu(m) => m.bulk_insert_device(d_keys, d_values, count),
315            #[cfg(feature = "rayon")]
316            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
317        }
318    }
319
320    /// Device-to-device insert without count readback. Fully async.
321    ///
322    /// No D→H sync. No load factor check (caller's responsibility).
323    /// Assumes all insertions are new. Call `recount()` if exact len needed.
324    #[cfg(feature = "cuda")]
325    pub fn bulk_insert_device_uncounted(
326        &mut self,
327        d_keys: &CudaSlice<u8>,
328        d_values: &CudaSlice<u8>,
329        count: usize,
330    ) -> Result<(), FlashMapError> {
331        match &mut self.inner {
332            FlashMapBackend::Gpu(m) => {
333                m.bulk_insert_device_uncounted(d_keys, d_values, count)
334            }
335            #[cfg(feature = "rayon")]
336            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
337        }
338    }
339
340    /// Recount occupied entries by scanning flags on GPU.
341    ///
342    /// Corrects internal `len` after `bulk_insert_device_uncounted`.
343    #[cfg(feature = "cuda")]
344    pub fn recount(&self) -> Result<usize, FlashMapError> {
345        match &self.inner {
346            FlashMapBackend::Gpu(m) => m.recount(),
347            #[cfg(feature = "rayon")]
348            FlashMapBackend::Rayon(_) => Err(FlashMapError::GpuRequired),
349        }
350    }
351}
352
353impl<K: PodBound + Send + Sync, V: PodBound + Send + Sync> std::fmt::Debug
354    for FlashMap<K, V>
355{
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        let backend = match &self.inner {
358            #[cfg(feature = "cuda")]
359            FlashMapBackend::Gpu(_) => "GPU",
360            #[cfg(feature = "rayon")]
361            FlashMapBackend::Rayon(_) => "Rayon",
362        };
363        f.debug_struct("FlashMap")
364            .field("backend", &backend)
365            .field("len", &self.len())
366            .field("capacity", &self.capacity())
367            .field("load_factor", &format!("{:.1}%", self.load_factor() * 100.0))
368            .finish()
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Builder
374// ---------------------------------------------------------------------------
375
376/// Builder for configuring a [`FlashMap`].
377pub struct FlashMapBuilder {
378    capacity: usize,
379    hash_strategy: HashStrategy,
380    device_id: usize,
381    force_rayon: bool,
382}
383
384impl FlashMapBuilder {
385    /// Create a builder targeting the given capacity.
386    pub fn new(capacity: usize) -> Self {
387        Self {
388            capacity,
389            hash_strategy: HashStrategy::Identity,
390            device_id: 0,
391            force_rayon: false,
392        }
393    }
394
395    /// Set the hash strategy (default: Identity).
396    pub fn hash_strategy(mut self, strategy: HashStrategy) -> Self {
397        self.hash_strategy = strategy;
398        self
399    }
400
401    /// Set the CUDA device ordinal (default: 0).
402    pub fn device_id(mut self, id: usize) -> Self {
403        self.device_id = id;
404        self
405    }
406
407    /// Force Rayon backend even if CUDA is available.
408    pub fn force_cpu(mut self) -> Self {
409        self.force_rayon = true;
410        self
411    }
412
413    /// Build the FlashMap. Tries GPU first, falls back to Rayon.
414    pub fn build<K: PodBound + Send + Sync, V: PodBound + Send + Sync>(
415        self,
416    ) -> Result<FlashMap<K, V>, FlashMapError> {
417        let mut _gpu_err: Option<FlashMapError> = None;
418
419        #[cfg(feature = "cuda")]
420        if !self.force_rayon {
421            match gpu::GpuFlashMap::<K, V>::new(
422                self.capacity,
423                self.hash_strategy,
424                self.device_id,
425            ) {
426                Ok(m) => return Ok(FlashMap { inner: FlashMapBackend::Gpu(m) }),
427                Err(e) => _gpu_err = Some(e),
428            }
429        }
430
431        #[cfg(feature = "rayon")]
432        {
433            if let Some(ref e) = _gpu_err {
434                eprintln!("[flash-map] GPU unavailable ({e}), using Rayon backend");
435            }
436            return Ok(FlashMap {
437                inner: FlashMapBackend::Rayon(rayon_cpu::RayonFlashMap::new(
438                    self.capacity,
439                    self.hash_strategy,
440                )),
441            });
442        }
443
444        #[allow(unreachable_code)]
445        match _gpu_err {
446            Some(e) => Err(e),
447            None => Err(FlashMapError::NoBackend),
448        }
449    }
450}