mod arithmetic;
mod ctrl;
mod trace;
mod witness;
pub use ctrl::MlKemCtrlColumns;
use super::basemul::BasemulChiplet;
use super::ntt::NttChiplet;
use super::twiddle_rom::TwiddleRomChiplet;
use alloc::vec;
use ctrl::MlKemCtrlChiplet;
use hekate_core::trace::TraceCompatibleField;
use hekate_gadgets::chiplets::ram::RamChiplet;
use hekate_keccak::KeccakChiplet;
use hekate_math::{Flat, HardwareField, PackableField, TowerField};
use hekate_program::chiplet::CompositeChiplet;
use hekate_program::define_columns;
use hekate_program::permutation::{PermutationCheckSpec, REQUEST_IDX_LABEL, Source};
pub const MLKEM_Q: u32 = 3329;
pub const MLKEM_BIT_WIDTH: usize = 12;
pub const MLKEM_DATA_BUS_ID: &str = "ml_kem_data";
pub const MLKEM_SS_BUS_ID: &str = "ml_kem_ss";
const KEC_INPUT_BIND_BUS_ID: &str = "kec_input_bind";
const N: usize = 256;
#[derive(Clone, Copy, Debug)]
pub struct MlKemLevel {
pub k: usize,
pub eta1: usize,
pub eta2: usize,
pub du: usize,
pub dv: usize,
}
impl MlKemLevel {
pub const MLKEM_512: Self = Self {
k: 2,
eta1: 3,
eta2: 2,
du: 10,
dv: 4,
};
pub const MLKEM_768: Self = Self {
k: 3,
eta1: 2,
eta2: 2,
du: 10,
dv: 4,
};
pub const MLKEM_1024: Self = Self {
k: 4,
eta1: 2,
eta2: 2,
du: 11,
dv: 5,
};
pub fn sk_bytes(&self) -> usize {
let dk_pke = self.k * 12 * N / 8;
let ek = self.ek_bytes();
dk_pke + ek + 32 + 32
}
pub fn ek_bytes(&self) -> usize {
self.k * 12 * N / 8 + 32
}
pub fn ct_bytes(&self) -> usize {
self.k * N * self.du / 8 + N * self.dv / 8
}
}
#[derive(Clone, Debug)]
pub struct MlKemParams {
pub ctrl_rows: usize,
pub keccak_rows: usize,
pub ntt_rows: usize,
pub twiddle_rows: usize,
pub basemul_rows: usize,
pub ram_rows: usize,
}
impl Default for MlKemParams {
fn default() -> Self {
Self {
ctrl_rows: 1 << 14, keccak_rows: 1 << 9, ntt_rows: 1 << 15, twiddle_rows: 1 << 10, basemul_rows: 1 << 12, ram_rows: 1 << 15, }
}
}
#[derive(Clone)]
pub struct MlKemChiplet<F: TraceCompatibleField> {
composite: CompositeChiplet<F>,
level: MlKemLevel,
params: MlKemParams,
}
impl<F> MlKemChiplet<F>
where
F: TowerField + TraceCompatibleField + PackableField + HardwareField + 'static,
<F as PackableField>::Packed: Copy + Send + Sync,
Flat<F>: Send + Sync,
{
pub fn new(level: MlKemLevel, params: MlKemParams) -> Self {
let ctrl = MlKemCtrlChiplet::new(params.ctrl_rows);
let keccak = KeccakChiplet::new(params.keccak_rows);
let ntt = NttChiplet::new(MLKEM_Q, params.ntt_rows);
let twiddle = TwiddleRomChiplet::new(MLKEM_Q, params.twiddle_rows);
let basemul = BasemulChiplet::new(MLKEM_Q, params.basemul_rows);
let ram = RamChiplet::new(params.ram_rows);
let composite = CompositeChiplet::<F>::builder("mlkem")
.chiplet(ctrl)
.chiplet(keccak)
.chiplet(ntt)
.chiplet(twiddle)
.chiplet(basemul)
.chiplet(ram)
.external_bus(MLKEM_DATA_BUS_ID, MlKemCtrlChiplet::main_linking_spec())
.external_bus(MLKEM_SS_BUS_ID, MlKemCtrlChiplet::ss_linking_spec())
.build()
.expect("ML-KEM composite build must succeed");
Self {
composite,
level,
params,
}
}
pub fn composite(&self) -> &CompositeChiplet<F> {
&self.composite
}
pub fn level(&self) -> MlKemLevel {
self.level
}
pub fn params(&self) -> &MlKemParams {
&self.params
}
}
define_columns! {
pub CpuMlKemColumns {
DATA: B32,
SELECTOR: Bit,
SS_DATA: [B32; 8],
SS_SELECTOR: Bit,
}
}
pub struct CpuMlKemUnit;
impl CpuMlKemUnit {
pub fn num_columns() -> usize {
CpuMlKemColumns::NUM_COLUMNS
}
pub fn linking_spec() -> PermutationCheckSpec {
PermutationCheckSpec::new(
vec![
(
Source::Column(CpuMlKemColumns::DATA),
b"kappa_mlkem_d0" as &[u8],
),
(Source::RowIndexLeBytes(4), REQUEST_IDX_LABEL),
],
Some(CpuMlKemColumns::SELECTOR),
)
}
pub fn ss_linking_spec() -> PermutationCheckSpec {
PermutationCheckSpec::new(
vec![
(
Source::Column(CpuMlKemColumns::SS_DATA),
b"kappa_ss_lo0" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 1),
b"kappa_ss_lo1" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 2),
b"kappa_ss_lo2" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 3),
b"kappa_ss_lo3" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 4),
b"kappa_ss_hi0" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 5),
b"kappa_ss_hi1" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 6),
b"kappa_ss_hi2" as &[u8],
),
(
Source::Column(CpuMlKemColumns::SS_DATA + 7),
b"kappa_ss_hi3" as &[u8],
),
(Source::RowIndexLeBytes(4), REQUEST_IDX_LABEL),
],
Some(CpuMlKemColumns::SS_SELECTOR),
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub(crate) enum Phase {
Io = 0,
Decrypt = 1,
GHash = 2,
Encrypt = 3,
CmpHash = 4,
Compare = 5,
}
#[cfg(test)]
mod tests {
use super::*;
use hekate_math::Block128;
use hekate_program::chiplet::ChipletDef;
type F = Block128;
#[test]
fn composite_builds_six_chiplets() {
let mlkem = MlKemChiplet::<F>::new(
MlKemLevel::MLKEM_768,
MlKemParams {
ctrl_rows: 16,
keccak_rows: 32,
ntt_rows: 16,
twiddle_rows: 16,
basemul_rows: 16,
ram_rows: 16,
},
);
assert_eq!(mlkem.composite().len(), 6);
assert_eq!(mlkem.composite().name(), "mlkem");
}
#[test]
fn flatten_defs_produces_six_defs() {
let mlkem = MlKemChiplet::<F>::new(
MlKemLevel::MLKEM_768,
MlKemParams {
ctrl_rows: 16,
keccak_rows: 32,
ntt_rows: 16,
twiddle_rows: 16,
basemul_rows: 16,
ram_rows: 16,
},
);
let defs: Vec<ChipletDef<F>> = mlkem.composite().flatten_defs().unwrap();
assert_eq!(defs.len(), 6);
}
#[test]
fn internal_buses_namespaced() {
let mlkem = MlKemChiplet::<F>::new(
MlKemLevel::MLKEM_768,
MlKemParams {
ctrl_rows: 16,
keccak_rows: 32,
ntt_rows: 16,
twiddle_rows: 16,
basemul_rows: 16,
ram_rows: 16,
},
);
let defs: Vec<ChipletDef<F>> = mlkem.composite().flatten_defs().unwrap();
let mut bus_ids: Vec<String> = Vec::new();
for def in &defs {
for (id, _) in &def.permutation_checks {
bus_ids.push(id.clone());
}
}
assert!(
bus_ids.contains(&"ml_kem_data".to_string()),
"external bus must not be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::keccak_link".to_string()),
"keccak bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::ntt_data".to_string()),
"ntt_data bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::ntt_twiddle".to_string()),
"ntt_twiddle bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::basemul".to_string()),
"basemul bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::ram_link".to_string()),
"ram bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::ntt_bound_in".to_string()),
"ntt_bound_in bus must be namespaced, got: {bus_ids:?}",
);
assert!(
bus_ids.contains(&"mlkem::ntt_bound_out".to_string()),
"ntt_bound_out bus must be namespaced, got: {bus_ids:?}",
);
}
#[test]
fn external_bus_spec_returned() {
let mlkem = MlKemChiplet::<F>::new(
MlKemLevel::MLKEM_768,
MlKemParams {
ctrl_rows: 16,
keccak_rows: 32,
ntt_rows: 16,
twiddle_rows: 16,
basemul_rows: 16,
ram_rows: 16,
},
);
let ext = mlkem.composite().external_buses();
assert_eq!(ext.len(), 2);
assert_eq!(ext[0].0, "ml_kem_data");
assert_eq!(ext[1].0, "ml_kem_ss");
}
}