1#![cfg_attr(not(feature = "std"), no_std)]
19
20extern crate alloc;
21extern crate core;
22
23use alloc::string::{String, ToString};
24use alloc::vec::Vec;
25use constraint::{BoundaryConstraint, Constraint, ConstraintAst};
26use core::marker::PhantomData;
27use expander::VirtualExpander;
28use hekate_core::errors;
29use hekate_core::trace::{ColumnTrace, ColumnType, Trace, TraceCompatibleField};
30use hekate_math::{Flat, HardwareField, TowerField};
31use permutation::PermutationCheckSpec;
32
33pub mod chiplet;
34pub mod constraint;
35pub mod expander;
36pub mod permutation;
37pub mod schema;
38
39pub trait Air<F: TowerField>: Sized + Clone + Sync {
50 fn name(&self) -> String {
51 "HekateAir".to_string()
52 }
53
54 fn num_columns(&self) -> usize {
55 self.virtual_column_layout().len()
56 }
57
58 fn constraints(&self) -> Vec<Constraint<F>> {
60 self.constraint_ast().to_constraints()
61 }
62
63 fn boundary_constraints(&self) -> Vec<BoundaryConstraint<F>> {
67 Vec::new()
68 }
69
70 fn column_layout(&self) -> &[ColumnType];
76
77 fn virtual_column_layout(&self) -> &[ColumnType] {
81 match self.virtual_expander() {
82 Some(e) => e.virtual_layout(),
83 None => self.column_layout(),
84 }
85 }
86
87 fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
99 Vec::new()
100 }
101
102 fn fixed_columns(&self) -> Vec<FixedColumn<F>> {
106 Vec::new()
107 }
108
109 fn virtual_expander(&self) -> Option<&VirtualExpander> {
112 None
113 }
114
115 fn parse_virtual_row(&self, bytes: &[u8], res: &mut Vec<Flat<F>>)
124 where
125 F: TraceCompatibleField,
126 {
127 res.clear();
128
129 if let Some(e) = self.virtual_expander() {
130 e.parse_row(bytes, res)
131 .expect("committed row byte length must match physical_row_bytes");
132 return;
133 }
134
135 let mut offset = 0;
136 for col_type in self.column_layout() {
137 let size = col_type.byte_size();
138 if offset + size <= bytes.len() {
139 res.push(col_type.parse_from_bytes(&bytes[offset..offset + size]));
140 offset += size;
141 }
142 }
143 }
144
145 fn constraint_ast(&self) -> ConstraintAst<F>;
147
148 fn inline_chiplets(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
150 Ok(Vec::new())
151 }
152
153 fn inline_chiplet_kernels(&self) -> Vec<InlineKernelHint> {
155 Vec::new()
156 }
157}
158
159pub trait Program<F: TowerField>: Air<F> {
170 fn num_public_inputs(&self) -> usize {
172 0
173 }
174
175 fn chiplet_defs(&self) -> errors::Result<Vec<chiplet::ChipletDef<F>>> {
180 Ok(Vec::new())
181 }
182}
183
184#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
188pub struct ProgramCell {
189 pub col_idx: usize,
190
191 pub next_row: bool,
194}
195
196impl ProgramCell {
197 pub fn current(col_idx: usize) -> Self {
199 Self {
200 col_idx,
201 next_row: false,
202 }
203 }
204
205 pub fn next(col_idx: usize) -> Self {
207 Self {
208 col_idx,
209 next_row: true,
210 }
211 }
212}
213
214#[derive(Clone, Debug)]
221pub struct ProgramInstance<F: TowerField> {
222 num_rows: usize,
223 public_inputs: Vec<F>,
224}
225
226impl<F: TowerField> ProgramInstance<F> {
227 pub fn new(num_rows: usize, public_inputs: Vec<F>) -> Self {
228 assert!(
229 num_rows.is_power_of_two(),
230 "Program trace height must be power of 2"
231 );
232
233 Self {
234 num_rows,
235 public_inputs,
236 }
237 }
238
239 #[inline(always)]
240 pub fn num_rows(&self) -> usize {
241 self.num_rows
242 }
243
244 #[inline(always)]
246 pub fn public_inputs(&self) -> &[F] {
247 &self.public_inputs
248 }
249
250 #[inline(always)]
251 pub fn public_input(&self, idx: usize) -> Option<F> {
252 self.public_inputs.get(idx).copied()
253 }
254}
255
256pub struct ProgramWitness<F: TowerField, T: Trace = ColumnTrace> {
260 pub trace: T,
261 pub chiplet_traces: Vec<ColumnTrace>,
262 _marker: PhantomData<F>,
263}
264
265impl<F: TowerField, T: Trace> ProgramWitness<F, T> {
266 pub fn new(trace: T) -> Self {
267 Self {
268 trace,
269 chiplet_traces: Vec::new(),
270 _marker: PhantomData,
271 }
272 }
273
274 pub fn with_chiplets(mut self, chiplet_traces: Vec<ColumnTrace>) -> Self {
277 self.chiplet_traces = chiplet_traces;
278 self
279 }
280}
281
282#[derive(Clone, Copy, Debug)]
286pub struct InlineKernelHint {
287 pub chiplet_idx: usize,
289
290 pub root_offset: usize,
293
294 pub column_offset: usize,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
308pub enum FixedShape<F> {
309 LastRow,
310 FirstRow,
311 Custom(Vec<bool>),
312 Periodic { period: usize, values: Vec<F> },
313 Sparse(Vec<(usize, F)>),
314 Dense(Vec<F>),
315}
316
317impl<F: HardwareField> FixedShape<F> {
318 pub fn evaluate(&self, r: &[Flat<F>]) -> Flat<F> {
321 let one = Flat::from_raw(F::ONE);
322 match self {
323 FixedShape::LastRow => {
324 let mut prod = one;
325 for &r_k in r {
326 prod *= r_k;
327 }
328
329 one - prod
330 }
331 FixedShape::FirstRow => {
332 let mut prod = one;
333 for &r_k in r {
334 prod *= one - r_k;
335 }
336
337 prod
338 }
339 FixedShape::Custom(bits) => {
340 debug_assert_eq!(bits.len(), r.len(), "Custom point bit width != r.len()");
341
342 let mut prod = one;
343 for (k, &b) in bits.iter().enumerate() {
344 let factor = if b { r[k] } else { one - r[k] };
345 prod *= factor;
346 }
347
348 prod
349 }
350 FixedShape::Periodic { period, values } => {
351 let p = period.trailing_zeros() as usize;
354
355 let mut acc = Flat::from_raw(F::ZERO);
356 for (j, &v) in values.iter().enumerate() {
357 acc += v.to_hardware() * eq_index(&r[..p], j);
358 }
359
360 acc
361 }
362 FixedShape::Sparse(entries) => {
363 let mut acc = Flat::from_raw(F::ZERO);
364 for &(row, v) in entries {
365 acc += v.to_hardware() * eq_index(r, row);
366 }
367
368 acc
369 }
370 FixedShape::Dense(values) => {
371 let mut acc = Flat::from_raw(F::ZERO);
372 for (i, &v) in values.iter().enumerate() {
373 acc += v.to_hardware() * eq_index(r, i);
374 }
375
376 acc
377 }
378 }
379 }
380
381 pub fn value_at_row(&self, row: usize, num_vars: usize) -> Flat<F> {
385 let one = Flat::from_raw(F::ONE);
386 let zero = Flat::from_raw(F::ZERO);
387
388 match self {
389 FixedShape::FirstRow => {
390 if row == 0 {
391 one
392 } else {
393 zero
394 }
395 }
396 FixedShape::LastRow => {
397 if row == (1usize << num_vars) - 1 {
398 zero
399 } else {
400 one
401 }
402 }
403 FixedShape::Custom(bits) => {
404 let target = bits
405 .iter()
406 .enumerate()
407 .fold(0usize, |acc, (k, &b)| acc | ((b as usize) << k));
408
409 if row == target { one } else { zero }
410 }
411 FixedShape::Periodic { period, values } => values[row % period].to_hardware(),
412 FixedShape::Sparse(entries) => {
413 let mut acc = zero;
414 for &(r, v) in entries {
415 if r == row {
416 acc += v.to_hardware();
417 }
418 }
419
420 acc
421 }
422 FixedShape::Dense(values) => values[row].to_hardware(),
423 }
424 }
425}
426
427fn eq_index<F: HardwareField>(r: &[Flat<F>], index: usize) -> Flat<F> {
428 let one = Flat::from_raw(F::ONE);
429
430 let mut prod = one;
431 for (k, &r_k) in r.iter().enumerate() {
432 let factor = if (index >> k) & 1 == 1 {
433 r_k
434 } else {
435 one - r_k
436 };
437 prod *= factor;
438 }
439
440 prod
441}
442
443#[derive(Clone, Debug, PartialEq, Eq)]
445pub struct FixedColumn<F> {
446 pub col_idx: usize,
447 pub shape: FixedShape<F>,
448}
449
450impl<F> FixedColumn<F> {
451 pub fn last_row(col_idx: usize) -> Self {
452 Self {
453 col_idx,
454 shape: FixedShape::LastRow,
455 }
456 }
457
458 pub fn first_row(col_idx: usize) -> Self {
459 Self {
460 col_idx,
461 shape: FixedShape::FirstRow,
462 }
463 }
464
465 pub fn custom(col_idx: usize, bits: Vec<bool>) -> Self {
466 Self {
467 col_idx,
468 shape: FixedShape::Custom(bits),
469 }
470 }
471
472 pub fn periodic(col_idx: usize, period: usize, values: Vec<F>) -> Self {
473 Self {
474 col_idx,
475 shape: FixedShape::Periodic { period, values },
476 }
477 }
478
479 pub fn sparse(col_idx: usize, entries: Vec<(usize, F)>) -> Self {
480 Self {
481 col_idx,
482 shape: FixedShape::Sparse(entries),
483 }
484 }
485
486 pub fn dense(col_idx: usize, values: Vec<F>) -> Self {
487 Self {
488 col_idx,
489 shape: FixedShape::Dense(values),
490 }
491 }
492}
493
494pub fn fix<F>(col_idx: usize, shape: FixedShape<F>) -> FixedColumn<F> {
496 FixedColumn { col_idx, shape }
497}
498
499pub fn validate_fixed_columns<F: TowerField>(
503 fixed: &[FixedColumn<F>],
504 layout: &[ColumnType],
505 num_vars: Option<usize>,
506) -> errors::Result<()> {
507 for (i, fc) in fixed.iter().enumerate() {
508 if fc.col_idx >= layout.len() {
509 return Err(errors::Error::Protocol {
510 protocol: "fixed_column",
511 message: "col_idx out of range",
512 });
513 }
514
515 validate_shape(&fc.shape, layout[fc.col_idx], num_vars)?;
516
517 for prior in &fixed[..i] {
518 if prior.col_idx == fc.col_idx {
519 return Err(errors::Error::Protocol {
520 protocol: "fixed_column",
521 message: "duplicate pin on same column",
522 });
523 }
524 }
525 }
526
527 Ok(())
528}
529
530fn validate_shape<F: TowerField>(
531 shape: &FixedShape<F>,
532 col_type: ColumnType,
533 num_vars: Option<usize>,
534) -> errors::Result<()> {
535 match shape {
536 FixedShape::LastRow | FixedShape::FirstRow => Ok(()),
537 FixedShape::Custom(bits) => match num_vars {
538 Some(nv) if bits.len() != nv => Err(errors::Error::Protocol {
539 protocol: "fixed_column",
540 message: "Custom point bit width != num_vars",
541 }),
542 _ => Ok(()),
543 },
544 FixedShape::Periodic { period, values } => {
545 if !period.is_power_of_two() {
546 return Err(errors::Error::Protocol {
547 protocol: "fixed_column",
548 message: "Periodic period must be a power of two",
549 });
550 }
551
552 if values.len() != *period {
553 return Err(errors::Error::Protocol {
554 protocol: "fixed_column",
555 message: "Periodic values length != period",
556 });
557 }
558
559 if let Some(nv) = num_vars
560 && *period > (1usize << nv)
561 {
562 return Err(errors::Error::Protocol {
563 protocol: "fixed_column",
564 message: "Periodic period exceeds trace height",
565 });
566 }
567
568 check_bit_domain(values.iter().copied(), col_type)
569 }
570 FixedShape::Sparse(entries) => {
571 if let Some(nv) = num_vars {
572 let n = 1usize << nv;
573 for &(row, _) in entries {
574 if row >= n {
575 return Err(errors::Error::Protocol {
576 protocol: "fixed_column",
577 message: "Sparse row index exceeds trace height",
578 });
579 }
580 }
581 }
582
583 for (i, &(row, _)) in entries.iter().enumerate() {
584 if entries[..i].iter().any(|&(prior, _)| prior == row) {
585 return Err(errors::Error::Protocol {
586 protocol: "fixed_column",
587 message: "duplicate Sparse row",
588 });
589 }
590 }
591
592 check_bit_domain(entries.iter().map(|&(_, v)| v), col_type)
593 }
594 FixedShape::Dense(values) => {
595 if let Some(nv) = num_vars
596 && values.len() != (1usize << nv)
597 {
598 return Err(errors::Error::Protocol {
599 protocol: "fixed_column",
600 message: "Dense values length != trace height",
601 });
602 }
603
604 check_bit_domain(values.iter().copied(), col_type)
605 }
606 }
607}
608
609fn check_bit_domain<F: TowerField>(
610 values: impl Iterator<Item = F>,
611 col_type: ColumnType,
612) -> errors::Result<()> {
613 if col_type != ColumnType::Bit {
614 return Ok(());
615 }
616
617 for v in values {
618 if v != F::ZERO && v != F::ONE {
619 return Err(errors::Error::Protocol {
620 protocol: "fixed_column",
621 message: "Bit fixed column value not in {0,1}",
622 });
623 }
624 }
625
626 Ok(())
627}