miden-gpu 0.6.0

GPU acceleration for the Miden VM prover
Documentation
use std::rc::Rc;

use metal::CommandBufferRef;
use once_cell::sync::Lazy;
use winter_math::fields::f64::BaseElement;

use crate::{
    HashFn,
    metal::{
        stage::{
            AbsorbColumnsStage256, AbsorbRowsStage256, GenMerkleNodesFirstRowStage256,
            GenMerkleNodesRowStage256,
        },
        utils::{buffer_mut_no_copy, buffer_no_copy, is_page_aligned, page_aligned_uninit_vector},
    },
};

const LIBRARY_DATA: &[u8] = include_bytes!("shaders/shaders.metallib");
const RATE: usize = 8;

static PLANNER: Lazy<Planner> = Lazy::new(Planner::default);

pub fn get_planner() -> &'static Planner {
    &PLANNER
}

pub struct Planner {
    pub library: metal::Library,
    pub command_queue: Rc<metal::CommandQueue>,
}

// TODO: unsafe
unsafe impl Send for Planner {}
unsafe impl Sync for Planner {}

impl Planner {
    pub fn new(device: &metal::DeviceRef) -> Self {
        let library = device.new_library_with_data(LIBRARY_DATA).unwrap();
        let command_queue = Rc::new(device.new_command_queue());
        Self { library, command_queue }
    }
}

impl Default for Planner {
    fn default() -> Self {
        Planner::new(&metal::Device::system_default().expect("no device found"))
    }
}

// COLUMN HASHER
// ================================================================================================

/// A hasher used for hashing segments of a matrix in column-major form using the Metal API.
///
/// # Usage
///
/// Instantiate a new `ColumnHasher` with a given size of low degree extension domain, a number of
/// columns and a hash function.
///
/// ```no_run
/// use miden_gpu::HashFn;
/// use miden_gpu::metal::ColumnHasher;
///
/// let row_size = 32768; // lde_domain_size
/// let num_columns = 8;
/// let hash_fn = HashFn::Rpo256;
/// let mut column_hasher = ColumnHasher::new(row_size, num_columns, hash_fn);
/// ```
///
/// Call an `update` function with a column from the matrix, adding it to the state. Once the state
/// is filled it executes the compute shader and computes the hash.
/// ```ignore
///  column_hasher.update(column);
/// ```
///
/// Call a `finish` function that awaits for `update` completion, and returns the result (digest).
/// ```ignore
///  column_hasher.finish().await;
/// ```
///
pub struct ColumnHasher<'a> {
    num_rows: usize,
    stage: AbsorbColumnsStage256,
    state: Vec<&'a [BaseElement]>,
    command_buffer: Option<&'a CommandBufferRef>,
    hash_fn: HashFn,
}

impl<'a> ColumnHasher<'a> {
    /// Constructs a new ColumnHasher with a given lde domain size, a number of base field columns
    /// and a hash function.
    pub fn new(num_rows: usize, num_columns: usize, hash_fn: HashFn) -> Self {
        Self {
            num_rows,
            stage: AbsorbColumnsStage256::new(
                &get_planner().library,
                num_rows,
                num_columns,
                hash_fn,
            ),
            state: Vec::new(),
            command_buffer: None,
            hash_fn,
        }
    }

    /// Takes a column of a matrix and absorbs it into the `state`.
    ///
    /// Once the state is complete (the length of the state vector reaches the set constant), the
    /// compute shader is executed and the hash is computed.
    ///
    /// The columns must be in column-major form of a matrix. The state is built from columns and
    /// computed in row-wise order, meaning the state(i)[0,1,2...7] would be computed by applying
    /// the hash to the [col0\[i\],col1\[i\],col2\[i\]...col7\[i\]].
    ///
    /// # Panics
    /// - The rows are not page aligned.
    /// - Incorrect size of the column.
    pub fn update(&mut self, col: &'a [BaseElement]) {
        assert!(is_page_aligned(col));
        self.state.push(col);
        if self.state.len() % RATE == 0 {
            let command_buffer = get_planner().command_queue.new_command_buffer();
            #[cfg(debug_assertions)]
            command_buffer.set_label("update columns");
            let state = &core::mem::take(&mut self.state)[0..8];
            self.stage.encode(command_buffer, state.try_into().unwrap());
            command_buffer.commit();
            self.command_buffer = Some(command_buffer);
        }
    }

