use crate::{BignP256, ProjectivePoint, Sec1Point, U256, arithmetic::FieldElement};
use belt_kwp::{BeltKwp, KeyInit};
use core::num::{NonZero, NonZeroU16};
use elliptic_curve::{
Field,
array::Array,
bigint::U384,
consts::{U32, U48, U128},
sec1::FromSec1Point,
subtle::ConditionallySelectable,
};
use hash2curve::{ExpandMsg, Expander, MapToCurve};
use primefield::bigint::{NonZero as BigintNonZero, Reduce};
use primeorder::PrimeCurveParams;
impl Reduce<Array<u8, U48>> for FieldElement {
#[allow(clippy::arithmetic_side_effects)]
fn reduce(value: &Array<u8, U48>) -> Self {
const WIDE_MODULUS: BigintNonZero<U384> = BigintNonZero::<U384>::new_unwrap(
U384::from_be_hex(
"00000000000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF43",
),
);
let value = U384::from_le_slice(value);
let value = value % WIDE_MODULUS;
let mut words = [0; U256::LIMBS];
words.copy_from_slice(&value.to_words()[..U256::LIMBS]);
FieldElement::from_uint_unchecked(U256::from_words(words))
}
}
struct BeltKwpExpander {
buf: [u8; 48],
offset: usize,
remaining: usize,
}
impl Expander for BeltKwpExpander {
#[allow(clippy::arithmetic_side_effects)]
fn fill_bytes(&mut self, dst: &mut [u8]) -> Result<usize, elliptic_curve::Error> {
if self.remaining == 0 {
return Err(elliptic_curve::Error);
}
let len = dst.len().min(self.remaining);
dst[..len].copy_from_slice(&self.buf[self.offset..self.offset + len]);
self.offset += len;
self.remaining -= len;
Ok(len)
}
}
impl ExpandMsg<U32> for BeltKwpExpander {
type Hash = ();
type Expander<'dst> = Self;
type Error = elliptic_curve::Error;
#[allow(clippy::arithmetic_side_effects)]
fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>, Self::Error> {
if len_in_bytes.get() != 48 {
return Err(elliptic_curve::Error);
}
let mut input = [0u8; 32];
let mut pos = 0;
for m in msg {
let to_copy = (m.len()).min(32 - pos);
if to_copy == 0 {
break;
}
input[pos..pos + to_copy].copy_from_slice(&m[..to_copy]);
pos += to_copy;
}
for d in dst {
let to_copy = (d.len()).min(32 - pos);
if to_copy == 0 {
break;
}
input[pos..pos + to_copy].copy_from_slice(&d[..to_copy]);
pos += to_copy;
}
let kwp = BeltKwp::new_from_slice(&[0u8; 32]).map_err(|_| elliptic_curve::Error)?;
let mut buf = [0u8; 48];
kwp.wrap_key(&input, &[0u8; 16], &mut buf)
.map_err(|_| elliptic_curve::Error)?;
Ok(Self {
buf,
offset: 0,
remaining: 48,
})
}
}
impl FieldElement {
#[allow(clippy::arithmetic_side_effects)]
fn swu(&self) -> ProjectivePoint {
let t = self.square().neg();
let t_squared = t.square();
let num = FieldElement::ONE + t + t_squared;
let den = BignP256::EQUATION_A * (t + t_squared);
let x1 = -BignP256::EQUATION_B * num * den.invert().unwrap();
let x2 = t * x1;
let gx1 = x1.cube() + BignP256::EQUATION_A * x1 + BignP256::EQUATION_B;
let (is_square, y1) = gx1.sqrt_alt();
let y2 = self.cube() * y1;
let x = FieldElement::conditional_select(&x2, &x1, is_square);
let y = FieldElement::conditional_select(&y2, &y1, is_square);
let point = Sec1Point::from_affine_coordinates(&x.to_bytes(), &y.to_bytes(), false);
ProjectivePoint::from_sec1_point(&point).expect("Always possible")
}
}
impl MapToCurve for BignP256 {
type SecurityLevel = U128;
type FieldElement = FieldElement;
type Length = U48;
fn map_to_curve(element: Self::FieldElement) -> ProjectivePoint {
element.swu()
}
}
impl BignP256 {
pub fn hash_secret_to_curve(secret: &[u8]) -> elliptic_curve::Result<ProjectivePoint> {
let mut expander = BeltKwpExpander::expand_message(
&[secret],
&[],
NonZeroU16::new(48).expect("48 is always nonzero"),
)?;
let mut buf = Array::<u8, U48>::default();
expander.fill_bytes(&mut buf)?;
let s = FieldElement::reduce(&buf);
let point = BignP256::map_to_curve(s);
Ok(point)
}
}
#[test]
fn test_expander() {
use hex_literal::hex;
let input = hex!(
"AD1362A8 F9A3D42F BE1B8E6F 1C88AAD5"
"0F51D913 47617C20 BD4AB07A EF4F26A1"
);
let expected = hex!(
"CFE573F6745C633867EFBF702504394B585B1D6F454721F4"
"7BD28E3DFF19230E18D8A279A5C2047069585F26315BF1A5"
);
let mut expander =
BeltKwpExpander::expand_message(&[&input], &[b""], NonZeroU16::new(48).unwrap()).unwrap();
let mut output = [0u8; 48];
expander.fill_bytes(&mut output).unwrap();
assert_eq!(output, expected);
}