miden-gpu 0.6.0

GPU acceleration for the Miden VM prover
Documentation
use core::mem::size_of;

use metal::NSUInteger;
use winter_math::fields::f64::BaseElement;

use crate::{
    HashFn,
    metal::{
        plan::get_planner,
        utils::{buffer_mut_no_copy, buffer_no_copy, page_aligned_uninit_vector, void_ptr},
    },
};

const RATE: usize = 8;

pub struct AbsorbColumnsStage256 {
    row_size: usize,
    pipeline: metal::ComputePipelineState,
    threadgroup_dim: metal::MTLSize,
    grid_dim: metal::MTLSize,
    _states: Vec<[BaseElement; 4]>,
    states_buffer: metal::Buffer,
    pub digests: Vec<[BaseElement; 4]>,
    digests_buffer: metal::Buffer,
}

impl AbsorbColumnsStage256 {
    const HASHERS_PER_THREADGROUP: usize = 64;

    pub fn new(
        library: &metal::LibraryRef,
        row_size: usize,
        num_columns: usize,
        hash: HashFn,
    ) -> Self {
        let kernel_name = format!("{}_absorb_columns_and_permute_p18446744069414584321_fp", hash);
        let func = library.get_function(&kernel_name, None).unwrap();
        let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();

        let threadgroup_dim =
            metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
        let grid_dim = metal::MTLSize::new(row_size.try_into().unwrap(), 1, 1);

        // TODO: creating page aligned vectors in this fashion is rather brittle.
        // If the vector is resized there is no guarantee that the new memory will be
        // page aligned. Rust's Allocator api would be great but it's not currently
        // available on Rust Stable.
        let mut digests = unsafe { page_aligned_uninit_vector(row_size) };
        let digests_buffer = buffer_mut_no_copy(library.device(), &mut digests);

        let padding_rule = (num_columns % RATE) as u32;

        let mut _states = unsafe { page_aligned_uninit_vector(row_size) };
        _states.fill([
            BaseElement::from(padding_rule),
            BaseElement::from(0u32),
            BaseElement::from(0u32),
            BaseElement::from(0u32),
        ]);
        let states_buffer = buffer_mut_no_copy(library.device(), &mut _states);

        AbsorbColumnsStage256 {
            row_size,
            threadgroup_dim,
            pipeline,
            grid_dim,
            digests,
            digests_buffer,
            _states,
            states_buffer,
        }
    }

    pub fn encode(&self, command_buffer: &metal::CommandBufferRef, columns: [&[BaseElement]; 8]) {
        let [col0, col1, col2, col3, col4, col5, col6, col7] = columns;
        assert_eq!(self.row_size, col1.len());
        assert_eq!(self.row_size, col2.len());
        assert_eq!(self.row_size, col3.len());
        assert_eq!(self.row_size, col4.len());
        assert_eq!(self.row_size, col5.len());
        assert_eq!(self.row_size, col6.len());
        assert_eq!(self.row_size, col7.len());

        let planner = get_planner();
        let device = planner.library.device();
        let command_encoder = command_buffer
            .compute_command_encoder_with_dispatch_type(metal::MTLDispatchType::Concurrent);
        #[cfg(debug_assertions)]
        command_encoder.set_label("absorb and permute 8 columns");
        let state_width = 16;
        let field_size = size_of::<BaseElement>() as NSUInteger;
        let mem_per_hasher = state_width * field_size;
        let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
        command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
        command_encoder.set_compute_pipeline_state(&self.pipeline);
        command_encoder.set_buffer(0, Some(&buffer_no_copy(device, col0)), 0);
        command_encoder.set_buffer(1, Some(&buffer_no_copy(device, col1)), 0);
        command_encoder.set_buffer(2, Some(&buffer_no_copy(device, col2)), 0);
        command_encoder.set_buffer(3, Some(&buffer_no_copy(device, col3)), 0);
        command_encoder.set_buffer(4, Some(&buffer_no_copy(device, col4)), 0);
        command_encoder.set_buffer(5, Some(&buffer_no_copy(device, col5)), 0);
        command_encoder.set_buffer(6, Some(&buffer_no_copy(device, col6)), 0);
        command_encoder.set_buffer(7, Some(&buffer_no_copy(device, col7)), 0);
        command_encoder.set_buffer(8, Some(&self.states_buffer), 0);
        command_encoder.set_buffer(9, Some(&self.digests_buffer), 0);
        command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
        command_encoder.memory_barrier_with_resources(&[&self.states_buffer, &self.digests_buffer]);
        command_encoder.end_encoding()
    }
}