    /// Waits for the update stage to finish, and returns the result.
    pub async fn finish(mut self) -> Vec<[BaseElement; 4]> {
        // Wait for all update stages to finish. Stages run sequentially so waiting for
        // the last stage to finish is sufficient
        if let Some(cb) = self.command_buffer {
            cb.wait_until_completed()
        } else {
            // TODO: error? "The zero-length input is not allowed."
        }

        // return if no padding is required
        // TODO: check self.requires_padding == false
        if self.state.is_empty() || self.hash_fn == HashFn::Rpx256 {
            return self.stage.digests;
        }

        // padding rule: "a single 1 element followed by as many zeros as are necessary
        // to make the input length a multiple of the rate." - https://eprint.iacr.org/2022/1577.pdf
        // TODO: check self.requires_padding == true
        let mut ones = unsafe { page_aligned_uninit_vector(self.num_rows) };
        ones.fill(BaseElement::from(1u32));
        self.state.push(&ones);

        let mut zeros: Vec<BaseElement>;
        if self.state.len() != RATE {
            // only access memory for zeros if needed
            zeros = unsafe { page_aligned_uninit_vector(self.num_rows) };
            zeros.fill(BaseElement::from(0u32));
            while self.state.len() != 8 {
                self.state.push(&zeros);
            }
        }

        let planner = get_planner();
        let command_buffer = planner.command_queue.new_command_buffer();
        let state = &self.state[0..8];
        self.stage.encode(command_buffer, state.try_into().unwrap());
        command_buffer.commit();
        command_buffer.wait_until_completed();
        self.stage.digests
    }
}

// ROW HASHER
// ================================================================================================

/// A hasher used for hashing segments of a row-major matrix using the Metal API.
///
/// A segment is a set of columns of the matrix stored in row-major form.
///
/// # Usage
///
/// Instantiate a new `RowHasher` with given size of a low degree extension domain, a number of
/// columns and a hash function.
///
/// ```rust
/// use miden_gpu::HashFn;
/// use miden_gpu::metal::RowHasher;
///
/// let row_size = 32768; // lde_domain_size
/// let num_columns = 8;
/// let hash_fn = HashFn::Rpo256;
/// let mut row_hasher = RowHasher::new(row_size, num_columns, hash_fn);
/// ```
///
/// Call an `update` function with a segment from the matrix in row-major form, that will be added
/// to the hashing process.
/// ```ignore
///  row_hasher.update(segment);
/// ```
///
/// Call the `finish` function that awaits for `update` completion, and returns the result (digest).
/// ```ignore
///  row_hasher.finish().await;
/// ```
///
pub struct RowHasher<'a> {
    stage: AbsorbRowsStage256,
    command_buffer: Option<&'a CommandBufferRef>,
}

impl<'a> RowHasher<'a> {
    /// Constructs a new RowHasher with a given lde domain size, a number of columns and a hash
    /// function.
    pub fn new(row_size: usize, num_columns: usize, hash_fn: HashFn) -> Self {
        Self {
            stage: AbsorbRowsStage256::new(&get_planner().library, row_size, num_columns, hash_fn),
            command_buffer: None,
        }
    }

    /// Takes a segment of the matrix and adds it to the hashing process.
    ///
    /// A segment is a set of columns of a matrix stored in row-major form.
    ///
    /// # Panics
    /// - The rows are not page aligned.
    pub fn update(&mut self, rows: &'a [[BaseElement; RATE]]) {
        assert!(is_page_aligned(rows));
        let planner = get_planner();
        let command_buffer = planner.command_queue.new_command_buffer();
        #[cfg(debug_assertions)]
        command_buffer.set_label("update rows");
        self.stage.encode(command_buffer, rows);
        command_buffer.commit();
        self.command_buffer = Some(command_buffer);
    }

    /// Waits for the update stage to finish, and returns the result.
    pub async fn finish(&self) -> Vec<[BaseElement; 4]> {
        if let Some(cb) = self.command_buffer {
            cb.wait_until_completed();
            self.stage.digests.clone()
        } else {
            // TODO: error? "The zero-length input is not allowed."
            panic!()
        }
    }
}

// MERKLE TREE BUILDER
// ================================================================================================

/// Constructs a Merkle tree from the provided leaves using a specified hash function.
///
/// Returns the Merkle Tree as a Vec. The zeroth element is always 0. First element is the root,
/// second and third are children of the root, and so on.
pub async fn build_merkle_tree(
    leaves: &[[BaseElement; 4]],
    hash_fn: HashFn,
) -> Vec<[BaseElement; 4]> {
    assert!(is_page_aligned(leaves));
    let planner = get_planner();
    let num_leaves = leaves.len();
    let leaves_buffer = buffer_no_copy(planner.library.device(), leaves);
    let mut nodes = unsafe { page_aligned_uninit_vector(num_leaves) };
    // TODO: might be unnecessary. only zero first item?
    nodes.fill([BaseElement::from(0u32); 4]);
    let nodes_buffer = buffer_mut_no_copy(planner.library.device(), &mut nodes);

    let first_row_stage =
        GenMerkleNodesFirstRowStage256::new(&planner.library, num_leaves, hash_fn);
    let nth_row_stage = GenMerkleNodesRowStage256::new(&planner.library, num_leaves, hash_fn);

    let command_buffer = planner.command_queue.new_command_buffer();
    #[cfg(debug_assertions)]
    command_buffer.set_label(&format!("{} Merkle tree", hash_fn));
    first_row_stage.encode(command_buffer, &leaves_buffer, &nodes_buffer);
    for row in 2..=num_leaves.ilog2() {
        nth_row_stage.encode(command_buffer, &nodes_buffer, row);
    }
    command_buffer.commit();
    command_buffer.wait_until_completed();

    nodes
}