mod arithmetic;
mod ctrl;
mod trace;
mod witness;
pub use ctrl::MlDsaCtrlColumns;
pub use witness::{MlDsaPublicKey, MlDsaSignature};
use super::high_bits::HighBitsChiplet;
use super::norm_check::NormCheckChiplet;
use super::ntt::NttChiplet;
use super::twiddle_rom::TwiddleRomChiplet;
use alloc::vec;
use ctrl::MlDsaCtrlChiplet;
use hekate_core::trace::TraceCompatibleField;
use hekate_gadgets::chiplets::ram::RamChiplet;
use hekate_keccak::KeccakChiplet;
use hekate_math::{Flat, HardwareField, PackableField, TowerField};
use hekate_program::chiplet::CompositeChiplet;
use hekate_program::define_columns;
use hekate_program::permutation::{PermutationCheckSpec, REQUEST_IDX_LABEL, Source};
pub const MLDSA_Q: u32 = 8380417;
pub const MLDSA_BIT_WIDTH: usize = 23;
pub const N: usize = 256;
pub const MLDSA_DATA_BUS_ID: &str = "ml_dsa_data";
const KEC_INPUT_BIND_BUS_ID: &str = "kec_input_bind";
#[derive(Clone, Copy, Debug)]
pub struct MlDsaLevel {
pub(crate) k: usize,
pub(crate) l: usize,
#[allow(dead_code)]
pub(crate) eta: u32,
pub(crate) tau: usize,
pub(crate) gamma1: u32,
pub(crate) gamma2: u32,
pub(crate) beta: u32,
pub(crate) omega: usize,
pub(crate) d: usize,
}
impl MlDsaLevel {
pub const MLDSA_44: Self = Self {
k: 4,
l: 4,
eta: 2,
tau: 39,
gamma1: 1 << 17, gamma2: 95232, beta: 78, omega: 80,
d: 13,
};
pub const MLDSA_65: Self = Self {
k: 6,
l: 5,
eta: 4,
tau: 49,
gamma1: 1 << 19, gamma2: 261888, beta: 196, omega: 55,
d: 13,
};
pub const MLDSA_87: Self = Self {
k: 8,
l: 7,
eta: 2,
tau: 60,
gamma1: 1 << 19,
gamma2: 261888, beta: 120, omega: 75,
d: 13,
};
pub fn k(&self) -> usize {
self.k
}
pub fn l(&self) -> usize {
self.l
}
pub fn gamma2(&self) -> u32 {
self.gamma2
}
pub fn omega(&self) -> usize {
self.omega
}
pub fn z_bound(&self) -> u32 {
self.gamma1 - self.beta
}
pub fn highbits_divisor(&self) -> u32 {
2 * self.gamma2
}
pub fn pk_bytes(&self) -> usize {
32 + self.k * 320
}
pub fn sig_bytes(&self) -> usize {
let lambda_bytes = match self.k {
4 => 32, 6 => 48, 8 => 64, _ => unreachable!(),
};
let gamma1_bits = if self.gamma1 == (1 << 17) { 18 } else { 20 };
let z_bytes = self.l * gamma1_bits * N / 8;
let h_bytes = self.omega + self.k;
lambda_bytes + z_bytes + h_bytes
}
}
#[derive(Clone, Debug)]
pub struct MlDsaParams {
pub ctrl_rows: usize,
pub keccak_rows: usize,
pub ntt_rows: usize,
pub twiddle_rows: usize,
pub norm_rows: usize,
pub highbits_rows: usize,
pub ram_rows: usize,
}
impl Default for MlDsaParams {
fn default() -> Self {
Self {
ctrl_rows: 1 << 15, keccak_rows: 1 << 11, ntt_rows: 1 << 16, twiddle_rows: 1 << 10, norm_rows: 1 << 11, highbits_rows: 1 << 11, ram_rows: 1 << 15, }
}
}
#[derive(Clone)]
pub struct MlDsaChiplet<F: TraceCompatibleField> {
composite: CompositeChiplet<F>,
level: MlDsaLevel,
params: MlDsaParams,
}
impl<F> MlDsaChiplet<F>
where
F: TowerField + TraceCompatibleField + PackableField + HardwareField + 'static,
<F as PackableField>::Packed: Copy + Send + Sync,
Flat<F>: Send + Sync,
{
pub fn new(level: MlDsaLevel, params: MlDsaParams) -> Self {
let ctrl = MlDsaCtrlChiplet::new(params.ctrl_rows);
let keccak = KeccakChiplet::new(params.keccak_rows);
let ntt = NttChiplet::new(MLDSA_Q, params.ntt_rows);
let twiddle = TwiddleRomChiplet::new(MLDSA_Q, params.twiddle_rows);
let norm = NormCheckChiplet::new(MLDSA_Q, level.z_bound(), params.norm_rows);
let highbits =
HighBitsChiplet::new(MLDSA_Q, level.highbits_divisor(), params.highbits_rows);
let ram = RamChiplet::new(params.ram_rows);
let composite = CompositeChiplet::<F>::builder("mldsa")
.chiplet(ctrl)
.chiplet(keccak)
.chiplet(ntt)
.chiplet(twiddle)
.chiplet(norm)
.chiplet(highbits)
.chiplet(ram)
.external_bus(MLDSA_DATA_BUS_ID, MlDsaCtrlChiplet::main_linking_spec())
.build()
.expect("ML-DSA composite build must succeed");
Self {
composite,
level,
params,
}
}
pub fn composite(&self) -> &CompositeChiplet<F> {
&self.composite
}
pub fn level(&self) -> MlDsaLevel {
self.level
}
pub fn params(&self) -> &MlDsaParams {
&self.params
}
}
define_columns! {
pub CpuMlDsaColumns {
DATA: B32,
SELECTOR: Bit,
}
}
pub struct CpuMlDsaUnit;
impl CpuMlDsaUnit {
pub fn num_columns() -> usize {
CpuMlDsaColumns::NUM_COLUMNS
}
pub fn linking_spec() -> PermutationCheckSpec {
PermutationCheckSpec::new(
vec![
(
Source::Column(CpuMlDsaColumns::DATA),
b"kappa_mldsa_d0" as &[u8],
),
(Source::RowIndexLeBytes(4), REQUEST_IDX_LABEL),
],
Some(CpuMlDsaColumns::SELECTOR),
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub(crate) enum Phase {
Io = 0,
ExpandSample = 1,
NttForward = 2,
PointwiseMul = 3,
NttInverse = 4,
UseHint = 5,
HashCompare = 6,
NormCheck = 7,
}