use crate::public_key::bigint::{BigUint, MontgomeryCtx};
use crate::public_key::primes::{mod_inverse, random_nonzero_below};
use crate::Csprng;
use std::sync::OnceLock;
#[derive(Clone, Debug)]
pub struct TwistedEdwardsCurve {
pub p: BigUint,
pub a: BigUint,
pub d: BigUint,
pub(crate) d2: BigUint,
pub n: BigUint,
pub gx: BigUint,
pub gy: BigUint,
pub(crate) field: MontgomeryCtx,
pub(crate) _scalar: MontgomeryCtx,
pub coord_len: usize,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EdwardsPoint {
pub x: BigUint,
pub y: BigUint,
pub neutral: bool,
}
#[derive(Clone, Debug)]
struct ExtendedPoint {
x: BigUint,
y: BigUint,
z: BigUint,
t: BigUint,
}
const SCALAR_WINDOW_BITS: usize = 4;
const ED25519_BASE_WINDOW_BITS: usize = 8;
const CACHED_PUBLIC_WINDOW_BITS: usize = 8;
#[derive(Clone, Debug)]
pub(crate) struct EdwardsMulTable {
table: Vec<ExtendedPoint>,
window_bits: usize,
}
impl EdwardsPoint {
#[must_use]
pub fn neutral() -> Self {
Self {
x: BigUint::zero(),
y: BigUint::one(),
neutral: true,
}
}
#[must_use]
pub fn new(x: BigUint, y: BigUint) -> Self {
Self {
x,
y,
neutral: false,
}
}
#[must_use]
pub fn is_neutral(&self) -> bool {
self.neutral
}
}
impl ExtendedPoint {
fn neutral() -> Self {
Self {
x: BigUint::zero(),
y: BigUint::one(),
z: BigUint::one(),
t: BigUint::zero(),
}
}
fn is_neutral(&self) -> bool {
self.x.is_zero() && self.y == self.z
}
fn from_affine(p: &EdwardsPoint, ctx: &MontgomeryCtx) -> Self {
if p.neutral {
return Self::neutral();
}
let t = ctx.mul(&p.x, &p.y);
Self {
x: p.x.clone(),
y: p.y.clone(),
z: BigUint::one(),
t,
}
}
fn to_affine(&self, curve: &TwistedEdwardsCurve) -> EdwardsPoint {
if self.is_neutral() {
return EdwardsPoint::neutral();
}
if self.z == BigUint::one() {
return EdwardsPoint::new(self.x.clone(), self.y.clone());
}
let ctx = &curve.field;
let p_minus_2 = curve.p.sub_ref(&BigUint::from_u64(2));
let z_inv = ctx.pow(&self.z, &p_minus_2);
let x = ctx.mul(&self.x, &z_inv);
let y = ctx.mul(&self.y, &z_inv);
EdwardsPoint::new(x, y)
}
}
#[inline]
fn fadd(a: &BigUint, b: &BigUint, p: &BigUint) -> BigUint {
let s = a.add_ref(b);
if &s >= p {
s.sub_ref(p)
} else {
s
}
}
#[inline]
fn fsub(a: &BigUint, b: &BigUint, p: &BigUint) -> BigUint {
if a >= b {
a.sub_ref(b)
} else {
p.sub_ref(&b.sub_ref(a))
}
}
#[inline]
fn fneg(a: &BigUint, p: &BigUint) -> BigUint {
if a.is_zero() {
BigUint::zero()
} else {
p.sub_ref(a)
}
}
fn point_add_extended(
curve: &TwistedEdwardsCurve,
p1: &ExtendedPoint,
p2: &ExtendedPoint,
) -> ExtendedPoint {
let ctx = &curve.field;
let m = &curve.p;
let y1_m_x1 = fsub(&p1.y, &p1.x, m);
let y2_m_x2 = fsub(&p2.y, &p2.x, m);
let a = ctx.mul(&y1_m_x1, &y2_m_x2);
let y1_p_x1 = fadd(&p1.y, &p1.x, m);
let y2_p_x2 = fadd(&p2.y, &p2.x, m);
let b = ctx.mul(&y1_p_x1, &y2_p_x2);
let t2_scaled = ctx.mul(&p2.t, &curve.d2);
let c = ctx.mul(&p1.t, &t2_scaled);
let z1z2 = ctx.mul(&p1.z, &p2.z);
let d = fadd(&z1z2, &z1z2, m);
let e = fsub(&b, &a, m);
let f = fsub(&d, &c, m);
let g = fadd(&d, &c, m);
let h = fadd(&b, &a, m);
ExtendedPoint {
x: ctx.mul(&e, &f),
y: ctx.mul(&g, &h),
z: ctx.mul(&f, &g),
t: ctx.mul(&e, &h),
}
}
fn point_double_extended(curve: &TwistedEdwardsCurve, p1: &ExtendedPoint) -> ExtendedPoint {
if p1.is_neutral() {
return ExtendedPoint::neutral();
}
let ctx = &curve.field;
let m = &curve.p;
let a = ctx.square(&p1.x);
let b = ctx.square(&p1.y);
let z2 = ctx.square(&p1.z);
let c = fadd(&z2, &z2, m);
let d = ctx.mul(&curve.a, &a);
let x_plus_y = fadd(&p1.x, &p1.y, m);
let e = {
let sum_sq = ctx.square(&x_plus_y);
fsub(&fsub(&sum_sq, &a, m), &b, m)
};
let g = fadd(&d, &b, m);
let f = fsub(&g, &c, m);
let h = fsub(&d, &b, m);
ExtendedPoint {
x: ctx.mul(&e, &f),
y: ctx.mul(&g, &h),
t: ctx.mul(&e, &h),
z: ctx.mul(&f, &g),
}
}
#[inline]
fn scalar_window(k: &BigUint, bit_offset: usize, width: usize) -> usize {
let mut value = 0usize;
for bit in 0..width {
if k.bit(bit_offset + bit) {
value |= 1usize << bit;
}
}
value
}
fn precompute_window_table(
curve: &TwistedEdwardsCurve,
point: &ExtendedPoint,
window_bits: usize,
) -> Vec<ExtendedPoint> {
let table_size = 1usize << window_bits;
let mut table = Vec::with_capacity(table_size);
table.push(ExtendedPoint::neutral());
table.push(point.clone());
for _ in 2..table_size {
let next = point_add_extended(curve, table.last().expect("table non-empty"), point);
table.push(next);
}
table
}
fn scalar_mul_with_table(
curve: &TwistedEdwardsCurve,
k: &BigUint,
table: &[ExtendedPoint],
window_bits: usize,
) -> EdwardsPoint {
if k.is_zero() {
return EdwardsPoint::neutral();
}
let mut result = ExtendedPoint::neutral();
let windows = k.bits().div_ceil(window_bits);
for window_index in (0..windows).rev() {
for _ in 0..window_bits {
result = point_double_extended(curve, &result);
}
let value = scalar_window(k, window_index * window_bits, window_bits);
result = point_add_extended(curve, &result, &table[value]);
}
result.to_affine(curve)
}
fn cached_ed25519() -> &'static TwistedEdwardsCurve {
static CURVE: OnceLock<TwistedEdwardsCurve> = OnceLock::new();
CURVE.get_or_init(ed25519)
}
fn is_ed25519_curve(curve: &TwistedEdwardsCurve) -> bool {
let reference = cached_ed25519();
curve.p == reference.p
&& curve.a == reference.a
&& curve.d == reference.d
&& curve.n == reference.n
&& curve.gx == reference.gx
&& curve.gy == reference.gy
}
fn ed25519_base_table() -> &'static [ExtendedPoint] {
static TABLE: OnceLock<Vec<ExtendedPoint>> = OnceLock::new();
TABLE
.get_or_init(|| {
let curve = cached_ed25519();
let base = ExtendedPoint::from_affine(&curve.base_point(), &curve.field);
precompute_window_table(curve, &base, ED25519_BASE_WINDOW_BITS)
})
.as_slice()
}
fn scalar_mul_extended(
curve: &TwistedEdwardsCurve,
point: &EdwardsPoint,
k: &BigUint,
) -> EdwardsPoint {
if k.is_zero() || point.is_neutral() {
return EdwardsPoint::neutral();
}
let p_ext = ExtendedPoint::from_affine(point, &curve.field);
let table = precompute_window_table(curve, &p_ext, SCALAR_WINDOW_BITS);
scalar_mul_with_table(curve, k, &table, SCALAR_WINDOW_BITS)
}
impl TwistedEdwardsCurve {
#[must_use]
pub fn new(
p: BigUint,
a: BigUint,
d: BigUint,
n: BigUint,
gx: BigUint,
gy: BigUint,
) -> Option<Self> {
let field = MontgomeryCtx::new(&p)?;
let scalar = MontgomeryCtx::new(&n)?;
let coord_len = p.bits().div_ceil(8);
let d2 = {
let v = d.add_ref(&d);
if v.cmp(&p).is_ge() {
v.sub_ref(&p)
} else {
v
}
};
Some(Self {
p,
a,
d,
d2,
n,
gx,
gy,
field,
_scalar: scalar,
coord_len,
})
}
#[must_use]
pub fn base_point(&self) -> EdwardsPoint {
EdwardsPoint::new(self.gx.clone(), self.gy.clone())
}
#[must_use]
pub fn is_on_curve(&self, point: &EdwardsPoint) -> bool {
if point.neutral {
return true;
}
let ctx = &self.field;
let x2 = ctx.square(&point.x);
let y2 = ctx.square(&point.y);
let ax2 = ctx.mul(&self.a, &x2);
let lhs = fadd(&ax2, &y2, &self.p);
let x2y2 = ctx.mul(&x2, &y2);
let dx2y2 = ctx.mul(&self.d, &x2y2);
let rhs = fadd(&BigUint::one(), &dx2y2, &self.p);
lhs == rhs
}
#[must_use]
pub fn negate(&self, point: &EdwardsPoint) -> EdwardsPoint {
if point.neutral {
return point.clone();
}
EdwardsPoint::new(fneg(&point.x, &self.p), point.y.clone())
}
#[must_use]
pub fn add(&self, p: &EdwardsPoint, q: &EdwardsPoint) -> EdwardsPoint {
let pe = ExtendedPoint::from_affine(p, &self.field);
let qe = ExtendedPoint::from_affine(q, &self.field);
point_add_extended(self, &pe, &qe).to_affine(self)
}
#[must_use]
pub fn double(&self, p: &EdwardsPoint) -> EdwardsPoint {
let pe = ExtendedPoint::from_affine(p, &self.field);
point_double_extended(self, &pe).to_affine(self)
}
#[must_use]
pub fn scalar_mul(&self, point: &EdwardsPoint, k: &BigUint) -> EdwardsPoint {
if !point.neutral && point.x == self.gx && point.y == self.gy {
return self.scalar_mul_base(k);
}
scalar_mul_extended(self, point, k)
}
#[must_use]
pub fn scalar_mul_base(&self, k: &BigUint) -> EdwardsPoint {
if k.is_zero() {
return EdwardsPoint::neutral();
}
if is_ed25519_curve(self) {
return scalar_mul_with_table(self, k, ed25519_base_table(), ED25519_BASE_WINDOW_BITS);
}
scalar_mul_extended(self, &self.base_point(), k)
}
#[must_use]
pub fn diffie_hellman(
&self,
private_scalar: &BigUint,
public_point: &EdwardsPoint,
) -> EdwardsPoint {
self.scalar_mul(public_point, private_scalar)
}
pub fn random_scalar<R: Csprng>(&self, rng: &mut R) -> BigUint {
random_nonzero_below(rng, &self.n)
.expect("curve order n is always > 1 for any valid cryptographic curve")
}
pub fn generate_keypair<R: Csprng>(&self, rng: &mut R) -> (BigUint, EdwardsPoint) {
let d = self.random_scalar(rng);
let q = self.scalar_mul_base(&d);
(d, q)
}
#[must_use]
pub(crate) fn precompute_mul_table(&self, point: &EdwardsPoint) -> EdwardsMulTable {
let point_ext = ExtendedPoint::from_affine(point, &self.field);
EdwardsMulTable {
table: precompute_window_table(self, &point_ext, CACHED_PUBLIC_WINDOW_BITS),
window_bits: CACHED_PUBLIC_WINDOW_BITS,
}
}
#[must_use]
pub(crate) fn scalar_mul_cached(&self, table: &EdwardsMulTable, k: &BigUint) -> EdwardsPoint {
scalar_mul_with_table(self, k, &table.table, table.window_bits)
}
#[must_use]
pub fn same_curve(&self, other: &Self) -> bool {
self.p == other.p
&& self.a == other.a
&& self.d == other.d
&& self.n == other.n
&& self.gx == other.gx
&& self.gy == other.gy
}
#[must_use]
pub fn scalar_invert(&self, k: &BigUint) -> Option<BigUint> {
mod_inverse(k, &self.n)
}
#[must_use]
pub fn encode_point(&self, point: &EdwardsPoint) -> Vec<u8> {
let (x_ref, y_ref) = if point.neutral {
(&BigUint::zero(), &BigUint::one())
} else {
(&point.x, &point.y)
};
let y_be = pad_to(y_ref.to_be_bytes(), self.coord_len);
let mut out: Vec<u8> = y_be.into_iter().rev().collect();
if x_ref.is_odd() {
*out.last_mut().expect("coord_len > 0") |= 0x80;
}
out
}
#[must_use]
pub fn decode_point(&self, bytes: &[u8]) -> Option<EdwardsPoint> {
if bytes.len() != self.coord_len {
return None;
}
let x_odd = (bytes[self.coord_len - 1] & 0x80) != 0;
let mut y_le = bytes.to_vec();
*y_le.last_mut().expect("length > 0") &= 0x7f;
let y_be: Vec<u8> = y_le.into_iter().rev().collect();
let y = BigUint::from_be_bytes(&y_be);
if y >= self.p {
return None;
}
let x = self.field_recover_x(&y, x_odd)?;
if x.is_zero() && x_odd {
return None;
}
let pt = if x.is_zero() && y == BigUint::one() {
EdwardsPoint::neutral()
} else {
EdwardsPoint::new(x, y)
};
if self.is_on_curve(&pt) {
Some(pt)
} else {
None
}
}
fn field_recover_x(&self, y: &BigUint, x_odd: bool) -> Option<BigUint> {
let ctx = &self.field;
let y2 = ctx.square(y);
let numerator = fsub(&y2, &BigUint::one(), &self.p);
let dy2 = ctx.mul(&self.d, &y2);
let denominator = fsub(&dy2, &self.a, &self.p);
let p_minus_2 = self.p.sub_ref(&BigUint::from_u64(2));
let denom_inv = ctx.pow(&denominator, &p_minus_2);
let x_squared = ctx.mul(&numerator, &denom_inv);
let x_candidate = self.field_sqrt(&x_squared)?;
let x = if x_candidate.is_odd() == x_odd {
x_candidate
} else {
fneg(&x_candidate, &self.p)
};
Some(x)
}
fn field_sqrt(&self, u: &BigUint) -> Option<BigUint> {
let ctx = &self.field;
let p_mod8 = self.p.rem_u64(8);
if p_mod8 == 3 || p_mod8 == 7 {
let (exp, _) = self
.p
.add_ref(&BigUint::one())
.div_rem(&BigUint::from_u64(4));
let candidate = ctx.pow(u, &exp);
if ctx.square(&candidate) == *u {
Some(candidate)
} else {
None
}
} else if p_mod8 == 5 {
let (exp, _) = self
.p
.add_ref(&BigUint::from_u64(3))
.div_rem(&BigUint::from_u64(8));
let beta = ctx.pow(u, &exp);
let beta2 = ctx.square(&beta);
if beta2 == *u {
return Some(beta);
}
let neg_u = fneg(u, &self.p);
if beta2 == neg_u {
let (sqrt_m1_exp, _) = self
.p
.sub_ref(&BigUint::one())
.div_rem(&BigUint::from_u64(4));
let sqrt_m1 = ctx.pow(&BigUint::from_u64(2), &sqrt_m1_exp);
return Some(ctx.mul(&beta, &sqrt_m1));
}
None } else {
None
}
}
}
fn pad_to(bytes: Vec<u8>, len: usize) -> Vec<u8> {
if bytes.len() >= len {
debug_assert_eq!(
bytes.len(),
len,
"field encodings must fit exactly in coord_len bytes"
);
return bytes;
}
let mut out = vec![0u8; len - bytes.len()];
out.extend_from_slice(&bytes);
out
}
fn from_hex(hex: &str) -> BigUint {
let cleaned: String = hex.chars().filter(|c| !c.is_ascii_whitespace()).collect();
assert!(
cleaned.len().is_multiple_of(2),
"hex string must have even length"
);
let bytes: Vec<u8> = (0..cleaned.len())
.step_by(2)
.map(|i| u8::from_str_radix(&cleaned[i..i + 2], 16).expect("valid hex digit"))
.collect();
BigUint::from_be_bytes(&bytes)
}
#[must_use]
pub fn ed25519() -> TwistedEdwardsCurve {
let p = from_hex("7FFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFED");
let a = p.sub_ref(&BigUint::one());
let d = from_hex("52036CEE 2B6FFE73 8CC74079 7779E898 00700A4D 4141D8AB 75EB4DCA 135978A3");
let n = from_hex("10000000 00000000 00000000 00000000 14DEF9DE A2F79CD6 5812631A 5CF5D3ED");
let gx = from_hex("216936D3 CD6E53FE C0A4E231 FDD6DC5C 692CC760 9525A7B2 C9562D60 8F25D51A");
let gy = from_hex("6666666666666666 66666666666666666666666666666666 6666666666666658");
TwistedEdwardsCurve::new(p, a, d, n, gx, gy).expect("Ed25519 parameters are well-formed")
}
#[cfg(test)]
mod tests {
use super::*;
fn decode_hex(hex: &str) -> Vec<u8> {
let bytes = hex.as_bytes();
let mut out = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
let hi = (chunk[0] as char).to_digit(16).expect("hex") as u8;
let lo = (chunk[1] as char).to_digit(16).expect("hex") as u8;
out.push((hi << 4) | lo);
}
out
}
#[test]
fn ed25519_base_point_on_curve() {
let curve = ed25519();
let g = curve.base_point();
assert!(
curve.is_on_curve(&g),
"Ed25519 base point G must satisfy −x² + y² = 1 + d·x²·y²"
);
}
#[test]
fn ed25519_double_equals_add_self() {
let curve = ed25519();
let g = curve.base_point();
let via_double = curve.double(&g);
let via_add = curve.add(&g, &g);
assert_eq!(via_double, via_add, "2G via double must equal G+G via add");
assert!(curve.is_on_curve(&via_double), "2G must lie on Ed25519");
}
#[test]
fn ed25519_scalar_mul_matches_repeated_add() {
let curve = ed25519();
let g = curve.base_point();
let four_g_scalar = curve.scalar_mul(&g, &BigUint::from_u64(4));
let two_g = curve.double(&g);
let four_g_add = curve.add(&two_g, &two_g);
assert_eq!(
four_g_scalar, four_g_add,
"4G via scalar_mul must equal 2G+2G"
);
}
#[test]
fn ed25519_scalar_mul_base_matches_generic_base_path() {
let curve = ed25519();
let g = curve.base_point();
let scalar = BigUint::from_u64(77);
let via_base = curve.scalar_mul_base(&scalar);
let via_generic = scalar_mul_extended(&curve, &g, &scalar);
assert_eq!(
via_base, via_generic,
"fixed-base path must match generic path"
);
}
#[test]
fn ed25519_order_times_base_point_is_neutral() {
let curve = ed25519();
let g = curve.base_point();
let n = curve.n.clone();
let result = curve.scalar_mul(&g, &n);
assert!(
result.is_neutral(),
"n·G must be the neutral element for Ed25519"
);
}
#[test]
fn ed25519_negation_sums_to_neutral() {
let curve = ed25519();
let g = curve.base_point();
let neg_g = curve.negate(&g);
let sum = curve.add(&g, &neg_g);
assert!(sum.is_neutral(), "G + (−G) must be the neutral element");
}
#[test]
fn ed25519_encode_decode_roundtrip() {
let curve = ed25519();
let g = curve.base_point();
let encoded = curve.encode_point(&g);
assert_eq!(encoded.len(), 32, "Ed25519 encoding must be 32 bytes");
let decoded = curve
.decode_point(&encoded)
.expect("decode must succeed for the standard base point");
assert_eq!(decoded, g, "encode/decode must be the identity");
}
#[test]
fn ed25519_encode_decode_2g_roundtrip() {
let curve = ed25519();
let two_g = curve.double(&curve.base_point());
let encoded = curve.encode_point(&two_g);
let decoded = curve.decode_point(&encoded).expect("decode 2G");
assert_eq!(decoded, two_g);
}
#[test]
fn ed25519_known_basepoint_multiples_match_fixture_encodings() {
let curve = ed25519();
let g = curve.base_point();
let fixtures = [
(
1_u64,
"5866666666666666666666666666666666666666666666666666666666666666",
),
(
2_u64,
"c9a3f86aae465f0e56513864510f3997561fa2c9e85ea21dc2292309f3cd6022",
),
(
3_u64,
"d4b4f5784868c3020403246717ec169ff79e26608ea126a1ab69ee77d1b16712",
),
(
4_u64,
"2f1132ca61ab38dff00f2fea3228f24c6c71d58085b80e47e19515cb27e8d047",
),
(
5_u64,
"edc876d6831fd2105d0b4389ca2e283166469289146e2ce06faefe98b22548df",
),
(
7_u64,
"b862409fb5c4c4123df2abf7462b88f041ad36dd6864ce872fd5472be363c5b1",
),
(
11_u64,
"1337036ac32d8f30d4589c3c1c595812ce0fff40e37c6f5a97ab213f318290ad",
),
(
77_u64,
"aa6df914f7a0f04e7f852adf459873f17dba5b1671ea62e82cc10ed6aecc489c",
),
(
82_u64,
"b03ed935d1de5bba7f51574b9fd88239083116ff867ee8562ae990c487579623",
),
];
for (scalar, encoding_hex) in fixtures {
let point = curve.scalar_mul(&g, &BigUint::from_u64(scalar));
let encoding = curve.encode_point(&point);
assert_eq!(
encoding,
decode_hex(encoding_hex),
"{scalar}G encoding mismatch"
);
let decoded = curve
.decode_point(&encoding)
.expect("decode known multiple");
assert_eq!(decoded, point, "{scalar}G decode mismatch");
}
}
#[test]
fn ed25519_neutral_encodes_correctly() {
let curve = ed25519();
let neutral = EdwardsPoint::neutral();
let enc = curve.encode_point(&neutral);
assert_eq!(
enc[0], 0x01,
"first byte of neutral encoding must be 1 (LE)"
);
assert!(
enc[1..].iter().all(|&b| b == 0),
"remaining bytes of neutral must be 0"
);
}
#[test]
fn ed25519_neutral_roundtrip_preserves_identity() {
let curve = ed25519();
let neutral = EdwardsPoint::neutral();
let enc = curve.encode_point(&neutral);
let dec = curve.decode_point(&enc).expect("decode neutral");
assert!(
dec.is_neutral(),
"decode_point must preserve the neutral element"
);
}
#[test]
fn ed25519_decode_rejects_bad_length() {
let curve = ed25519();
let g = curve.base_point();
let mut enc = curve.encode_point(&g);
enc.pop();
assert!(
curve.decode_point(&enc).is_none(),
"truncated encoding must be rejected"
);
}
#[test]
fn ed25519_decode_rejects_neutral_with_sign_bit_set() {
let curve = ed25519();
let mut enc = curve.encode_point(&EdwardsPoint::neutral());
*enc.last_mut().expect("32-byte encoding") |= 0x80;
assert!(
curve.decode_point(&enc).is_none(),
"RFC 8032 forbids x = 0 with the sign bit set"
);
}
#[test]
fn ed25519_decode_rejects_non_canonical_y() {
let curve = ed25519();
let y_be = pad_to(curve.p.to_be_bytes(), curve.coord_len);
let enc: Vec<u8> = y_be.into_iter().rev().collect();
assert!(
curve.decode_point(&enc).is_none(),
"compressed encodings must reject y >= p"
);
}
#[test]
fn ed25519_ecdh_shared_secret_agrees() {
use crate::CtrDrbgAes256;
let curve = ed25519();
let mut rng = CtrDrbgAes256::new(&[0xcd; 48]);
let (d_a, q_a) = curve.generate_keypair(&mut rng);
let (d_b, q_b) = curve.generate_keypair(&mut rng);
let shared_a = curve.diffie_hellman(&d_a, &q_b);
let shared_b = curve.diffie_hellman(&d_b, &q_a);
assert_eq!(shared_a, shared_b, "ECDH shared points must agree");
assert!(
!shared_a.is_neutral(),
"ECDH shared point must not be neutral"
);
assert!(
curve.is_on_curve(&shared_a),
"ECDH shared point must lie on Ed25519"
);
}
#[test]
fn ed25519_scalar_invert_roundtrip() {
let curve = ed25519();
let k = BigUint::from_u64(0x1234_5678_abcd_ef01);
let k_inv = curve.scalar_invert(&k).expect("k is non-zero");
let product = BigUint::mod_mul(&k, &k_inv, &curve.n);
assert_eq!(product, BigUint::one(), "k * k⁻¹ must equal 1 mod n");
}
}