use sha2_const_stable::Sha256;
mod choice;
mod field_element;
use choice::Choice;
use field_element::FieldElement;
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
#[repr(C)]
pub struct CompressedEdwardsY(pub [u8; 32]);
#[derive(Copy, Clone)]
#[repr(C)]
#[allow(non_snake_case)]
pub struct EdwardsPoint {
pub(crate) X: FieldElement,
pub(crate) Y: FieldElement,
pub(crate) Z: FieldElement,
pub(crate) T: FieldElement,
}
pub const fn crypto_unsafe_is_on_curve(key: &[u8; 32]) -> bool {
let (is_valid_y_coord, _, _, _) = decompress_step_1(key);
is_valid_y_coord.into()
}
const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress";
pub const fn derive_program_address(seeds: &[&[u8]], program: &[u8; 32]) -> ([u8; 32], u8) {
let mut bump = u8::MAX;
loop {
let mut hasher = Sha256::new();
let mut i = 0;
while i < seeds.len() {
hasher = hasher.update(seeds[i]);
i += 1;
}
hasher = hasher.update(&[bump]);
hasher = hasher.update(program);
hasher = hasher.update(PDA_MARKER);
let candidate = hasher.finalize();
if !crypto_unsafe_is_on_curve(&candidate) {
return (candidate, bump);
}
bump -= 1;
}
}
#[derive(Clone)]
pub struct PartialPda {
inner: Sha256,
}
impl PartialPda {
pub const fn from_partial_preimage(partial_preimage: &[&[u8]]) -> PartialPda {
let mut inner = Sha256::new();
let mut i = 0;
while i != partial_preimage.len() {
inner = inner.update(partial_preimage[i]);
i += 1
}
PartialPda { inner }
}
pub const fn finalize_with(
self,
remaining_seeds: &[&[u8]],
program: &[u8; 32],
) -> ([u8; 32], u8) {
let mut bump = u8::MAX;
loop {
let mut hasher: Sha256 = unsafe { core::mem::transmute_copy(&self.inner) };
let mut i = 0;
while i < remaining_seeds.len() {
hasher = hasher.update(remaining_seeds[i]);
i += 1;
}
hasher = hasher.update(&[bump]);
hasher = hasher.update(program);
hasher = hasher.update(PDA_MARKER);
let candidate = hasher.finalize();
if !crypto_unsafe_is_on_curve(&candidate) {
return (candidate, bump);
}
bump -= 1;
}
}
}
#[rustfmt::skip] pub(super) const fn decompress_step_1(
repr: &[u8; 32],
) -> (Choice, FieldElement, FieldElement, FieldElement) {
let y = FieldElement::from_bytes(repr);
let z = FieldElement::ONE;
let yy = y.square();
let u = yy.sub(z); let v = yy.mul(FieldElement::EDWARDS_D).add(z); let (is_valid_y_coord, x) = FieldElement::sqrt_ratio_i(u, v);
(is_valid_y_coord, x, y, z)
}
#[test]
fn test_on_curve() {
fn safe_is_on_curve(key: &[u8; 32]) -> bool {
curve25519_dalek::edwards::CompressedEdwardsY::from_slice(key.as_ref())
.unwrap()
.decompress()
.is_some()
}
for _ in 0..50_000 {
let bytes = rand::random::<[u8; 32]>();
assert_eq!(crypto_unsafe_is_on_curve(&bytes), safe_is_on_curve(&bytes));
}
}
#[test]
fn test_with_partial_preimage() {
for _ in 0..50_000 {
let first = rand::random::<[u8; 32]>();
let second = rand::random::<[u8; 32]>();
let program = [0; 32];
let direct = derive_program_address(&[&first, &second], &program);
let via_parial =
PartialPda::from_partial_preimage(&[&first]).finalize_with(&[&second], &program);
assert_eq!(direct, via_parial)
}
}