Skip to main content

hekate_program/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18#![cfg_attr(not(feature = "std"), no_std)]
19
20extern crate alloc;
21extern crate core;
22
23use alloc::string::{String, ToString};
24use alloc::vec::Vec;
25use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
26use core::marker::PhantomData;
27use expander::VirtualExpander;
28use hekate_core::errors;
29use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
30use hekate_math::{Flat, HardwareField, TowerField};
31use permutation::PermutationCheckSpec;
32
33pub mod chiplet;
34pub mod constraint;
35pub mod expander;
36pub mod permutation;
37pub mod schema;
38
39// =================================================================
40// AIR TRAIT:
41// Core Algebraic Intermediate Representation
42// =================================================================
43
44/// Defines the algebraic structure, trace
45/// layout, and constraints of an AIR table.
46///
47/// Implemented by both standalone
48/// programs and independent chiplets.
49pub trait Air<F: TowerField>: Sized + Clone + Sync {
50    fn name(&self) -> String {
51        "HekateAir".to_string()
52    }
53
54    fn num_columns(&self) -> usize {
55        self.virtual_column_layout().len()
56    }
57
58    /// Flat expansion of `constraint_ast()`.
59    fn constraints(&self) -> Vec<Constraint<F>> {
60        self.constraint_ast().to_constraints()
61    }
62
63    /// Returns the list of boundary constraints. Each
64    /// constraint ties a specific trace cell to a public
65    /// input value. By default, returns an empty list.
66    fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
67        Vec::new()
68    }
69
70    /// Returns the physical layout
71    /// of the columns in the trace.
72    ///
73    /// This describes the storage type
74    /// (Bit, B8, B32, etc.) of each column.
75    fn column_layout(&self) -> &[ColumnType];
76
77    /// Returns the virtual layout of the columns
78    /// (after unpacking). Defaults to the expander's
79    /// layout if present, else the physical layout.
80    fn virtual_column_layout(&self) -> &[ColumnType] {
81        match self.virtual_expander() {
82            Some(e) => e.virtual_layout(),
83            None => self.column_layout(),
84        }
85    }
86
87    /// Returns the permutation check
88    /// specifications for this AIR table.
89    ///
90    /// Each tuple contains:
91    /// - `String`:
92    ///   Unique bus identifier (e.g., `RomChiplet::BUS_ID`)
93    /// - `PermutationCheckSpec`:
94    ///   The GPA specification (sources, selector)
95    ///
96    /// Default:
97    /// No permutation checks.
98    fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
99        Vec::new()
100    }
101
102    /// Columns whose MLE evaluation must equal a fixed
103    /// Lagrange kernel at the verifier's r_final.
104    fn lagrange_pinned_columns(&self) -> Vec<LagrangePin> {
105        Vec::new()
106    }
107
108    /// Returns the `VirtualExpander` for chiplets
109    /// with physical to virtual column expansion.
110    fn virtual_expander(&self) -> Option<&VirtualExpander> {
111        None
112    }
113
114    /// Parses a raw physical row (bytes) into
115    /// the full Virtual Row (fields). Used by
116    /// the Verifier to reconstruct the virtual
117    /// trace from committed data.
118    ///
119    /// Delegates to `virtual_expander().parse_row()`
120    /// when present. Falls back to 1:1 parsing
121    /// from `column_layout()`.
122    fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
123    where
124        F: TraceCompatibleField,
125    {
126        res.clear();
127
128        if let Some(e) = self.virtual_expander() {
129            e.parse_row(bytes, res)
130                .expect("committed row byte length must match physical_row_bytes");
131            return;
132        }
133
134        let mut offset = 0;
135        for col_type in self.column_layout() {
136            let size = col_type.byte_size();
137            if offset + size <= bytes.len() {
138                res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
139                offset += size;
140            }
141        }
142    }
143
144    /// Returns the constraint system as an AST-DAG.
145    fn constraint_ast(&self) -> ConstraintAst<F>;
146
147    /// Chiplet defs used only for kernel dispatch.
148    fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
149        Ok(Vec::new())
150    }
151
152    /// Each hint's `chiplet_idx` indexes into `inline_chiplets()`.
153    fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
154        Vec::new()
155    }
156}
157
158// =================================================================
159// PROGRAM TRAIT — Composition over Air
160// =================================================================
161
162/// Extends `Air<F>` with multi-table composition:
163/// independent chiplets, GKR gadgets, and public inputs.
164///
165/// The top-level prover and verifier require `Program<F>`.
166/// Internal sub-protocols (ZeroCheck, chiplet verification)
167/// operate on `Air<F>` alone.
168pub trait Program<F: TowerField>: Air<F> {
169    /// Number of public inputs for this program.
170    fn num_public_inputs(&self) -> usize {
171        0
172    }
173
174    /// Returns independent AIR chiplet definitions.
175    /// Each chiplet gets its own trace, commitment,
176    /// ZeroCheck, and evaluation argument.
177    /// Connected to the main trace via GPA bus.
178    fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
179        Ok(Vec::new())
180    }
181}
182
183/// Represents a reference to a trace cell within
184/// the program's execution trace. Points to a specific
185/// column and relative row offset (current or next).
186#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
187pub struct ProgramCell {
188    pub col_idx: usize,
189
190    /// false = current row (i),
191    /// true = next row (i+1)
192    pub next_row: bool,
193}
194
195impl ProgramCell {
196    /// Reference to a cell in the current row.
197    pub fn current(col_idx: usize) -> Self {
198        Self {
199            col_idx,
200            next_row: false,
201        }
202    }
203
204    /// Reference to a cell in the next row.
205    pub fn next(col_idx: usize) -> Self {
206        Self {
207            col_idx,
208            next_row: true,
209        }
210    }
211}
212
213// =================================================================
214// INSTANCE & WITNESS
215// =================================================================
216
217/// Public Instance (Common inputs)
218/// of the program execution.
219#[derive(Clone, Debug)]
220pub struct ProgramInstance<F: TowerField> {
221    num_rows: usize,
222    public_inputs: Vec<F>,
223}
224
225impl<F: TowerField> ProgramInstance<F> {
226    pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
227        assert!(
228            num_rows.is_power_of_two(),
229            "Program trace height must be power of 2"
230        );
231
232        Self {
233            num_rows,
234            public_inputs,
235        }
236    }
237
238    #[inline(always)]
239    pub fn num_rows(&self) -> usize {
240        self.num_rows
241    }
242
243    /// Public inputs in canonical basis.
244    #[inline(always)]
245    pub fn public_inputs(&self) -> &[F] {
246        &self.public_inputs
247    }
248
249    #[inline(always)]
250    pub fn public_input(&self, idx: usize) -> Option<F> {
251        self.public_inputs.get(idx).copied()
252    }
253}
254
255/// Secret Witness (The Execution Trace) of the program.
256/// Holds the trace data. Generic over T to support both
257/// raw ColumnTrace and specialized wrappers.
258pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
259    pub trace: T,
260    pub chiplet_traces: Vec<ColumnTrace>,
261    _marker: PhantomData<F>,
262}
263
264impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
265    pub fn new(trace: T) -> Self {
266        Self {
267            trace,
268            chiplet_traces: Vec::new(),
269            _marker: PhantomData,
270        }
271    }
272
273    /// Attach independent chiplet traces.
274    /// Each entry corresponds by index to `chiplet_defs()`.
275    pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
276        self.chiplet_traces = chiplet_traces;
277        self
278    }
279}
280
281/// Locates a chiplet's inlined sub-AST in
282/// the program's merged `constraint_ast()`
283/// so the prover can dispatch its kernel.
284#[derive(Clone, Copy, Debug)]
285pub struct InlineKernelHint {
286    /// Index into `Air::inline_chiplets()`.
287    pub chiplet_idx: usize,
288
289    /// Absolute index of the chiplet's
290    /// first root in the program's `roots`.
291    pub root_offset: usize,
292
293    /// Absolute column index where
294    /// the chiplet's columns start.
295    pub column_offset: usize,
296}
297
298// =================================================================
299// LAGRANGE-PINNED COLUMNS
300// =================================================================
301
302/// Hypercube point at which a Lagrange MLE is anchored.
303///
304/// `Custom(bits)` carries the bit-decomposition of an arbitrary
305/// row index, LSB first, length must equal the trace's `num_vars`.
306#[derive(Clone, Debug, PartialEq, Eq)]
307pub enum LagrangePoint {
308    LastRow,
309    FirstRow,
310    Custom(Vec<bool>),
311}
312
313impl LagrangePoint {
314    /// MLE evaluation of the pinned column
315    /// at point `r` (LSB-first bit order).
316    /// `r.len()` must equal `num_vars`.
317    pub fn evaluate<F>(&self, r: &[Flat<F>]) -> Flat<F>
318    where
319        F: HardwareField,
320    {
321        let one = Flat::from_raw(F::ONE);
322        match self {
323            LagrangePoint::LastRow => {
324                let mut prod = one;
325                for &r_k in r {
326                    prod *= r_k;
327                }
328
329                one - prod
330            }
331            LagrangePoint::FirstRow => {
332                let mut prod = one;
333                for &r_k in r {
334                    prod *= one - r_k;
335                }
336
337                prod
338            }
339            LagrangePoint::Custom(bits) => {
340                debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
341
342                let mut prod = one;
343                for (k, &b) in bits.iter().enumerate() {
344                    let factor = if b { r[k] } else { one - r[k] };
345                    prod *= factor;
346                }
347
348                prod
349            }
350        }
351    }
352}
353
354/// Single binding:
355/// one virtual column pinned to a Lagrange MLE.
356#[derive(Clone, Debug, PartialEq, Eq)]
357pub struct LagrangePin {
358    pub col_idx: usize,
359    pub point: LagrangePoint,
360}
361
362impl LagrangePin {
363    pub fn last_row(col_idx: usize) -> Self {
364        Self {
365            col_idx,
366            point: LagrangePoint::LastRow,
367        }
368    }
369
370    pub fn first_row(col_idx: usize) -> Self {
371        Self {
372            col_idx,
373            point: LagrangePoint::FirstRow,
374        }
375    }
376
377    pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
378        Self {
379            col_idx,
380            point: LagrangePoint::Custom(bits),
381        }
382    }
383}
384
385/// Rejects out-of-range `col_idx`, mis-sized `Custom`
386/// bit vectors, and duplicate pins on the same column
387/// (a column anchored to two distinct points is unsatisfiable).
388pub fn validate_lagrange_pins(
389    pins: &[LagrangePin],
390    num_columns: usize,
391    num_vars: Option<usize>,
392) -> errors::Result<()> {
393    for (i, pin) in pins.iter().enumerate() {
394        if pin.col_idx >= num_columns {
395            return Err(errors::Error::Protocol {
396                protocol: "lagrange_pin",
397                message: "col_idx out of range",
398            });
399        }
400
401        if let (LagrangePoint::Custom(bits), Some(nv)) = (&pin.point, num_vars)
402            && bits.len() != nv
403        {
404            return Err(errors::Error::Protocol {
405                protocol: "lagrange_pin",
406                message: "Custom point bit width != num_vars",
407            });
408        }
409
410        for prior in &pins[..i] {
411            if prior.col_idx == pin.col_idx {
412                return Err(errors::Error::Protocol {
413                    protocol: "lagrange_pin",
414                    message: "duplicate pin on same column",
415                });
416            }
417        }
418    }
419
420    Ok(())
421}