#![cfg_attr(not(feature = "std"), no_std)]
extern crate alloc;
extern crate core;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
use core::marker::PhantomData;
use expander::VirtualExpander;
use hekate_core::errors;
use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
use hekate_math::{Flat, HardwareField, TowerField};
use permutation::PermutationCheckSpec;
pub mod chiplet;
pub mod constraint;
pub mod expander;
pub mod permutation;
pub mod schema;
pub trait Air<F: TowerField>: Sized + Clone + Sync {
fn name(&self) -> String {
"HekateAir".to_string()
}
fn num_columns(&self) -> usize {
self.virtual_column_layout().len()
}
fn constraints(&self) -> Vec<Constraint<F>> {
self.constraint_ast().to_constraints()
}
fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
Vec::new()
}
fn column_layout(&self) -> &[ColumnType];
fn virtual_column_layout(&self) -> &[ColumnType] {
match self.virtual_expander() {
Some(e) => e.virtual_layout(),
None => self.column_layout(),
}
}
fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
Vec::new()
}
fn fixed_columns(&self) -> Vec<FixedColumn<F>> {
Vec::new()
}
fn virtual_expander(&self) -> Option<&VirtualExpander> {
None
}
fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
where
F: TraceCompatibleField,
{
res.clear();
if let Some(e) = self.virtual_expander() {
e.parse_row(bytes, res)
.expect("committed row byte length must match physical_row_bytes");
return;
}
let mut offset = 0;
for col_type in self.column_layout() {
let size = col_type.byte_size();
if offset + size <= bytes.len() {
res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
offset += size;
}
}
}
fn constraint_ast(&self) -> ConstraintAst<F>;
fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
Ok(Vec::new())
}
fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
Vec::new()
}
}
pub trait Program<F: TowerField>: Air<F> {
fn num_public_inputs(&self) -> usize {
0
}
fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
Ok(Vec::new())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ProgramCell {
pub col_idx: usize,
pub next_row: bool,
}
impl ProgramCell {
pub fn current(col_idx: usize) -> Self {
Self {
col_idx,
next_row: false,
}
}
pub fn next(col_idx: usize) -> Self {
Self {
col_idx,
next_row: true,
}
}
}
#[derive(Clone, Debug)]
pub struct ProgramInstance<F: TowerField> {
num_rows: usize,
public_inputs: Vec<F>,
}
impl<F: TowerField> ProgramInstance<F> {
pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
assert!(
num_rows.is_power_of_two(),
"Program trace height must be power of 2"
);
Self {
num_rows,
public_inputs,
}
}
#[inline(always)]
pub fn num_rows(&self) -> usize {
self.num_rows
}
#[inline(always)]
pub fn public_inputs(&self) -> &[F] {
&self.public_inputs
}
#[inline(always)]
pub fn public_input(&self, idx: usize) -> Option<F> {
self.public_inputs.get(idx).copied()
}
}
pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
pub trace: T,
pub chiplet_traces: Vec<ColumnTrace>,
_marker: PhantomData<F>,
}
impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
pub fn new(trace: T) -> Self {
Self {
trace,
chiplet_traces: Vec::new(),
_marker: PhantomData,
}
}
pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
self.chiplet_traces = chiplet_traces;
self
}
}
#[derive(Clone, Copy, Debug)]
pub struct InlineKernelHint {
pub chiplet_idx: usize,
pub root_offset: usize,
pub column_offset: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FixedShape<F> {
LastRow,
FirstRow,
Custom(Vec<bool>),
Periodic { period: usize, values: Vec<F> },
Sparse(Vec<(usize, F)>),
Dense(Vec<F>),
}
impl<F: HardwareField> FixedShape<F> {
pub fn evaluate(&self, r: &[Flat<F>]) -> Flat<F> {
let one = Flat::from_raw(F::ONE);
match self {
FixedShape::LastRow => {
let mut prod = one;
for &r_k in r {
prod *= r_k;
}
one - prod
}
FixedShape::FirstRow => {
let mut prod = one;
for &r_k in r {
prod *= one - r_k;
}
prod
}
FixedShape::Custom(bits) => {
debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
let mut prod = one;
for (k, &b) in bits.iter().enumerate() {
let factor = if b { r[k] } else { one - r[k] };
prod *= factor;
}
prod
}
FixedShape::Periodic { period, values } => {
let p = period.trailing_zeros() as usize;
let mut acc = Flat::from_raw(F::ZERO);
for (j, &v) in values.iter().enumerate() {
acc += v.to_hardware() * eq_index(&r[..p], j);
}
acc
}
FixedShape::Sparse(entries) => {
let mut acc = Flat::from_raw(F::ZERO);
for &(row, v) in entries {
acc += v.to_hardware() * eq_index(r, row);
}
acc
}
FixedShape::Dense(values) => {
let mut acc = Flat::from_raw(F::ZERO);
for (i, &v) in values.iter().enumerate() {
acc += v.to_hardware() * eq_index(r, i);
}
acc
}
}
}
pub fn value_at_row(&self, row: usize, num_vars: usize) -> Flat<F> {
let one = Flat::from_raw(F::ONE);
let zero = Flat::from_raw(F::ZERO);
match self {
FixedShape::FirstRow => {
if row == 0 {
one
} else {
zero
}
}
FixedShape::LastRow => {
if row == (1usize << num_vars) - 1 {
zero
} else {
one
}
}
FixedShape::Custom(bits) => {
let target = bits
.iter()
.enumerate()
.fold(0usize, |acc, (k, &b)| acc | ((b as usize) << k));
if row == target { one } else { zero }
}
FixedShape::Periodic { period, values } => values[row % period].to_hardware(),
FixedShape::Sparse(entries) => {
let mut acc = zero;
for &(r, v) in entries {
if r == row {
acc += v.to_hardware();
}
}
acc
}
FixedShape::Dense(values) => values[row].to_hardware(),
}
}
}
fn eq_index<F: HardwareField>(r: &[Flat<F>], index: usize) -> Flat<F> {
let one = Flat::from_raw(F::ONE);
let mut prod = one;
for (k, &r_k) in r.iter().enumerate() {
let factor = if (index >> k) & 1 == 1 {
r_k
} else {
one - r_k
};
prod *= factor;
}
prod
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FixedColumn<F> {
pub col_idx: usize,
pub shape: FixedShape<F>,
}
impl<F> FixedColumn<F> {
pub fn last_row(col_idx: usize) -> Self {
Self {
col_idx,
shape: FixedShape::LastRow,
}
}
pub fn first_row(col_idx: usize) -> Self {
Self {
col_idx,
shape: FixedShape::FirstRow,
}
}
pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
Self {
col_idx,
shape: FixedShape::Custom(bits),
}
}
pub fn periodic(col_idx: usize, period: usize, values: Vec<F>) -> Self {
Self {
col_idx,
shape: FixedShape::Periodic { period, values },
}
}
pub fn sparse(col_idx: usize, entries: Vec<(usize, F)>) -> Self {
Self {
col_idx,
shape: FixedShape::Sparse(entries),
}
}
pub fn dense(col_idx: usize, values: Vec<F>) -> Self {
Self {
col_idx,
shape: FixedShape::Dense(values),
}
}
}
pub fn fix<F>(col_idx: usize, shape: FixedShape<F>) -> FixedColumn<F> {
FixedColumn { col_idx, shape }
}
pub fn validate_fixed_columns<F: TowerField>(
fixed: &[FixedColumn<F>],
layout: &[ColumnType],
num_vars: Option<usize>,
) -> errors::Result<()> {
for (i, fc) in fixed.iter().enumerate() {
if fc.col_idx >= layout.len() {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "col_idx out of range",
});
}
validate_shape(&fc.shape, layout[fc.col_idx], num_vars)?;
for prior in &fixed[..i] {
if prior.col_idx == fc.col_idx {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "duplicate pin on same column",
});
}
}
}
Ok(())
}
fn validate_shape<F: TowerField>(
shape: &FixedShape<F>,
col_type: ColumnType,
num_vars: Option<usize>,
) -> errors::Result<()> {
match shape {
FixedShape::LastRow | FixedShape::FirstRow => Ok(()),
FixedShape::Custom(bits) => match num_vars {
Some(nv) if bits.len() != nv => Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Custom point bit width != num_vars",
}),
_ => Ok(()),
},
FixedShape::Periodic { period, values } => {
if !period.is_power_of_two() {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Periodic period must be a power of two",
});
}
if values.len() != *period {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Periodic values length != period",
});
}
if let Some(nv) = num_vars
&& *period > (1usize << nv)
{
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Periodic period exceeds trace height",
});
}
check_bit_domain(values.iter().copied(), col_type)
}
FixedShape::Sparse(entries) => {
if let Some(nv) = num_vars {
let n = 1usize << nv;
for &(row, _) in entries {
if row >= n {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Sparse row index exceeds trace height",
});
}
}
}
for (i, &(row, _)) in entries.iter().enumerate() {
if entries[..i].iter().any(|&(prior, _)| prior == row) {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "duplicate Sparse row",
});
}
}
check_bit_domain(entries.iter().map(|&(_, v)| v), col_type)
}
FixedShape::Dense(values) => {
if let Some(nv) = num_vars
&& values.len() != (1usize << nv)
{
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Dense values length != trace height",
});
}
check_bit_domain(values.iter().copied(), col_type)
}
}
}
fn check_bit_domain<F: TowerField>(
values: impl Iterator<Item = F>,
col_type: ColumnType,
) -> errors::Result<()> {
if col_type != ColumnType::Bit {
return Ok(());
}
for v in values {
if v != F::ZERO && v != F::ONE {
return Err(errors::Error::Protocol {
protocol: "fixed_column",
message: "Bit fixed column value not in {0,1}",
});
}
}
Ok(())
}