use alloc::string::ToString;
use core::ops::Deref;
use num::Zero;
use super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, LOG_N, MODULUS, N, Nonce,
SIG_L2_BOUND, SIG_POLY_BYTE_LEN, Serializable,
hash_to_point::hash_to_point_poseidon2,
keys::PublicKey,
math::{FalconFelt, FastFft, Polynomial},
};
use crate::Word;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
header: SignatureHeader,
nonce: Nonce,
s2: SignaturePoly,
h: PublicKey,
}
impl Signature {
pub fn new(nonce: Nonce, h: PublicKey, s2: SignaturePoly) -> Signature {
Self {
header: SignatureHeader::default(),
nonce,
s2,
h,
}
}
pub fn public_key(&self) -> &PublicKey {
&self.h
}
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
&self.s2
}
pub fn nonce(&self) -> &Nonce {
&self.nonce
}
pub fn verify(&self, message: Word, pub_key: &PublicKey) -> bool {
if self.h != *pub_key {
return false;
}
let c = hash_to_point_poseidon2(message, &self.nonce);
verify_helper(&c, &self.s2, pub_key)
}
}
impl Serializable for Signature {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(&self.header);
target.write(&self.nonce);
target.write(&self.s2);
target.write(&self.h);
}
}
impl Deserializable for Signature {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let header = source.read()?;
let nonce = source.read()?;
let s2 = source.read()?;
let h = source.read()?;
Ok(Self { header, nonce, s2, h })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignatureHeader(u8);
impl Default for SignatureHeader {
fn default() -> Self {
Self(0b1011_1001)
}
}
impl Serializable for &SignatureHeader {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.0)
}
}
impl Deserializable for SignatureHeader {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let header = source.read_u8()?;
let (encoding, log_n) = (header >> 4, header & 0b00001111);
if encoding != 0b1011 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: not supported encoding algorithm".to_string(),
));
}
if log_n != LOG_N {
return Err(DeserializationError::InvalidValue(format!(
"Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"
)));
}
Ok(Self(header))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
impl Deref for SignaturePoly {
type Target = Polynomial<FalconFelt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Polynomial<FalconFelt>> for SignaturePoly {
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
Self(pk_poly)
}
}
impl TryFrom<&[i16; N]> for SignaturePoly {
type Error = ();
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
if are_coefficients_valid(coefficients) {
Ok(Self(coefficients.to_vec().into()))
} else {
Err(())
}
}
}
impl Serializable for &SignaturePoly {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let sig_coeff = self.0.to_balanced_values();
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
let mut acc = 0;
let mut acc_len = 0;
let mut v = 0;
let mut t;
let mut w;
for &c in sig_coeff.iter() {
acc <<= 1;
t = c;
if t < 0 {
t = -t;
acc |= 1;
}
w = t as u16;
acc <<= 7;
let mask = 127_u32;
acc |= (w as u32) & mask;
w >>= 7;
acc_len += 8;
acc <<= w + 1;
acc |= 1;
acc_len += w + 1;
while acc_len >= 8 {
acc_len -= 8;
sk_bytes[v] = (acc >> acc_len) as u8;
v += 1;
}
}
if acc_len > 0 {
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
}
target.write_bytes(&sk_bytes);
}
}
impl Deserializable for SignaturePoly {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
let mut input_idx = 0;
let mut acc = 0u32;
let mut acc_len = 0;
let mut coefficients = [FalconFelt::zero(); N];
for c in coefficients.iter_mut() {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
let b = acc >> acc_len;
let s = b & 128;
let mut m = b & 127;
loop {
if acc_len == 0 {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
acc_len = 8;
}
acc_len -= 1;
if ((acc >> acc_len) & 1) != 0 {
break;
}
m += 128;
if m >= 2048 {
return Err(DeserializationError::InvalidValue(format!(
"Failed to decode signature: high bits {m} exceed 2048",
)));
}
}
if s != 0 && m == 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: -0 is forbidden".to_string(),
));
}
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
*c = FalconFelt::new(felt as i16);
}
if (acc & ((1 << acc_len) - 1)) != 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
));
}
Ok(Polynomial::new(coefficients.to_vec()).into())
}
}
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PublicKey) -> bool {
let h_fft = h.fft();
let s2_fft = s2.fft();
let c_fft = c.fft();
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
let s1 = s1_fft.ifft();
let length_squared_s1 = s1.norm_squared();
let length_squared_s2 = s2.norm_squared();
let length_squared = length_squared_s1 + length_squared_s2;
length_squared < SIG_L2_BOUND
}
fn are_coefficients_valid(x: &[i16]) -> bool {
if x.len() != N {
return false;
}
for &c in x {
if !(-2047..=2047).contains(&c) {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use super::{
super::{SIG_SERIALIZED_LEN, SecretKey},
*,
};
#[test]
fn test_serialization_round_trip() {
let seed = [0_u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed);
let sk = SecretKey::with_rng(&mut rng);
let signature = sk.sign_with_rng(Word::default(), &mut rng);
let serialized = signature.to_bytes();
assert_eq!(serialized.len(), SIG_SERIALIZED_LEN);
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
}
}