use crate::{
field::{
element::FieldElement,
fields::fft_friendly::stark_252_prime_field::Stark252PrimeField,
traits::{IsFFTField, IsField, RootsConfig},
},
gpu::cuda::field::element::CUDAFieldElement,
};
use cudarc::{
driver::{
safe::CudaSlice, safe::DeviceSlice, CudaDevice, CudaFunction, LaunchAsync, LaunchConfig,
},
nvrtc::safe::Ptx,
};
use lambdaworks_gpu::cuda::abstractions::errors::CudaError;
use std::sync::Arc;
const STARK256_PTX: &str = include_str!("../../../gpu/cuda/shaders/field/stark256.ptx");
const WARP_SIZE: usize = 32;
pub struct CudaState {
device: Arc<CudaDevice>,
}
impl CudaState {
pub fn new() -> Result<Self, CudaError> {
let device =
CudaDevice::new(0).map_err(|err| CudaError::DeviceNotFound(err.to_string()))?;
let state = Self { device };
state.load_library::<Stark252PrimeField>(STARK256_PTX)?;
Ok(state)
}
fn load_library<F: IsFFTField>(&self, src: &'static str) -> Result<(), CudaError> {
let mod_name: &'static str = F::field_name();
let functions = [
"radix2_dit_butterfly",
"calc_twiddles",
"calc_twiddles_bitrev",
"bitrev_permutation",
];
self.device
.load_ptx(Ptx::from_src(src), mod_name, &functions)
.map_err(|err| CudaError::PtxError(err.to_string()))
}
fn get_function<F: IsFFTField>(&self, func_name: &str) -> Result<CudaFunction, CudaError> {
let mod_name = F::field_name();
self.device
.get_func(mod_name, func_name)
.ok_or_else(|| CudaError::FunctionError(func_name.to_string()))
}
fn alloc_buffer_with_data<F: IsField>(
&self,
data: &[FieldElement<F>],
) -> Result<CudaSlice<CUDAFieldElement<F>>, CudaError> {
self.device
.htod_sync_copy(&data.iter().map(CUDAFieldElement::from).collect::<Vec<_>>())
.map_err(|err| CudaError::AllocateMemory(err.to_string()))
}
pub(crate) fn get_radix2_dit_butterfly<F: IsFFTField>(
&self,
input: &[FieldElement<F>],
twiddles: &[FieldElement<F>],
) -> Result<Radix2DitButterflyFunction<F>, CudaError> {
let function = self.get_function::<F>("radix2_dit_butterfly")?;
let input_buffer = self.alloc_buffer_with_data(input)?;
let twiddles_buffer = self.alloc_buffer_with_data(twiddles)?;
Ok(Radix2DitButterflyFunction::new(
Arc::clone(&self.device),
function,
input_buffer,
twiddles_buffer,
))
}
pub(crate) fn get_calc_twiddles<F: IsFFTField>(
&self,
order: u64,
config: RootsConfig,
) -> Result<CalcTwiddlesFunction<F>, CudaError> {
let root: FieldElement<F> = F::get_primitive_root_of_unity(order).map_err(|_| {
CudaError::FunctionError(format!(
"Couldn't get primitive root of unity of order {}",
order
))
})?;
let (root, function_name) = match config {
RootsConfig::Natural => (root, "calc_twiddles"),
RootsConfig::NaturalInversed => (root.inv(), "calc_twiddles"),
RootsConfig::BitReverse => (root, "calc_twiddles_bitrev"),
RootsConfig::BitReverseInversed => (root.inv(), "calc_twiddles_bitrev"),
};
let function = self.get_function::<F>(function_name)?;
let count = (1 << order) / 2;
let omega_buffer = self.alloc_buffer_with_data(&[root])?;
let twiddles: &[FieldElement<F>] = &(0..count)
.map(|_| FieldElement::one())
.collect::<Vec<FieldElement<F>>>();
let twiddles_buffer = self.alloc_buffer_with_data(twiddles)?;
Ok(CalcTwiddlesFunction::new(
Arc::clone(&self.device),
function,
omega_buffer,
twiddles_buffer,
))
}
pub(crate) fn get_bitrev_permutation<F: IsFFTField>(
&self,
input: &[FieldElement<F>],
result: &[FieldElement<F>],
) -> Result<BitrevPermutationFunction<F>, CudaError> {
let function = self.get_function::<F>("bitrev_permutation")?;
let input_buffer = self.alloc_buffer_with_data(input)?;
let result_buffer = self.alloc_buffer_with_data(result)?;
Ok(BitrevPermutationFunction::new(
Arc::clone(&self.device),
function,
input_buffer,
result_buffer,
))
}
}
pub(crate) struct Radix2DitButterflyFunction<F: IsField> {
device: Arc<CudaDevice>,
function: CudaFunction,
input: CudaSlice<CUDAFieldElement<F>>,
twiddles: CudaSlice<CUDAFieldElement<F>>,
}
impl<F: IsField> Radix2DitButterflyFunction<F> {
fn new(
device: Arc<CudaDevice>,
function: CudaFunction,
input: CudaSlice<CUDAFieldElement<F>>,
twiddles: CudaSlice<CUDAFieldElement<F>>,
) -> Self {
Self {
device,
function,
input,
twiddles,
}
}
pub(crate) fn launch(
&mut self,
block_count: usize,
block_size: usize,
stage: u32,
butterfly_count: u32,
) -> Result<(), CudaError> {
let grid_dim = (block_count as u32, 1, 1); let block_dim = (block_size as u32, 1, 1);
let config = LaunchConfig {
grid_dim,
block_dim,
shared_mem_bytes: 0,
};
unsafe {
self.function.clone().launch(
config,
(&mut self.input, &self.twiddles, stage, butterfly_count),
)
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
pub(crate) fn retrieve_result(self) -> Result<Vec<FieldElement<F>>, CudaError> {
let Self { device, input, .. } = self;
let output = device
.sync_reclaim(input)
.map_err(|err| CudaError::RetrieveMemory(err.to_string()))?
.into_iter()
.map(FieldElement::from)
.collect();
Ok(output)
}
}
pub(crate) struct CalcTwiddlesFunction<F: IsField> {
device: Arc<CudaDevice>,
function: CudaFunction,
omega: CudaSlice<CUDAFieldElement<F>>,
twiddles: CudaSlice<CUDAFieldElement<F>>,
}
impl<F: IsField> CalcTwiddlesFunction<F> {
fn new(
device: Arc<CudaDevice>,
function: CudaFunction,
omega: CudaSlice<CUDAFieldElement<F>>,
twiddles: CudaSlice<CUDAFieldElement<F>>,
) -> Self {
Self {
device,
function,
omega,
twiddles,
}
}
pub(crate) fn launch(&mut self, count: usize) -> Result<(), CudaError> {
let block_size = WARP_SIZE;
let block_count = (count + block_size - 1) / block_size;
let grid_dim = (block_count as u32, 1, 1); let block_dim = (block_size as u32, 1, 1);
let config = LaunchConfig {
grid_dim,
block_dim,
shared_mem_bytes: 0,
};
unsafe {
self.function
.clone()
.launch(config, (&mut self.twiddles, &self.omega, count as u32))
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
pub(crate) fn retrieve_result(self) -> Result<Vec<FieldElement<F>>, CudaError> {
let Self {
device, twiddles, ..
} = self;
let output = device
.sync_reclaim(twiddles)
.map_err(|err| CudaError::RetrieveMemory(err.to_string()))?
.into_iter()
.map(FieldElement::from)
.collect();
Ok(output)
}
}
pub(crate) struct BitrevPermutationFunction<F: IsField> {
device: Arc<CudaDevice>,
function: CudaFunction,
input: CudaSlice<CUDAFieldElement<F>>,
result: CudaSlice<CUDAFieldElement<F>>,
}
impl<F: IsField> BitrevPermutationFunction<F> {
fn new(
device: Arc<CudaDevice>,
function: CudaFunction,
input: CudaSlice<CUDAFieldElement<F>>,
result: CudaSlice<CUDAFieldElement<F>>,
) -> Self {
Self {
device,
function,
input,
result,
}
}
pub(crate) fn launch(&mut self) -> Result<(), CudaError> {
let len = self.input.len();
let block_size = WARP_SIZE;
let block_count = (len + block_size - 1) / block_size;
let grid_dim = (block_count as u32, 1, 1); let block_dim = (block_size as u32, 1, 1);
let config = LaunchConfig {
grid_dim,
block_dim,
shared_mem_bytes: 0,
};
unsafe {
self.function
.clone()
.launch(config, (&mut self.input, &self.result, len))
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
pub(crate) fn retrieve_result(self) -> Result<Vec<FieldElement<F>>, CudaError> {
let Self { device, result, .. } = self;
let output = device
.sync_reclaim(result)
.map_err(|err| CudaError::RetrieveMemory(err.to_string()))?
.into_iter()
.map(FieldElement::from)
.collect();
Ok(output)
}
}