pub struct AbsorbRowsStage256 {
    row_size: usize,
    pipeline: metal::ComputePipelineState,
    threadgroup_dim: metal::MTLSize,
    grid_dim: metal::MTLSize,
    _states: Vec<[BaseElement; 4]>,
    states_buffer: metal::Buffer,
    pub digests: Vec<[BaseElement; 4]>,
    digests_buffer: metal::Buffer,
}

impl AbsorbRowsStage256 {
    const HASHERS_PER_THREADGROUP: usize = 128;

    pub fn new(
        library: &metal::LibraryRef,
        row_size: usize,
        num_columns: usize,
        hash: HashFn,
    ) -> Self {
        let kernel_name = format!("{}_absorb_rows_and_permute_p18446744069414584321_fp", hash);
        let func = library.get_function(&kernel_name, None).unwrap();
        let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();

        let threadgroup_dim =
            metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
        let grid_dim = metal::MTLSize::new(row_size.try_into().unwrap(), 1, 1);

        let mut digests = unsafe { page_aligned_uninit_vector(row_size) };
        let digests_buffer = buffer_mut_no_copy(library.device(), &mut digests);

        let padding_rule = (num_columns % 8) as u32;

        let mut _states = unsafe { page_aligned_uninit_vector(row_size) };
        _states.fill([
            BaseElement::from(padding_rule),
            BaseElement::from(0u32),
            BaseElement::from(0u32),
            BaseElement::from(0u32),
        ]);
        let states_buffer = buffer_mut_no_copy(library.device(), &mut _states);

        AbsorbRowsStage256 {
            row_size,
            threadgroup_dim,
            pipeline,
            grid_dim,
            digests,
            digests_buffer,
            _states,
            states_buffer,
        }
    }

    pub fn encode(&self, command_buffer: &metal::CommandBufferRef, rows: &[[BaseElement; 8]]) {
        assert_eq!(self.row_size, rows.len());
        let planner = get_planner();
        let device = planner.library.device();
        let command_encoder = command_buffer
            .compute_command_encoder_with_dispatch_type(metal::MTLDispatchType::Concurrent);
        #[cfg(debug_assertions)]
        command_encoder.set_label("absorb and permute 8 column rows");
        let state_width = 16;
        let field_size = size_of::<BaseElement>() as NSUInteger;
        let mem_per_hasher = state_width * field_size;
        let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
        command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
        command_encoder.set_compute_pipeline_state(&self.pipeline);
        command_encoder.set_buffer(0, Some(&buffer_no_copy(device, rows)), 0);
        command_encoder.set_buffer(1, Some(&self.states_buffer), 0);
        command_encoder.set_buffer(2, Some(&self.digests_buffer), 0);
        command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
        command_encoder.memory_barrier_with_resources(&[&self.states_buffer, &self.digests_buffer]);
        command_encoder.end_encoding()
    }
}

pub struct GenMerkleNodesFirstRowStage256 {
    pipeline: metal::ComputePipelineState,
    threadgroup_dim: metal::MTLSize,
    grid_dim: metal::MTLSize,
}

impl GenMerkleNodesFirstRowStage256 {
    pub const HASHERS_PER_THREADGROUP: usize = 64;

    pub fn new(library: &metal::LibraryRef, num_leaves: usize, hash: HashFn) -> Self {
        use metal::MTLDataType::UInt;
        assert!(num_leaves.is_power_of_two());
        assert!((num_leaves / 2) >= Self::HASHERS_PER_THREADGROUP);

        let kernel_call = match hash {
            HashFn::Rpo256 => "rpo_128",
            HashFn::Rpx256 => "rpx_128",
        };

        let constants = metal::FunctionConstantValues::new();
        constants.set_constant_value_at_index(void_ptr(&(num_leaves as u32)), UInt, 0);
        let kernel_name =
            format!("{}_gen_merkle_nodes_first_row_p18446744069414584321_fp", kernel_call);
        let func = library.get_function(&kernel_name, Some(constants)).unwrap();
        let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();

        let threadgroup_dim =
            metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP.try_into().unwrap(), 1, 1);
        let grid_dim = metal::MTLSize::new((num_leaves / 2).try_into().unwrap(), 1, 1);

