hekate_program/lib.rs
1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18#![cfg_attr(not(feature = "std"), no_std)]
19
20extern crate alloc;
21extern crate core;
22
23use alloc::string::{String, ToString};
24use alloc::vec::Vec;
25use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
26use core::marker::PhantomData;
27use expander::VirtualExpander;
28use hekate_core::errors;
29use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
30use hekate_math::{Flat, HardwareField, TowerField};
31use permutation::PermutationCheckSpec;
32
33pub mod chiplet;
34pub mod constraint;
35pub mod expander;
36pub mod permutation;
37pub mod schema;
38
39// =================================================================
40// AIR TRAIT:
41// Core Algebraic Intermediate Representation
42// =================================================================
43
44/// Defines the algebraic structure, trace
45/// layout, and constraints of an AIR table.
46///
47/// Implemented by both standalone
48/// programs and independent chiplets.
49pub trait Air<F: TowerField>: Sized + Clone + Sync {
50 fn name(&self) -> String {
51 "HekateAir".to_string()
52 }
53
54 fn num_columns(&self) -> usize {
55 self.virtual_column_layout().len()
56 }
57
58 /// Flat expansion of `constraint_ast()`.
59 fn constraints(&self) -> Vec<Constraint<F>> {
60 self.constraint_ast().to_constraints()
61 }
62
63 /// Returns the list of boundary constraints. Each
64 /// constraint ties a specific trace cell to a public
65 /// input value. By default, returns an empty list.
66 fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
67 Vec::new()
68 }
69
70 /// Returns the physical layout
71 /// of the columns in the trace.
72 ///
73 /// This describes the storage type
74 /// (Bit, B8, B32, etc.) of each column.
75 fn column_layout(&self) -> &[ColumnType];
76
77 /// Returns the virtual layout of the columns
78 /// (after unpacking). Defaults to the expander's
79 /// layout if present, else the physical layout.
80 fn virtual_column_layout(&self) -> &[ColumnType] {
81 match self.virtual_expander() {
82 Some(e) => e.virtual_layout(),
83 None => self.column_layout(),
84 }
85 }
86
87 /// Returns the permutation check
88 /// specifications for this AIR table.
89 ///
90 /// Each tuple contains:
91 /// - `String`:
92 /// Unique bus identifier (e.g., `RomChiplet::BUS_ID`)
93 /// - `PermutationCheckSpec`:
94 /// The GPA specification (sources, selector)
95 ///
96 /// Default:
97 /// No permutation checks.
98 fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
99 Vec::new()
100 }
101
102 /// Columns whose MLE evaluation must equal a fixed
103 /// Lagrange kernel at the verifier's r_final.
104 fn lagrange_pinned_columns(&self) -> Vec<LagrangePin> {
105 Vec::new()
106 }
107
108 /// Returns the `VirtualExpander` for chiplets
109 /// with physical to virtual column expansion.
110 fn virtual_expander(&self) -> Option<&VirtualExpander> {
111 None
112 }
113
114 /// Parses a raw physical row (bytes) into
115 /// the full Virtual Row (fields). Used by
116 /// the Verifier to reconstruct the virtual
117 /// trace from committed data.
118 ///
119 /// Delegates to `virtual_expander().parse_row()`
120 /// when present. Falls back to 1:1 parsing
121 /// from `column_layout()`.
122 fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
123 where
124 F: TraceCompatibleField,
125 {
126 res.clear();
127
128 if let Some(e) = self.virtual_expander() {
129 e.parse_row(bytes, res)
130 .expect("committed row byte length must match physical_row_bytes");
131 return;
132 }
133
134 let mut offset = 0;
135 for col_type in self.column_layout() {
136 let size = col_type.byte_size();
137 if offset + size <= bytes.len() {
138 res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
139 offset += size;
140 }
141 }
142 }
143
144 /// Returns the constraint system as an AST-DAG.
145 fn constraint_ast(&self) -> ConstraintAst<F>;
146
147 /// Chiplet defs used only for kernel dispatch.
148 fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
149 Ok(Vec::new())
150 }
151
152 /// Each hint's `chiplet_idx` indexes into `inline_chiplets()`.
153 fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
154 Vec::new()
155 }
156}
157
158// =================================================================
159// PROGRAM TRAIT — Composition over Air
160// =================================================================
161
162/// Extends `Air<F>` with multi-table composition:
163/// independent chiplets, GKR gadgets, and public inputs.
164///
165/// The top-level prover and verifier require `Program<F>`.
166/// Internal sub-protocols (ZeroCheck, chiplet verification)
167/// operate on `Air<F>` alone.
168pub trait Program<F: TowerField>: Air<F> {
169 /// Number of public inputs for this program.
170 fn num_public_inputs(&self) -> usize {
171 0
172 }
173
174 /// Returns independent AIR chiplet definitions.
175 /// Each chiplet gets its own trace, commitment,
176 /// ZeroCheck, and evaluation argument.
177 /// Connected to the main trace via GPA bus.
178 fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
179 Ok(Vec::new())
180 }
181}
182
183/// Represents a reference to a trace cell within
184/// the program's execution trace. Points to a specific
185/// column and relative row offset (current or next).
186#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
187pub struct ProgramCell {
188 pub col_idx: usize,
189
190 /// false = current row (i),
191 /// true = next row (i+1)
192 pub next_row: bool,
193}
194
195impl ProgramCell {
196 /// Reference to a cell in the current row.
197 pub fn current(col_idx: usize) -> Self {
198 Self {
199 col_idx,
200 next_row: false,
201 }
202 }
203
204 /// Reference to a cell in the next row.
205 pub fn next(col_idx: usize) -> Self {
206 Self {
207 col_idx,
208 next_row: true,
209 }
210 }
211}
212
213// =================================================================
214// INSTANCE & WITNESS
215// =================================================================
216
217/// Public Instance (Common inputs)
218/// of the program execution.
219#[derive(Clone, Debug)]
220pub struct ProgramInstance<F: TowerField> {
221 num_rows: usize,
222 public_inputs: Vec<F>,
223}
224
225impl<F: TowerField> ProgramInstance<F> {
226 pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
227 assert!(
228 num_rows.is_power_of_two(),
229 "Program trace height must be power of 2"
230 );
231
232 Self {
233 num_rows,
234 public_inputs,
235 }
236 }
237
238 #[inline(always)]
239 pub fn num_rows(&self) -> usize {
240 self.num_rows
241 }
242
243 /// Public inputs in canonical basis.
244 #[inline(always)]
245 pub fn public_inputs(&self) -> &[F] {
246 &self.public_inputs
247 }
248
249 #[inline(always)]
250 pub fn public_input(&self, idx: usize) -> Option<F> {
251 self.public_inputs.get(idx).copied()
252 }
253}
254
255/// Secret Witness (The Execution Trace) of the program.
256/// Holds the trace data. Generic over T to support both
257/// raw ColumnTrace and specialized wrappers.
258pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
259 pub trace: T,
260 pub chiplet_traces: Vec<ColumnTrace>,
261 _marker: PhantomData<F>,
262}
263
264impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
265 pub fn new(trace: T) -> Self {
266 Self {
267 trace,
268 chiplet_traces: Vec::new(),
269 _marker: PhantomData,
270 }
271 }
272
273 /// Attach independent chiplet traces.
274 /// Each entry corresponds by index to `chiplet_defs()`.
275 pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
276 self.chiplet_traces = chiplet_traces;
277 self
278 }
279}
280
281/// Locates a chiplet's inlined sub-AST in
282/// the program's merged `constraint_ast()`
283/// so the prover can dispatch its kernel.
284#[derive(Clone, Copy, Debug)]
285pub struct InlineKernelHint {
286 /// Index into `Air::inline_chiplets()`.
287 pub chiplet_idx: usize,
288
289 /// Absolute index of the chiplet's
290 /// first root in the program's `roots`.
291 pub root_offset: usize,
292
293 /// Absolute column index where
294 /// the chiplet's columns start.
295 pub column_offset: usize,
296}
297
298// =================================================================
299// LAGRANGE-PINNED COLUMNS
300// =================================================================
301
302/// Hypercube point at which a Lagrange MLE is anchored.
303///
304/// `Custom(bits)` carries the bit-decomposition of an arbitrary
305/// row index, LSB first, length must equal the trace's `num_vars`.
306#[derive(Clone, Debug, PartialEq, Eq)]
307pub enum LagrangePoint {
308 LastRow,
309 FirstRow,
310 Custom(Vec<bool>),
311}
312
313impl LagrangePoint {
314 /// MLE evaluation of the pinned column
315 /// at point `r` (LSB-first bit order).
316 /// `r.len()` must equal `num_vars`.
317 pub fn evaluate<F>(&self, r: &[Flat<F>]) -> Flat<F>
318 where
319 F: HardwareField,
320 {
321 let one = Flat::from_raw(F::ONE);
322 match self {
323 LagrangePoint::LastRow => {
324 let mut prod = one;
325 for &r_k in r {
326 prod *= r_k;
327 }
328
329 one - prod
330 }
331 LagrangePoint::FirstRow => {
332 let mut prod = one;
333 for &r_k in r {
334 prod *= one - r_k;
335 }
336
337 prod
338 }
339 LagrangePoint::Custom(bits) => {
340 debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
341
342 let mut prod = one;
343 for (k, &b) in bits.iter().enumerate() {
344 let factor = if b { r[k] } else { one - r[k] };
345 prod *= factor;
346 }
347
348 prod
349 }
350 }
351 }
352}
353
354/// Single binding:
355/// one virtual column pinned to a Lagrange MLE.
356#[derive(Clone, Debug, PartialEq, Eq)]
357pub struct LagrangePin {
358 pub col_idx: usize,
359 pub point: LagrangePoint,
360}
361
362impl LagrangePin {
363 pub fn last_row(col_idx: usize) -> Self {
364 Self {
365 col_idx,
366 point: LagrangePoint::LastRow,
367 }
368 }
369
370 pub fn first_row(col_idx: usize) -> Self {
371 Self {
372 col_idx,
373 point: LagrangePoint::FirstRow,
374 }
375 }
376
377 pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
378 Self {
379 col_idx,
380 point: LagrangePoint::Custom(bits),
381 }
382 }
383}
384
385/// Rejects out-of-range `col_idx`, mis-sized `Custom`
386/// bit vectors, and duplicate pins on the same column
387/// (a column anchored to two distinct points is unsatisfiable).
388pub fn validate_lagrange_pins(
389 pins: &[LagrangePin],
390 num_columns: usize,
391 num_vars: Option<usize>,
392) -> errors::Result<()> {
393 for (i, pin) in pins.iter().enumerate() {
394 if pin.col_idx >= num_columns {
395 return Err(errors::Error::Protocol {
396 protocol: "lagrange_pin",
397 message: "col_idx out of range",
398 });
399 }
400
401 if let (LagrangePoint::Custom(bits), Some(nv)) = (&pin.point, num_vars)
402 && bits.len() != nv
403 {
404 return Err(errors::Error::Protocol {
405 protocol: "lagrange_pin",
406 message: "Custom point bit width != num_vars",
407 });
408 }
409
410 for prior in &pins[..i] {
411 if prior.col_idx == pin.col_idx {
412 return Err(errors::Error::Protocol {
413 protocol: "lagrange_pin",
414 message: "duplicate pin on same column",
415 });
416 }
417 }
418 }
419
420 Ok(())
421}