#![cfg_attr(not(feature = "std"), no_std)]
extern crate alloc;
extern crate core;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
use core::marker::PhantomData;
use expander::VirtualExpander;
use hekate_core::errors;
use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
use hekate_math::{Flat, HardwareField, TowerField};
use permutation::PermutationCheckSpec;
pub mod chiplet;
pub mod constraint;
pub mod expander;
pub mod permutation;
pub mod schema;
pub trait Air<F: TowerField>: Sized + Clone + Sync {
fn name(&self) -> String {
"HekateAir".to_string()
}
fn num_columns(&self) -> usize {
self.virtual_column_layout().len()
}
fn constraints(&self) -> Vec<Constraint<F>> {
self.constraint_ast().to_constraints()
}
fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
Vec::new()
}
fn column_layout(&self) -> &[ColumnType];
fn virtual_column_layout(&self) -> &[ColumnType] {
match self.virtual_expander() {
Some(e) => e.virtual_layout(),
None => self.column_layout(),
}
}
fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
Vec::new()
}
fn lagrange_pinned_columns(&self) -> Vec<LagrangePin> {
Vec::new()
}
fn virtual_expander(&self) -> Option<&VirtualExpander> {
None
}
fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
where
F: TraceCompatibleField,
{
res.clear();
if let Some(e) = self.virtual_expander() {
e.parse_row(bytes, res)
.expect("committed row byte length must match physical_row_bytes");
return;
}
let mut offset = 0;
for col_type in self.column_layout() {
let size = col_type.byte_size();
if offset + size <= bytes.len() {
res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
offset += size;
}
}
}
fn constraint_ast(&self) -> ConstraintAst<F>;
fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
Ok(Vec::new())
}
fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
Vec::new()
}
}
pub trait Program<F: TowerField>: Air<F> {
fn num_public_inputs(&self) -> usize {
0
}
fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
Ok(Vec::new())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ProgramCell {
pub col_idx: usize,
pub next_row: bool,
}
impl ProgramCell {
pub fn current(col_idx: usize) -> Self {
Self {
col_idx,
next_row: false,
}
}
pub fn next(col_idx: usize) -> Self {
Self {
col_idx,
next_row: true,
}
}
}
#[derive(Clone, Debug)]
pub struct ProgramInstance<F: TowerField> {
num_rows: usize,
public_inputs: Vec<F>,
}
impl<F: TowerField> ProgramInstance<F> {
pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
assert!(
num_rows.is_power_of_two(),
"Program trace height must be power of 2"
);
Self {
num_rows,
public_inputs,
}
}
#[inline(always)]
pub fn num_rows(&self) -> usize {
self.num_rows
}
#[inline(always)]
pub fn public_inputs(&self) -> &[F] {
&self.public_inputs
}
#[inline(always)]
pub fn public_input(&self, idx: usize) -> Option<F> {
self.public_inputs.get(idx).copied()
}
}
pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
pub trace: T,
pub chiplet_traces: Vec<ColumnTrace>,
_marker: PhantomData<F>,
}
impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
pub fn new(trace: T) -> Self {
Self {
trace,
chiplet_traces: Vec::new(),
_marker: PhantomData,
}
}
pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
self.chiplet_traces = chiplet_traces;
self
}
}
#[derive(Clone, Copy, Debug)]
pub struct InlineKernelHint {
pub chiplet_idx: usize,
pub root_offset: usize,
pub column_offset: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum LagrangePoint {
LastRow,
FirstRow,
Custom(Vec<bool>),
}
impl LagrangePoint {
pub fn evaluate<F>(&self, r: &[Flat<F>]) -> Flat<F>
where
F: HardwareField,
{
let one = Flat::from_raw(F::ONE);
match self {
LagrangePoint::LastRow => {
let mut prod = one;
for &r_k in r {
prod *= r_k;
}
one - prod
}
LagrangePoint::FirstRow => {
let mut prod = one;
for &r_k in r {
prod *= one - r_k;
}
prod
}
LagrangePoint::Custom(bits) => {
debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
let mut prod = one;
for (k, &b) in bits.iter().enumerate() {
let factor = if b { r[k] } else { one - r[k] };
prod *= factor;
}
prod
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LagrangePin {
pub col_idx: usize,
pub point: LagrangePoint,
}
impl LagrangePin {
pub fn last_row(col_idx: usize) -> Self {
Self {
col_idx,
point: LagrangePoint::LastRow,
}
}
pub fn first_row(col_idx: usize) -> Self {
Self {
col_idx,
point: LagrangePoint::FirstRow,
}
}
pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
Self {
col_idx,
point: LagrangePoint::Custom(bits),
}
}
}
pub fn validate_lagrange_pins(
pins: &[LagrangePin],
num_columns: usize,
num_vars: Option<usize>,
) -> errors::Result<()> {
for (i, pin) in pins.iter().enumerate() {
if pin.col_idx >= num_columns {
return Err(errors::Error::Protocol {
protocol: "lagrange_pin",
message: "col_idx out of range",
});
}
if let (LagrangePoint::Custom(bits), Some(nv)) = (&pin.point, num_vars)
&& bits.len() != nv
{
return Err(errors::Error::Protocol {
protocol: "lagrange_pin",
message: "Custom point bit width != num_vars",
});
}
for prior in &pins[..i] {
if prior.col_idx == pin.col_idx {
return Err(errors::Error::Protocol {
protocol: "lagrange_pin",
message: "duplicate pin on same column",
});
}
}
}
Ok(())
}