use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use hekate_core::errors::Error;
use hekate_core::trace::{ColumnTrace, ColumnType, TraceCompatibleField};
use hekate_math::TowerField;
use hekate_math::{Flat, HardwareField, PackableField};
use hekate_program::Air;
use hekate_program::chiplet::CompositeChiplet;
use hekate_program::constraint::ConstraintAst;
use hekate_program::constraint::builder::ConstraintSystem;
use hekate_program::define_columns;
use hekate_program::expander::VirtualExpander;
use hekate_program::permutation::{PermutationCheckSpec, REQUEST_IDX_LABEL, Source};
use super::sbox_rom;
use super::{AES_BYTE_LABELS, ROT_MAP, SBOX_IN_LABELS, SBOX_OUT_LABELS};
#[rustfmt::skip]
const AES256_KEY_LABELS: [&[u8]; 32] = [
b"aes_key_byte_0", b"aes_key_byte_1",
b"aes_key_byte_2", b"aes_key_byte_3",
b"aes_key_byte_4", b"aes_key_byte_5",
b"aes_key_byte_6", b"aes_key_byte_7",
b"aes_key_byte_8", b"aes_key_byte_9",
b"aes_key_byte_10", b"aes_key_byte_11",
b"aes_key_byte_12", b"aes_key_byte_13",
b"aes_key_byte_14", b"aes_key_byte_15",
b"aes_key_byte_16", b"aes_key_byte_17",
b"aes_key_byte_18", b"aes_key_byte_19",
b"aes_key_byte_20", b"aes_key_byte_21",
b"aes_key_byte_22", b"aes_key_byte_23",
b"aes_key_byte_24", b"aes_key_byte_25",
b"aes_key_byte_26", b"aes_key_byte_27",
b"aes_key_byte_28", b"aes_key_byte_29",
b"aes_key_byte_30", b"aes_key_byte_31",
];
define_columns! {
pub PhysAes256Columns {
P_STATE_IN: [B8; 16],
P_SBOX_OUT: [B8; 16],
P_ROUND_KEY: [B8; 16],
P_KEY_AUX: [B8; 16],
P_ROUND_NUM: B8,
P_RCON: B8,
P_S_ROUND: Bit,
P_S_FINAL: Bit,
P_S_IN_OUT: Bit,
P_S_ACTIVE: Bit,
P_S_INPUT: Bit,
P_S_EVEN: Bit,
P_K0: [B8; 32],
P_KS_INPUT: [B8; 4],
P_KS_SUB: [B8; 4],
P_KS_INV: [B8; 4],
P_KS_Z: [Bit; 4],
P_REQUEST_IDX_LINK: B32,
P_REQUEST_IDX_KEY: B32,
}
}
define_columns! {
pub Aes256Columns {
STATE_IN: [B8; 16],
SBOX_OUT: [B8; 16],
ROUND_KEY: [B8; 16],
KEY_AUX: [B8; 16],
ROUND_NUM: B8,
RCON: B8,
S_ROUND: Bit,
S_FINAL: Bit,
S_IN_OUT: Bit,
S_ACTIVE: Bit,
S_INPUT: Bit,
S_EVEN: Bit,
K0: [B8; 32],
KS_INPUT: [B8; 4],
KS_SUB: [B8; 4],
KS_INV_BITS: [Bit; 32],
KS_Z: [Bit; 4],
REQUEST_IDX_LINK: B32,
REQUEST_IDX_KEY: B32,
}
}
#[derive(Clone, Debug)]
pub struct AesRound256Air {
pub num_rows: usize,
}
impl AesRound256Air {
pub const LINK_BUS_ID: &'static str = "aes256_link";
pub const KEY_BUS_ID: &'static str = "aes256_key_in";
pub(crate) fn new(num_rows: usize) -> Self {
Self { num_rows }
}
pub fn for_constraints() -> Self {
Self { num_rows: 0 }
}
pub fn link_spec() -> PermutationCheckSpec {
let mut sources: Vec<_> = (0..16)
.map(|i| {
(
Source::Column(Aes256Columns::STATE_IN + i),
AES_BYTE_LABELS[i],
)
})
.collect();
sources.push((
Source::Column(Aes256Columns::REQUEST_IDX_LINK),
REQUEST_IDX_LABEL,
));
PermutationCheckSpec::new(sources, Some(Aes256Columns::S_IN_OUT))
}
pub fn key_spec() -> PermutationCheckSpec {
let mut sources: Vec<_> = (0..32)
.map(|i| (Source::Column(Aes256Columns::K0 + i), AES256_KEY_LABELS[i]))
.collect();
sources.push((
Source::Column(Aes256Columns::REQUEST_IDX_KEY),
REQUEST_IDX_LABEL,
));
PermutationCheckSpec::new(sources, Some(Aes256Columns::S_INPUT))
}
pub fn sbox_specs() -> Vec<(String, PermutationCheckSpec)> {
let mut sources = Vec::with_capacity(32);
for i in 0..16 {
sources.push((
Source::Column(Aes256Columns::STATE_IN + i),
SBOX_IN_LABELS[i],
));
sources.push((
Source::Column(Aes256Columns::SBOX_OUT + i),
SBOX_OUT_LABELS[i],
));
}
let spec = PermutationCheckSpec::new(sources, Some(Aes256Columns::S_ACTIVE))
.with_clock_waiver(
"see hekate-chiplets/src/aes/aes256.rs: AES<>SboxRom internal; \
phantom blocks caught at link+key v3",
);
vec![(sbox_rom::SboxRomChiplet::BUS_ID.into(), spec)]
}
}
impl<F: TowerField> Air<F> for AesRound256Air {
fn name(&self) -> String {
"AesRound256Air".to_string()
}
fn column_layout(&self) -> &[ColumnType] {
static LAYOUT: once_cell::race::OnceBox<Vec<ColumnType>> = once_cell::race::OnceBox::new();
LAYOUT.get_or_init(|| Box::new(PhysAes256Columns::build_layout()))
}
fn permutation_checks(&self) -> Vec<(String, PermutationCheckSpec)> {
let mut checks = Vec::with_capacity(3);
checks.push((Self::LINK_BUS_ID.into(), Self::link_spec()));
checks.push((Self::KEY_BUS_ID.into(), Self::key_spec()));
checks.extend(Self::sbox_specs());
checks
}
fn virtual_expander(&self) -> Option<&VirtualExpander> {
static E: once_cell::race::OnceBox<VirtualExpander> = once_cell::race::OnceBox::new();
Some(E.get_or_init(|| {
Box::new(
VirtualExpander::new()
.pass_through(66, ColumnType::B8) .control_bits(6) .pass_through(32, ColumnType::B8) .pass_through(4, ColumnType::B8) .pass_through(4, ColumnType::B8) .expand_bits(4, ColumnType::B8) .control_bits(4) .pass_through(2, ColumnType::B32) .build()
.expect("AesRound256Air expander"),
)
}))
}
#[allow(clippy::needless_range_loop)]
fn constraint_ast(&self) -> ConstraintAst<F> {
let cs = ConstraintSystem::<F>::new();
let s_round = cs.col(Aes256Columns::S_ROUND);
let s_final = cs.col(Aes256Columns::S_FINAL);
let s_in_out = cs.col(Aes256Columns::S_IN_OUT);
let s_active = cs.col(Aes256Columns::S_ACTIVE);
cs.assert_boolean(s_round);
cs.assert_boolean(s_final);
cs.assert_boolean(s_in_out);
cs.assert_boolean(s_active);
cs.constrain(s_round * s_final);
cs.constrain(s_active + s_round + s_final + s_round * s_final);
let next_s_active = cs.next(Aes256Columns::S_ACTIVE);
let next_s_in_out = cs.next(Aes256Columns::S_IN_OUT);
cs.constrain(s_active * (cs.one() + next_s_active + next_s_in_out));
let round_num = cs.col(Aes256Columns::ROUND_NUM);
let next_round_num = cs.next(Aes256Columns::ROUND_NUM);
let s_input = cs.col(Aes256Columns::S_INPUT);
let s_even = cs.col(Aes256Columns::S_EVEN);
let rcon = cs.col(Aes256Columns::RCON);
let next_rcon = cs.next(Aes256Columns::RCON);
let one = cs.one();
let two = cs.constant(F::from(2u8));
cs.assert_boolean(s_input);
cs.assert_boolean(s_even);
cs.constrain(s_input + s_in_out * s_round);
cs.assert_zero_when(s_input, round_num + one);
cs.assert_zero_when(s_round, next_round_num + two * round_num);
cs.assert_zero_when(s_final, round_num + cs.constant(F::from(0x4Du8)));
cs.assert_zero_when(s_input, s_even + one);
cs.assert_zero_when(s_round, cs.next(Aes256Columns::S_EVEN) + s_even + one);
cs.assert_zero_when(s_input, rcon + one);
cs.assert_zero_when(s_round * s_even, next_rcon + rcon);
cs.assert_zero_when(s_round * (one + s_even), next_rcon + two * rcon);
super::build_round_constraints(
&cs,
Aes256Columns::STATE_IN,
Aes256Columns::SBOX_OUT,
Aes256Columns::ROUND_KEY,
Aes256Columns::S_ROUND,
Aes256Columns::S_FINAL,
);
let s_round_even = s_round * s_even;
let s_round_odd = s_round * (one + s_even);
for j in 0..4usize {
let ks_in = cs.col(Aes256Columns::KS_INPUT + j);
let rot = cs.col(Aes256Columns::ROUND_KEY + ROT_MAP[j]);
let direct = cs.col(Aes256Columns::ROUND_KEY + 12 + j);
cs.assert_zero_when(s_round_even, ks_in + rot);
cs.assert_zero_when(s_round_odd, ks_in + direct);
}
super::build_sbox_inversion_constraints(
&cs,
core::array::from_fn(|j| Aes256Columns::KS_INPUT + j),
Aes256Columns::KS_SUB,
Aes256Columns::KS_INV_BITS,
Aes256Columns::KS_Z,
Aes256Columns::S_ROUND,
);
for j in 0..16usize {
let next_rk = cs.next(Aes256Columns::ROUND_KEY + j);
let aux = cs.col(Aes256Columns::KEY_AUX + j);
let body = match j {
0 => next_rk + aux + cs.col(Aes256Columns::KS_SUB) + s_even * rcon,
1..=3 => next_rk + aux + cs.col(Aes256Columns::KS_SUB + j),
4..=15 => next_rk + aux + cs.next(Aes256Columns::ROUND_KEY + j - 4),
_ => unreachable!(),
};
cs.assert_zero_when(s_round, body);
}
for j in 0..16usize {
cs.assert_zero_when(
s_round,
cs.next(Aes256Columns::KEY_AUX + j) + cs.col(Aes256Columns::ROUND_KEY + j),
);
}
for j in 0..16usize {
cs.assert_zero_when(
s_input,
cs.col(Aes256Columns::ROUND_KEY + j) + cs.col(Aes256Columns::K0 + 16 + j),
);
cs.assert_zero_when(
s_input,
cs.col(Aes256Columns::KEY_AUX + j) + cs.col(Aes256Columns::K0 + j),
);
}
let not_s_round = one + s_round;
cs.assert_zero_when(not_s_round, s_even);
for i in 0..4 {
cs.assert_zero_when(not_s_round, cs.col(Aes256Columns::KS_Z + i));
let ks_inv_byte = cs.sum(
&(0..8)
.map(|k| {
cs.scale(
F::from(1u8 << k),
cs.col(Aes256Columns::KS_INV_BITS + i * 8 + k),
)
})
.collect::<Vec<_>>(),
);
cs.assert_zero_when(not_s_round, ks_inv_byte);
}
cs.build()
}
}
define_columns! {
pub CpuAes256Columns {
KEY: [B8; 32],
KEY_SELECTOR: Bit,
DATA: [B8; 16],
SELECTOR: Bit,
}
}
pub struct CpuAes256Unit;
impl CpuAes256Unit {
pub fn num_columns() -> usize {
CpuAes256Columns::NUM_COLUMNS
}
pub fn linking_spec() -> PermutationCheckSpec {
let mut sources: Vec<_> = (0..16)
.map(|i| {
(
Source::Column(CpuAes256Columns::DATA + i),
AES_BYTE_LABELS[i],
)
})
.collect();
sources.push((Source::RowIndexLeBytes(4), REQUEST_IDX_LABEL));
PermutationCheckSpec::new(sources, Some(CpuAes256Columns::SELECTOR))
}
pub fn key_linking_spec() -> PermutationCheckSpec {
let mut sources: Vec<_> = (0..32)
.map(|i| {
(
Source::Column(CpuAes256Columns::KEY + i),
AES256_KEY_LABELS[i],
)
})
.collect();
sources.push((Source::RowIndexLeBytes(4), REQUEST_IDX_LABEL));
PermutationCheckSpec::new(sources, Some(CpuAes256Columns::KEY_SELECTOR))
}
}
#[derive(Clone)]
pub struct Aes256Chiplet<F: TraceCompatibleField> {
composite: CompositeChiplet<F>,
num_rows: usize,
sbox_rom_rows: usize,
}
impl<F> Aes256Chiplet<F>
where
F: TowerField + TraceCompatibleField + PackableField + HardwareField + 'static,
<F as PackableField>::Packed: Copy + Send + Sync,
Flat<F>: Send + Sync,
{
pub fn new(num_rows: usize, sbox_rom_rows: usize) -> Result<Self, Error> {
if !num_rows.is_power_of_two() {
return Err(Error::Protocol {
protocol: "aes256_chiplet",
message: "num_rows must be power of 2",
});
}
let round_air = AesRound256Air::new(num_rows);
let sbox_rom = sbox_rom::SboxRomChiplet::new(sbox_rom_rows)?;
let composite = CompositeChiplet::<F>::builder("aes256")
.chiplet(round_air)
.chiplet(sbox_rom)
.external_bus(AesRound256Air::LINK_BUS_ID, AesRound256Air::link_spec())
.external_bus(AesRound256Air::KEY_BUS_ID, AesRound256Air::key_spec())
.build()?;
Ok(Self {
composite,
num_rows,
sbox_rom_rows,
})
}
pub fn composite(&self) -> &CompositeChiplet<F> {
&self.composite
}
pub fn generate_traces(
&self,
calls: &[super::trace::Aes256Call],
) -> Result<Vec<ColumnTrace>, Error> {
let aes_trace = super::trace::generate_aes_trace(calls, None, self.num_rows)?;
let mut sbox_rounds = Vec::new();
let s_active = aes_trace.columns[PhysAes256Columns::P_S_ACTIVE]
.as_bit_slice()
.ok_or(Error::Protocol {
protocol: "aes256_chiplet",
message: "S_ACTIVE column type mismatch",
})?;
for (row, &active) in s_active.iter().enumerate() {
if active != hekate_math::Bit::ONE {
continue;
}
let mut inputs = [0u8; 16];
let mut outputs = [0u8; 16];
for j in 0..16 {
inputs[j] = aes_trace.columns[PhysAes256Columns::P_STATE_IN + j]
.as_b8_slice()
.unwrap()[row]
.to_tower()
.0;
outputs[j] = aes_trace.columns[PhysAes256Columns::P_SBOX_OUT + j]
.as_b8_slice()
.unwrap()[row]
.to_tower()
.0;
}
sbox_rounds.push(sbox_rom::SboxRound { inputs, outputs });
}
let sbox_trace = sbox_rom::generate_sbox_rom_trace(&sbox_rounds, self.sbox_rom_rows)?;
Ok(vec![aes_trace, sbox_trace])
}
}
#[cfg(test)]
mod tests {
use super::*;
use hekate_math::Block128;
type F = Block128;
#[test]
fn physical_column_count() {
let layout = PhysAes256Columns::build_layout();
assert_eq!(layout.len(), PhysAes256Columns::NUM_COLUMNS);
assert_eq!(PhysAes256Columns::P_STATE_IN, 0);
assert_eq!(PhysAes256Columns::P_ROUND_KEY, 32);
assert_eq!(PhysAes256Columns::P_KEY_AUX, 48);
}
#[test]
fn virtual_column_count() {
assert_eq!(Aes256Columns::STATE_IN, 0);
assert_eq!(Aes256Columns::SBOX_OUT, 16);
assert_eq!(Aes256Columns::ROUND_KEY, 32);
assert_eq!(Aes256Columns::KEY_AUX, 48);
}
#[test]
fn constraint_count() {
let ast: ConstraintAst<F> = AesRound256Air::for_constraints().constraint_ast();
assert_eq!(ast.roots.len(), 183);
}
#[test]
fn link_spec_structure() {
let spec = AesRound256Air::link_spec();
assert_eq!(spec.num_sources(), 17);
assert_eq!(spec.selector, Some(Aes256Columns::S_IN_OUT));
assert_eq!(spec.sources[16].1, REQUEST_IDX_LABEL);
}
#[test]
fn key_spec_structure() {
let spec = AesRound256Air::key_spec();
assert_eq!(spec.num_sources(), 33);
assert_eq!(spec.selector, Some(Aes256Columns::S_INPUT));
assert_eq!(spec.sources[32].1, REQUEST_IDX_LABEL);
}
#[test]
fn sbox_specs_structure() {
let specs = AesRound256Air::sbox_specs();
assert_eq!(specs.len(), 1);
let (bus_id, spec) = &specs[0];
assert_eq!(bus_id, sbox_rom::SboxRomChiplet::BUS_ID);
assert_eq!(spec.num_sources(), 32);
assert_eq!(spec.selector, Some(Aes256Columns::S_ACTIVE));
}
#[test]
fn virtual_expander_dimensions() {
let air = AesRound256Air::for_constraints();
let exp = Air::<F>::virtual_expander(&air).expect("expander must exist");
assert_eq!(exp.num_physical_columns(), PhysAes256Columns::NUM_COLUMNS);
assert_eq!(exp.num_virtual_columns(), Aes256Columns::NUM_COLUMNS);
}
#[test]
fn composite_builds() {
let aes = Aes256Chiplet::<F>::new(16, 256).unwrap();
assert_eq!(aes.composite().flatten_defs().unwrap().len(), 2);
}
#[test]
fn new_validates() {
assert!(Aes256Chiplet::<F>::new(100, 256).is_err());
assert!(Aes256Chiplet::<F>::new(16, 7).is_err());
assert!(Aes256Chiplet::<F>::new(16, 16).is_ok());
}
}