use alloc::{string::ToString, vec::Vec};
use miden_crypto_derive::{SilentDebug, SilentDisplay};
use num::Complex;
use num_complex::Complex64;
use rand::Rng;
use super::{
super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, MODULUS, N, Nonce,
SIG_L2_BOUND, SIGMA, Serializable, ShortLatticeBasis, Signature,
math::{FalconFelt, FastFft, LdlTree, Polynomial, ffldl, ffsampling, gram, normalize_tree},
signature::SignaturePoly,
},
PublicKey,
};
use crate::{
Word,
dsa::falcon512_poseidon2::{
LOG_N, SK_LEN, hash_to_point::hash_to_point_poseidon2, math::ntru_gen,
},
hash::blake::Blake3_256,
utils::zeroize::{Zeroize, ZeroizeOnDrop},
};
pub(crate) const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
pub(crate) const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
#[derive(Clone, SilentDebug, SilentDisplay)]
pub struct SecretKey {
secret_key: ShortLatticeBasis,
tree: LdlTree,
}
impl Zeroize for SecretKey {
fn zeroize(&mut self) {
self.secret_key.zeroize();
self.tree.zeroize();
}
}
impl Drop for SecretKey {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for SecretKey {}
#[allow(clippy::new_without_default)]
impl SecretKey {
#[cfg(feature = "std")]
pub fn new() -> Self {
let mut rng = rand::rng();
Self::with_rng(&mut rng)
}
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
let basis = ntru_gen(N, rng);
Self::from_short_lattice_basis(basis)
}
pub(crate) fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
let basis_fft = to_complex_fft(&basis);
let gram_fft = gram(basis_fft);
let mut tree = ffldl(gram_fft);
normalize_tree(&mut tree, SIGMA);
Self { secret_key: basis, tree }
}
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
&self.secret_key
}
pub fn public_key(&self) -> PublicKey {
self.compute_pub_key_poly()
}
pub fn tree(&self) -> &LdlTree {
&self.tree
}
pub fn sign(&self, message: Word) -> Signature {
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
let mut seed = self.generate_seed(&message);
let mut rng = ChaCha20Rng::from_seed(seed);
let signature = self.sign_with_rng(message, &mut rng);
seed.zeroize();
signature
}
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
let nonce = Nonce::deterministic();
let h = self.compute_pub_key_poly();
let c = hash_to_point_poseidon2(message, &nonce);
let s2 = self.sign_helper(c, rng);
Signature::new(nonce, h, s2)
}
#[cfg(test)]
pub fn sign_with_rng_testing<R: Rng>(&self, message: &[u8], rng: &mut R) -> Signature {
use crate::dsa::falcon512_poseidon2::{
hash_to_point::hash_to_point_shake256, tests::ChaCha,
};
let nonce = Nonce::random(rng);
let h = self.compute_pub_key_poly();
let c = hash_to_point_shake256(message, &nonce);
let mut chacha_prng = ChaCha::new(rng);
let s2 = self.sign_helper(c, &mut chacha_prng);
Signature::new(nonce, h, s2)
}
fn compute_pub_key_poly(&self) -> PublicKey {
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
let g_fft = g.fft();
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
let f = -minus_f;
let f_fft = f.fft();
let h_fft = g_fft.hadamard_div(&f_fft);
h_fft.ifft().into()
}
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
let one_over_q = 1.0 / (MODULUS as f64);
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
loop {
let bold_s = loop {
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
let t0_min_z0 = t0.clone() - z.0;
let t1_min_z1 = t1.clone() - z.1;
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
let s1 =
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
let length_squared: f64 =
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
/ (N as f64);
if length_squared > (SIG_L2_BOUND as f64) {
continue;
}
break [-s0, s1];
};
let s2 = bold_s[1].ifft();
let s2_coef: [i16; N] = s2
.coefficients
.iter()
.map(|a| a.re.round() as i16)
.collect::<Vec<i16>>()
.try_into()
.expect("The number of coefficients should be equal to N");
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
return s2;
}
}
}
fn generate_seed(&self, message: &Word) -> [u8; 32] {
let mut buffer = Vec::with_capacity(1 + SK_LEN + Word::SERIALIZED_SIZE);
buffer.push(LOG_N);
buffer.extend_from_slice(&self.to_bytes());
buffer.extend_from_slice(&message.to_bytes());
let digest = Blake3_256::hash(&buffer);
buffer.zeroize();
digest.into()
}
}
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
use subtle::ConstantTimeEq;
self.to_bytes().ct_eq(&other.to_bytes()).into()
}
}
impl Eq for SecretKey {}
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let basis = &self.secret_key;
let n = basis[0].coefficients.len();
let l = n.checked_ilog2().unwrap() as u8;
let header: u8 = (5 << 4) | l;
let neg_f = &basis[1];
let g = &basis[0];
let neg_big_f = &basis[3];
let mut buffer = Vec::with_capacity(1281);
buffer.push(header);
let mut f_i8: Vec<i8> = neg_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&f_i8_encoded);
f_i8.zeroize();
let mut g_i8: Vec<i8> = g
.coefficients
.iter()
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
.collect();
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&g_i8_encoded);
g_i8.zeroize();
let mut big_f_i8: Vec<i8> = neg_big_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&big_f_i8_encoded);
big_f_i8.zeroize();
target.write_bytes(&buffer);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let byte_vector: [u8; SK_LEN] = source.read_array()?;
let header = byte_vector[0];
if (header >> 4) != 5 {
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
}
let logn = (header & 15) as usize;
let n = 1 << logn;
if n != N {
return Err(DeserializationError::InvalidValue(
"Unsupported Falcon DSA variant".to_string(),
));
}
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).ok_or(
DeserializationError::InvalidValue("Failed to decode f coefficients".to_string()),
)?;
let g = decode_i8(
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
WIDTH_SMALL_POLY_COEFFICIENT,
)
.unwrap();
let big_f = decode_i8(
&byte_vector[(chunk_size_f + chunk_size_g + 1)
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
WIDTH_BIG_POLY_COEFFICIENT,
)
.unwrap();
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
let basis = [
Polynomial::new(g.to_balanced_values()),
-Polynomial::new(f.to_balanced_values()),
Polynomial::new(big_g.to_balanced_values()),
-Polynomial::new(big_f.to_balanced_values()),
];
Ok(Self::from_short_lattice_basis(basis))
}
}
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
let [g, f, big_g, big_f] = basis.clone();
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
}
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
let maxv = (1 << (bits - 1)) - 1_usize;
let maxv = maxv as i8;
let minv = -maxv;
for &c in x {
if c > maxv || c < minv {
return None;
}
}
let out_len = ((N * bits) + 7) >> 3;
let mut buf = vec![0_u8; out_len];
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = ((1_u16 << bits) - 1) as u8;
let mut input_pos = 0;
for &c in x {
acc = (acc << bits) | (c as u8 & mask) as u32;
acc_len += bits;
while acc_len >= 8 {
acc_len -= 8;
buf[input_pos] = (acc >> acc_len) as u8;
input_pos += 1;
}
}
if acc_len > 0 {
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
}
Some(buf)
}
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
let mut x = [0_i8; N];
let mut i = 0;
let mut j = 0;
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = (1_u32 << bits) - 1;
let a = (1 << bits) as u8;
let b = ((1 << (bits - 1)) - 1) as u8;
while i < N {
acc = (acc << 8) | (buf[j] as u32);
j += 1;
acc_len += 8;
while acc_len >= bits && i < N {
acc_len -= bits;
let w = (acc >> acc_len) & mask;
let w = w as u8;
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
x[i] = z;
i += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Some(x.to_vec())
} else {
None
}
}