miden-gpu 0.6.0

GPU acceleration for the Miden VM prover
Documentation
#![cfg(all(target_arch = "aarch64", target_os = "macos"))]

extern crate core;

use std::time::Instant;

use miden_crypto::{Felt, FieldElement, Word, hash::rpx::Rpx256};
use miden_gpu::{
    HashFn,
    metal::{ColumnHasher, RowHasher, build_merkle_tree, utils::page_aligned_uninit_vector},
};

#[test]
fn gpu_rpx_from_columns() {
    let n = 8388608;
    let mut col0 = unsafe { page_aligned_uninit_vector(n) };
    let mut col1 = unsafe { page_aligned_uninit_vector(n) };
    let mut col2 = unsafe { page_aligned_uninit_vector(n) };
    let mut col3 = unsafe { page_aligned_uninit_vector(n) };
    let mut col4 = unsafe { page_aligned_uninit_vector(n) };
    let mut col5 = unsafe { page_aligned_uninit_vector(n) };
    let mut col6 = unsafe { page_aligned_uninit_vector(n) };
    let mut col7 = unsafe { page_aligned_uninit_vector(n) };
    col0.fill(Felt::ONE);
    col1.fill(Felt::ONE);
    col2.fill(Felt::ONE);
    col3.fill(Felt::ONE);
    col4.fill(Felt::ONE);
    col5.fill(Felt::ONE);
    col6.fill(Felt::ONE);
    col7.fill(Felt::ONE);
    let input_size = 8;
    let mut rpx = ColumnHasher::new(n, input_size, HashFn::Rpx256);

    let now = Instant::now();
    for i in 0..input_size {
        let col = match i % 8 {
            0 => &col0,
            1 => &col1,
            2 => &col2,
            3 => &col3,
            4 => &col4,
            5 => &col5,
            6 => &col6,
            7 => &col7,
            _ => unreachable!(),
        };

        rpx.update(col);
    }
    println!("Encode in {:?}", now.elapsed());

    let now = Instant::now();
    let hashes = pollster::block_on(rpx.finish());
    println!("Run in {:?}", now.elapsed());
    println!("Hashes: {:?}", hashes[0]);
    println!("Hashes1: {:?}", hashes[1]);
    println!("Hashes2: {:?}", hashes[1024]);

    let now = Instant::now();
    let merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpx256));
    println!("Root: {:?}", merkle_tree[0]);
    println!("Root1: {:?}", merkle_tree[1]);
    println!("Root1: {:?}", merkle_tree[2]);
    println!("Root1: {:?}", merkle_tree[3]);
    println!("Root1: {:?}", merkle_tree[merkle_tree.len() / 2 + 1]);
    println!("Merkle tree in {:?}", now.elapsed());
}

#[test]
fn gpu_rpx_from_rows() {
    let n = 8388608;
    let mut rows: Vec<[Felt; 8]> = unsafe { page_aligned_uninit_vector(n) };
    rows.fill([Felt::ONE; 8]);
    let mut rpx = RowHasher::new(n, 8, HashFn::Rpx256);

    let now = Instant::now();
    rpx.update(&rows);
    println!("Encode in {:?}", now.elapsed());

    let now = Instant::now();
    let hashes = pollster::block_on(rpx.finish());
    println!("Run in {:?}", now.elapsed());
    println!("Hashes (row): {:?}", hashes[0]);
    println!("Hashes1 (row): {:?}", hashes[1]);
    println!("Hashes2 (row): {:?}", hashes[1024]);
    hashes
        .iter()
        .zip(&hashes[1..])
        .enumerate()
        .for_each(|(i, (a, b))| assert_eq!(a, b, "mismatch at {i}"));

    let now = Instant::now();
    let merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpx256));
    println!("Root (row): {:?}", merkle_tree[0]);
    println!("Root1 (row): {:?}", merkle_tree[1]);
    println!("Root1 (row): {:?}", merkle_tree[2]);
    println!("Root1 (row): {:?}", merkle_tree[3]);
    println!("Root1 (row): {:?}", merkle_tree[merkle_tree.len() / 2 + 1]);
    println!("Merkle tree in {:?}", now.elapsed());
}

#[test]
fn compare_cpu_gpu_rpx() {
    let n = 32768;
    let mut rows: Vec<[Felt; 8]> = unsafe { page_aligned_uninit_vector(n) };
    rows.fill([Felt::ONE; 8]);
    let mut rpx = RowHasher::new(n, 8, HashFn::Rpx256);
    rpx.update(&rows);

    let hashes = pollster::block_on(rpx.finish());

    let cpu_elements = std::iter::repeat_n(Felt::ONE, 8).collect::<Vec<Felt>>();

    // hash using cpu implementation
    let cpu_rpx_digest = Rpx256::hash_elements(&cpu_elements);

    let cpu_fields: &[Felt] = cpu_rpx_digest.as_elements();

    let gpu_fields = hashes[0];

    gpu_fields.iter().zip(cpu_fields).for_each(|(gpu_field, cpu_field)| {
        assert_eq!(gpu_field.to_string(), cpu_field.to_string());
    });

    assert_eq!(hashes[0], hashes[1]);
    assert_eq!(hashes[1], hashes[2]);

    // Generate a Merkle tree using gpu
    let gpu_merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpx256));

    // Create the same input as for the gpu
    let cpu_merkle_digests = std::iter::repeat_n(cpu_rpx_digest, n).collect::<Vec<Word>>();
    // Generate a Merkle tree using cpu
    let cpu_merkle_tree = winter_crypto::MerkleTree::<Rpx256>::new(cpu_merkle_digests)
        .expect("failed to construct Merkle tree");

    // Check if roots are equal
    // The 0th element is always 0, 1st element is the root, 2 & 3 elements are children of root, etc.
    let gpu_root = &gpu_merkle_tree[1];
    let cpu_root = cpu_merkle_tree.root().as_elements();
    gpu_root.iter().zip(cpu_root.iter()).for_each(|(gr, cr)| {
        assert_eq!(gr, cr);
    });
}