use super::{InputCounts, InputLayout, InputRegion, LayoutRegions, StarkVarIndices};
use crate::{EXT_DEGREE, randomness};
#[derive(Clone, Copy)]
enum Alignment {
Unaligned = 1,
Word = 2,
DoubleWord = 4,
QuadWord = 8,
}
#[derive(Clone, Copy)]
struct LayoutPolicy {
public_values: Alignment,
vlpi: Alignment,
vlpi_stride: usize,
randomness: Alignment,
main: Alignment,
aux: Alignment,
quotient: Alignment,
aux_bus_boundary: Alignment,
stark_vars: Alignment,
end_align: Option<Alignment>,
}
impl LayoutPolicy {
fn native() -> Self {
Self {
public_values: Alignment::Unaligned,
vlpi: Alignment::Unaligned,
vlpi_stride: 1,
randomness: Alignment::Unaligned,
main: Alignment::Unaligned,
aux: Alignment::Unaligned,
quotient: Alignment::Unaligned,
aux_bus_boundary: Alignment::Unaligned,
stark_vars: Alignment::Unaligned,
end_align: None,
}
}
fn masm() -> Self {
Self {
public_values: Alignment::QuadWord,
vlpi: Alignment::Word,
vlpi_stride: 2,
randomness: Alignment::Word,
main: Alignment::DoubleWord,
aux: Alignment::DoubleWord,
quotient: Alignment::DoubleWord,
aux_bus_boundary: Alignment::Word,
stark_vars: Alignment::Word,
end_align: Some(Alignment::Word),
}
}
}
struct LayoutBuilder {
offset: usize,
}
impl LayoutBuilder {
fn new() -> Self {
Self { offset: 0 }
}
fn align(&mut self, alignment: Alignment) {
self.offset = self.offset.next_multiple_of(alignment as usize);
}
fn alloc(&mut self, width: usize, alignment: Alignment) -> InputRegion {
self.align(alignment);
let region = InputRegion { offset: self.offset, width };
self.offset += width;
region
}
}
impl InputLayout {
pub(crate) fn new(counts: InputCounts) -> Self {
Self::build_with_policy(counts, LayoutPolicy::native())
}
pub(crate) fn new_masm(counts: InputCounts) -> Self {
Self::build_with_policy(counts, LayoutPolicy::masm())
}
fn build_with_policy(counts: InputCounts, policy: LayoutPolicy) -> Self {
const NUM_STARK_VARS: usize = 10;
let mut builder = LayoutBuilder::new();
let public_values = builder.alloc(counts.num_public, policy.public_values);
let vlpi_reductions = builder.alloc(counts.num_vlpi, policy.vlpi);
const NUM_RANDOMNESS_INPUTS: usize = 2;
let randomness = builder.alloc(NUM_RANDOMNESS_INPUTS, policy.randomness);
let (aux_rand_alpha, aux_rand_beta) = randomness::aux_rand_indices(randomness);
let main_curr = builder.alloc(counts.width, policy.main);
let aux_coord_width = counts.aux_width * EXT_DEGREE;
let aux_curr = builder.alloc(aux_coord_width, policy.aux);
let quotient_curr = builder.alloc(counts.num_quotient_chunks * EXT_DEGREE, policy.quotient);
let main_next = builder.alloc(counts.width, policy.main);
let aux_next = builder.alloc(aux_coord_width, policy.aux);
let quotient_next = builder.alloc(counts.num_quotient_chunks * EXT_DEGREE, policy.quotient);
let aux_bus_boundary = builder.alloc(counts.num_aux_boundary, policy.aux_bus_boundary);
let stark_vars = builder.alloc(NUM_STARK_VARS, policy.stark_vars);
let b = stark_vars.offset;
let alpha = b;
let z_pow_n = b + 1;
let z_k = b + 2;
let is_first = b + 3;
let is_last = b + 4;
let is_transition = b + 5;
let gamma = b + 6;
let weight0 = b + 7;
let f = b + 8;
let s0 = b + 9;
if let Some(end_align) = policy.end_align {
builder.align(end_align);
}
Self {
regions: LayoutRegions {
public_values,
vlpi_reductions,
randomness,
main_curr,
aux_curr,
quotient_curr,
main_next,
aux_next,
quotient_next,
aux_bus_boundary,
stark_vars,
},
aux_rand_alpha,
aux_rand_beta,
vlpi_stride: policy.vlpi_stride,
stark: StarkVarIndices {
alpha,
z_pow_n,
z_k,
is_first,
is_last,
is_transition,
gamma,
weight0,
f,
s0,
},
total_inputs: builder.offset,
counts,
}
}
}
#[cfg(test)]
mod tests {
use super::super::{InputCounts, InputKey, InputLayout};
#[test]
fn masm_layout_vlpi_groups_use_word_stride() {
let counts = InputCounts {
width: 1,
aux_width: 1,
num_aux_boundary: 1,
num_public: 8,
num_vlpi: 4,
num_randomness: 2,
num_periodic: 0,
num_quotient_chunks: 1,
};
let layout = InputLayout::new_masm(counts);
let vlpi_base = layout.index(InputKey::VlpiReduction(0)).unwrap();
assert_eq!(layout.index(InputKey::VlpiReduction(0)), Some(vlpi_base));
assert_eq!(
layout.index(InputKey::VlpiReduction(1)),
Some(vlpi_base + 2),
"MASM VLPI groups should advance by a word-aligned stride"
);
}
#[test]
fn native_layout_vlpi_groups_use_unit_stride() {
let counts = InputCounts {
width: 1,
aux_width: 1,
num_aux_boundary: 1,
num_public: 8,
num_vlpi: 2,
num_randomness: 2,
num_periodic: 0,
num_quotient_chunks: 1,
};
let layout = InputLayout::new(counts);
let vlpi_base = layout.index(InputKey::VlpiReduction(0)).unwrap();
assert_eq!(
layout.index(InputKey::VlpiReduction(1)),
Some(vlpi_base + 1),
"Native VLPI groups should advance by unit stride"
);
}
}