#[doc(hidden)]
pub mod tests;
use icicle_cuda_runtime::{
device_context::{get_default_device_context, DeviceContext},
memory::HostOrDeviceSlice,
};
use crate::{error::IcicleResult, traits::FieldImpl};
#[repr(C)]
#[derive(Debug, Clone)]
pub struct PoseidonConstants<'a, F: FieldImpl> {
arity: u32,
partial_rounds: u32,
full_rounds_half: u32,
round_constants: &'a [F],
mds_matrix: &'a [F],
non_sparse_matrix: &'a [F],
sparse_matrices: &'a [F],
domain_tag: F,
}
#[repr(C)]
#[derive(Debug, Clone)]
pub struct PoseidonConfig<'a> {
pub ctx: DeviceContext<'a>,
are_inputs_on_device: bool,
are_outputs_on_device: bool,
pub input_is_a_state: bool,
pub aligned: bool,
pub loop_state: bool,
pub is_async: bool,
}
impl<'a> Default for PoseidonConfig<'a> {
fn default() -> Self {
let ctx = get_default_device_context();
Self {
ctx,
are_inputs_on_device: false,
are_outputs_on_device: false,
input_is_a_state: false,
aligned: false,
loop_state: false,
is_async: false,
}
}
}
pub trait Poseidon<F: FieldImpl> {
fn create_optimized_constants<'a>(
arity: u32,
full_rounds_half: u32,
partial_rounds: u32,
constants: &mut [F],
ctx: &DeviceContext,
) -> IcicleResult<PoseidonConstants<'a, F>>;
fn load_optimized_constants<'a>(arity: u32, ctx: &DeviceContext) -> IcicleResult<PoseidonConstants<'a, F>>;
fn poseidon_unchecked(
input: &mut HostOrDeviceSlice<F>,
output: &mut HostOrDeviceSlice<F>,
number_of_states: u32,
arity: u32,
constants: &PoseidonConstants<F>,
config: &PoseidonConfig,
) -> IcicleResult<()>;
}
pub fn load_optimized_poseidon_constants<'a, F>(
arity: u32,
ctx: &DeviceContext,
) -> IcicleResult<PoseidonConstants<'a, F>>
where
F: FieldImpl,
<F as FieldImpl>::Config: Poseidon<F>,
{
<<F as FieldImpl>::Config as Poseidon<F>>::load_optimized_constants(arity, ctx)
}
pub fn create_optimized_poseidon_constants<'a, F>(
arity: u32,
ctx: &DeviceContext,
full_rounds_half: u32,
partial_rounds: u32,
constants: &mut [F],
) -> IcicleResult<PoseidonConstants<'a, F>>
where
F: FieldImpl,
<F as FieldImpl>::Config: Poseidon<F>,
{
<<F as FieldImpl>::Config as Poseidon<F>>::create_optimized_constants(
arity,
full_rounds_half,
partial_rounds,
constants,
ctx,
)
}
pub fn poseidon_hash_many<F>(
input: &mut HostOrDeviceSlice<F>,
output: &mut HostOrDeviceSlice<F>,
number_of_states: u32,
arity: u32,
constants: &PoseidonConstants<F>,
config: &PoseidonConfig,
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: Poseidon<F>,
{
let input_len_required = if config.input_is_a_state {
number_of_states * (arity + 1)
} else {
number_of_states * arity
};
if input.len() < input_len_required as usize {
panic!(
"input len is {}; but needs to be at least {}",
input.len(),
input_len_required
);
}
if output.len() < number_of_states as usize {
panic!(
"output len is {}; but needs to be at least {}",
output.len(),
number_of_states
);
}
let mut local_cfg = config.clone();
local_cfg.are_inputs_on_device = input.is_on_device();
local_cfg.are_outputs_on_device = output.is_on_device();
<<F as FieldImpl>::Config as Poseidon<F>>::poseidon_unchecked(
input,
output,
number_of_states,
arity,
constants,
&local_cfg,
)
}
#[macro_export]
macro_rules! impl_poseidon {
(
$field_prefix:literal,
$field_prefix_ident:ident,
$field:ident,
$field_config:ident
) => {
mod $field_prefix_ident {
use crate::poseidon::{$field, $field_config, CudaError, DeviceContext, PoseidonConfig, PoseidonConstants};
extern "C" {
#[link_name = concat!($field_prefix, "CreateOptimizedPoseidonConstants")]
pub(crate) fn _create_optimized_constants(
arity: u32,
full_rounds_half: u32,
partial_rounds: u32,
constants: *mut $field,
ctx: &DeviceContext,
poseidon_constants: *mut PoseidonConstants<$field>,
) -> CudaError;
#[link_name = concat!($field_prefix, "InitOptimizedPoseidonConstants")]
pub(crate) fn _load_optimized_constants(
arity: u32,
ctx: &DeviceContext,
constants: *mut PoseidonConstants<$field>,
) -> CudaError;
#[link_name = concat!($field_prefix, "PoseidonHash")]
pub(crate) fn hash_many(
input: *mut $field,
output: *mut $field,
number_of_states: u32,
arity: u32,
constants: &PoseidonConstants<$field>,
config: &PoseidonConfig,
) -> CudaError;
}
}
impl Poseidon<$field> for $field_config {
fn create_optimized_constants<'a>(
arity: u32,
full_rounds_half: u32,
partial_rounds: u32,
constants: &mut [$field],
ctx: &DeviceContext,
) -> IcicleResult<PoseidonConstants<'a, $field>> {
unsafe {
let mut poseidon_constants = MaybeUninit::<PoseidonConstants<'a, $field>>::uninit();
let err = $field_prefix_ident::_create_optimized_constants(
arity,
full_rounds_half,
partial_rounds,
constants as *mut _ as *mut $field,
ctx,
poseidon_constants.as_mut_ptr(),
)
.wrap();
err.and(Ok(poseidon_constants.assume_init()))
}
}
fn load_optimized_constants<'a>(
arity: u32,
ctx: &DeviceContext,
) -> IcicleResult<PoseidonConstants<'a, $field>> {
unsafe {
let mut constants = MaybeUninit::<PoseidonConstants<'a, $field>>::uninit();
let err = $field_prefix_ident::_load_optimized_constants(arity, ctx, constants.as_mut_ptr()).wrap();
err.and(Ok(constants.assume_init()))
}
}
fn poseidon_unchecked(
input: &mut HostOrDeviceSlice<$field>,
output: &mut HostOrDeviceSlice<$field>,
number_of_states: u32,
arity: u32,
constants: &PoseidonConstants<$field>,
config: &PoseidonConfig,
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::hash_many(
input.as_mut_ptr(),
output.as_mut_ptr(),
number_of_states,
arity,
constants,
config,
)
.wrap()
}
}
}
};
}
#[macro_export]
macro_rules! impl_poseidon_tests {
(
$field:ident,
$field_bytes:literal,
$field_prefix:literal,
$partial_rounds:literal
) => {
#[test]
fn test_poseidon_hash_many() {
check_poseidon_hash_many::<$field>()
}
#[test]
fn test_poseidon_custom_config() {
check_poseidon_custom_config::<$field>($field_bytes, $field_prefix, $partial_rounds)
}
};
}