use core::marker::PhantomData;
use core::{fmt, mem, ops};
use super::Polynomial;
use crate::primitives::hrp::Hrp;
use crate::{Fe1024, Fe32};
pub trait Checksum {
type MidstateRepr: PackedFe32;
type CorrectionField: super::ExtensionField<BaseField = Fe32>;
const ROOT_GENERATOR: Self::CorrectionField;
const ROOT_EXPONENTS: core::ops::RangeInclusive<usize>;
const CODE_LENGTH: usize;
const CHECKSUM_LENGTH: usize;
const GENERATOR_SH: [Self::MidstateRepr; 5];
const TARGET_RESIDUE: Self::MidstateRepr;
fn sanity_check() {
assert!(Self::CHECKSUM_LENGTH <= Self::MidstateRepr::WIDTH);
for i in 1..5 {
for j in 0..Self::MidstateRepr::WIDTH {
let last = Self::GENERATOR_SH[i - 1].unpack(j);
let curr = Self::GENERATOR_SH[i].unpack(j);
assert_eq!(
curr,
(last << 1) ^ if last & 0x10 == 0x10 { 41 } else { 0 },
"Element {} of generator << 2^{} was incorrectly computed. (Should have been {} << 1)",
j, i, last,
);
}
}
}
}
pub struct PrintImpl<'a, ExtField = Fe1024> {
name: &'a str,
generator: Polynomial<Fe32>,
target: &'a [Fe32],
bit_len: usize,
hex_width: usize,
midstate_repr: &'static str,
phantom: PhantomData<ExtField>,
}
impl<'a, ExtField> PrintImpl<'a, ExtField> {
pub fn new(name: &'a str, generator: &'a [Fe32], target: &'a [Fe32]) -> Self {
assert_ne!(name.len(), 0, "type name cannot be the empty string",);
assert_ne!(
generator.len(),
0,
"generator polynomial cannot be the empty string (constant 1)"
);
assert_ne!(target.len(), 0, "target residue cannot be the empty string");
if generator.len() != target.len() {
let hint = if generator.len() == target.len() + 1 {
" (you should not include the monic term of the generator polynomial"
} else if generator.len() > target.len() {
" (you may need to zero-pad your target residue)"
} else {
""
};
panic!(
"Generator length {} does not match target residue length {}{}",
generator.len(),
target.len(),
hint
);
}
let bit_len = 5 * target.len();
let (hex_width, midstate_repr) = if bit_len <= 32 {
(8, "u32")
} else if bit_len <= 64 {
(16, "u64")
} else if bit_len <= 128 {
(32, "u128")
} else {
panic!("Generator length {} cannot exceed 25, as we cannot represent it by packing bits into a Rust numeric type", generator.len());
};
PrintImpl {
name,
generator: Polynomial::with_monic_leading_term(generator),
target,
bit_len,
hex_width,
midstate_repr,
phantom: PhantomData,
}
}
}
impl<ExtField> fmt::Display for PrintImpl<'_, ExtField>
where
ExtField: super::Bech32Field + super::ExtensionField<BaseField = Fe32>,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let (gen, length, exponents) = self.generator.bch_generator_primitive_element::<ExtField>();
write!(f, "// Code block generated by Checksum::print_impl polynomial ")?;
for fe in self.generator.as_inner() {
write!(f, "{}", fe)?;
}
write!(f, " target ")?;
for fe in self.target {
write!(f, "{}", fe)?;
}
f.write_str("\n")?;
writeln!(f, "impl Checksum for {} {{", self.name)?;
writeln!(
f,
" type MidstateRepr = {}; // checksum packs into {} bits",
self.midstate_repr, self.bit_len
)?;
f.write_str("\n")?;
writeln!(f, " type CorrectionField = {};", core::any::type_name::<ExtField>())?;
f.write_str(" const ROOT_GENERATOR: Self::CorrectionField = ")?;
gen.format_as_rust_code(f)?;
f.write_str(";\n")?;
writeln!(
f,
" const ROOT_EXPONENTS: core::ops::RangeInclusive<usize> = {}..={};",
exponents.start(),
exponents.end()
)?;
f.write_str("\n")?;
writeln!(f, " const CODE_LENGTH: usize = {};", length)?;
writeln!(f, " const CHECKSUM_LENGTH: usize = {};", self.generator.degree())?;
writeln!(f, " const GENERATOR_SH: [{}; 5] = [", self.midstate_repr)?;
let mut gen5 = self.generator.clone().into_inner();
for _ in 0..5 {
let gen_packed = u128::pack(gen5.iter().copied().map(From::from));
writeln!(f, " 0x{:0width$x},", gen_packed, width = self.hex_width)?;
gen5.iter_mut().for_each(|x| *x *= Fe32::Z);
}
writeln!(f, " ];")?;
writeln!(
f,
" const TARGET_RESIDUE: {} = 0x{:0width$x};",
self.midstate_repr,
u128::pack(self.target.iter().copied().map(From::from)),
width = self.hex_width,
)?;
f.write_str("}")
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct Engine<Ck: Checksum> {
residue: Ck::MidstateRepr,
}
impl<Ck: Checksum> Default for Engine<Ck> {
fn default() -> Self { Self::new() }
}
impl<Ck: Checksum> Engine<Ck> {
#[inline]
pub fn new() -> Self { Engine { residue: Ck::MidstateRepr::ONE } }
#[inline]
pub fn input_hrp(&mut self, hrp: Hrp) {
for fe in HrpFe32Iter::new(&hrp) {
self.input_fe(fe)
}
}
#[inline]
pub fn input_fe(&mut self, e: Fe32) {
let xn = self.residue.mul_by_x_then_add(Ck::CHECKSUM_LENGTH, e.into());
for i in 0..5 {
if xn & (1 << i) != 0 {
self.residue = self.residue ^ Ck::GENERATOR_SH[i];
}
}
}
#[inline]
pub fn input_target_residue(&mut self) {
for i in 0..Ck::CHECKSUM_LENGTH {
self.input_fe(Fe32(Ck::TARGET_RESIDUE.unpack(Ck::CHECKSUM_LENGTH - i - 1)));
}
}
#[inline]
pub fn residue(&self) -> &Ck::MidstateRepr { &self.residue }
}
pub trait PackedFe32: Copy + PartialEq + Eq + ops::BitXor<Self, Output = Self> {
const ONE: Self;
const WIDTH: usize = mem::size_of::<Self>() * 8 / 5;
fn pack<I: Iterator<Item = u8>>(iter: I) -> Self;
fn unpack(&self, n: usize) -> u8;
fn mul_by_x_then_add(&mut self, degree: usize, add: u8) -> u8;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct PackedNull;
impl ops::BitXor<PackedNull> for PackedNull {
type Output = PackedNull;
#[inline]
fn bitxor(self, _: PackedNull) -> PackedNull { PackedNull }
}
impl PackedFe32 for PackedNull {
const ONE: Self = PackedNull;
#[inline]
fn unpack(&self, _: usize) -> u8 { 0 }
#[inline]
fn mul_by_x_then_add(&mut self, _: usize, _: u8) -> u8 { 0 }
#[inline]
fn pack<I: Iterator<Item = u8>>(mut iter: I) -> Self {
if iter.next().is_some() {
panic!("Cannot pack anything into a PackedNull");
}
Self
}
}
macro_rules! impl_packed_fe32 {
($ty:ident) => {
impl PackedFe32 for $ty {
const ONE: Self = 1;
#[inline]
fn unpack(&self, n: usize) -> u8 {
debug_assert!(n < Self::WIDTH);
(*self >> (n * 5)) as u8 & 0x1f
}
#[inline]
fn mul_by_x_then_add(&mut self, degree: usize, add: u8) -> u8 {
debug_assert!(degree > 0);
debug_assert!(degree <= Self::WIDTH);
debug_assert!(add < 32);
let ret = self.unpack(degree - 1);
*self &= !(0x1f << ((degree - 1) * 5));
*self <<= 5;
*self |= Self::from(add);
ret
}
#[inline]
fn pack<I: Iterator<Item = u8>>(iter: I) -> Self {
let mut ret: Self = 0;
for (n, elem) in iter.enumerate() {
debug_assert!(elem < 32);
debug_assert!(n < Self::WIDTH);
ret <<= 5;
ret |= Self::from(elem);
}
ret
}
}
};
}
impl_packed_fe32!(u32);
impl_packed_fe32!(u64);
impl_packed_fe32!(u128);
pub struct HrpFe32Iter<'hrp> {
high_iter: Option<crate::primitives::hrp::LowercaseByteIter<'hrp>>,
low_iter: Option<crate::primitives::hrp::LowercaseByteIter<'hrp>>,
}
impl<'hrp> HrpFe32Iter<'hrp> {
#[inline]
pub fn new(hrp: &'hrp Hrp) -> Self {
let high_iter = hrp.lowercase_byte_iter();
let low_iter = hrp.lowercase_byte_iter();
Self { high_iter: Some(high_iter), low_iter: Some(low_iter) }
}
}
impl Iterator for HrpFe32Iter<'_> {
type Item = Fe32;
#[inline]
fn next(&mut self) -> Option<Fe32> {
if let Some(ref mut high_iter) = &mut self.high_iter {
match high_iter.next() {
Some(high) => return Some(Fe32(high >> 5)),
None => {
self.high_iter = None;
return Some(Fe32::Q);
}
}
}
if let Some(ref mut low_iter) = &mut self.low_iter {
match low_iter.next() {
Some(low) => return Some(Fe32(low & 0x1f)),
None => self.low_iter = None,
}
}
None
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let high = match &self.high_iter {
Some(high_iter) => {
let (min, max) = high_iter.size_hint();
(min + 1, max.map(|max| max + 1)) }
None => (0, Some(0)),
};
let low = match &self.low_iter {
Some(low_iter) => low_iter.size_hint(),
None => (0, Some(0)),
};
let min = high.0 + low.0;
let max = high.1.zip(low.1).map(|(high, low)| high + low);
(min, max)
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use core::convert::TryFrom;
use super::*;
#[test]
fn pack_unpack() {
let packed = u128::pack([0, 0, 0, 1].iter().copied());
assert_eq!(packed, 1);
assert_eq!(packed.unpack(0), 1);
assert_eq!(packed.unpack(3), 0);
let packed = u128::pack([1, 2, 3, 4].iter().copied());
assert_eq!(packed, 0b00001_00010_00011_00100);
assert_eq!(packed.unpack(0), 4);
assert_eq!(packed.unpack(1), 3);
assert_eq!(packed.unpack(2), 2);
assert_eq!(packed.unpack(3), 1);
}
#[test]
#[cfg(feature = "alloc")]
fn bech32() {
let unpacked_poly = (0..6)
.rev() .map(|i| 0x3b6a57b2u128.unpack(i))
.map(|u| Fe32::try_from(u).unwrap())
.collect::<Vec<_>>();
assert_eq!(unpacked_poly, [Fe32::A, Fe32::K, Fe32::_5, Fe32::_4, Fe32::A, Fe32::J],);
let _s = PrintImpl::<Fe1024>::new(
"Bech32",
&[Fe32::A, Fe32::K, Fe32::_5, Fe32::_4, Fe32::A, Fe32::J],
&[Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::P],
)
.to_string();
#[cfg(feature = "std")]
println!("{}", _s);
}
#[test]
#[cfg(feature = "alloc")]
fn descriptor() {
let unpacked_poly = (0..8)
.rev() .map(|i| 0xf5dee51989u64.unpack(i))
.map(|u| Fe32::try_from(u).unwrap())
.collect::<Vec<_>>();
assert_eq!(
unpacked_poly,
[Fe32::_7, Fe32::H, Fe32::_0, Fe32::W, Fe32::_2, Fe32::X, Fe32::V, Fe32::F],
);
let _s = PrintImpl::<crate::Fe32768>::new(
"DescriptorChecksum",
&[Fe32::_7, Fe32::H, Fe32::_0, Fe32::W, Fe32::_2, Fe32::X, Fe32::V, Fe32::F],
&[Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::P],
)
.to_string();
#[cfg(feature = "std")]
println!("{}", _s);
}
#[test]
#[should_panic]
fn sanity_check_panics_on_bad_checksums() {
enum BadGeneratorShifts {}
impl Checksum for BadGeneratorShifts {
type MidstateRepr = u32;
type CorrectionField = crate::Fe1024;
const ROOT_GENERATOR: Self::CorrectionField = crate::Fe1024::new([Fe32::P, Fe32::X]);
const ROOT_EXPONENTS: core::ops::RangeInclusive<usize> = 24..=26;
const CODE_LENGTH: usize = 1023;
const CHECKSUM_LENGTH: usize = 6;
const GENERATOR_SH: [u32; 5] = [1, 1, 1, 1, 1];
const TARGET_RESIDUE: u32 = 1;
}
BadGeneratorShifts::sanity_check();
}
#[test]
#[cfg(feature = "alloc")]
fn print_impl_renders_expected_bech32_impl() {
let rendered = PrintImpl::<crate::Fe1024>::new(
"Bech32",
&[Fe32::A, Fe32::K, Fe32::_5, Fe32::_4, Fe32::A, Fe32::J],
&[Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::P],
)
.to_string();
let generator_shift = "0x4eb8406b0"; assert!(
rendered.contains(generator_shift),
"generated impl missing expected fragment: {}\n\n{}",
generator_shift,
rendered
);
}
#[test]
fn packed_null() {
let mut packed = PackedNull;
assert_eq!(packed.unpack(0), 0);
assert_eq!(packed.mul_by_x_then_add(1, 31), 0);
}
#[test]
fn hrp_fe32_iter_size_hint_matches_remaining_length() {
let hrp = Hrp::parse_unchecked("ab");
let mut iter = HrpFe32Iter::new(&hrp);
for remaining in (0..=(2 * hrp.len() + 1)).rev() {
assert_eq!(iter.size_hint(), (remaining, Some(remaining)));
if remaining > 0 {
iter.next().expect("iterator has more elements");
}
}
}
}
#[cfg(bench)]
mod benches {
use std::io::{sink, Write};
use test::{black_box, Bencher};
use crate::{Fe1024, Fe32, Fe32768, PrintImpl};
#[bench]
fn compute_bech32_params(bh: &mut Bencher) {
bh.iter(|| {
let im = PrintImpl::<Fe1024>::new(
"Bech32",
&[Fe32::A, Fe32::K, Fe32::_5, Fe32::_4, Fe32::A, Fe32::J],
&[Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::P],
);
let res = write!(sink(), "{}", im);
black_box(&im);
black_box(&res);
})
}
#[bench]
fn compute_descriptor_params(bh: &mut Bencher) {
bh.iter(|| {
let im = PrintImpl::<Fe32768>::new(
"DescriptorChecksum",
&[Fe32::_7, Fe32::H, Fe32::_0, Fe32::W, Fe32::_2, Fe32::X, Fe32::V, Fe32::F],
&[Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::Q, Fe32::P],
);
let res = write!(sink(), "{}", im);
black_box(&im);
black_box(&res);
})
}
}