// Written for Graviola by Joe Birr-Pixton, 2024.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0
use crate::error::Error;
use crate::low;
#[derive(Clone, Debug)]
pub(crate) struct RsaPublicKey {
pub(crate) n: RsaPosInt,
pub(crate) e: u32,
montifier: RsaPosInt,
one: RsaPosInt,
n0: u64,
}
impl RsaPublicKey {
pub(crate) fn new(n: RsaPosInt, e: u32) -> Result<Self, Error> {
let n_len = n.len_bytes();
if n.is_even()
|| !(MIN_PUBLIC_MODULUS_BYTES..=MAX_PUBLIC_MODULUS_BYTES).contains(&n_len)
|| e == 0
{
return Err(Error::OutOfRange);
}
// determine M^2 mod n
let montifier = n.montifier();
// and its inverse such that n * n0 == -1 (mod 2^64)
let n0 = n.mont_neg_inverse();
// and just M
let one = n.fixed_one().mont_mul(&montifier, &n, n0);
Ok(Self {
n,
e,
montifier,
one,
n0,
})
}
pub(crate) fn modulus_len_bytes(&self) -> usize {
self.n.len_bytes()
}
/// m = c ** e mod n
pub(crate) fn public_op(&self, mut c: RsaPosInt) -> Result<RsaPosInt, Error> {
if !c.less_than(&self.n) {
return Err(Error::OutOfRange);
}
c.expand(&self.n);
// bring c into montgomery domain, c_mont = c * M^2 * M^-1 mod n
let c_mont = c.to_montgomery(&self.montifier, &self.n);
// accumulator is 1 * 1 in montgomery domain, ie, just M
let mut accum = self.one.clone();
let mut first = true;
for bit in (0..self.e.ilog2() + 1).rev() {
let tmp = if first {
// avoid pointless squaring of multiplicative identity
first = false;
accum
} else {
accum.mont_sqr(&self.n, self.n0)
};
let mask = 1 << bit;
if self.e & mask == mask {
accum = tmp.mont_mul(&c_mont, &self.n, self.n0);
} else {
accum = tmp;
}
}
// drop accumulator out of montgomery domain
Ok(accum.from_montgomery(&self.n))
}
}
const MAX_PUBLIC_MODULUS_BITS: usize = 8192;
const MAX_PUBLIC_MODULUS_WORDS: usize = MAX_PUBLIC_MODULUS_BITS / 64;
pub(crate) const MAX_PUBLIC_MODULUS_BYTES: usize = MAX_PUBLIC_MODULUS_BITS / 8;
const MIN_PUBLIC_MODULUS_BITS: usize = 2048;
const MIN_PUBLIC_MODULUS_BYTES: usize = MIN_PUBLIC_MODULUS_BITS / 8;
type RsaPosInt = low::PosInt<MAX_PUBLIC_MODULUS_WORDS>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn smoke() {
let n = RsaPosInt::from_bytes(b"\xe4\x46\x29\x68\xe3\xe2\x9c\xe7\x3b\xe8\xac\xda\xf9\xd5\x92\xbe\x99\x04\x36\x3a\xef\x33\x99\xf7\x93\xb9\x17\x13\x42\x9c\xea\xf9\x63\xa1\xe5\xc6\xbb\x57\x71\x4c\xc1\x46\x01\xec\xac\x5a\xe5\xb8\x95\x43\xaa\xfa\x68\x3d\x50\x73\x87\xfc\x83\x04\x66\x1f\xab\x1e\x0c\x6e\xf0\x32\x50\x63\x21\xc6\x74\xec\xe4\xf6\x7a\xb2\x94\xbe\xae\x81\x66\x3e\x1a\xa6\x98\xcd\x5b\x78\x2c\x7b\xf4\xdf\x39\x76\xf1\x5e\x88\xda\xa2\xe0\xe8\x2e\xb5\x83\xdb\x1b\x56\xe4\x6b\x6f\x4e\x3c\xde\x9f\x00\x7e\x3b\x8f\x8f\x5c\xb8\x55\x04\x22\xea\x1f\x6d\x92\xe1\x08\x76\x2a\x68\xc5\x35\xd2\x37\x9a\x54\xdc\xf7\x4f\x19\x38\xdb\x77\x02\xd9\xf9\x72\x4d\x7f\x98\xa5\xe3\x7c\xef\x06\xc7\xb0\x3f\x58\xbc\x9d\x38\x72\x8a\xac\x18\x03\xb9\xee\x60\xe7\x6e\x18\xf6\x90\x87\xb3\x8a\x5f\xbb\x95\xd0\x99\x09\x5b\x2c\xda\x4b\xd7\x88\xaa\x2a\x05\x07\x38\xae\xf6\xa1\x6e\x93\x00\x1f\xc3\x6b\xb4\xdc\x6b\xc1\xc6\x06\x1e\x34\x9c\x5b\x2b\xd6\x50\x5d\x64\xd9\x05\xdb\x95\xa0\xe1\x2c\xb3\xb1\x5b\xa4\x90\xa2\xa7\xcc\xbf\x10\xaf\x12\xe3\x16\xb3\xde\xc5\x4f\xb1\xb6\x63\x68\xd8\xd9\xb1").unwrap();
let c = RsaPosInt::from_bytes(b"\x00\x0b\x36\xb5\xc6\xd9\x32\xd0\x18\xa6\x31\x99\x82\xf6\xba\x83\xd5\x1b\xb6\xdb\x84\x99\x87\xc0\xe9\x8f\x06\x63\xac\x8d\xe4\x43\xb0\x45\xd3\x01\x3e\x03\xba\xed\xd0\xa9\xc6\x49\x08\x63\x22\x29\x0f\x1f\xf3\x25\xef\xfe\x65\xff\x27\xf2\x5d\xc6\xe7\x79\xe9\x5f\xd2\xf5\x09\x0c\x28\xfe\xe5\x6c\x75\x24\x0a\x79\xe4\xf6\x9e\x2b\x5b\x52\x71\xb6\x22\xd8\x08\x97\xea\xbd\x4b\x06\x53\xa6\x2e\xb9\x26\x91\x0f\xc7\x34\xa4\x5d\x3b\x9d\x23\xc0\x10\xf8\x82\xa7\xbb\x8c\x50\x35\x7d\x44\x9d\x14\x00\xcf\x5a\xe0\x92\xeb\x83\x60\x9a\x48\xbc\xac\xe0\x20\xd7\x44\xc9\xe7\xf7\x66\x25\x04\x0e\xa9\x20\x9c\xb6\x23\x02\x8f\x2b\xa3\x86\xfa\x23\x4e\xdd\xe9\xf8\xc8\xa4\x63\x65\x4c\x9d\x52\x24\x4a\x0d\x0a\xd6\x2d\x94\x95\x64\x45\xaa\xf9\xf5\x26\x8b\xf7\x21\xf7\x6a\xf9\x19\x46\xbc\x2e\xeb\x2a\xaf\x0f\x31\x2f\x27\x86\x4e\xd4\x2e\xf7\xbc\x0f\x14\xce\x75\xef\x93\xad\x3a\x84\x3a\xb3\x29\x6f\xe9\xd7\x33\xd8\x6c\xbe\x20\x11\xf3\x92\x3c\x16\x78\x0b\xc4\x79\xaa\x8d\xeb\xb1\xd1\xe2\xda\xf3\xd7\x43\x92\x72\x8c\x81\x52\x3d\xf1\xc9\x7e\x7c\xfd\x0e\xb2\x02\x84\x51").unwrap();
let k = RsaPublicKey::new(n, 0x10001).unwrap();
let m = k.public_op(c).unwrap();
println!("m = {m:016x?}");
let mut mb = [0; 256];
let mb = m.to_bytes(&mut mb).unwrap();
println!("m = {mb:02x?}");
}
}