use boojum::cs::implementations::utils::domain_generator_for_size;
use boojum::fft::{bitreverse_enumeration_inplace, distribute_powers};
use boojum::field::goldilocks::GoldilocksField;
use boojum::field::{Field, PrimeField};
use era_cudart::memory::{memory_copy, DeviceAllocation};
use era_cudart::result::{CudaResult, CudaResultWrap};
use era_cudart::slice::DeviceSlice;
use era_cudart_sys::{cudaMemcpyToSymbol, cuda_struct_and_stub, CudaMemoryCopyKind};
use std::mem::size_of;
use std::os::raw::c_void;
pub const OMEGA_LOG_ORDER: u32 = 24;
#[repr(C)]
struct PowersLayerData {
values: *const GoldilocksField,
mask: u32,
log_count: u32,
}
impl PowersLayerData {
fn new(values: *const GoldilocksField, log_count: u32) -> Self {
let mask = (1 << log_count) - 1;
Self {
values,
mask,
log_count,
}
}
}
unsafe impl Sync for PowersLayerData {}
#[repr(C)]
struct PowersData {
fine: PowersLayerData,
coarse: PowersLayerData,
}
impl PowersData {
fn new(
fine_values: *const GoldilocksField,
fine_log_count: u32,
coarse_values: *const GoldilocksField,
coarse_log_count: u32,
) -> Self {
let fine = PowersLayerData::new(fine_values, fine_log_count);
let coarse = PowersLayerData::new(coarse_values, coarse_log_count);
Self { fine, coarse }
}
}
unsafe impl Sync for PowersData {}
cuda_struct_and_stub! { static powers_data_w: PowersData; }
cuda_struct_and_stub! { static powers_data_w_bitrev_for_ntt: PowersData; }
cuda_struct_and_stub! { static powers_data_w_inv_bitrev_for_ntt: PowersData; }
cuda_struct_and_stub! { static powers_data_g_f: PowersData; }
cuda_struct_and_stub! { static powers_data_g_i: PowersData; }
cuda_struct_and_stub! { static inv_sizes: [GoldilocksField; OMEGA_LOG_ORDER as usize + 1]; }
unsafe fn copy_to_symbol<T>(symbol: &T, src: &T) -> CudaResult<()> {
cudaMemcpyToSymbol(
symbol as *const T as *const c_void,
src as *const T as *const c_void,
size_of::<T>(),
0,
CudaMemoryCopyKind::HostToDevice,
)
.wrap()
}
#[allow(clippy::too_many_arguments)]
unsafe fn copy_to_symbols(
powers_of_w_coarse_log_count: u32,
powers_of_w_fine: *const GoldilocksField,
powers_of_w_coarse: *const GoldilocksField,
powers_of_w_fine_bitrev_for_ntt: *const GoldilocksField,
powers_of_w_coarse_bitrev_for_ntt: *const GoldilocksField,
powers_of_w_inv_fine_bitrev_for_ntt: *const GoldilocksField,
powers_of_w_inv_coarse_bitrev_for_ntt: *const GoldilocksField,
powers_of_g_coarse_log_count: u32,
powers_of_g_f_fine: *const GoldilocksField,
powers_of_g_f_coarse: *const GoldilocksField,
powers_of_g_i_fine: *const GoldilocksField,
powers_of_g_i_coarse: *const GoldilocksField,
inv_sizes_host: [GoldilocksField; OMEGA_LOG_ORDER as usize + 1],
) -> CudaResult<()> {
let coarse_log_count = powers_of_w_coarse_log_count;
let fine_log_count = OMEGA_LOG_ORDER - coarse_log_count;
copy_to_symbol(
&powers_data_w,
&PowersData::new(
powers_of_w_fine,
fine_log_count,
powers_of_w_coarse,
coarse_log_count,
),
)?;
let fine_log_count = fine_log_count - 1;
copy_to_symbol(
&powers_data_w_bitrev_for_ntt,
&PowersData::new(
powers_of_w_fine_bitrev_for_ntt,
fine_log_count,
powers_of_w_coarse_bitrev_for_ntt,
coarse_log_count,
),
)?;
copy_to_symbol(
&powers_data_w_inv_bitrev_for_ntt,
&PowersData::new(
powers_of_w_inv_fine_bitrev_for_ntt,
fine_log_count,
powers_of_w_inv_coarse_bitrev_for_ntt,
coarse_log_count,
),
)?;
let coarse_log_count = powers_of_g_coarse_log_count;
let fine_log_count = OMEGA_LOG_ORDER - coarse_log_count;
copy_to_symbol(
&powers_data_g_f,
&PowersData::new(
powers_of_g_f_fine,
fine_log_count,
powers_of_g_f_coarse,
coarse_log_count,
),
)?;
copy_to_symbol(
&powers_data_g_i,
&PowersData::new(
powers_of_g_i_fine,
fine_log_count,
powers_of_g_i_coarse,
coarse_log_count,
),
)?;
copy_to_symbol(&inv_sizes, &inv_sizes_host)?;
Ok(())
}
fn generate_powers_dev(
base: GoldilocksField,
powers_dev: &mut DeviceSlice<GoldilocksField>,
bit_reverse: bool,
) -> CudaResult<()> {
let mut powers_host = vec![GoldilocksField::ONE; powers_dev.len()];
distribute_powers(&mut powers_host, base);
if bit_reverse {
bitreverse_enumeration_inplace(&mut powers_host);
}
memory_copy(powers_dev, &powers_host)
}
pub struct Context {
pub powers_of_w_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_w_coarse: DeviceAllocation<GoldilocksField>,
pub powers_of_w_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_inv_fine_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_w_inv_coarse_bitrev_for_ntt: DeviceAllocation<GoldilocksField>,
pub powers_of_g_f_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_g_f_coarse: DeviceAllocation<GoldilocksField>,
pub powers_of_g_i_fine: DeviceAllocation<GoldilocksField>,
pub powers_of_g_i_coarse: DeviceAllocation<GoldilocksField>,
}
impl Context {
pub fn create(
powers_of_w_coarse_log_count: u32,
powers_of_g_coarse_log_count: u32,
) -> CudaResult<Self> {
assert!(powers_of_w_coarse_log_count <= OMEGA_LOG_ORDER);
assert!(powers_of_g_coarse_log_count <= OMEGA_LOG_ORDER);
let length_fine = 1usize << (OMEGA_LOG_ORDER - powers_of_w_coarse_log_count);
let length_coarse = 1usize << powers_of_w_coarse_log_count;
let mut powers_of_w_fine = DeviceAllocation::<GoldilocksField>::alloc(length_fine)?;
let mut powers_of_w_coarse = DeviceAllocation::<GoldilocksField>::alloc(length_coarse)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>(1u64 << OMEGA_LOG_ORDER),
&mut powers_of_w_fine,
false,
)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>(length_coarse as u64),
&mut powers_of_w_coarse,
false,
)?;
let length_fine = 1usize << (OMEGA_LOG_ORDER - powers_of_w_coarse_log_count - 1);
let length_coarse = 1usize << powers_of_w_coarse_log_count;
let mut powers_of_w_fine_bitrev_for_ntt =
DeviceAllocation::<GoldilocksField>::alloc(length_fine)?;
let mut powers_of_w_coarse_bitrev_for_ntt =
DeviceAllocation::<GoldilocksField>::alloc(length_coarse)?;
let mut powers_of_w_inv_fine_bitrev_for_ntt =
DeviceAllocation::<GoldilocksField>::alloc(length_fine)?;
let mut powers_of_w_inv_coarse_bitrev_for_ntt =
DeviceAllocation::<GoldilocksField>::alloc(length_coarse)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>(1u64 << OMEGA_LOG_ORDER),
&mut powers_of_w_fine_bitrev_for_ntt,
true,
)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>((length_coarse * 2) as u64),
&mut powers_of_w_coarse_bitrev_for_ntt,
true,
)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>(1u64 << OMEGA_LOG_ORDER)
.inverse()
.expect("must exist"),
&mut powers_of_w_inv_fine_bitrev_for_ntt,
true,
)?;
generate_powers_dev(
domain_generator_for_size::<GoldilocksField>((length_coarse * 2) as u64)
.inverse()
.expect("must exist"),
&mut powers_of_w_inv_coarse_bitrev_for_ntt,
true,
)?;
let length_fine = 1usize << (OMEGA_LOG_ORDER - powers_of_g_coarse_log_count);
let length_coarse = 1usize << powers_of_g_coarse_log_count;
let mut powers_of_g_f_fine = DeviceAllocation::<GoldilocksField>::alloc(length_fine)?;
let mut powers_of_g_f_coarse = DeviceAllocation::<GoldilocksField>::alloc(length_coarse)?;
let mut powers_of_g_i_fine = DeviceAllocation::<GoldilocksField>::alloc(length_fine)?;
let mut powers_of_g_i_coarse = DeviceAllocation::<GoldilocksField>::alloc(length_coarse)?;
generate_powers_dev(
GoldilocksField::multiplicative_generator(),
&mut powers_of_g_f_fine,
false,
)?;
generate_powers_dev(
GoldilocksField::multiplicative_generator().pow_u64(length_fine as u64),
&mut powers_of_g_f_coarse,
false,
)?;
let g_inv = GoldilocksField::multiplicative_generator()
.inverse()
.expect("inv of generator must exist");
generate_powers_dev(g_inv, &mut powers_of_g_i_fine, false)?;
generate_powers_dev(
g_inv.pow_u64(length_fine as u64),
&mut powers_of_g_i_coarse,
false,
)?;
let two_inv = GoldilocksField(2).inverse().expect("must exist");
let mut inv_sizes_host = [GoldilocksField::ONE; (OMEGA_LOG_ORDER + 1) as usize];
distribute_powers(&mut inv_sizes_host, two_inv);
unsafe {
copy_to_symbols(
powers_of_w_coarse_log_count,
powers_of_w_fine.as_ptr(),
powers_of_w_coarse.as_ptr(),
powers_of_w_fine_bitrev_for_ntt.as_ptr(),
powers_of_w_coarse_bitrev_for_ntt.as_ptr(),
powers_of_w_inv_fine_bitrev_for_ntt.as_ptr(),
powers_of_w_inv_coarse_bitrev_for_ntt.as_ptr(),
powers_of_g_coarse_log_count,
powers_of_g_f_fine.as_ptr(),
powers_of_g_f_coarse.as_ptr(),
powers_of_g_i_fine.as_ptr(),
powers_of_g_i_coarse.as_ptr(),
inv_sizes_host,
)?;
}
Ok(Self {
powers_of_w_fine,
powers_of_w_coarse,
powers_of_w_fine_bitrev_for_ntt,
powers_of_w_coarse_bitrev_for_ntt,
powers_of_w_inv_fine_bitrev_for_ntt,
powers_of_w_inv_coarse_bitrev_for_ntt,
powers_of_g_f_fine,
powers_of_g_f_coarse,
powers_of_g_i_fine,
powers_of_g_i_coarse,
})
}
pub fn destroy(self) -> CudaResult<()> {
self.powers_of_w_fine.free()?;
self.powers_of_w_coarse.free()?;
self.powers_of_w_fine_bitrev_for_ntt.free()?;
self.powers_of_w_coarse_bitrev_for_ntt.free()?;
self.powers_of_w_inv_fine_bitrev_for_ntt.free()?;
self.powers_of_w_inv_coarse_bitrev_for_ntt.free()?;
self.powers_of_g_f_fine.free()?;
self.powers_of_g_f_coarse.free()?;
self.powers_of_g_i_fine.free()?;
self.powers_of_g_i_coarse.free()?;
Ok(())
}
}