use base64::DecodeError;
use ff::Field;
use group::{Curve, GroupEncoding};
use midnight_curves::k256::{Fp as K256Base, Fq as K256Scalar, K256Affine, K256};
use rand::RngCore;
use crate::CircuitField;
#[derive(Clone, Debug)]
pub struct Ecdsa;
pub type PublicKey = K256;
pub type SecretKey = K256Scalar;
#[derive(Clone, Copy, Debug)]
pub struct ECDSASig {
r: [u8; 32],
s: K256Scalar,
}
impl ECDSASig {
pub fn get_r(&self) -> [u8; 32] {
self.r
}
pub fn get_s(&self) -> K256Scalar {
self.s
}
pub fn from_bytes_be(bytes: &[u8]) -> Self {
assert_eq!(bytes.len(), 64);
let mut r = [0u8; 32];
r.copy_from_slice(&bytes[..32]);
r.reverse();
let s =
K256Scalar::from_bytes_be(&bytes[32..]).expect("Valid secp256k1 scalar in signature");
ECDSASig { r, s }
}
}
impl Ecdsa {
pub fn keygen<R: RngCore>(rng: &mut R) -> (PublicKey, SecretKey) {
let sk = K256Scalar::random(rng);
let pk = K256::generator() * sk;
(pk, sk)
}
pub fn sign<R: RngCore>(sk: &SecretKey, msg_hash: &K256Scalar, rng: &mut R) -> ECDSASig {
let k = K256Scalar::random(rng);
let k_point: K256 = K256::generator() * k;
let r_as_base = k_point.to_affine().x();
let r = r_as_base.to_bytes_le();
let r_as_scalar = K256Scalar::from_bytes_le(&r).unwrap();
let s = k.invert().unwrap() * (msg_hash + r_as_scalar * sk);
ECDSASig { r, s }
}
pub fn verify(pk: &PublicKey, msg_hash: &K256Scalar, signature: &ECDSASig) -> bool {
let g = K256::generator();
let r_as_scalar = K256Scalar::from_bytes_le(&signature.r).unwrap();
let r_as_base = K256Base::from_bytes_le(&signature.r).unwrap();
let s_inv = signature.s.invert().unwrap();
let k_point = g * (s_inv * msg_hash) + *pk * (s_inv * r_as_scalar);
k_point.to_affine().x() == r_as_base
}
}
pub trait FromBase64: Sized {
fn from_base64(base64_bytes: &[u8]) -> Result<Self, DecodeError>;
}
impl FromBase64 for ECDSASig {
fn from_base64(base64_bytes: &[u8]) -> Result<Self, DecodeError> {
let bytes = base64::decode_config(base64_bytes, base64::URL_SAFE_NO_PAD)?;
Ok(ECDSASig::from_bytes_be(&bytes))
}
}
impl FromBase64 for PublicKey {
fn from_base64(base64_bytes: &[u8]) -> Result<Self, DecodeError> {
let input_len = base64_bytes.len();
match input_len {
44 => {
let bytes = base64::decode_config(base64_bytes, base64::STANDARD_NO_PAD)?;
assert_eq!(bytes.len(), 33);
let repr: [u8; 33] = bytes.try_into().expect("33 bytes");
let ret = K256Affine::from_bytes(&repr.into())
.expect("Valid compressed secp256k1 point.");
Ok(ret.into())
}
86 => from_jwk(&base64_bytes[..43], &base64_bytes[43..]),
_ => Err(DecodeError::InvalidLength),
}
}
}
fn from_jwk(x: &[u8], y: &[u8]) -> Result<PublicKey, DecodeError> {
let x_bytes = base64::decode_config(x, base64::URL_SAFE)?;
let y_bytes = base64::decode_config(y, base64::URL_SAFE)?;
assert_eq!(x_bytes.len(), 32);
assert_eq!(y_bytes.len(), 32);
let x_fp = K256Base::from_bytes_be(&x_bytes).expect("Valid x coordinate");
let y_fp = K256Base::from_bytes_be(&y_bytes).expect("Valid y coordinate");
let ret = K256Affine::from_xy(x_fp, y_fp).expect("Valid point on curve");
Ok(ret.into())
}