gm-sm9 0.2.0

A Rust Implementation of China's Standards of Encryption Algorithms SM9
Documentation
use crate::fields::FieldElement;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
use std::io::Cursor;

pub type U256 = [u64; 4];
pub type U512 = [u64; 8];

pub(crate) const SM9_ZERO: U256 = [0, 0, 0, 0];
pub(crate) const SM9_ONE: U256 = [1, 0, 0, 0];
pub(crate) const SM9_TWO: U256 = [2, 0, 0, 0];
pub(crate) const SM9_FIVE: U256 = [5, 0, 0, 0];

#[inline(always)]
pub fn sm9_random_u256(range: &U256) -> U256 {
    let mut rng = rand::thread_rng();
    let mut ret;
    loop {
        let mut buf: [u8; 32] = [0; 32];
        rng.fill_bytes(&mut buf[..]);
        ret = u256_from_be_bytes(&buf);
        if u256_cmp(&ret, range) < 0 && ret >= [1, 0, 0, 0] {
            break;
        }
    }
    ret
}

#[inline(always)]
pub const fn u256_add(a: &U256, b: &U256) -> (U256, bool) {
    let mut sum = [0; 4];
    let mut carry = false;
    let mut i = 0;
    loop {
        let (t_sum, c) = {
            let (m, c1) = a[i].overflowing_add(b[i]);
            let (r, c2) = m.overflowing_add(carry as u64);
            (r, c1 || c2)
        };
        sum[i] = t_sum;
        carry = c;
        if i == 3 {
            break;
        }
        i += 1;
    }
    (sum, carry)
}

#[inline(always)]
pub const fn u512_add(a: &U512, b: &U512) -> (U512, bool) {
    let mut sum = [0; 8];
    let mut carry = false;
    let mut i = 0;
    loop {
        let (t_sum, c) = {
            let (m, c1) = a[i].overflowing_add(b[i]);
            let (r, c2) = m.overflowing_add(carry as u64);
            (r, c1 || c2)
        };
        sum[i] = t_sum;
        carry = c;
        if i == 7 {
            break;
        }
        i += 1;
    }
    (sum, carry)
}

#[inline(always)]
pub const fn u256_sub(a: &U256, b: &U256) -> (U256, bool) {
    let mut r = [0; 4];
    let mut borrow = false;
    let mut j = 3;
    loop {
        let i = 3 - j;
        let (diff, bor) = {
            let (a, b1) = a[i].overflowing_sub(borrow as u64);
            let (res, b2) = a.overflowing_sub(b[i]);
            (res, b1 || b2)
        };
        r[i] = diff;
        borrow = bor;
        if j == 0 {
            break;
        }
        j -= 1;
    }
    (r, borrow)
}

#[inline(always)]
pub const fn u512_sub(a: &U512, b: &U512) -> (U512, bool) {
    let mut r = [0; 8];
    let mut borrow = false;
    let mut j = 7;
    loop {
        let i = 7 - j;
        let (diff, bor) = {
            let (a, b1) = a[i].overflowing_sub(borrow as u64);
            let (res, b2) = a.overflowing_sub(b[i]);
            (res, b1 || b2)
        };
        r[i] = diff;
        borrow = bor;
        if j == 0 {
            break;
        }
        j -= 1;
    }
    (r, borrow)
}

#[inline(always)]
pub const fn u256_cmp(a: &U256, b: &U256) -> i32 {
    if a[3] > b[3] {
        return 1;
    }
    if a[3] < b[3] {
        return -1;
    }
    if a[2] > b[2] {
        return 1;
    }
    if (a[2] < b[2]) {
        return -1;
    }
    if (a[1] > b[1]) {
        return 1;
    }
    if a[1] < b[1] {
        return -1;
    }
    if a[0] > b[0] {
        return 1;
    }
    if a[0] < b[0] {
        return -1;
    }
    0
}

#[inline(always)]
pub fn u256_mul(a: &U256, b: &U256) -> U512 {
    let mut a_: [u64; 8] = [0; 8];
    let mut b_: [u64; 8] = [0; 8];
    let mut ret: [u64; 8] = [0; 8];
    let mut s: [u64; 16] = [0; 16];

    for i in 0..4 {
        a_[2 * i] = a[i] & 0xffffffff;
        b_[2 * i] = b[i] & 0xffffffff;
        a_[2 * i + 1] = a[i] >> 32;
        b_[2 * i + 1] = b[i] >> 32;
    }

    let mut u = 0;
    for i in 0..8 {
        u = 0;
        for j in 0..8 {
            u = s[i + j] + a_[i] * b_[j] + u;
            s[i + j] = u & 0xffffffff;
            u >>= 32;
        }
        s[i + 8] = u;
    }

    for i in 0..8 {
        ret[i] = (s[2 * i + 1] << 32) | s[2 * i];
    }
    ret
}

