use std::ptr::{null, null_mut};
use boojum::field::goldilocks::GoldilocksField;
use era_cudart::result::{CudaResult, CudaResultWrap};
use era_cudart::slice::{DeviceSlice, DeviceVariable};
use era_cudart::stream::CudaStream;
use era_cudart_sys::{cudaError_t, cudaStream_t, cuda_fn_and_stub};
cuda_fn_and_stub! {
fn encode_u32(
d_temp_storage: *mut u8,
temp_storage_bytes: &mut usize,
d_in: *const u32,
d_unique_out: *mut u32,
d_counts_out: *mut u32,
d_num_runs_out: *mut u32,
num_items: i32,
stream: cudaStream_t,
) -> cudaError_t;
}
cuda_fn_and_stub! {
fn encode_u64(
d_temp_storage: *mut u8,
temp_storage_bytes: &mut usize,
d_in: *const u64,
d_unique_out: *mut u64,
d_counts_out: *mut u32,
d_num_runs_out: *mut u32,
num_items: i32,
stream: cudaStream_t,
) -> cudaError_t;
}
pub type EncodeFunction<T> = unsafe extern "C" fn(
d_temp_storage: *mut u8,
temp_storage_bytes: &mut usize,
d_in: *const T,
d_unique_out: *mut T,
d_counts_out: *mut u32,
d_num_runs_out: *mut u32,
num_items: i32,
stream: cudaStream_t,
) -> cudaError_t;
pub trait Encode: Sized {
fn get_function() -> EncodeFunction<Self>;
fn get_encode_temp_storage_bytes(num_items: i32) -> CudaResult<usize> {
let mut temp_storage_bytes = 0;
let function = Self::get_function();
unsafe {
function(
null_mut(),
&mut temp_storage_bytes,
null(),
null_mut(),
null_mut(),
null_mut(),
num_items,
null_mut(),
)
.wrap_value(temp_storage_bytes)
}
}
fn encode(
d_temp_storage: &mut DeviceSlice<u8>,
d_in: &DeviceSlice<Self>,
d_unique_out: &mut DeviceSlice<Self>,
d_counts_out: &mut DeviceSlice<u32>,
d_num_runs_out: &mut DeviceVariable<u32>,
stream: &CudaStream,
) -> CudaResult<()> {
let mut temp_storage_bytes = d_temp_storage.len();
assert_eq!(d_in.len(), d_counts_out.len());
assert_eq!(d_unique_out.len(), d_counts_out.len());
assert!(d_counts_out.len() <= i32::MAX as usize);
let num_items = d_counts_out.len() as i32;
let function = Self::get_function();
unsafe {
function(
d_temp_storage.as_mut_ptr(),
&mut temp_storage_bytes,
d_in.as_ptr(),
d_unique_out.as_mut_ptr(),
d_counts_out.as_mut_ptr(),
d_num_runs_out.as_mut_ptr(),
num_items,
stream.into(),
)
.wrap()
}
}
}
impl Encode for u32 {
fn get_function() -> EncodeFunction<Self> {
encode_u32
}
}
impl Encode for u64 {
fn get_function() -> EncodeFunction<Self> {
encode_u64
}
}
impl Encode for GoldilocksField {
fn get_function() -> EncodeFunction<Self> {
unimplemented!()
}
fn encode(
d_temp_storage: &mut DeviceSlice<u8>,
d_in: &DeviceSlice<Self>,
d_unique_out: &mut DeviceSlice<Self>,
d_counts_out: &mut DeviceSlice<u32>,
d_num_runs_out: &mut DeviceVariable<u32>,
stream: &CudaStream,
) -> CudaResult<()> {
let d_in = unsafe { d_in.transmute() };
let d_unique_out = unsafe { d_unique_out.transmute_mut() };
u64::encode(
d_temp_storage,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
stream,
)
}
}
pub fn get_encode_temp_storage_bytes<T: Encode>(num_items: i32) -> CudaResult<usize> {
T::get_encode_temp_storage_bytes(num_items)
}
pub fn encode<T: Encode>(
d_temp_storage: &mut DeviceSlice<u8>,
d_in: &DeviceSlice<T>,
d_unique_out: &mut DeviceSlice<T>,
d_counts_out: &mut DeviceSlice<u32>,
d_num_runs_out: &mut DeviceVariable<u32>,
stream: &CudaStream,
) -> CudaResult<()> {
T::encode(
d_temp_storage,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
stream,
)
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use itertools::Itertools;
use rand::distributions::{Distribution, Standard, Uniform};
use rand::{thread_rng, Rng};
use era_cudart::memory::{memory_copy_async, DeviceAllocation};
use era_cudart::stream::CudaStream;
fn encode<T>()
where
T: super::Encode + Default + Copy + Clone + Ord + Eq + Debug + From<u32>,
Standard: Distribution<T>,
{
const NUM_ITEMS: usize = 1 << 16;
const RANGE_MAX: u32 = 2;
let temp_storage_bytes =
super::get_encode_temp_storage_bytes::<T>(NUM_ITEMS as i32).unwrap();
let mut d_temp_storage = DeviceAllocation::alloc(temp_storage_bytes).unwrap();
let h_in = thread_rng()
.sample_iter(Uniform::from(0..RANGE_MAX))
.map(|x| x.into())
.take(NUM_ITEMS)
.collect_vec();
let mut h_unique_out = vec![T::default(); NUM_ITEMS];
let mut h_counts_out = vec![0u32; NUM_ITEMS];
let mut h_num_runs_out = [0u32];
let mut d_in = DeviceAllocation::alloc(NUM_ITEMS).unwrap();
let mut d_unique_out = DeviceAllocation::alloc(NUM_ITEMS).unwrap();
let mut d_counts_out = DeviceAllocation::alloc(NUM_ITEMS).unwrap();
let mut d_num_runs_out = DeviceAllocation::alloc(1).unwrap();
let stream = CudaStream::default();
memory_copy_async(&mut d_in, &h_in, &stream).unwrap();
super::encode(
&mut d_temp_storage,
&d_in,
&mut d_unique_out,
&mut d_counts_out,
&mut d_num_runs_out[0],
&stream,
)
.unwrap();
memory_copy_async(&mut h_unique_out, &d_unique_out, &stream).unwrap();
memory_copy_async(&mut h_counts_out, &d_counts_out, &stream).unwrap();
memory_copy_async(&mut h_num_runs_out, &d_num_runs_out, &stream).unwrap();
stream.synchronize().unwrap();
let mut run_index = 0usize;
let mut current_value = h_in[0];
let mut current_count = 0;
for x in h_in {
if x == current_value {
current_count += 1;
} else {
assert_eq!(current_value, h_unique_out[run_index]);
assert_eq!(current_count, h_counts_out[run_index]);
run_index += 1;
current_value = x;
current_count = 1;
}
}
assert_eq!(current_value, h_unique_out[run_index]);
assert_eq!(current_count, h_counts_out[run_index]);
assert_eq!(run_index + 1, h_num_runs_out[0] as usize);
}
#[test]
fn encode_u32() {
encode::<u32>();
}
#[test]
fn encode_u64() {
encode::<u64>();
}
}