use k256::elliptic_curve::group::{Group, GroupEncoding};
use k256::elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint};
use k256::{AffinePoint, ProjectivePoint, Scalar};
use crate::prelude::{Bigint, ByteString, Point};
fn point_to_projective(p: &[u8]) -> ProjectivePoint {
assert_eq!(p.len(), 64, "Point must be exactly 64 bytes");
if p.iter().all(|&b| b == 0) {
return ProjectivePoint::IDENTITY;
}
let mut sec1 = vec![0x04u8];
sec1.extend_from_slice(p);
let encoded = k256::EncodedPoint::from_bytes(&sec1)
.expect("invalid SEC1 encoding");
let affine = AffinePoint::from_encoded_point(&encoded)
.expect("point not on curve");
ProjectivePoint::from(affine)
}
fn projective_to_point(p: &ProjectivePoint) -> Point {
if p.is_identity().into() {
return vec![0u8; 64];
}
let affine = p.to_affine();
let encoded = affine.to_encoded_point(false); let bytes = encoded.as_bytes(); bytes[1..65].to_vec()
}
fn i64_to_scalar(k: Bigint) -> Scalar {
if k >= 0 {
Scalar::from(k as u64)
} else {
Scalar::ZERO - Scalar::from((-k) as u64)
}
}
pub fn ec_add(a: &[u8], b: &[u8]) -> Point {
let pa = point_to_projective(a);
let pb = point_to_projective(b);
projective_to_point(&(pa + pb))
}
pub fn ec_mul(p: &[u8], k: Bigint) -> Point {
let pp = point_to_projective(p);
let s = i64_to_scalar(k);
projective_to_point(&(pp * s))
}
pub fn ec_mul_gen(k: Bigint) -> Point {
let s = i64_to_scalar(k);
projective_to_point(&(ProjectivePoint::GENERATOR * s))
}
pub fn ec_negate(p: &[u8]) -> Point {
let pp = point_to_projective(p);
projective_to_point(&(-pp))
}
pub fn ec_on_curve(p: &[u8]) -> bool {
if p.len() != 64 {
return false;
}
if p.iter().all(|&b| b == 0) {
return true;
}
let mut sec1 = vec![0x04u8];
sec1.extend_from_slice(p);
let Ok(enc) = k256::EncodedPoint::from_bytes(&sec1) else { return false };
let ct = AffinePoint::from_encoded_point(&enc);
ct.is_some().into()
}
pub fn ec_mod_reduce(value: Bigint, m: Bigint) -> Bigint {
let r = value % m;
if r < 0 { r + m } else { r }
}
pub fn ec_encode_compressed(p: &[u8]) -> ByteString {
let pp = point_to_projective(p);
let affine = pp.to_affine();
affine.to_bytes().to_vec()
}
pub fn ec_make_point(x: Bigint, y: Bigint) -> Point {
let mut buf = vec![0u8; 64];
let xb = (x as u64).to_be_bytes();
let yb = (y as u64).to_be_bytes();
buf[24..32].copy_from_slice(&xb);
buf[56..64].copy_from_slice(&yb);
buf
}
pub fn ec_point_x(p: &[u8]) -> Bigint {
assert_eq!(p.len(), 64, "Point must be exactly 64 bytes");
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&p[24..32]);
u64::from_be_bytes(bytes) as i64
}
pub fn ec_point_y(p: &[u8]) -> Bigint {
assert_eq!(p.len(), 64, "Point must be exactly 64 bytes");
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&p[56..64]);
u64::from_be_bytes(bytes) as i64
}
#[cfg(test)]
mod tests {
use super::*;
fn ec_g() -> Point {
ec_mul_gen(1)
}
#[test]
fn ec_g_is_64_bytes() {
assert_eq!(ec_g().len(), 64);
}
#[test]
fn ec_g_is_on_curve() {
assert!(ec_on_curve(&ec_g()));
}
#[test]
fn ec_add_g_g_equals_ec_mul_g_2() {
let g = ec_g();
let sum = ec_add(&g, &g);
let doubled = ec_mul(&g, 2);
assert_eq!(sum, doubled);
}
#[test]
fn ec_add_g_g_equals_ec_mul_gen_2() {
let g = ec_g();
let sum = ec_add(&g, &g);
let gen2 = ec_mul_gen(2);
assert_eq!(sum, gen2);
}
#[test]
fn ec_mul_gen_1_equals_g() {
let g = ec_g();
let gen1 = ec_mul_gen(1);
assert_eq!(gen1, g);
}
#[test]
fn ec_negate_produces_on_curve_point() {
let g = ec_g();
let neg = ec_negate(&g);
assert_eq!(neg.len(), 64);
assert!(ec_on_curve(&neg));
assert_ne!(neg, g);
}
#[test]
fn ec_negate_double_negate_is_identity() {
let g = ec_g();
let double_neg = ec_negate(&ec_negate(&g));
assert_eq!(double_neg, g);
}
#[test]
fn ec_add_point_and_negation_is_identity() {
let g = ec_g();
let neg = ec_negate(&g);
let sum = ec_add(&g, &neg);
assert_eq!(sum, vec![0u8; 64]);
}
#[test]
fn ec_make_point_round_trip() {
let x: Bigint = 12345;
let y: Bigint = 67890;
let p = ec_make_point(x, y);
assert_eq!(p.len(), 64);
assert_eq!(ec_point_x(&p), x);
assert_eq!(ec_point_y(&p), y);
}
#[test]
fn ec_encode_compressed_produces_33_bytes() {
let g = ec_g();
let compressed = ec_encode_compressed(&g);
assert_eq!(compressed.len(), 33);
assert!(compressed[0] == 0x02 || compressed[0] == 0x03);
}
#[test]
fn ec_on_curve_rejects_invalid_point() {
let bad_point = vec![0xffu8; 64];
assert!(!ec_on_curve(&bad_point));
}
#[test]
fn ec_on_curve_rejects_wrong_length() {
assert!(!ec_on_curve(&[0u8; 32]));
}
#[test]
fn ec_on_curve_accepts_identity() {
assert!(ec_on_curve(&vec![0u8; 64]));
}
#[test]
fn ec_mod_reduce_basic() {
assert_eq!(ec_mod_reduce(10, 3), 1);
assert_eq!(ec_mod_reduce(-1, 5), 4);
assert_eq!(ec_mod_reduce(0, 7), 0);
}
#[test]
fn ec_mul_associative() {
let g = ec_g();
let g3 = ec_mul(&g, 3);
let g3x2 = ec_mul(&g3, 2);
let g6 = ec_mul_gen(6);
assert_eq!(g3x2, g6);
}
}