#[inline(always)]
pub fn xor(k: &[u8], data: &[u8], len: usize) -> Vec<u8> {
    let mut ret: Vec<u8> = vec![];
    for i in 0..len {
        ret.push(k[i] ^ data[i]);
    }
    ret
}

#[inline(always)]
pub fn u256_to_be_bytes(a: &U256) -> Vec<u8> {
    let mut ret: Vec<u8> = Vec::new();
    for i in (0..4).rev() {
        ret.write_u64::<BigEndian>(a[i]).unwrap();
    }
    ret
}

#[inline(always)]
pub fn u256_to_bits(a: U256) -> [char; 256] {
    let mut bits = ['0'; 256]; // 初始化长度为 256 的字符数组,默认值为 '0'
    let mut index = 0;
    for i in (0..4).rev() {
        // 遍历 4 个 u64 元素,倒序访问
        let mut w = a[i];
        for _ in 0..64 {
            bits[index] = if (w & 0x8000_0000_0000_0000) != 0 {
                '1'
            } else {
                '0'
            };
            w <<= 1;
            index += 1;
        }
    }
    bits
}

#[inline(always)]
pub fn u256_from_be_bytes(input: &[u8]) -> U256 {
    let mut elem = [0, 0, 0, 0];
    let mut c = Cursor::new(input);
    for i in (0..4).rev() {
        elem[i] = c.read_u64::<BigEndian>().unwrap();
    }
    elem
}

pub fn sm9_u256_get_booth(a: &[u64], window_size: u64, i: u64) -> i32 {
    let mask = (1 << window_size) - 1;
    let (mut n, mut j) = (0_usize, 0_usize);
    if i == 0 {
        return (((a[0] << 1) & mask) as i32) - ((a[0] & mask) as i32);
    }

    j = (i * window_size - 1) as usize;
    n = j / 64;
    j = j % 64;

    let mut wbits = a[n] >> j;
    if (64 - j) < (window_size + 1) as usize && n < 3 {
        wbits |= a[n + 1] << (64 - j);
    }
    ((wbits & mask) as i32) - (((wbits >> 1) & mask) as i32)
}

#[cfg(test)]
mod test_operation {
    use num_bigint::BigUint;

    use crate::u256::{
        sm9_u256_get_booth, u256_add, u256_from_be_bytes, u256_mul, u256_sub, u256_to_be_bytes,
    };

    #[test]
    fn test_raw_add_u64() {
        let a: [u64; 4] = [
            0x54806C11D8806141,
            0xF1DD2C190F5E93C4,
            0x597B6027B441A01F,
            0x85AEF3D078640C98,
        ];

        let b: [u64; 4] = [
            0x0E75C05FB4E3216D,
            0x1006E85F5CDFF073,
            0x1A7CE027B7A46F74,
            0x41E00A53DDA532DA,
        ];

        let a1 = BigUint::from_bytes_be(
            &hex::decode("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")
                .unwrap(),
        );

        let b1 = BigUint::from_bytes_be(
            &hex::decode("41E00A53DDA532DA1A7CE027B7A46F741006E85F5CDFF0730E75C05FB4E3216D")
                .unwrap(),
        );

        let (mut r, c) = u256_add(&a, &b);
        r.reverse();
        let mut sum = (&a1 + &b1).to_u64_digits();
        sum.reverse();
        assert_eq!(r, *sum);

        let (mut r, c) = u256_sub(&a, &b);
        r.reverse();
        let mut sub = (&a1 - &b1).to_u64_digits();
        sub.reverse();
        assert_eq!(r, *sub);

        let mut r = u256_mul(&a, &b);
        r.reverse();
        let mut mul = (&a1 * &b1).to_u64_digits();
        mul.reverse();
        assert_eq!(r, *mul);
    }

    #[test]
    fn test_to_bytes_be() {
        let a: [u64; 4] = [
            0x54806C11D8806141,
            0xF1DD2C190F5E93C4,
            0x597B6027B441A01F,
            0x85AEF3D078640C98,
        ];

        let r_b = [
            133, 174, 243, 208, 120, 100, 12, 152, 89, 123, 96, 39, 180, 65, 160, 31, 241, 221, 44,
            25, 15, 94, 147, 196, 84, 128, 108, 17, 216, 128, 97, 65,
        ];

        let ret = u256_to_be_bytes(&a);
        assert_eq!(ret, r_b);

        let ret = u256_from_be_bytes(&r_b);
        assert_eq!(ret, a);
    }

    #[test]
    fn test_get_booth() {
        let k = u256_from_be_bytes(
            &hex::decode("123456789abcdef00fedcba987654321123456789abcdef00fedcba987654321")
                .unwrap(),
        );
        let window_size = 7u64;
        let n = (256 + window_size - 1) / window_size;
        for i in (0..n).rev() {
            let booth = sm9_u256_get_booth(&k, window_size, i);
            println!("i = {}, booth = {}", i, booth);
        }
    }
}