use flash_map::{FlashMap, FlashMapError};
use super::gpu_hasher::GpuHasher;
const DEFAULT_NODE_STORE_CAPACITY: usize = 1 << 20;
pub struct GpuNodeStore {
map: FlashMap<u64, [u8; 32]>,
}
impl GpuNodeStore {
pub fn new() -> Result<Self, FlashMapError> {
Self::with_capacity(DEFAULT_NODE_STORE_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Result<Self, FlashMapError> {
let map = FlashMap::<u64, [u8; 32]>::with_capacity(capacity)?;
Ok(Self { map })
}
pub fn insert_from_host(&mut self, pairs: &[(u64, [u8; 32])]) -> Result<usize, FlashMapError> {
self.map.bulk_insert(pairs)
}
pub fn get_to_host(&self, keys: &[u64]) -> Result<Vec<Option<[u8; 32]>>, FlashMapError> {
self.map.bulk_get(keys)
}
pub fn get_single(&self, key: u64) -> Result<Option<[u8; 32]>, FlashMapError> {
let results = self.map.bulk_get(&[key])?;
Ok(results.into_iter().next().unwrap())
}
pub fn sync_level_on_device(
&mut self,
gpu: &GpuHasher,
level: i64,
n_list: &[u64],
) -> Result<Vec<u64>, String> {
if n_list.is_empty() {
return Ok(Vec::new());
}
let n = n_list.len();
let child_level = level - 1;
let mut left_keys: Vec<u64> = Vec::with_capacity(n);
let mut right_keys: Vec<u64> = Vec::with_capacity(n);
let mut parent_keys: Vec<u64> = Vec::with_capacity(n);
for &i in n_list {
left_keys.push(encode_node_pos(child_level as u64, 2 * i));
right_keys.push(encode_node_pos(child_level as u64, 2 * i + 1));
parent_keys.push(encode_node_pos(level as u64, i));
}
let d_left_keys = self
.map
.upload_keys(&left_keys)
.map_err(|e| format!("upload left keys: {e}"))?;
let d_right_keys = self
.map
.upload_keys(&right_keys)
.map_err(|e| format!("upload right keys: {e}"))?;
let d_parent_keys = self
.map
.upload_keys(&parent_keys)
.map_err(|e| format!("upload parent keys: {e}"))?;
let (d_left_vals, _d_left_found) = self
.map
.bulk_get_device(&d_left_keys, n)
.map_err(|e| format!("bulk_get left: {e}"))?;
let (d_right_vals, _d_right_found) = self
.map
.bulk_get_device(&d_right_keys, n)
.map_err(|e| format!("bulk_get right: {e}"))?;
let d_levels = gpu.fill_device_bytes((level - 1) as u8, n);
let d_results = gpu.batch_node_hash_device_soa(&d_levels, &d_left_vals, &d_right_vals, n);
gpu.sync();
self.map
.bulk_insert_device(&d_parent_keys, &d_results, n)
.map_err(|e| format!("bulk_insert results: {e}"))?;
let mut new_list = Vec::with_capacity(n);
for &i in n_list {
if new_list.is_empty() || *new_list.last().unwrap() != i / 2 {
new_list.push(i / 2);
}
}
Ok(new_list)
}
pub fn sync_upper_nodes_on_device(
&mut self,
gpu: &GpuHasher,
mut n_list: Vec<u64>,
first_level: i64,
max_level: i64,
) -> Result<(Vec<u64>, [u8; 32]), String> {
if !n_list.is_empty() {
for level in first_level..=max_level {
n_list = self.sync_level_on_device(gpu, level, &n_list)?;
}
}
let root_pos = encode_node_pos(max_level as u64, 0);
let root = self
.get_single(root_pos)
.map_err(|e| format!("get root: {e}"))?
.unwrap_or([0u8; 32]);
Ok((n_list, root))
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn clear(&mut self) -> Result<(), FlashMapError> {
self.map.clear()
}
}
#[inline]
fn encode_node_pos(level: u64, nth: u64) -> u64 {
(level << 56) | nth
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_node_pos() {
let pos = encode_node_pos(13, 42);
assert_eq!(pos >> 56, 13);
assert_eq!((pos << 8) >> 8, 42);
}
#[test]
fn test_gpu_node_store_basic() {
let mut store = match GpuNodeStore::new() {
Ok(s) => s,
Err(e) => {
eprintln!("Skipping GPU node store test (no CUDA): {e}");
return;
}
};
let pairs: Vec<(u64, [u8; 32])> = (0..100)
.map(|i| {
let key = encode_node_pos(13, i);
let mut val = [0u8; 32];
val[0] = i as u8;
(key, val)
})
.collect();
store.insert_from_host(&pairs).unwrap();
assert_eq!(store.len(), 100);
let keys: Vec<u64> = pairs.iter().map(|(k, _)| *k).collect();
let results = store.get_to_host(&keys).unwrap();
for (i, result) in results.iter().enumerate() {
let val = result.unwrap();
assert_eq!(val[0], i as u8);
}
}
}