#![cfg(all(target_arch = "aarch64", target_os = "macos"))]
use std::time::Instant;
use miden_crypto::{Felt, FieldElement, Word, hash::rpo::Rpo256};
use miden_gpu::{
HashFn,
metal::{ColumnHasher, RowHasher, build_merkle_tree, utils::page_aligned_uninit_vector},
};
#[test]
fn gpu_rpo_from_columns() {
let n = 8388608;
let mut col0 = unsafe { page_aligned_uninit_vector(n) };
let mut col1 = unsafe { page_aligned_uninit_vector(n) };
let mut col2 = unsafe { page_aligned_uninit_vector(n) };
let mut col3 = unsafe { page_aligned_uninit_vector(n) };
let mut col4 = unsafe { page_aligned_uninit_vector(n) };
let mut col5 = unsafe { page_aligned_uninit_vector(n) };
let mut col6 = unsafe { page_aligned_uninit_vector(n) };
let mut col7 = unsafe { page_aligned_uninit_vector(n) };
col0.fill(Felt::ONE);
col1.fill(Felt::ONE);
col2.fill(Felt::ONE);
col3.fill(Felt::ONE);
col4.fill(Felt::ONE);
col5.fill(Felt::ONE);
col6.fill(Felt::ONE);
col7.fill(Felt::ONE);
let input_size = 8;
let mut rpo = ColumnHasher::new(n, input_size, HashFn::Rpo256);
let now = Instant::now();
for i in 0..input_size {
let col = match i % 8 {
0 => &col0,
1 => &col1,
2 => &col2,
3 => &col3,
4 => &col4,
5 => &col5,
6 => &col6,
7 => &col7,
_ => unreachable!(),
};
rpo.update(col);
}
println!("Encode in {:?}", now.elapsed());
let now = Instant::now();
let hashes = pollster::block_on(rpo.finish());
println!("Run in {:?}", now.elapsed());
println!("Hashes: {:?}", hashes[0]);
println!("Hashes1: {:?}", hashes[1]);
println!("Hashes2: {:?}", hashes[1024]);
let now = Instant::now();
let merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpo256));
println!("Root: {:?}", merkle_tree[0]);
println!("Root1: {:?}", merkle_tree[1]);
println!("Root1: {:?}", merkle_tree[2]);
println!("Root1: {:?}", merkle_tree[3]);
println!("Root1: {:?}", merkle_tree[merkle_tree.len() / 2 + 1]);
println!("Merkle tree in {:?}", now.elapsed());
}
#[test]
fn gpu_rpo_from_rows() {
let n = 8388608;
let mut rows: Vec<[Felt; 8]> = unsafe { page_aligned_uninit_vector(n) };
rows.fill([Felt::ONE; 8]);
let mut rpo = RowHasher::new(n, 8, HashFn::Rpo256);
let now = Instant::now();
rpo.update(&rows);
println!("Encode in {:?}", now.elapsed());
let now = Instant::now();
let hashes = pollster::block_on(rpo.finish());
println!("Run in {:?}", now.elapsed());
println!("Hashes (row): {:?}", hashes[0]);
println!("Hashes1 (row): {:?}", hashes[1]);
println!("Hashes2 (row): {:?}", hashes[1024]);
hashes
.iter()
.zip(&hashes[1..])
.enumerate()
.for_each(|(i, (a, b))| assert_eq!(a, b, "mismatch at {i}"));
let now = Instant::now();
let merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpo256));
println!("Root (row): {:?}", merkle_tree[0]);
println!("Root1 (row): {:?}", merkle_tree[1]);
println!("Root1 (row): {:?}", merkle_tree[2]);
println!("Root1 (row): {:?}", merkle_tree[3]);
println!("Root1 (row): {:?}", merkle_tree[merkle_tree.len() / 2 + 1]);
println!("Merkle tree in {:?}", now.elapsed());
}
#[test]
fn compare_cpu_gpu_rpo() {
let n = 32768;
let mut rows: Vec<[Felt; 8]> = unsafe { page_aligned_uninit_vector(n) };
rows.fill([Felt::ONE; 8]);
let mut rpo = RowHasher::new(n, 8, HashFn::Rpo256);
rpo.update(&rows);
let hashes = pollster::block_on(rpo.finish());
let cpu_elements = std::iter::repeat_n(Felt::ONE, 8).collect::<Vec<Felt>>();
let cpu_rpo_digest = Rpo256::hash_elements(&cpu_elements);
let cpu_fields: &[Felt] = cpu_rpo_digest.as_elements();
let gpu_fields = hashes[0];
gpu_fields.iter().zip(cpu_fields).for_each(|(gpu_field, cpu_field)| {
assert_eq!(gpu_field.to_string(), cpu_field.to_string());
});
assert_eq!(hashes[0], hashes[1]);
assert_eq!(hashes[1], hashes[2]);
let gpu_merkle_tree = pollster::block_on(build_merkle_tree(&hashes, HashFn::Rpo256));
let cpu_merkle_digests = std::iter::repeat_n(cpu_rpo_digest, n).collect::<Vec<Word>>();
let cpu_merkle_tree = winter_crypto::MerkleTree::<Rpo256>::new(cpu_merkle_digests)
.expect("failed to construct trace Merkle tree");
let gpu_root = &gpu_merkle_tree[1];
let cpu_root = cpu_merkle_tree.root().as_elements();
gpu_root.iter().zip(cpu_root.iter()).for_each(|(gr, cr)| {
assert_eq!(gr, cr);
})
}