use crate::check_cx_ok;
use crate::ecc::CxError;
use core::cell::Cell;
use core::cmp::Ordering;
use core::ffi::c_int;
use ledger_secure_sdk_sys::*;
pub(crate) const BN_DEFAULT_WORD_NBYTES: usize = 32;
struct BnRefCount {
count: Cell<u32>,
}
unsafe impl Sync for BnRefCount {}
static BN_RC: BnRefCount = BnRefCount {
count: Cell::new(0),
};
pub(crate) fn bn_retain(word_nbytes: usize) -> Result<(), CxError> {
let c = BN_RC.count.get();
if c == 0 {
check_cx_ok!(cx_bn_lock(word_nbytes, 0));
}
BN_RC.count.set(c + 1);
Ok(())
}
pub(crate) fn bn_release() {
let c = BN_RC.count.get();
debug_assert!(c > 0, "bn_release called with zero ref count");
let new = c - 1;
BN_RC.count.set(new);
if new == 0 {
unsafe {
cx_bn_unlock();
}
}
}
pub struct BnLock;
impl BnLock {
pub fn acquire(word_nbytes: usize) -> Result<Self, CxError> {
bn_retain(word_nbytes)?;
Ok(BnLock)
}
pub fn release(self) {
drop(self);
}
pub fn is_locked() -> bool {
unsafe { cx_bn_locked() == CX_OK }
}
pub fn is_locked_bool() -> bool {
unsafe { cx_bn_is_locked() }
}
}
impl Drop for BnLock {
fn drop(&mut self) {
bn_release();
}
}
#[derive(Debug)]
pub struct Bn {
handle: cx_bn_t,
}
impl Bn {
pub fn alloc(nbytes: usize) -> Result<Self, CxError> {
bn_retain(BN_DEFAULT_WORD_NBYTES)?;
let mut handle: cx_bn_t = CX_BN_FLAG_UNSET;
let err = unsafe { cx_bn_alloc(&mut handle, nbytes) };
if err != CX_OK {
bn_release();
return Err(err.into());
}
Ok(Self { handle })
}
pub fn alloc_init(value: &[u8]) -> Result<Self, CxError> {
bn_retain(BN_DEFAULT_WORD_NBYTES)?;
let nbytes = align_bn_size(value.len());
let mut handle: cx_bn_t = CX_BN_FLAG_UNSET;
let err = unsafe { cx_bn_alloc_init(&mut handle, nbytes, value.as_ptr(), value.len()) };
if err != CX_OK {
bn_release();
return Err(err.into());
}
Ok(Self { handle })
}
pub fn alloc_init_size(nbytes: usize, value: &[u8]) -> Result<Self, CxError> {
bn_retain(BN_DEFAULT_WORD_NBYTES)?;
let mut handle: cx_bn_t = CX_BN_FLAG_UNSET;
let err = unsafe { cx_bn_alloc_init(&mut handle, nbytes, value.as_ptr(), value.len()) };
if err != CX_OK {
bn_release();
return Err(err.into());
}
Ok(Self { handle })
}
pub fn raw(&self) -> cx_bn_t {
self.handle
}
pub fn raw_mut(&mut self) -> &mut cx_bn_t {
&mut self.handle
}
pub fn init(&self, value: &[u8]) -> Result<(), CxError> {
check_cx_ok!(cx_bn_init(self.handle, value.as_ptr(), value.len()));
Ok(())
}
pub fn rand(&self) -> Result<(), CxError> {
check_cx_ok!(cx_bn_rand(self.handle));
Ok(())
}
pub fn copy_from(&self, other: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_copy(self.handle, other.handle));
Ok(())
}
pub fn nbytes(&self) -> Result<usize, CxError> {
let mut n = 0usize;
check_cx_ok!(cx_bn_nbytes(self.handle, &mut n));
Ok(n)
}
pub fn set_u32(&self, n: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_set_u32(self.handle, n));
Ok(())
}
pub fn get_u32(&self) -> Result<u32, CxError> {
let mut n = 0u32;
check_cx_ok!(cx_bn_get_u32(self.handle, &mut n));
Ok(n)
}
pub fn export(&self, bytes: &mut [u8]) -> Result<(), CxError> {
check_cx_ok!(cx_bn_export(self.handle, bytes.as_mut_ptr(), bytes.len()));
Ok(())
}
pub fn cmp_bn(&self, other: &Bn) -> Result<Ordering, CxError> {
let mut diff: c_int = 0;
check_cx_ok!(cx_bn_cmp(self.handle, other.handle, &mut diff));
Ok(int_to_ordering(diff))
}
pub fn cmp_u32(&self, other: u32) -> Result<Ordering, CxError> {
let mut diff: c_int = 0;
check_cx_ok!(cx_bn_cmp_u32(self.handle, other, &mut diff));
Ok(int_to_ordering(diff))
}
pub fn is_odd(&self) -> Result<bool, CxError> {
let mut odd = false;
check_cx_ok!(cx_bn_is_odd(self.handle, &mut odd));
Ok(odd)
}
pub fn xor(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_xor(self.handle, a.handle, b.handle));
Ok(())
}
pub fn or(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_or(self.handle, a.handle, b.handle));
Ok(())
}
pub fn and(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_and(self.handle, a.handle, b.handle));
Ok(())
}
pub fn tst_bit(&self, pos: u32) -> Result<bool, CxError> {
let mut set = false;
check_cx_ok!(cx_bn_tst_bit(self.handle, pos, &mut set));
Ok(set)
}
pub fn set_bit(&self, pos: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_set_bit(self.handle, pos));
Ok(())
}
pub fn clr_bit(&self, pos: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_clr_bit(self.handle, pos));
Ok(())
}
pub fn cnt_bits(&self) -> Result<u32, CxError> {
let mut nbits = 0u32;
check_cx_ok!(cx_bn_cnt_bits(self.handle, &mut nbits));
Ok(nbits)
}
pub fn shr(&self, n: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_shr(self.handle, n));
Ok(())
}
pub fn shl(&self, n: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_shl(self.handle, n));
Ok(())
}
pub fn add(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_add(self.handle, a.handle, b.handle));
Ok(())
}
pub fn sub(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_sub(self.handle, a.handle, b.handle));
Ok(())
}
pub fn mul(&self, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mul(self.handle, a.handle, b.handle));
Ok(())
}
pub fn mod_add(&self, a: &Bn, b: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_add(self.handle, a.handle, b.handle, n.handle));
Ok(())
}
pub fn mod_sub(&self, a: &Bn, b: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_sub(self.handle, a.handle, b.handle, n.handle));
Ok(())
}
pub fn mod_mul(&self, a: &Bn, b: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_mul(self.handle, a.handle, b.handle, n.handle));
Ok(())
}
pub fn reduce(&self, d: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_reduce(self.handle, d.handle, n.handle));
Ok(())
}
pub fn mod_sqrt(&self, a: &Bn, n: &Bn, sign: u32) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_sqrt(self.handle, a.handle, n.handle, sign));
Ok(())
}
pub fn mod_pow_bn(&self, a: &Bn, e: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_pow_bn(self.handle, a.handle, e.handle, n.handle));
Ok(())
}
pub fn mod_pow(&self, a: &Bn, e: &[u8], n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_pow(
self.handle,
a.handle,
e.as_ptr(),
e.len() as u32,
n.handle,
));
Ok(())
}
pub fn mod_pow2(&self, a: &Bn, e: &[u8], n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_pow2(
self.handle,
a.handle,
e.as_ptr(),
e.len() as u32,
n.handle,
));
Ok(())
}
pub fn mod_invert_nprime(&self, a: &Bn, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_invert_nprime(self.handle, a.handle, n.handle));
Ok(())
}
pub fn mod_u32_invert(&self, a: u32, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_mod_u32_invert(self.handle, a, n.handle));
Ok(())
}
pub fn is_prime(&self) -> Result<bool, CxError> {
let mut prime = false;
check_cx_ok!(cx_bn_is_prime(self.handle, &mut prime));
Ok(prime)
}
pub fn next_prime(&self) -> Result<(), CxError> {
check_cx_ok!(cx_bn_next_prime(self.handle));
Ok(())
}
pub fn rng(&self, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_rng(self.handle, n.handle));
Ok(())
}
pub fn gf2_n_mul(&self, a: &Bn, b: &Bn, n: &Bn, h: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_bn_gf2_n_mul(
self.handle,
a.handle,
b.handle,
n.handle,
h.handle,
));
Ok(())
}
}
impl Drop for Bn {
fn drop(&mut self) {
unsafe {
cx_bn_destroy(&mut self.handle);
}
bn_release();
}
}
pub struct MontCtx {
inner: cx_bn_mont_ctx_t,
}
impl MontCtx {
pub fn alloc(length: usize) -> Result<Self, CxError> {
bn_retain(BN_DEFAULT_WORD_NBYTES)?;
let mut inner = cx_bn_mont_ctx_t::default();
let err = unsafe { cx_mont_alloc(&mut inner, length) };
if err != CX_OK {
bn_release();
return Err(err.into());
}
Ok(Self { inner })
}
pub fn init(&mut self, n: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_init(&mut self.inner, n.handle));
Ok(())
}
pub fn init2(&mut self, n: &Bn, h: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_init2(&mut self.inner, n.handle, h.handle));
Ok(())
}
pub fn to_montgomery(&self, x: &Bn, z: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_to_montgomery(x.handle, z.handle, &self.inner));
Ok(())
}
pub fn from_montgomery(&self, z: &Bn, x: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_from_montgomery(z.handle, x.handle, &self.inner));
Ok(())
}
pub fn mul(&self, r: &Bn, a: &Bn, b: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_mul(r.handle, a.handle, b.handle, &self.inner));
Ok(())
}
pub fn pow(&self, r: &Bn, a: &Bn, e: &[u8]) -> Result<(), CxError> {
check_cx_ok!(cx_mont_pow(
r.handle,
a.handle,
e.as_ptr(),
e.len() as u32,
&self.inner,
));
Ok(())
}
pub fn pow_bn(&self, r: &Bn, a: &Bn, e: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_pow_bn(r.handle, a.handle, e.handle, &self.inner));
Ok(())
}
pub fn invert_nprime(&self, r: &Bn, a: &Bn) -> Result<(), CxError> {
check_cx_ok!(cx_mont_invert_nprime(r.handle, a.handle, &self.inner));
Ok(())
}
pub fn as_raw(&self) -> &cx_bn_mont_ctx_t {
&self.inner
}
pub fn as_raw_mut(&mut self) -> &mut cx_bn_mont_ctx_t {
&mut self.inner
}
}
impl Drop for MontCtx {
fn drop(&mut self) {
bn_release();
}
}
fn align_bn_size(n: usize) -> usize {
let align = CX_BN_WORD_ALIGNEMENT as usize;
(n + align - 1) & !(align - 1)
}
fn int_to_ordering(diff: c_int) -> Ordering {
match diff {
d if d < 0 => Ordering::Less,
0 => Ordering::Equal,
_ => Ordering::Greater,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_eq_err as assert_eq;
use crate::testing::TestType;
use testmacro::test_item as test;
fn err_to_unit(e: CxError) {
let ec = crate::testing::to_hex(e.into());
crate::log::info!(
"BN error: \x1b[1;33m{}\x1b[0m",
core::str::from_utf8(&ec).unwrap()
);
}
#[test]
fn bn_alloc_set_get_u32() {
let a = Bn::alloc(32).map_err(err_to_unit)?;
a.set_u32(12345).map_err(err_to_unit)?;
assert_eq!(a.get_u32().map_err(err_to_unit)?, 12345u32);
}
#[test]
fn bn_add_sub() {
let a = Bn::alloc(32).map_err(err_to_unit)?;
a.set_u32(100).map_err(err_to_unit)?;
let b = Bn::alloc(32).map_err(err_to_unit)?;
b.set_u32(42).map_err(err_to_unit)?;
let r = Bn::alloc(32).map_err(err_to_unit)?;
r.add(&a, &b).map_err(err_to_unit)?;
assert_eq!(r.get_u32().map_err(err_to_unit)?, 142u32);
r.sub(&a, &b).map_err(err_to_unit)?;
assert_eq!(r.get_u32().map_err(err_to_unit)?, 58u32);
}
#[test]
fn bn_cmp() {
let a = Bn::alloc(32).map_err(err_to_unit)?;
a.set_u32(10).map_err(err_to_unit)?;
let b = Bn::alloc(32).map_err(err_to_unit)?;
b.set_u32(20).map_err(err_to_unit)?;
assert_eq!(a.cmp_bn(&b).map_err(err_to_unit)?, Ordering::Less);
assert_eq!(b.cmp_bn(&a).map_err(err_to_unit)?, Ordering::Greater);
assert_eq!(a.cmp_u32(10).map_err(err_to_unit)?, Ordering::Equal);
}
#[test]
fn bn_shift_bits() {
let a = Bn::alloc(32).map_err(err_to_unit)?;
a.set_u32(0b1010).map_err(err_to_unit)?;
a.shl(1).map_err(err_to_unit)?;
assert_eq!(a.get_u32().map_err(err_to_unit)?, 0b10100u32);
a.shr(2).map_err(err_to_unit)?;
assert_eq!(a.get_u32().map_err(err_to_unit)?, 0b101u32);
}
}