Skip to main content

hekate_program/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18#![cfg_attr(not(feature = "std"), no_std)]
19
20extern crate alloc;
21extern crate core;
22
23use alloc::string::{String, ToString};
24use alloc::vec::Vec;
25use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
26use core::marker::PhantomData;
27use expander::VirtualExpander;
28use hekate_core::errors;
29use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
30use hekate_math::{Flat, HardwareField, TowerField};
31use permutation::PermutationCheckSpec;
32
33pub mod chiplet;
34pub mod constraint;
35pub mod expander;
36pub mod permutation;
37pub mod schema;
38
39// =================================================================
40// AIR TRAIT:
41// Core Algebraic Intermediate Representation
42// =================================================================
43
44/// Defines the algebraic structure, trace
45/// layout, and constraints of an AIR table.
46///
47/// Implemented by both standalone
48/// programs and independent chiplets.
49pub trait Air<F: TowerField>: Sized + Clone + Sync {
50    fn name(&self) -> String {
51        "HekateAir".to_string()
52    }
53
54    fn num_columns(&self) -> usize {
55        self.virtual_column_layout().len()
56    }
57
58    /// Flat expansion of `constraint_ast()`.
59    fn constraints(&self) -> Vec<Constraint<F>> {
60        self.constraint_ast().to_constraints()
61    }
62
63    /// Returns the list of boundary constraints. Each
64    /// constraint ties a specific trace cell to a public
65    /// input value. By default, returns an empty list.
66    fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
67        Vec::new()
68    }
69
70    /// Returns the physical layout
71    /// of the columns in the trace.
72    ///
73    /// This describes the storage type
74    /// (Bit, B8, B32, etc.) of each column.
75    fn column_layout(&self) -> &[ColumnType];
76
77    /// Returns the virtual layout of the columns
78    /// (after unpacking). Defaults to the expander's
79    /// layout if present, else the physical layout.
80    fn virtual_column_layout(&self) -> &[ColumnType] {
81        match self.virtual_expander() {
82            Some(e) => e.virtual_layout(),
83            None => self.column_layout(),
84        }
85    }
86
87    /// Returns the permutation check
88    /// specifications for this AIR table.
89    ///
90    /// Each tuple contains:
91    /// - `String`:
92    ///   Unique bus identifier (e.g., `RomChiplet::BUS_ID`)
93    /// - `PermutationCheckSpec`:
94    ///   The GPA specification (sources, selector)
95    ///
96    /// Default:
97    /// No permutation checks.
98    fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
99        Vec::new()
100    }
101
102    /// Columns pinned to a fixed shape, bound by MLE
103    /// equality at `r_final`. Each shape must be a pure
104    /// function of the row index, not the witness.
105    fn fixed_columns(&self) -> Vec<FixedColumn<F>> {
106        Vec::new()
107    }
108
109    /// Returns the `VirtualExpander` for chiplets
110    /// with physical to virtual column expansion.
111    fn virtual_expander(&self) -> Option<&VirtualExpander> {
112        None
113    }
114
115    /// Parses a raw physical row (bytes) into
116    /// the full Virtual Row (fields). Used by
117    /// the Verifier to reconstruct the virtual
118    /// trace from committed data.
119    ///
120    /// Delegates to `virtual_expander().parse_row()`
121    /// when present. Falls back to 1:1 parsing
122    /// from `column_layout()`.
123    fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
124    where
125        F: TraceCompatibleField,
126    {
127        res.clear();
128
129        if let Some(e) = self.virtual_expander() {
130            e.parse_row(bytes, res)
131                .expect("committed row byte length must match physical_row_bytes");
132            return;
133        }
134
135        let mut offset = 0;
136        for col_type in self.column_layout() {
137            let size = col_type.byte_size();
138            if offset + size <= bytes.len() {
139                res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
140                offset += size;
141            }
142        }
143    }
144
145    /// Returns the constraint system as an AST-DAG.
146    fn constraint_ast(&self) -> ConstraintAst<F>;
147
148    /// Chiplet defs used only for kernel dispatch.
149    fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
150        Ok(Vec::new())
151    }
152
153    /// Each hint's `chiplet_idx` indexes into `inline_chiplets()`.
154    fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
155        Vec::new()
156    }
157}
158
159// =================================================================
160// PROGRAM TRAIT — Composition over Air
161// =================================================================
162
163/// Extends `Air<F>` with multi-table composition:
164/// independent chiplets, GKR gadgets, and public inputs.
165///
166/// The top-level prover and verifier require `Program<F>`.
167/// Internal sub-protocols (ZeroCheck, chiplet verification)
168/// operate on `Air<F>` alone.
169pub trait Program<F: TowerField>: Air<F> {
170    /// Number of public inputs for this program.
171    fn num_public_inputs(&self) -> usize {
172        0
173    }
174
175    /// Returns independent AIR chiplet definitions.
176    /// Each chiplet gets its own trace, commitment,
177    /// ZeroCheck, and evaluation argument.
178    /// Connected to the main trace via GPA bus.
179    fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
180        Ok(Vec::new())
181    }
182}
183
184/// Represents a reference to a trace cell within
185/// the program's execution trace. Points to a specific
186/// column and relative row offset (current or next).
187#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
188pub struct ProgramCell {
189    pub col_idx: usize,
190
191    /// false = current row (i),
192    /// true = next row (i+1)
193    pub next_row: bool,
194}
195
196impl ProgramCell {
197    /// Reference to a cell in the current row.
198    pub fn current(col_idx: usize) -> Self {
199        Self {
200            col_idx,
201            next_row: false,
202        }
203    }
204
205    /// Reference to a cell in the next row.
206    pub fn next(col_idx: usize) -> Self {
207        Self {
208            col_idx,
209            next_row: true,
210        }
211    }
212}
213
214// =================================================================
215// INSTANCE & WITNESS
216// =================================================================
217
218/// Public Instance (Common inputs)
219/// of the program execution.
220#[derive(Clone, Debug)]
221pub struct ProgramInstance<F: TowerField> {
222    num_rows: usize,
223    public_inputs: Vec<F>,
224}
225
226impl<F: TowerField> ProgramInstance<F> {
227    pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
228        assert!(
229            num_rows.is_power_of_two(),
230            "Program trace height must be power of 2"
231        );
232
233        Self {
234            num_rows,
235            public_inputs,
236        }
237    }
238
239    #[inline(always)]
240    pub fn num_rows(&self) -> usize {
241        self.num_rows
242    }
243
244    /// Public inputs in canonical basis.
245    #[inline(always)]
246    pub fn public_inputs(&self) -> &[F] {
247        &self.public_inputs
248    }
249
250    #[inline(always)]
251    pub fn public_input(&self, idx: usize) -> Option<F> {
252        self.public_inputs.get(idx).copied()
253    }
254}
255
256/// Secret Witness (The Execution Trace) of the program.
257/// Holds the trace data. Generic over T to support both
258/// raw ColumnTrace and specialized wrappers.
259pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
260    pub trace: T,
261    pub chiplet_traces: Vec<ColumnTrace>,
262    _marker: PhantomData<F>,
263}
264
265impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
266    pub fn new(trace: T) -> Self {
267        Self {
268            trace,
269            chiplet_traces: Vec::new(),
270            _marker: PhantomData,
271        }
272    }
273
274    /// Attach independent chiplet traces.
275    /// Each entry corresponds by index to `chiplet_defs()`.
276    pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
277        self.chiplet_traces = chiplet_traces;
278        self
279    }
280}
281
282/// Locates a chiplet's inlined sub-AST in
283/// the program's merged `constraint_ast()`
284/// so the prover can dispatch its kernel.
285#[derive(Clone, Copy, Debug)]
286pub struct InlineKernelHint {
287    /// Index into `Air::inline_chiplets()`.
288    pub chiplet_idx: usize,
289
290    /// Absolute index of the chiplet's
291    /// first root in the program's `roots`.
292    pub root_offset: usize,
293
294    /// Absolute column index where
295    /// the chiplet's columns start.
296    pub column_offset: usize,
297}
298
299// =================================================================
300// FIXED COLUMNS
301// =================================================================
302
303/// Row-index-determined shape a fixed column is
304/// pinned to. `FirstRow`/`LastRow`/`Custom` are
305/// single-row indicators; `Periodic`/`Sparse`/`Dense`
306/// are arbitrary row-indexed patterns.
307#[derive(Clone, Debug, PartialEq, Eq)]
308pub enum FixedShape<F> {
309    LastRow,
310    FirstRow,
311    Custom(Vec<bool>),
312    Periodic { period: usize, values: Vec<F> },
313    Sparse(Vec<(usize, F)>),
314    Dense(Vec<F>),
315}
316
317impl<F: HardwareField> FixedShape<F> {
318    /// MLE of the shape at point `r` (LSB-first),
319    /// in flat basis. `r.len()` must equal `num_vars`.
320    pub fn evaluate(&self, r: &[Flat<F>]) -> Flat<F> {
321        let one = Flat::from_raw(F::ONE);
322        match self {
323            FixedShape::LastRow => {
324                let mut prod = one;
325                for &r_k in r {
326                    prod *= r_k;
327                }
328
329                one - prod
330            }
331            FixedShape::FirstRow => {
332                let mut prod = one;
333                for &r_k in r {
334                    prod *= one - r_k;
335                }
336
337                prod
338            }
339            FixedShape::Custom(bits) => {
340                debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
341
342                let mut prod = one;
343                for (k, &b) in bits.iter().enumerate() {
344                    let factor = if b { r[k] } else { one - r[k] };
345                    prod *= factor;
346                }
347
348                prod
349            }
350            FixedShape::Periodic { period, values } => {
351                // Low p = log2(period) coords only;
352                // high coords each sum to 1.
353                let p = period.trailing_zeros() as usize;
354
355                let mut acc = Flat::from_raw(F::ZERO);
356                for (j, &v) in values.iter().enumerate() {
357                    acc += v.to_hardware() * eq_index(&r[..p], j);
358                }
359
360                acc
361            }
362            FixedShape::Sparse(entries) => {
363                let mut acc = Flat::from_raw(F::ZERO);
364                for &(row, v) in entries {
365                    acc += v.to_hardware() * eq_index(r, row);
366                }
367
368                acc
369            }
370            FixedShape::Dense(values) => {
371                let mut acc = Flat::from_raw(F::ZERO);
372                for (i, &v) in values.iter().enumerate() {
373                    acc += v.to_hardware() * eq_index(r, i);
374                }
375
376                acc
377            }
378        }
379    }
380
381    /// Shape value at integer row `row`. O(1);
382    /// prefer over `evaluate` at a vertex,
383    /// which is O(N) for `Dense`.
384    pub fn value_at_row(&self, row: usize, num_vars: usize) -> Flat<F> {
385        let one = Flat::from_raw(F::ONE);
386        let zero = Flat::from_raw(F::ZERO);
387
388        match self {
389            FixedShape::FirstRow => {
390                if row == 0 {
391                    one
392                } else {
393                    zero
394                }
395            }
396            FixedShape::LastRow => {
397                if row == (1usize << num_vars) - 1 {
398                    zero
399                } else {
400                    one
401                }
402            }
403            FixedShape::Custom(bits) => {
404                let target = bits
405                    .iter()
406                    .enumerate()
407                    .fold(0usize, |acc, (k, &b)| acc | ((b as usize) << k));
408
409                if row == target { one } else { zero }
410            }
411            FixedShape::Periodic { period, values } => values[row % period].to_hardware(),
412            FixedShape::Sparse(entries) => {
413                let mut acc = zero;
414                for &(r, v) in entries {
415                    if r == row {
416                        acc += v.to_hardware();
417                    }
418                }
419
420                acc
421            }
422            FixedShape::Dense(values) => values[row].to_hardware(),
423        }
424    }
425}
426
427fn eq_index<F: HardwareField>(r: &[Flat<F>], index: usize) -> Flat<F> {
428    let one = Flat::from_raw(F::ONE);
429
430    let mut prod = one;
431    for (k, &r_k) in r.iter().enumerate() {
432        let factor = if (index >> k) & 1 == 1 {
433            r_k
434        } else {
435            one - r_k
436        };
437        prod *= factor;
438    }
439
440    prod
441}
442
443/// One committed column pinned to a fixed shape.
444#[derive(Clone, Debug, PartialEq, Eq)]
445pub struct FixedColumn<F> {
446    pub col_idx: usize,
447    pub shape: FixedShape<F>,
448}
449
450impl<F> FixedColumn<F> {
451    pub fn last_row(col_idx: usize) -> Self {
452        Self {
453            col_idx,
454            shape: FixedShape::LastRow,
455        }
456    }
457
458    pub fn first_row(col_idx: usize) -> Self {
459        Self {
460            col_idx,
461            shape: FixedShape::FirstRow,
462        }
463    }
464
465    pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
466        Self {
467            col_idx,
468            shape: FixedShape::Custom(bits),
469        }
470    }
471
472    pub fn periodic(col_idx: usize, period: usize, values: Vec<F>) -> Self {
473        Self {
474            col_idx,
475            shape: FixedShape::Periodic { period, values },
476        }
477    }
478
479    pub fn sparse(col_idx: usize, entries: Vec<(usize, F)>) -> Self {
480        Self {
481            col_idx,
482            shape: FixedShape::Sparse(entries),
483        }
484    }
485
486    pub fn dense(col_idx: usize, values: Vec<F>) -> Self {
487        Self {
488            col_idx,
489            shape: FixedShape::Dense(values),
490        }
491    }
492}
493
494/// Declares a fixed column from a shape.
495pub fn fix<F>(col_idx: usize, shape: FixedShape<F>) -> FixedColumn<F> {
496    FixedColumn { col_idx, shape }
497}
498
499/// Rejects out-of-range `col_idx`, duplicate pins,
500/// malformed shapes, and out-of-domain values
501/// (`Bit` columns require values in {0, 1}).
502pub fn validate_fixed_columns<F: TowerField>(
503    fixed: &[FixedColumn<F>],
504    layout: &[ColumnType],
505    num_vars: Option<usize>,
506) -> errors::Result<()> {
507    for (i, fc) in fixed.iter().enumerate() {
508        if fc.col_idx >= layout.len() {
509            return Err(errors::Error::Protocol {
510                protocol: "fixed_column",
511                message: "col_idx out of range",
512            });
513        }
514
515        validate_shape(&fc.shape, layout[fc.col_idx], num_vars)?;
516
517        for prior in &fixed[..i] {
518            if prior.col_idx == fc.col_idx {
519                return Err(errors::Error::Protocol {
520                    protocol: "fixed_column",
521                    message: "duplicate pin on same column",
522                });
523            }
524        }
525    }
526
527    Ok(())
528}
529
530fn validate_shape<F: TowerField>(
531    shape: &FixedShape<F>,
532    col_type: ColumnType,
533    num_vars: Option<usize>,
534) -> errors::Result<()> {
535    match shape {
536        FixedShape::LastRow | FixedShape::FirstRow => Ok(()),
537        FixedShape::Custom(bits) => match num_vars {
538            Some(nv) if bits.len() != nv => Err(errors::Error::Protocol {
539                protocol: "fixed_column",
540                message: "Custom point bit width != num_vars",
541            }),
542            _ => Ok(()),
543        },
544        FixedShape::Periodic { period, values } => {
545            if !period.is_power_of_two() {
546                return Err(errors::Error::Protocol {
547                    protocol: "fixed_column",
548                    message: "Periodic period must be a power of two",
549                });
550            }
551
552            if values.len() != *period {
553                return Err(errors::Error::Protocol {
554                    protocol: "fixed_column",
555                    message: "Periodic values length != period",
556                });
557            }
558
559            if let Some(nv) = num_vars
560                && *period > (1usize << nv)
561            {
562                return Err(errors::Error::Protocol {
563                    protocol: "fixed_column",
564                    message: "Periodic period exceeds trace height",
565                });
566            }
567
568            check_bit_domain(values.iter().copied(), col_type)
569        }
570        FixedShape::Sparse(entries) => {
571            if let Some(nv) = num_vars {
572                let n = 1usize << nv;
573                for &(row, _) in entries {
574                    if row >= n {
575                        return Err(errors::Error::Protocol {
576                            protocol: "fixed_column",
577                            message: "Sparse row index exceeds trace height",
578                        });
579                    }
580                }
581            }
582
583            for (i, &(row, _)) in entries.iter().enumerate() {
584                if entries[..i].iter().any(|&(prior, _)| prior == row) {
585                    return Err(errors::Error::Protocol {
586                        protocol: "fixed_column",
587                        message: "duplicate Sparse row",
588                    });
589                }
590            }
591
592            check_bit_domain(entries.iter().map(|&(_, v)| v), col_type)
593        }
594        FixedShape::Dense(values) => {
595            if let Some(nv) = num_vars
596                && values.len() != (1usize << nv)
597            {
598                return Err(errors::Error::Protocol {
599                    protocol: "fixed_column",
600                    message: "Dense values length != trace height",
601                });
602            }
603
604            check_bit_domain(values.iter().copied(), col_type)
605        }
606    }
607}
608
609fn check_bit_domain<F: TowerField>(
610    values: impl Iterator<Item = F>,
611    col_type: ColumnType,
612) -> errors::Result<()> {
613    if col_type != ColumnType::Bit {
614        return Ok(());
615    }
616
617    for v in values {
618        if v != F::ZERO && v != F::ONE {
619            return Err(errors::Error::Protocol {
620                protocol: "fixed_column",
621                message: "Bit fixed column value not in {0,1}",
622            });
623        }
624    }
625
626    Ok(())
627}