miden_ace_codegen/layout/
policy.rs1use super::{InputCounts, InputLayout, InputRegion, LayoutRegions, StarkVarIndices};
2use crate::{EXT_DEGREE, randomness};
3
4#[derive(Clone, Copy)]
5enum Alignment {
6 Unaligned = 1,
7 Word = 2,
8 DoubleWord = 4,
9 QuadWord = 8,
10}
11
12#[derive(Clone, Copy)]
13struct LayoutPolicy {
14 public_values: Alignment,
15 vlpi: Alignment,
16 vlpi_stride: usize,
17 randomness: Alignment,
18 main: Alignment,
19 aux: Alignment,
20 quotient: Alignment,
21 aux_bus_boundary: Alignment,
22 stark_vars: Alignment,
23 end_align: Option<Alignment>,
24}
25
26impl LayoutPolicy {
27 fn native() -> Self {
28 Self {
29 public_values: Alignment::Unaligned,
30 vlpi: Alignment::Unaligned,
31 vlpi_stride: 1,
32 randomness: Alignment::Unaligned,
33 main: Alignment::Unaligned,
34 aux: Alignment::Unaligned,
35 quotient: Alignment::Unaligned,
36 aux_bus_boundary: Alignment::Unaligned,
37 stark_vars: Alignment::Unaligned,
38 end_align: None,
39 }
40 }
41
42 fn masm() -> Self {
43 Self {
44 public_values: Alignment::QuadWord,
45 vlpi: Alignment::Word,
46 vlpi_stride: 2,
47 randomness: Alignment::Word,
48 main: Alignment::DoubleWord,
49 aux: Alignment::DoubleWord,
50 quotient: Alignment::DoubleWord,
51 aux_bus_boundary: Alignment::Word,
52 stark_vars: Alignment::Word,
53 end_align: Some(Alignment::Word),
54 }
55 }
56}
57
58struct LayoutBuilder {
59 offset: usize,
60}
61
62impl LayoutBuilder {
63 fn new() -> Self {
64 Self { offset: 0 }
65 }
66
67 fn align(&mut self, alignment: Alignment) {
68 self.offset = self.offset.next_multiple_of(alignment as usize);
69 }
70
71 fn alloc(&mut self, width: usize, alignment: Alignment) -> InputRegion {
72 self.align(alignment);
73 let region = InputRegion { offset: self.offset, width };
74 self.offset += width;
75 region
76 }
77}
78
79impl InputLayout {
80 pub fn new(counts: InputCounts) -> Self {
82 Self::build_with_policy(counts, LayoutPolicy::native(), false)
83 }
84
85 pub fn new_masm(counts: InputCounts) -> Self {
87 Self::build_with_policy(counts, LayoutPolicy::masm(), false)
88 }
89
90 pub fn new_multi_air(counts: InputCounts) -> Self {
92 Self::build_with_policy(counts, LayoutPolicy::native(), true)
93 }
94
95 pub fn new_masm_multi_air(counts: InputCounts) -> Self {
98 Self::build_with_policy(counts, LayoutPolicy::masm(), true)
99 }
100
101 fn build_with_policy(counts: InputCounts, policy: LayoutPolicy, is_multi_air: bool) -> Self {
102 const NUM_STARK_VARS_BASE: usize = 10;
107 let num_stark_vars = NUM_STARK_VARS_BASE + if is_multi_air { 8 } else { 0 };
108
109 let mut builder = LayoutBuilder::new();
110
111 let public_values = builder.alloc(counts.num_public, policy.public_values);
112 let vlpi_reductions = builder.alloc(counts.num_vlpi, policy.vlpi);
113 const NUM_RANDOMNESS_INPUTS: usize = 2;
115 let randomness = builder.alloc(NUM_RANDOMNESS_INPUTS, policy.randomness);
116 let (aux_rand_alpha, aux_rand_beta) = randomness::aux_rand_indices(randomness);
117 let main_curr = builder.alloc(counts.width, policy.main);
118 let aux_coord_width = counts.aux_width * EXT_DEGREE;
119 let aux_curr = builder.alloc(aux_coord_width, policy.aux);
120 let quotient_curr = builder.alloc(counts.num_quotient_chunks * EXT_DEGREE, policy.quotient);
121 let main_next = builder.alloc(counts.width, policy.main);
122 let aux_next = builder.alloc(aux_coord_width, policy.aux);
123 let quotient_next = builder.alloc(counts.num_quotient_chunks * EXT_DEGREE, policy.quotient);
124 let aux_bus_boundary = builder.alloc(counts.num_aux_boundary, policy.aux_bus_boundary);
125
126 let stark_vars = builder.alloc(num_stark_vars, policy.stark_vars);
127
128 let b = stark_vars.offset;
146 let alpha = b;
147 let z_pow_n = b + 1;
148 let z_k = b + 2;
149 let is_first = b + 3;
150 let is_last = b + 4;
151 let is_transition = b + 5;
152 let gamma = b + 6;
153 let weight0 = b + 7;
154 let f = b + 8;
155 let s0 = b + 9;
156 let multi_air_beta_core = is_multi_air.then_some(b + 10);
157 let multi_air_beta_chip = is_multi_air.then_some(b + 11);
158 let is_first_core = is_multi_air.then_some(b + 12);
159 let is_last_core = is_multi_air.then_some(b + 13);
160 let is_transition_core = is_multi_air.then_some(b + 14);
161 let is_first_chip = is_multi_air.then_some(b + 15);
162 let is_last_chip = is_multi_air.then_some(b + 16);
163 let is_transition_chip = is_multi_air.then_some(b + 17);
164
165 if let Some(end_align) = policy.end_align {
166 builder.align(end_align);
167 }
168
169 Self {
170 regions: LayoutRegions {
171 public_values,
172 vlpi_reductions,
173 randomness,
174 main_curr,
175 aux_curr,
176 quotient_curr,
177 main_next,
178 aux_next,
179 quotient_next,
180 aux_bus_boundary,
181 stark_vars,
182 },
183 aux_rand_alpha,
184 aux_rand_beta,
185 vlpi_stride: policy.vlpi_stride,
186 stark: StarkVarIndices {
187 alpha,
188 z_pow_n,
189 z_k,
190 is_first,
191 is_last,
192 is_transition,
193 gamma,
194 weight0,
195 f,
196 s0,
197 multi_air_beta_core,
198 multi_air_beta_chip,
199 is_first_core,
200 is_last_core,
201 is_transition_core,
202 is_first_chip,
203 is_last_chip,
204 is_transition_chip,
205 },
206 total_inputs: builder.offset,
207 counts,
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::super::{InputCounts, InputKey, InputLayout};
215
216 #[test]
217 fn masm_layout_vlpi_groups_use_word_stride() {
218 let counts = InputCounts {
219 width: 1,
220 aux_width: 1,
221 num_aux_boundary: 1,
222 num_public: 8,
223 num_vlpi: 4,
226 num_randomness: 2,
227 num_periodic: 0,
228 num_quotient_chunks: 1,
229 };
230 let layout = InputLayout::new_masm(counts);
231
232 let vlpi_base = layout.index(InputKey::VlpiReduction(0)).unwrap();
233 assert_eq!(layout.index(InputKey::VlpiReduction(0)), Some(vlpi_base));
234 assert_eq!(
235 layout.index(InputKey::VlpiReduction(1)),
236 Some(vlpi_base + 2),
237 "MASM VLPI groups should advance by a word-aligned stride"
238 );
239 }
240
241 #[test]
242 fn native_layout_vlpi_groups_use_unit_stride() {
243 let counts = InputCounts {
244 width: 1,
245 aux_width: 1,
246 num_aux_boundary: 1,
247 num_public: 8,
248 num_vlpi: 2,
249 num_randomness: 2,
250 num_periodic: 0,
251 num_quotient_chunks: 1,
252 };
253 let layout = InputLayout::new(counts);
254
255 let vlpi_base = layout.index(InputKey::VlpiReduction(0)).unwrap();
256 assert_eq!(
257 layout.index(InputKey::VlpiReduction(1)),
258 Some(vlpi_base + 1),
259 "Native VLPI groups should advance by unit stride"
260 );
261 }
262}