#if HAVE_CONFIG_H
# include "config.h"
#endif
#include <assert.h>
#include "rsa.h"
#include "rsa-internal.h"
#include "gmp-glue.h"
#if !NETTLE_USE_MINI_GMP
#define MAX(a, b) ((a) > (b) ? (a) : (b))
static mp_size_t
sec_mul_itch (mp_size_t an, mp_size_t bn)
{
if (an >= bn)
return mpn_sec_mul_itch (an, bn);
else
return mpn_sec_mul_itch (bn, an);
}
static void
sec_mul (mp_limb_t *rp,
const mp_limb_t *ap, mp_size_t an,
const mp_limb_t *bp, mp_size_t bn, mp_limb_t *scratch)
{
if (an >= bn)
mpn_sec_mul (rp, ap, an, bp, bn, scratch);
else
mpn_sec_mul (rp, bp, bn, ap, an, scratch);
}
static mp_size_t
sec_mod_mul_itch (mp_size_t an, mp_size_t bn, mp_size_t mn)
{
mp_size_t mul_itch = sec_mul_itch (an, bn);
mp_size_t mod_itch = mpn_sec_div_r_itch (an + bn, mn);
return MAX(mul_itch, mod_itch);
}
static void
sec_mod_mul (mp_limb_t *rp,
const mp_limb_t *ap, mp_size_t an,
const mp_limb_t *bp, mp_size_t bn,
const mp_limb_t *mp, mp_size_t mn,
mp_limb_t *scratch)
{
assert (an + bn >= mn);
sec_mul (rp, ap, an, bp, bn, scratch);
mpn_sec_div_r (rp, an + bn, mp, mn, scratch);
}
static mp_size_t
sec_powm_itch (mp_size_t bn, mp_size_t en, mp_size_t mn)
{
mp_size_t mod_itch = bn + mpn_sec_div_r_itch (bn, mn);
mp_size_t pow_itch = mn + mpn_sec_powm_itch (mn, en * GMP_NUMB_BITS, mn);
return MAX (mod_itch, pow_itch);
}
static void
sec_powm (mp_limb_t *rp,
const mp_limb_t *bp, mp_size_t bn,
const mp_limb_t *ep, mp_size_t en,
const mp_limb_t *mp, mp_size_t mn, mp_limb_t *scratch)
{
assert (bn >= mn);
assert (en <= mn);
mpn_copyi (scratch, bp, bn);
mpn_sec_div_r (scratch, bn, mp, mn, scratch + bn);
mpn_sec_powm (rp, scratch, mn, ep, en * GMP_NUMB_BITS, mp, mn,
scratch + mn);
}
mp_size_t
_rsa_sec_compute_root_itch (const struct rsa_private_key *key)
{
mp_size_t nn = NETTLE_OCTET_SIZE_TO_LIMB_SIZE (key->size);
mp_size_t pn = mpz_size (key->p);
mp_size_t qn = mpz_size (key->q);
mp_size_t an = mpz_size (key->a);
mp_size_t bn = mpz_size (key->b);
mp_size_t cn = mpz_size (key->c);
mp_size_t powm_p_itch = sec_powm_itch (nn, an, pn);
mp_size_t powm_q_itch = sec_powm_itch (nn, bn, qn);
mp_size_t mod_mul_itch = cn + MAX(pn, qn)
+ sec_mod_mul_itch (MAX(pn, qn), cn, pn);
mp_size_t mul_itch = sec_mul_itch (qn, pn);
mp_size_t add_1_itch = mpn_sec_add_1_itch (nn - qn);
mp_size_t itch = pn + qn + MAX (mul_itch, add_1_itch);
itch = MAX (itch, powm_p_itch);
itch = MAX (itch, powm_q_itch);
itch = MAX (itch, mod_mul_itch);
return pn + qn + itch;
}
void
_rsa_sec_compute_root (const struct rsa_private_key *key,
mp_limb_t *rp, const mp_limb_t *mp,
mp_limb_t *scratch)
{
mp_size_t nn = NETTLE_OCTET_SIZE_TO_LIMB_SIZE (key->size);
const mp_limb_t *pp = mpz_limbs_read (key->p);
const mp_limb_t *qp = mpz_limbs_read (key->q);
mp_size_t pn = mpz_size (key->p);
mp_size_t qn = mpz_size (key->q);
mp_size_t an = mpz_size (key->a);
mp_size_t bn = mpz_size (key->b);
mp_size_t cn = mpz_size (key->c);
mp_limb_t *r_mod_p = scratch;
mp_limb_t *r_mod_q = scratch + pn;
mp_limb_t *scratch_out = r_mod_q + qn;
mp_limb_t cy;
assert (pn <= nn);
assert (qn <= nn);
assert (an <= pn);
assert (bn <= qn);
assert (cn <= pn);
sec_powm (r_mod_p, mp, nn, mpz_limbs_read (key->a), an, pp, pn, scratch_out);
sec_powm (r_mod_q, mp, nn, mpz_limbs_read (key->b), bn, qp, qn, scratch_out);
sec_mod_mul (scratch_out, r_mod_p, pn, mpz_limbs_read (key->c), cn, pp, pn,
scratch_out + cn + pn);
mpn_copyi (r_mod_p, scratch_out, pn);
sec_mod_mul (scratch_out, r_mod_q, qn, mpz_limbs_read (key->c), cn, pp, pn,
scratch_out + cn + qn);
cy = mpn_sub_n (r_mod_p, r_mod_p, scratch_out, pn);
cnd_add_n (cy, r_mod_p, pp, pn);
sec_mul (scratch_out, qp, qn, r_mod_p, pn, scratch_out + pn + qn);
cy = mpn_add_n (rp, scratch_out, r_mod_q, qn);
mpn_sec_add_1 (rp + qn, scratch_out + qn, nn - qn, cy, scratch_out + pn + qn);
}
#endif