use std::rc::Rc;
use metal::CommandBufferRef;
use once_cell::sync::Lazy;
use winter_math::fields::f64::BaseElement;
use crate::{
HashFn,
metal::{
stage::{
AbsorbColumnsStage256, AbsorbRowsStage256, GenMerkleNodesFirstRowStage256,
GenMerkleNodesRowStage256,
},
utils::{buffer_mut_no_copy, buffer_no_copy, is_page_aligned, page_aligned_uninit_vector},
},
};
const LIBRARY_DATA: &[u8] = include_bytes!("shaders/shaders.metallib");
const RATE: usize = 8;
static PLANNER: Lazy<Planner> = Lazy::new(Planner::default);
pub fn get_planner() -> &'static Planner {
&PLANNER
}
pub struct Planner {
pub library: metal::Library,
pub command_queue: Rc<metal::CommandQueue>,
}
unsafe impl Send for Planner {}
unsafe impl Sync for Planner {}
impl Planner {
pub fn new(device: &metal::DeviceRef) -> Self {
let library = device.new_library_with_data(LIBRARY_DATA).unwrap();
let command_queue = Rc::new(device.new_command_queue());
Self { library, command_queue }
}
}
impl Default for Planner {
fn default() -> Self {
Planner::new(&metal::Device::system_default().expect("no device found"))
}
}
pub struct ColumnHasher<'a> {
num_rows: usize,
stage: AbsorbColumnsStage256,
state: Vec<&'a [BaseElement]>,
command_buffer: Option<&'a CommandBufferRef>,
hash_fn: HashFn,
}
impl<'a> ColumnHasher<'a> {
pub fn new(num_rows: usize, num_columns: usize, hash_fn: HashFn) -> Self {
Self {
num_rows,
stage: AbsorbColumnsStage256::new(
&get_planner().library,
num_rows,
num_columns,
hash_fn,
),
state: Vec::new(),
command_buffer: None,
hash_fn,
}
}
pub fn update(&mut self, col: &'a [BaseElement]) {
assert!(is_page_aligned(col));
self.state.push(col);
if self.state.len() % RATE == 0 {
let command_buffer = get_planner().command_queue.new_command_buffer();
#[cfg(debug_assertions)]
command_buffer.set_label("update columns");
let state = &core::mem::take(&mut self.state)[0..8];
self.stage.encode(command_buffer, state.try_into().unwrap());
command_buffer.commit();
self.command_buffer = Some(command_buffer);
}
}
pub async fn finish(mut self) -> Vec<[BaseElement; 4]> {
if let Some(cb) = self.command_buffer {
cb.wait_until_completed()
} else {
}
if self.state.is_empty() || self.hash_fn == HashFn::Rpx256 {
return self.stage.digests;
}
let mut ones = unsafe { page_aligned_uninit_vector(self.num_rows) };
ones.fill(BaseElement::from(1u32));
self.state.push(&ones);
let mut zeros: Vec<BaseElement>;
if self.state.len() != RATE {
zeros = unsafe { page_aligned_uninit_vector(self.num_rows) };
zeros.fill(BaseElement::from(0u32));
while self.state.len() != 8 {
self.state.push(&zeros);
}
}
let planner = get_planner();
let command_buffer = planner.command_queue.new_command_buffer();
let state = &self.state[0..8];
self.stage.encode(command_buffer, state.try_into().unwrap());
command_buffer.commit();
command_buffer.wait_until_completed();
self.stage.digests
}
}
pub struct RowHasher<'a> {
stage: AbsorbRowsStage256,
command_buffer: Option<&'a CommandBufferRef>,
}
impl<'a> RowHasher<'a> {
pub fn new(row_size: usize, num_columns: usize, hash_fn: HashFn) -> Self {
Self {
stage: AbsorbRowsStage256::new(&get_planner().library, row_size, num_columns, hash_fn),
command_buffer: None,
}
}
pub fn update(&mut self, rows: &'a [[BaseElement; RATE]]) {
assert!(is_page_aligned(rows));
let planner = get_planner();
let command_buffer = planner.command_queue.new_command_buffer();
#[cfg(debug_assertions)]
command_buffer.set_label("update rows");
self.stage.encode(command_buffer, rows);
command_buffer.commit();
self.command_buffer = Some(command_buffer);
}
pub async fn finish(&self) -> Vec<[BaseElement; 4]> {
if let Some(cb) = self.command_buffer {
cb.wait_until_completed();
self.stage.digests.clone()
} else {
panic!()
}
}
}
pub async fn build_merkle_tree(
leaves: &[[BaseElement; 4]],
hash_fn: HashFn,
) -> Vec<[BaseElement; 4]> {
assert!(is_page_aligned(leaves));
let planner = get_planner();
let num_leaves = leaves.len();
let leaves_buffer = buffer_no_copy(planner.library.device(), leaves);
let mut nodes = unsafe { page_aligned_uninit_vector(num_leaves) };
nodes.fill([BaseElement::from(0u32); 4]);
let nodes_buffer = buffer_mut_no_copy(planner.library.device(), &mut nodes);
let first_row_stage =
GenMerkleNodesFirstRowStage256::new(&planner.library, num_leaves, hash_fn);
let nth_row_stage = GenMerkleNodesRowStage256::new(&planner.library, num_leaves, hash_fn);
let command_buffer = planner.command_queue.new_command_buffer();
#[cfg(debug_assertions)]
command_buffer.set_label(&format!("{} Merkle tree", hash_fn));
first_row_stage.encode(command_buffer, &leaves_buffer, &nodes_buffer);
for row in 2..=num_leaves.ilog2() {
nth_row_stage.encode(command_buffer, &nodes_buffer, row);
}
command_buffer.commit();
command_buffer.wait_until_completed();
nodes
}