use core::mem::size_of;
use metal::NSUInteger;
use winter_math::fields::f64::BaseElement;
use crate::{
HashFn,
metal::{
plan::get_planner,
utils::{buffer_mut_no_copy, buffer_no_copy, page_aligned_uninit_vector, void_ptr},
},
};
const RATE: usize = 8;
pub struct AbsorbColumnsStage256 {
row_size: usize,
pipeline: metal::ComputePipelineState,
threadgroup_dim: metal::MTLSize,
grid_dim: metal::MTLSize,
_states: Vec<[BaseElement; 4]>,
states_buffer: metal::Buffer,
pub digests: Vec<[BaseElement; 4]>,
digests_buffer: metal::Buffer,
}
impl AbsorbColumnsStage256 {
const HASHERS_PER_THREADGROUP: usize = 64;
pub fn new(
library: &metal::LibraryRef,
row_size: usize,
num_columns: usize,
hash: HashFn,
) -> Self {
let kernel_name = format!("{}_absorb_columns_and_permute_p18446744069414584321_fp", hash);
let func = library.get_function(&kernel_name, None).unwrap();
let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();
let threadgroup_dim =
metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
let grid_dim = metal::MTLSize::new(row_size.try_into().unwrap(), 1, 1);
let mut digests = unsafe { page_aligned_uninit_vector(row_size) };
let digests_buffer = buffer_mut_no_copy(library.device(), &mut digests);
let padding_rule = (num_columns % RATE) as u32;
let mut _states = unsafe { page_aligned_uninit_vector(row_size) };
_states.fill([
BaseElement::from(padding_rule),
BaseElement::from(0u32),
BaseElement::from(0u32),
BaseElement::from(0u32),
]);
let states_buffer = buffer_mut_no_copy(library.device(), &mut _states);
AbsorbColumnsStage256 {
row_size,
threadgroup_dim,
pipeline,
grid_dim,
digests,
digests_buffer,
_states,
states_buffer,
}
}
pub fn encode(&self, command_buffer: &metal::CommandBufferRef, columns: [&[BaseElement]; 8]) {
let [col0, col1, col2, col3, col4, col5, col6, col7] = columns;
assert_eq!(self.row_size, col1.len());
assert_eq!(self.row_size, col2.len());
assert_eq!(self.row_size, col3.len());
assert_eq!(self.row_size, col4.len());
assert_eq!(self.row_size, col5.len());
assert_eq!(self.row_size, col6.len());
assert_eq!(self.row_size, col7.len());
let planner = get_planner();
let device = planner.library.device();
let command_encoder = command_buffer
.compute_command_encoder_with_dispatch_type(metal::MTLDispatchType::Concurrent);
#[cfg(debug_assertions)]
command_encoder.set_label("absorb and permute 8 columns");
let state_width = 16;
let field_size = size_of::<BaseElement>() as NSUInteger;
let mem_per_hasher = state_width * field_size;
let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
command_encoder.set_compute_pipeline_state(&self.pipeline);
command_encoder.set_buffer(0, Some(&buffer_no_copy(device, col0)), 0);
command_encoder.set_buffer(1, Some(&buffer_no_copy(device, col1)), 0);
command_encoder.set_buffer(2, Some(&buffer_no_copy(device, col2)), 0);
command_encoder.set_buffer(3, Some(&buffer_no_copy(device, col3)), 0);
command_encoder.set_buffer(4, Some(&buffer_no_copy(device, col4)), 0);
command_encoder.set_buffer(5, Some(&buffer_no_copy(device, col5)), 0);
command_encoder.set_buffer(6, Some(&buffer_no_copy(device, col6)), 0);
command_encoder.set_buffer(7, Some(&buffer_no_copy(device, col7)), 0);
command_encoder.set_buffer(8, Some(&self.states_buffer), 0);
command_encoder.set_buffer(9, Some(&self.digests_buffer), 0);
command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
command_encoder.memory_barrier_with_resources(&[&self.states_buffer, &self.digests_buffer]);
command_encoder.end_encoding()
}
}
pub struct AbsorbRowsStage256 {
row_size: usize,
pipeline: metal::ComputePipelineState,
threadgroup_dim: metal::MTLSize,
grid_dim: metal::MTLSize,
_states: Vec<[BaseElement; 4]>,
states_buffer: metal::Buffer,
pub digests: Vec<[BaseElement; 4]>,
digests_buffer: metal::Buffer,
}
impl AbsorbRowsStage256 {
const HASHERS_PER_THREADGROUP: usize = 128;
pub fn new(
library: &metal::LibraryRef,
row_size: usize,
num_columns: usize,
hash: HashFn,
) -> Self {
let kernel_name = format!("{}_absorb_rows_and_permute_p18446744069414584321_fp", hash);
let func = library.get_function(&kernel_name, None).unwrap();
let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();
let threadgroup_dim =
metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
let grid_dim = metal::MTLSize::new(row_size.try_into().unwrap(), 1, 1);
let mut digests = unsafe { page_aligned_uninit_vector(row_size) };
let digests_buffer = buffer_mut_no_copy(library.device(), &mut digests);
let padding_rule = (num_columns % 8) as u32;
let mut _states = unsafe { page_aligned_uninit_vector(row_size) };
_states.fill([
BaseElement::from(padding_rule),
BaseElement::from(0u32),
BaseElement::from(0u32),
BaseElement::from(0u32),
]);
let states_buffer = buffer_mut_no_copy(library.device(), &mut _states);
AbsorbRowsStage256 {
row_size,
threadgroup_dim,
pipeline,
grid_dim,
digests,
digests_buffer,
_states,
states_buffer,
}
}
pub fn encode(&self, command_buffer: &metal::CommandBufferRef, rows: &[[BaseElement; 8]]) {
assert_eq!(self.row_size, rows.len());
let planner = get_planner();
let device = planner.library.device();
let command_encoder = command_buffer
.compute_command_encoder_with_dispatch_type(metal::MTLDispatchType::Concurrent);
#[cfg(debug_assertions)]
command_encoder.set_label("absorb and permute 8 column rows");
let state_width = 16;
let field_size = size_of::<BaseElement>() as NSUInteger;
let mem_per_hasher = state_width * field_size;
let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
command_encoder.set_compute_pipeline_state(&self.pipeline);
command_encoder.set_buffer(0, Some(&buffer_no_copy(device, rows)), 0);
command_encoder.set_buffer(1, Some(&self.states_buffer), 0);
command_encoder.set_buffer(2, Some(&self.digests_buffer), 0);
command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
command_encoder.memory_barrier_with_resources(&[&self.states_buffer, &self.digests_buffer]);
command_encoder.end_encoding()
}
}
pub struct GenMerkleNodesFirstRowStage256 {
pipeline: metal::ComputePipelineState,
threadgroup_dim: metal::MTLSize,
grid_dim: metal::MTLSize,
}
impl GenMerkleNodesFirstRowStage256 {
pub const HASHERS_PER_THREADGROUP: usize = 64;
pub fn new(library: &metal::LibraryRef, num_leaves: usize, hash: HashFn) -> Self {
use metal::MTLDataType::UInt;
assert!(num_leaves.is_power_of_two());
assert!((num_leaves / 2) >= Self::HASHERS_PER_THREADGROUP);
let kernel_call = match hash {
HashFn::Rpo256 => "rpo_128",
HashFn::Rpx256 => "rpx_128",
};
let constants = metal::FunctionConstantValues::new();
constants.set_constant_value_at_index(void_ptr(&(num_leaves as u32)), UInt, 0);
let kernel_name =
format!("{}_gen_merkle_nodes_first_row_p18446744069414584321_fp", kernel_call);
let func = library.get_function(&kernel_name, Some(constants)).unwrap();
let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();
let threadgroup_dim =
metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
let grid_dim = metal::MTLSize::new((num_leaves / 2).try_into().unwrap(), 1, 1);
GenMerkleNodesFirstRowStage256 { pipeline, threadgroup_dim, grid_dim }
}
pub fn encode(
&self,
command_buffer: &metal::CommandBufferRef,
leaves: &metal::Buffer,
nodes: &metal::Buffer,
) {
let command_encoder = command_buffer.new_compute_command_encoder();
let state_width = 12;
let field_size: NSUInteger = 8;
assert_eq!(field_size as usize, size_of::<BaseElement>());
let mem_per_hasher = state_width * field_size;
let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
command_encoder.set_compute_pipeline_state(&self.pipeline);
command_encoder.set_buffer(0, Some(leaves), 0);
command_encoder.set_buffer(1, Some(nodes), 0);
command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
command_encoder.memory_barrier_with_resources(&[nodes]);
command_encoder.end_encoding()
}
}
pub struct GenMerkleNodesRowStage256 {
num_leaves: usize,
pipeline: metal::ComputePipelineState,
threadgroup_dim: metal::MTLSize,
}
impl GenMerkleNodesRowStage256 {
pub const HASHERS_PER_THREADGROUP: usize = 32;
pub fn new(library: &metal::LibraryRef, num_leaves: usize, hash: HashFn) -> Self {
use metal::MTLDataType::UInt;
assert!(num_leaves.is_power_of_two());
let kernel_call = match hash {
HashFn::Rpo256 => "rpo_128",
HashFn::Rpx256 => "rpx_128",
};
let constants = metal::FunctionConstantValues::new();
constants.set_constant_value_at_index(void_ptr(&(num_leaves as u32)), UInt, 0);
let kernel_name = format!("{}_gen_merkle_nodes_row_p18446744069414584321_fp", kernel_call);
let func = library.get_function(&kernel_name, Some(constants)).unwrap();
let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();
let threadgroup_dim =
metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP as NSUInteger, 1, 1);
GenMerkleNodesRowStage256 { num_leaves, pipeline, threadgroup_dim }
}
pub fn encode(
&self,
command_buffer: &metal::CommandBufferRef,
nodes: &metal::Buffer,
row: u32,
) {
assert_ne!(1, row, "use GenMerkleNodesFirstRowStage");
let command_encoder = command_buffer.new_compute_command_encoder();
#[cfg(debug_assertions)]
command_encoder.set_label(&format!("merkle tree row={row}"));
let state_width = 12;
let field_size: NSUInteger = 8;
assert_eq!(field_size as usize, size_of::<BaseElement>());
let mem_per_hasher = state_width * field_size;
let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
command_encoder.set_compute_pipeline_state(&self.pipeline);
command_encoder.set_buffer(0, Some(nodes), 0);
command_encoder.set_bytes(1, size_of::<u32>() as NSUInteger, void_ptr(&row));
let grid_dim = metal::MTLSize::new((self.num_leaves >> row).try_into().unwrap(), 1, 1);
command_encoder.dispatch_threads(grid_dim, self.threadgroup_dim);
command_encoder.memory_barrier_with_resources(&[nodes]);
command_encoder.end_encoding()
}
}