        GenMerkleNodesFirstRowStage256 { pipeline, threadgroup_dim, grid_dim }
    }

    pub fn encode(
        &self,
        command_buffer: &metal::CommandBufferRef,
        leaves: &metal::Buffer,
        nodes: &metal::Buffer,
    ) {
        let command_encoder = command_buffer.new_compute_command_encoder();
        // TODO: use param
        let state_width = 12;
        let field_size: NSUInteger = 8;
        assert_eq!(field_size as usize, size_of::<BaseElement>());
        let mem_per_hasher = state_width * field_size;
        let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
        command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
        command_encoder.set_compute_pipeline_state(&self.pipeline);
        command_encoder.set_buffer(0, Some(leaves), 0);
        command_encoder.set_buffer(1, Some(nodes), 0);
        command_encoder.dispatch_threads(self.grid_dim, self.threadgroup_dim);
        command_encoder.memory_barrier_with_resources(&[nodes]);
        command_encoder.end_encoding()
    }
}

pub struct GenMerkleNodesRowStage256 {
    num_leaves: usize,
    pipeline: metal::ComputePipelineState,
    threadgroup_dim: metal::MTLSize,
}

impl GenMerkleNodesRowStage256 {
    pub const HASHERS_PER_THREADGROUP: usize = 32;

    pub fn new(library: &metal::LibraryRef, num_leaves: usize, hash: HashFn) -> Self {
        use metal::MTLDataType::UInt;
        assert!(num_leaves.is_power_of_two());

        let kernel_call = match hash {
            HashFn::Rpo256 => "rpo_128",
            HashFn::Rpx256 => "rpx_128",
        };

        let constants = metal::FunctionConstantValues::new();
        constants.set_constant_value_at_index(void_ptr(&(num_leaves as u32)), UInt, 0);
        let kernel_name = format!("{}_gen_merkle_nodes_row_p18446744069414584321_fp", kernel_call);
        let func = library.get_function(&kernel_name, Some(constants)).unwrap();
        let pipeline = library.device().new_compute_pipeline_state_with_function(&func).unwrap();

        let threadgroup_dim =
            metal::MTLSize::new(Self::HASHERS_PER_THREADGROUP as NSUInteger, 1, 1);

        GenMerkleNodesRowStage256 { num_leaves, pipeline, threadgroup_dim }
    }

    pub fn encode(
        &self,
        command_buffer: &metal::CommandBufferRef,
        nodes: &metal::Buffer,
        row: u32,
    ) {
        assert_ne!(1, row, "use GenMerkleNodesFirstRowStage");
        let command_encoder = command_buffer.new_compute_command_encoder();
        #[cfg(debug_assertions)]
        command_encoder.set_label(&format!("merkle tree row={row}"));
        // TODO: use param
        let state_width = 12;
        let field_size: NSUInteger = 8;
        assert_eq!(field_size as usize, size_of::<BaseElement>());
        let mem_per_hasher = state_width * field_size;
        let hashers_per_tg = Self::HASHERS_PER_THREADGROUP as NSUInteger;
        command_encoder.set_threadgroup_memory_length(0, mem_per_hasher * hashers_per_tg * 2);
        command_encoder.set_compute_pipeline_state(&self.pipeline);
        command_encoder.set_buffer(0, Some(nodes), 0);
        command_encoder.set_bytes(1, size_of::<u32>() as NSUInteger, void_ptr(&row));
        let grid_dim = metal::MTLSize::new((self.num_leaves >> row).try_into().unwrap(), 1, 1);
        command_encoder.dispatch_threads(grid_dim, self.threadgroup_dim);
        command_encoder.memory_barrier_with_resources(&[nodes]);
        command_encoder.end_encoding()
    }
}