use super::InputKey;
use crate::EXT_DEGREE;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct InputRegion {
pub offset: usize,
pub width: usize,
}
impl InputRegion {
pub fn index(&self, local: usize) -> Option<usize> {
(local < self.width).then(|| self.offset + local)
}
}
#[derive(Debug, Clone, Copy)]
pub struct InputCounts {
pub width: usize,
pub aux_width: usize,
pub num_aux_boundary: usize,
pub num_public: usize,
pub num_vlpi: usize,
pub num_randomness: usize,
pub num_periodic: usize,
pub num_quotient_chunks: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct LayoutRegions {
pub public_values: InputRegion,
pub vlpi_reductions: InputRegion,
pub randomness: InputRegion,
pub main_curr: InputRegion,
pub aux_curr: InputRegion,
pub quotient_curr: InputRegion,
pub main_next: InputRegion,
pub aux_next: InputRegion,
pub quotient_next: InputRegion,
pub aux_bus_boundary: InputRegion,
pub stark_vars: InputRegion,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct StarkVarIndices {
pub alpha: usize,
pub z_pow_n: usize,
pub z_k: usize,
pub is_first: usize,
pub is_last: usize,
pub is_transition: usize,
pub gamma: usize,
pub weight0: usize,
pub f: usize,
pub s0: usize,
}
#[derive(Debug, Clone)]
pub struct InputLayout {
pub(crate) regions: LayoutRegions,
pub(crate) aux_rand_alpha: usize,
pub(crate) aux_rand_beta: usize,
pub(crate) vlpi_stride: usize,
pub(crate) stark: StarkVarIndices,
pub total_inputs: usize,
pub counts: InputCounts,
}
impl InputLayout {
pub(crate) fn mapper(&self) -> super::InputKeyMapper<'_> {
super::InputKeyMapper { layout: self }
}
pub fn index(&self, key: InputKey) -> Option<usize> {
self.mapper().index_of(key)
}
pub(crate) fn validate(&self) {
let mut max_end = 0usize;
for region in [
self.regions.public_values,
self.regions.vlpi_reductions,
self.regions.randomness,
self.regions.main_curr,
self.regions.aux_curr,
self.regions.quotient_curr,
self.regions.main_next,
self.regions.aux_next,
self.regions.quotient_next,
self.regions.aux_bus_boundary,
self.regions.stark_vars,
] {
max_end = max_end.max(region.offset.saturating_add(region.width));
}
assert!(max_end <= self.total_inputs, "regions exceed total_inputs");
let aux_coord_width = self.counts.aux_width * EXT_DEGREE;
assert_eq!(self.regions.aux_curr.width, aux_coord_width, "aux_curr width mismatch");
assert_eq!(self.regions.aux_next.width, aux_coord_width, "aux_next width mismatch");
let quotient_width = self.counts.num_quotient_chunks * EXT_DEGREE;
assert_eq!(
self.regions.quotient_curr.width, quotient_width,
"quotient_curr width mismatch"
);
assert_eq!(
self.regions.quotient_next.width, quotient_width,
"quotient_next width mismatch"
);
assert_eq!(
self.regions.aux_bus_boundary.width, self.counts.num_aux_boundary,
"aux bus boundary width mismatch"
);
let stark_start = self.regions.stark_vars.offset;
let stark_end = stark_start + self.regions.stark_vars.width;
let check = |name: &str, idx: usize| {
assert!(idx >= stark_start && idx < stark_end, "stark var {name} out of range");
};
check("alpha", self.stark.alpha);
check("z_pow_n", self.stark.z_pow_n);
check("z_k", self.stark.z_k);
check("is_first", self.stark.is_first);
check("is_last", self.stark.is_last);
check("is_transition", self.stark.is_transition);
check("gamma", self.stark.gamma);
check("weight0", self.stark.weight0);
check("f", self.stark.f);
check("s0", self.stark.s0);
let rand_start = self.regions.randomness.offset;
let rand_end = rand_start + self.regions.randomness.width;
assert!(
self.aux_rand_alpha >= rand_start && self.aux_rand_alpha < rand_end,
"aux_rand_alpha out of randomness region"
);
assert!(
self.aux_rand_beta >= rand_start && self.aux_rand_beta < rand_end,
"aux_rand_beta out of randomness region"
);
}
}