use crate::primitives::base_point::BasePoint;
use crate::primitives::big_number::{BigNumber, Endian};
use crate::primitives::curve::Curve;
use crate::primitives::drbg::Drbg;
use crate::primitives::error::PrimitivesError;
use crate::primitives::jacobian_point::JacobianPoint;
use crate::primitives::point::Point;
use crate::primitives::signature::Signature;
fn truncate_to_n(msg: &BigNumber, trunc_only: bool) -> BigNumber {
let curve = Curve::secp256k1();
let n_bit_length = curve.n.bit_length();
let delta = (msg.byte_length() * 8).saturating_sub(n_bit_length);
let mut result = msg.clone();
if delta > 0 {
result = result.ushrn(delta);
}
if !trunc_only && result.cmp(&curve.n) >= 0 {
result = result.sub(&curve.n);
}
result
}
pub fn ecdsa_sign(
message_hash: &[u8; 32],
private_key: &BigNumber,
force_low_s: bool,
) -> Result<Signature, PrimitivesError> {
let curve = Curve::secp256k1();
let n = &curve.n;
let n_byte_len = n.byte_length();
let msg_bn = BigNumber::from_bytes(message_hash, Endian::Big);
let msg = truncate_to_n(&msg_bn, false);
let key_bytes = private_key.to_array(Endian::Big, Some(n_byte_len));
let nonce_bytes = msg.to_array(Endian::Big, Some(n_byte_len));
let mut entropy = [0u8; 32];
let mut nonce = [0u8; 32];
entropy.copy_from_slice(&key_bytes[..32]);
nonce.copy_from_slice(&nonce_bytes[..32]);
let mut drbg = Drbg::new(&entropy, &nonce);
let ns1 = n.subn(1);
let base_point = BasePoint::instance();
loop {
let k_bytes = drbg.generate(n_byte_len);
let k_hex: String = k_bytes.iter().map(|b| format!("{:02x}", b)).collect();
let k_bn = BigNumber::from_hex(&k_hex)
.map_err(|_| PrimitivesError::ArithmeticError("invalid k hex".to_string()))?;
let k_bn = truncate_to_n(&k_bn, true);
if k_bn.cmpn(1) < 0 || k_bn.cmp(&ns1) > 0 {
continue;
}
let r_point = base_point.mul(&k_bn);
if r_point.is_infinity() {
continue;
}
let r_bn = r_point
.get_x()
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("r mod n: {}", e)))?;
if r_bn.is_zero() {
continue;
}
let k_inv = k_bn
.invm(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("k inverse: {}", e)))?;
let r_times_key = r_bn
.mul(private_key)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("r*key mod n: {}", e)))?;
let sum = msg
.add(&r_times_key)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("hash+r*key mod n: {}", e)))?;
let mut s_bn = k_inv
.mul(&sum)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("s mod n: {}", e)))?;
if s_bn.is_zero() {
continue;
}
if force_low_s && s_bn.cmp(&curve.half_n) > 0 {
s_bn = n.sub(&s_bn);
}
return Ok(Signature::new(r_bn, s_bn));
}
}
pub fn ecdsa_sign_with_k(
message_hash: &[u8; 32],
private_key: &BigNumber,
k: &BigNumber,
force_low_s: bool,
) -> Result<Signature, PrimitivesError> {
let curve = Curve::secp256k1();
let n = &curve.n;
let msg_bn = BigNumber::from_bytes(message_hash, Endian::Big);
let msg = truncate_to_n(&msg_bn, false);
let ns1 = n.subn(1);
if k.cmpn(1) < 0 || k.cmp(&ns1) > 0 {
return Err(PrimitivesError::ArithmeticError(
"k must be in [1, n-1]".to_string(),
));
}
let base_point = BasePoint::instance();
let r_point = base_point.mul(k);
if r_point.is_infinity() {
return Err(PrimitivesError::ArithmeticError(
"k*G is point at infinity".to_string(),
));
}
let r_bn = r_point
.get_x()
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("r mod n: {}", e)))?;
if r_bn.is_zero() {
return Err(PrimitivesError::ArithmeticError("r is zero".to_string()));
}
let k_inv = k
.invm(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("k inverse: {}", e)))?;
let r_times_key = r_bn
.mul(private_key)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("r*key mod n: {}", e)))?;
let sum = msg
.add(&r_times_key)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("hash+r*key mod n: {}", e)))?;
let mut s_bn = k_inv
.mul(&sum)
.umod(n)
.map_err(|e| PrimitivesError::ArithmeticError(format!("s mod n: {}", e)))?;
if s_bn.is_zero() {
return Err(PrimitivesError::ArithmeticError("s is zero".to_string()));
}
if force_low_s && s_bn.cmp(&curve.half_n) > 0 {
s_bn = n.sub(&s_bn);
}
Ok(Signature::new(r_bn, s_bn))
}
pub fn ecdsa_verify(message_hash: &[u8; 32], signature: &Signature, public_key: &Point) -> bool {
let curve = Curve::secp256k1();
let n = &curve.n;
let msg_bn = BigNumber::from_bytes(message_hash, Endian::Big);
let r = signature.r();
let s = signature.s();
if r.cmpn(1) < 0 || r.cmp(n) >= 0 {
return false;
}
if s.cmpn(1) < 0 || s.cmp(n) >= 0 {
return false;
}
let s_inv = match s.invm(n) {
Ok(inv) => inv,
Err(_) => return false,
};
let u1 = match msg_bn.mul(&s_inv).umod(n) {
Ok(val) => val,
Err(_) => return false,
};
let u2 = match r.mul(&s_inv).umod(n) {
Ok(val) => val,
Err(_) => return false,
};
let base_point = BasePoint::instance();
let q_jac = JacobianPoint::from_affine(&public_key.x, &public_key.y);
let r_jac = JacobianPoint::shamir_mul_wnaf(&u1, base_point.table(), &u2, &q_jac);
if r_jac.is_infinity() {
return false;
}
let (rx, _ry) = r_jac.to_affine();
let v = match rx.umod(n) {
Ok(val) => val,
Err(_) => return false,
};
v.cmp(r) == 0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::primitives::hash::sha256;
fn hex_to_bytes(hex: &str) -> Vec<u8> {
(0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
.collect()
}
#[test]
fn test_ecdsa_sign_deterministic() {
let key = BigNumber::from_number(1);
let msg_hash = sha256(b"test message");
let sig1 = ecdsa_sign(&msg_hash, &key, false).unwrap();
let sig2 = ecdsa_sign(&msg_hash, &key, false).unwrap();
assert_eq!(
sig1.r().to_hex(),
sig2.r().to_hex(),
"r should be deterministic"
);
assert_eq!(
sig1.s().to_hex(),
sig2.s().to_hex(),
"s should be deterministic"
);
}
#[test]
fn test_ecdsa_sign_verify_roundtrip() {
let key = BigNumber::from_number(42);
let msg_hash = sha256(b"Hello, BSV!");
let sig = ecdsa_sign(&msg_hash, &key, true).unwrap();
let base_point = BasePoint::instance();
let pubkey = base_point.mul(&key);
assert!(
ecdsa_verify(&msg_hash, &sig, &pubkey),
"Valid signature should verify"
);
}
#[test]
fn test_ecdsa_verify_wrong_key() {
let key = BigNumber::from_number(1);
let wrong_key = BigNumber::from_number(2);
let msg_hash = sha256(b"test");
let sig = ecdsa_sign(&msg_hash, &key, false).unwrap();
let base_point = BasePoint::instance();
let wrong_pubkey = base_point.mul(&wrong_key);
assert!(
!ecdsa_verify(&msg_hash, &sig, &wrong_pubkey),
"Wrong public key should fail verification"
);
}
#[test]
fn test_ecdsa_verify_wrong_message() {
let key = BigNumber::from_number(1);
let msg_hash = sha256(b"correct message");
let wrong_hash = sha256(b"wrong message");
let sig = ecdsa_sign(&msg_hash, &key, false).unwrap();
let base_point = BasePoint::instance();
let pubkey = base_point.mul(&key);
assert!(
!ecdsa_verify(&wrong_hash, &sig, &pubkey),
"Wrong message should fail verification"
);
}
#[test]
fn test_ecdsa_sign_low_s() {
let curve = Curve::secp256k1();
let key = BigNumber::from_number(12345);
let msg_hash = sha256(b"low-s test");
let sig = ecdsa_sign(&msg_hash, &key, true).unwrap();
assert!(
sig.s().cmp(&curve.half_n) <= 0,
"S should be <= n/2 when force_low_s is true"
);
}
#[test]
fn test_ecdsa_sign_vectors() {
use serde::Deserialize;
#[derive(Deserialize)]
struct SignVector {
private_key_hex: String,
message_hash_hex: String,
expected_r: String,
expected_s: String,
force_low_s: bool,
#[allow(dead_code)]
description: String,
#[allow(dead_code)]
note: String,
}
let data = include_str!("../../test-vectors/ecdsa_sign.json");
let vectors: Vec<SignVector> = serde_json::from_str(data).unwrap();
for (i, v) in vectors.iter().enumerate() {
let key = BigNumber::from_hex(&v.private_key_hex).unwrap();
let msg_bytes = hex_to_bytes(&v.message_hash_hex);
let mut msg_hash = [0u8; 32];
msg_hash.copy_from_slice(&msg_bytes);
let sig = ecdsa_sign(&msg_hash, &key, v.force_low_s).unwrap();
let r_hex = sig.r().to_hex();
let s_hex = sig.s().to_hex();
let r_padded = format!("{:0>64}", r_hex);
let s_padded = format!("{:0>64}", s_hex);
assert_eq!(r_padded, v.expected_r, "Vector {}: r mismatch", i);
assert_eq!(s_padded, v.expected_s, "Vector {}: s mismatch", i);
}
}
#[test]
fn test_ecdsa_verify_vectors() {
use serde::Deserialize;
#[derive(Deserialize)]
struct VerifyVector {
message_hash_hex: String,
public_key_x: String,
public_key_y: String,
signature_r: String,
signature_s: String,
expected_valid: bool,
#[allow(dead_code)]
description: String,
#[allow(dead_code)]
note: String,
}
let data = include_str!("../../test-vectors/ecdsa_verify.json");
let vectors: Vec<VerifyVector> = serde_json::from_str(data).unwrap();
for (i, v) in vectors.iter().enumerate() {
let msg_bytes = hex_to_bytes(&v.message_hash_hex);
let mut msg_hash = [0u8; 32];
msg_hash.copy_from_slice(&msg_bytes);
let pub_x = BigNumber::from_hex(&v.public_key_x).unwrap();
let pub_y = BigNumber::from_hex(&v.public_key_y).unwrap();
let pubkey = Point::new(pub_x, pub_y);
let r = BigNumber::from_hex(&v.signature_r).unwrap();
let s = BigNumber::from_hex(&v.signature_s).unwrap();
let sig = Signature::new(r, s);
let result = ecdsa_verify(&msg_hash, &sig, &pubkey);
assert_eq!(
result, v.expected_valid,
"Vector {}: expected valid={}, got {}",
i, v.expected_valid, result
);
}
}
#[test]
fn test_ecdsa_verify_tampered_signature() {
let key = BigNumber::from_number(7);
let msg_hash = sha256(b"tamper test");
let sig = ecdsa_sign(&msg_hash, &key, false).unwrap();
let base_point = BasePoint::instance();
let pubkey = base_point.mul(&key);
let bad_r = sig.r().addn(1);
let bad_sig = Signature::new(bad_r, sig.s().clone());
assert!(
!ecdsa_verify(&msg_hash, &bad_sig, &pubkey),
"Tampered r should fail"
);
let bad_s = sig.s().addn(1);
let bad_sig = Signature::new(sig.r().clone(), bad_s);
assert!(
!ecdsa_verify(&msg_hash, &bad_sig, &pubkey),
"Tampered s should fail"
);
}
#[test]
fn test_ecdsa_multiple_keys() {
let base_point = BasePoint::instance();
for i in 1..=5 {
let key = BigNumber::from_number(i * 1000);
let msg_hash = sha256(format!("message {}", i).as_bytes());
let sig = ecdsa_sign(&msg_hash, &key, true).unwrap();
let pubkey = base_point.mul(&key);
assert!(
ecdsa_verify(&msg_hash, &sig, &pubkey),
"Key {} should verify",
i
);
}
}
}