#include "seal/util/common.h"
#include "seal/util/numth.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/util/rns.h"
#include "seal/util/uintarithmod.h"
#include "seal/util/uintarithsmallmod.h"
#include <algorithm>
using namespace std;
namespace seal
{
namespace util
{
RNSBase::RNSBase(const vector<Modulus> &rnsbase, MemoryPoolHandle pool)
: pool_(move(pool)), size_(rnsbase.size())
{
if (!size_)
{
throw invalid_argument("rnsbase cannot be empty");
}
if (!pool_)
{
throw invalid_argument("pool is uninitialized");
}
for (size_t i = 0; i < rnsbase.size(); i++)
{
if (rnsbase[i].is_zero())
{
throw invalid_argument("rnsbase is invalid");
}
for (size_t j = 0; j < i; j++)
{
if (!are_coprime(rnsbase[i].value(), rnsbase[j].value()))
{
throw invalid_argument("rnsbase is invalid");
}
}
}
base_ = allocate<Modulus>(size_, pool_);
copy_n(rnsbase.cbegin(), size_, base_.get());
if (!initialize())
{
throw invalid_argument("rnsbase is invalid");
}
}
RNSBase::RNSBase(const RNSBase ©, MemoryPoolHandle pool) : pool_(move(pool)), size_(copy.size_)
{
if (!pool_)
{
throw invalid_argument("pool is uninitialized");
}
base_ = allocate<Modulus>(size_, pool_);
copy_n(copy.base_.get(), size_, base_.get());
base_prod_ = allocate_uint(size_, pool_);
set_uint(copy.base_prod_.get(), size_, base_prod_.get());
punctured_prod_array_ = allocate_uint(size_ * size_, pool_);
set_uint(copy.punctured_prod_array_.get(), size_ * size_, punctured_prod_array_.get());
inv_punctured_prod_mod_base_array_ = allocate<MultiplyUIntModOperand>(size_, pool_);
copy_n(copy.inv_punctured_prod_mod_base_array_.get(), size_, inv_punctured_prod_mod_base_array_.get());
}
bool RNSBase::contains(const Modulus &value) const noexcept
{
bool result = false;
SEAL_ITERATE(iter(base_), size_, [&](auto &I) { result = result || (I == value); });
return result;
}
bool RNSBase::is_subbase_of(const RNSBase &superbase) const noexcept
{
bool result = true;
SEAL_ITERATE(iter(base_), size_, [&](auto &I) { result = result && superbase.contains(I); });
return result;
}
RNSBase RNSBase::extend(const Modulus &value) const
{
if (value.is_zero())
{
throw invalid_argument("value cannot be zero");
}
SEAL_ITERATE(iter(base_), size_, [&](auto I) {
if (!are_coprime(I.value(), value.value()))
{
throw logic_error("cannot extend by given value");
}
});
RNSBase newbase(pool_);
newbase.size_ = add_safe(size_, size_t(1));
newbase.base_ = allocate<Modulus>(newbase.size_, newbase.pool_);
copy_n(base_.get(), size_, newbase.base_.get());
newbase.base_[newbase.size_ - 1] = value;
if (!newbase.initialize())
{
throw logic_error("cannot extend by given value");
}
return newbase;
}
RNSBase RNSBase::extend(const RNSBase &other) const
{
for (size_t i = 0; i < other.size_; i++)
{
for (size_t j = 0; j < size_; j++)
{
if (!are_coprime(other[i].value(), base_[j].value()))
{
throw invalid_argument("rnsbase is invalid");
}
}
}
RNSBase newbase(pool_);
newbase.size_ = add_safe(size_, other.size_);
newbase.base_ = allocate<Modulus>(newbase.size_, newbase.pool_);
copy_n(base_.get(), size_, newbase.base_.get());
copy_n(other.base_.get(), other.size_, newbase.base_.get() + size_);
if (!newbase.initialize())
{
throw logic_error("cannot extend by given base");
}
return newbase;
}
RNSBase RNSBase::drop() const
{
if (size_ == 1)
{
throw logic_error("cannot drop from base of size 1");
}
RNSBase newbase(pool_);
newbase.size_ = size_ - 1;
newbase.base_ = allocate<Modulus>(newbase.size_, newbase.pool_);
copy_n(base_.get(), size_ - 1, newbase.base_.get());
newbase.initialize();
return newbase;
}
RNSBase RNSBase::drop(const Modulus &value) const
{
if (size_ == 1)
{
throw logic_error("cannot drop from base of size 1");
}
if (!contains(value))
{
throw logic_error("base does not contain value");
}
RNSBase newbase(pool_);
newbase.size_ = size_ - 1;
newbase.base_ = allocate<Modulus>(newbase.size_, newbase.pool_);
size_t source_index = 0;
size_t dest_index = 0;
while (dest_index < size_ - 1)
{
if (base_[source_index] != value)
{
newbase.base_[dest_index] = base_[source_index];
dest_index++;
}
source_index++;
}
newbase.initialize();
return newbase;
}
bool RNSBase::initialize()
{
if (!product_fits_in(size_, size_))
{
return false;
}
base_prod_ = allocate_uint(size_, pool_);
punctured_prod_array_ = allocate_zero_uint(size_ * size_, pool_);
inv_punctured_prod_mod_base_array_ = allocate<MultiplyUIntModOperand>(size_, pool_);
if (size_ > 1)
{
auto rnsbase_values = allocate<uint64_t>(size_, pool_);
SEAL_ITERATE(iter(base_, rnsbase_values), size_, [&](auto I) { get<1>(I) = get<0>(I).value(); });
StrideIter<uint64_t *> punctured_prod(punctured_prod_array_.get(), size_);
SEAL_ITERATE(iter(punctured_prod, size_t(0)), size_, [&](auto I) {
multiply_many_uint64_except(rnsbase_values.get(), size_, get<1>(I), get<0>(I).ptr(), pool_);
});
auto temp_mpi(allocate_uint(size_, pool_));
multiply_uint(punctured_prod_array_.get(), size_, base_[0].value(), size_, temp_mpi.get());
set_uint(temp_mpi.get(), size_, base_prod_.get());
bool invertible = true;
SEAL_ITERATE(iter(punctured_prod, base_, inv_punctured_prod_mod_base_array_), size_, [&](auto I) {
uint64_t temp = modulo_uint(get<0>(I), size_, get<1>(I));
invertible = invertible && try_invert_uint_mod(temp, get<1>(I), temp);
get<2>(I).set(temp, get<1>(I));
});
return invertible;
}
base_prod_[0] = base_[0].value();
punctured_prod_array_[0] = 1;
inv_punctured_prod_mod_base_array_[0].set(1, base_[0]);
return true;
}
void RNSBase::decompose(uint64_t *value, MemoryPoolHandle pool) const
{
if (!value)
{
throw invalid_argument("value cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
if (size_ > 1)
{
auto value_copy(allocate_uint(size_, pool));
set_uint(value, size_, value_copy.get());
SEAL_ITERATE(iter(value, base_), size_, [&](auto I) {
get<0>(I) = modulo_uint(value_copy.get(), size_, get<1>(I));
});
}
}
void RNSBase::decompose_array(uint64_t *value, size_t count, MemoryPoolHandle pool) const
{
if (!value)
{
throw invalid_argument("value cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
if (size_ > 1)
{
if (!product_fits_in(count, size_))
{
throw logic_error("invalid parameters");
}
SEAL_ALLOCATE_GET_STRIDE_ITER(value_copy, uint64_t, count, size_, pool);
set_uint(value, count * size_, value_copy);
RNSIter value_out(value, count);
SEAL_ITERATE(iter(base_, value_out), size_, [&](auto I) {
SEAL_ITERATE(iter(get<1>(I), value_copy), count, [&](auto J) {
get<0>(J) = modulo_uint(get<1>(J), size_, get<0>(I));
});
});
}
}
void RNSBase::compose(uint64_t *value, MemoryPoolHandle pool) const
{
if (!value)
{
throw invalid_argument("value cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
if (size_ > 1)
{
auto temp_value(allocate_uint(size_, pool));
set_uint(value, size_, temp_value.get());
set_zero_uint(size_, value);
StrideIter<uint64_t *> punctured_prod(punctured_prod_array_.get(), size_);
auto temp_mpi(allocate_uint(size_, pool));
SEAL_ITERATE(
iter(temp_value, inv_punctured_prod_mod_base_array_, punctured_prod, base_), size_, [&](auto I) {
uint64_t temp_prod = multiply_uint_mod(get<0>(I), get<1>(I), get<3>(I));
multiply_uint(get<2>(I), size_, temp_prod, size_, temp_mpi.get());
add_uint_uint_mod(temp_mpi.get(), value, base_prod_.get(), size_, value);
});
}
}
void RNSBase::compose_array(uint64_t *value, size_t count, MemoryPoolHandle pool) const
{
if (!value)
{
throw invalid_argument("value cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
if (size_ > 1)
{
if (!product_fits_in(count, size_))
{
throw logic_error("invalid parameters");
}
auto temp_array(allocate_uint(count * size_, pool));
for (size_t i = 0; i < count; i++)
{
for (size_t j = 0; j < size_; j++)
{
temp_array[j + (i * size_)] = value[(j * count) + i];
}
}
set_zero_uint(count * size_, value);
StrideIter<uint64_t *> temp_array_iter(temp_array.get(), size_);
StrideIter<uint64_t *> value_iter(value, size_);
StrideIter<uint64_t *> punctured_prod(punctured_prod_array_.get(), size_);
auto temp_mpi(allocate_uint(size_, pool));
SEAL_ITERATE(iter(temp_array_iter, value_iter), count, [&](auto I) {
SEAL_ITERATE(
iter(get<0>(I), inv_punctured_prod_mod_base_array_, punctured_prod, base_), size_, [&](auto J) {
uint64_t temp_prod = multiply_uint_mod(get<0>(J), get<1>(J), get<3>(J));
multiply_uint(get<2>(J), size_, temp_prod, size_, temp_mpi.get());
add_uint_uint_mod(temp_mpi.get(), get<1>(I), base_prod_.get(), size_, get<1>(I));
});
});
}
}
void BaseConverter::fast_convert(ConstCoeffIter in, CoeffIter out, MemoryPoolHandle pool) const
{
size_t ibase_size = ibase_.size();
size_t obase_size = obase_.size();
SEAL_ALLOCATE_GET_COEFF_ITER(temp, ibase_size, pool);
SEAL_ITERATE(
iter(temp, in, ibase_.inv_punctured_prod_mod_base_array(), ibase_.base()), ibase_size,
[&](auto I) { get<0>(I) = multiply_uint_mod(get<1>(I), get<2>(I), get<3>(I)); });
SEAL_ITERATE(iter(out, base_change_matrix_, obase_.base()), obase_size, [&](auto I) {
get<0>(I) = dot_product_mod(temp, get<1>(I).get(), ibase_size, get<2>(I));
});
}
void BaseConverter::fast_convert_array(ConstRNSIter in, RNSIter out, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (in.poly_modulus_degree() != out.poly_modulus_degree())
{
throw invalid_argument("in and out are incompatible");
}
#endif
size_t ibase_size = ibase_.size();
size_t obase_size = obase_.size();
size_t count = in.poly_modulus_degree();
SEAL_ALLOCATE_GET_STRIDE_ITER(temp, uint64_t, count, ibase_size, pool);
SEAL_ITERATE(
iter(in, ibase_.inv_punctured_prod_mod_base_array(), ibase_.base(), size_t(0)), ibase_size,
[&](auto I) {
size_t ibase_index = get<3>(I);
if (get<1>(I).operand == 1)
{
SEAL_ITERATE(iter(get<0>(I), temp), count, [&](auto J) {
get<1>(J)[ibase_index] = barrett_reduce_64(get<0>(J), get<2>(I));
});
}
else
{
SEAL_ITERATE(iter(get<0>(I), temp), count, [&](auto J) {
get<1>(J)[ibase_index] = multiply_uint_mod(get<0>(J), get<1>(I), get<2>(I));
});
}
});
SEAL_ITERATE(iter(out, base_change_matrix_, obase_.base()), obase_size, [&](auto I) {
SEAL_ITERATE(iter(get<0>(I), temp), count, [&](auto J) {
get<0>(J) = dot_product_mod(get<1>(J), get<1>(I).get(), ibase_size, get<2>(I));
});
});
}
void BaseConverter::exact_convert_array(ConstRNSIter in, CoeffIter out, MemoryPoolHandle pool) const
{
size_t ibase_size = ibase_.size();
size_t obase_size = obase_.size();
size_t count = in.poly_modulus_degree();
if (obase_size != 1)
{
throw invalid_argument("out base in exact_convert_array must be one.");
}
SEAL_ALLOCATE_GET_STRIDE_ITER(temp, uint64_t, count, ibase_size, pool);
SEAL_ALLOCATE_GET_STRIDE_ITER(v, double_t, count, ibase_size, pool);
SEAL_ALLOCATE_GET_PTR_ITER(aggregated_rounded_v, uint64_t, count, pool);
SEAL_ITERATE(
iter(in, ibase_.inv_punctured_prod_mod_base_array(), ibase_.base(), size_t(0)), ibase_size,
[&](auto I) {
size_t ibase_index = get<3>(I);
double_t divisor = static_cast<double_t>(get<2>(I).value());
if (get<1>(I).operand == 1)
{
SEAL_ITERATE(iter(get<0>(I), temp, v), count, [&](auto J) {
get<1>(J)[ibase_index] = barrett_reduce_64(get<0>(J), get<2>(I));
double_t dividend = static_cast<double_t>(get<1>(J)[ibase_index]);
get<2>(J)[ibase_index] = dividend / divisor;
});
}
else
{
SEAL_ITERATE(iter(get<0>(I), temp, v), count, [&](auto J) {
get<1>(J)[ibase_index] = multiply_uint_mod(get<0>(J), get<1>(I), get<2>(I));
double_t dividend = static_cast<double_t>(get<1>(J)[ibase_index]);
get<2>(J)[ibase_index] = dividend / divisor;
});
}
});
SEAL_ITERATE(iter(v, aggregated_rounded_v), count, [&](auto I) {
double_t aggregated_v = 0.0;
for (size_t i = 0; i < ibase_size; ++i)
{
aggregated_v += get<0>(I)[i];
}
aggregated_v += 0.5;
get<1>(I) = static_cast<uint64_t>(aggregated_v);
});
auto p = obase_.base()[0];
auto q_mod_p = modulo_uint(ibase_.base_prod(), ibase_size, p);
auto base_change_matrix_first = base_change_matrix_[0].get();
SEAL_ITERATE(iter(out, temp, aggregated_rounded_v), count, [&](auto J) {
auto sum_mod_obase = dot_product_mod(get<1>(J), base_change_matrix_first, ibase_size, p);
auto v_q_mod_p = multiply_uint_mod(get<2>(J), q_mod_p, p);
get<0>(J) = sub_uint_mod(sum_mod_obase, v_q_mod_p, p);
});
}
void BaseConverter::initialize()
{
if (!product_fits_in(ibase_.size(), obase_.size()))
{
throw logic_error("invalid parameters");
}
base_change_matrix_ = allocate<Pointer<uint64_t>>(obase_.size(), pool_);
SEAL_ITERATE(iter(base_change_matrix_, obase_.base()), obase_.size(), [&](auto I) {
get<0>(I) = allocate_uint(ibase_.size(), pool_);
StrideIter<const uint64_t *> ibase_punctured_prod_array(ibase_.punctured_prod_array(), ibase_.size());
SEAL_ITERATE(iter(get<0>(I), ibase_punctured_prod_array), ibase_.size(), [&](auto J) {
get<0>(J) = modulo_uint(get<1>(J), ibase_.size(), get<1>(I));
});
});
}
RNSTool::RNSTool(
size_t poly_modulus_degree, const RNSBase &coeff_modulus, const Modulus &plain_modulus,
MemoryPoolHandle pool)
: pool_(move(pool))
{
#ifdef SEAL_DEBUG
if (!pool_)
{
throw invalid_argument("pool is uninitialized");
}
#endif
initialize(poly_modulus_degree, coeff_modulus, plain_modulus);
}
void RNSTool::initialize(size_t poly_modulus_degree, const RNSBase &q, const Modulus &t)
{
if (q.size() < SEAL_COEFF_MOD_COUNT_MIN || q.size() > SEAL_COEFF_MOD_COUNT_MAX)
{
throw invalid_argument("rnsbase is invalid");
}
int coeff_count_power = get_power_of_two(poly_modulus_degree);
if (coeff_count_power < 0 || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX ||
poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN)
{
throw invalid_argument("poly_modulus_degree is invalid");
}
t_ = t;
coeff_count_ = poly_modulus_degree;
size_t base_q_size = q.size();
int total_coeff_bit_count = get_significant_bit_count_uint(q.base_prod(), q.size());
size_t base_B_size = base_q_size;
if (32 + t_.bit_count() + total_coeff_bit_count >=
SEAL_INTERNAL_MOD_BIT_COUNT * safe_cast<int>(base_q_size) + SEAL_INTERNAL_MOD_BIT_COUNT)
{
base_B_size++;
}
size_t base_Bsk_size = add_safe(base_B_size, size_t(1));
size_t base_Bsk_m_tilde_size = add_safe(base_Bsk_size, size_t(1));
size_t base_t_gamma_size = 0;
if (!product_fits_in(coeff_count_, base_Bsk_m_tilde_size))
{
throw logic_error("invalid parameters");
}
auto baseconv_primes =
get_primes(mul_safe(size_t(2), coeff_count_), SEAL_INTERNAL_MOD_BIT_COUNT, base_Bsk_m_tilde_size);
auto baseconv_primes_iter = baseconv_primes.cbegin();
m_sk_ = *baseconv_primes_iter++;
gamma_ = *baseconv_primes_iter++;
vector<Modulus> base_B_primes;
copy_n(baseconv_primes_iter, base_B_size, back_inserter(base_B_primes));
m_tilde_ = uint64_t(1) << 32;
base_q_ = allocate<RNSBase>(pool_, q, pool_);
base_B_ = allocate<RNSBase>(pool_, base_B_primes, pool_);
base_Bsk_ = allocate<RNSBase>(pool_, base_B_->extend(m_sk_));
base_Bsk_m_tilde_ = allocate<RNSBase>(pool_, base_Bsk_->extend(m_tilde_));
if (!t_.is_zero())
{
base_t_gamma_size = 2;
base_t_gamma_ = allocate<RNSBase>(pool_, vector<Modulus>{ t_, gamma_ }, pool_);
}
try
{
CreateNTTTables(
coeff_count_power, vector<Modulus>(base_Bsk_->base(), base_Bsk_->base() + base_Bsk_size),
base_Bsk_ntt_tables_, pool_);
}
catch (const logic_error &)
{
throw logic_error("invalid rns bases");
}
if (!t_.is_zero())
{
base_q_to_t_conv_ = allocate<BaseConverter>(pool_, *base_q_, RNSBase({ t_ }, pool_), pool_);
}
base_q_to_Bsk_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_Bsk_, pool_);
base_q_to_m_tilde_conv_ = allocate<BaseConverter>(pool_, *base_q_, RNSBase({ m_tilde_ }, pool_), pool_);
base_B_to_q_conv_ = allocate<BaseConverter>(pool_, *base_B_, *base_q_, pool_);
base_B_to_m_sk_conv_ = allocate<BaseConverter>(pool_, *base_B_, RNSBase({ m_sk_ }, pool_), pool_);
if (base_t_gamma_)
{
base_q_to_t_gamma_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_t_gamma_, pool_);
}
prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
});
uint64_t temp;
inv_prod_q_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
for (size_t i = 0; i < base_Bsk_size; i++)
{
temp = modulo_uint(base_q_->base_prod(), base_q_size, (*base_Bsk_)[i]);
if (!try_invert_uint_mod(temp, (*base_Bsk_)[i], temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_q_mod_Bsk_[i].set(temp, (*base_Bsk_)[i]);
}
temp = modulo_uint(base_B_->base_prod(), base_B_size, m_sk_);
if (!try_invert_uint_mod(temp, m_sk_, temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_B_mod_m_sk_.set(temp, m_sk_);
inv_m_tilde_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
SEAL_ITERATE(iter(inv_m_tilde_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
if (!try_invert_uint_mod(barrett_reduce_64(m_tilde_.value(), get<1>(I)), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
temp = modulo_uint(base_q_->base_prod(), base_q_size, m_tilde_);
if (!try_invert_uint_mod(temp, m_tilde_, temp))
{
throw logic_error("invalid rns bases");
}
neg_inv_prod_q_mod_m_tilde_.set(negate_uint_mod(temp, m_tilde_), m_tilde_);
prod_q_mod_Bsk_ = allocate_uint(base_Bsk_size, pool_);
SEAL_ITERATE(iter(prod_q_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
get<0>(I) = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
});
if (base_t_gamma_)
{
if (!try_invert_uint_mod(barrett_reduce_64(gamma_.value(), t_), t_, temp))
{
throw logic_error("invalid rns bases");
}
inv_gamma_mod_t_.set(temp, t_);
prod_t_gamma_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size, pool_);
SEAL_ITERATE(iter(prod_t_gamma_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I).set(
multiply_uint_mod((*base_t_gamma_)[0].value(), (*base_t_gamma_)[1].value(), get<1>(I)),
get<1>(I));
});
neg_inv_q_mod_t_gamma_ = allocate<MultiplyUIntModOperand>(base_t_gamma_size, pool_);
SEAL_ITERATE(iter(neg_inv_q_mod_t_gamma_, base_t_gamma_->base()), base_t_gamma_size, [&](auto I) {
get<0>(I).operand = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
if (!try_invert_uint_mod(get<0>(I).operand, get<1>(I), get<0>(I).operand))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(negate_uint_mod(get<0>(I).operand, get<1>(I)), get<1>(I));
});
}
inv_q_last_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size - 1, pool_);
SEAL_ITERATE(iter(inv_q_last_mod_q_, base_q_->base()), base_q_size - 1, [&](auto I) {
if (!try_invert_uint_mod((*base_q_)[base_q_size - 1].value(), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
if (t_.value() != 0)
{
if (!try_invert_uint_mod(base_q_->base()[base_q_size - 1].value(), t_, inv_q_last_mod_t_))
{
throw logic_error("invalid rns bases");
}
q_last_mod_t_ = barrett_reduce_64(base_q_->base()[base_q_size - 1].value(), t_);
}
}
void RNSTool::divide_and_round_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (!input)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
CoeffIter last_input = input[base_q_size - 1];
Modulus last_modulus = (*base_q_)[base_q_size - 1];
uint64_t half = last_modulus.value() >> 1;
add_poly_scalar_coeffmod(last_input, coeff_count_, half, last_modulus, last_input);
SEAL_ALLOCATE_GET_COEFF_ITER(temp, coeff_count_, pool);
SEAL_ITERATE(iter(input, inv_q_last_mod_q_, base_q_->base()), base_q_size - 1, [&](auto I) {
modulo_poly_coeffs(last_input, coeff_count_, get<2>(I), temp);
uint64_t half_mod = barrett_reduce_64(half, get<2>(I));
sub_poly_scalar_coeffmod(temp, coeff_count_, half_mod, get<2>(I), temp);
sub_poly_coeffmod(get<0>(I), temp, coeff_count_, get<2>(I), get<0>(I));
multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<1>(I), get<2>(I), get<0>(I));
});
}
void RNSTool::divide_and_round_q_last_ntt_inplace(
RNSIter input, ConstNTTTablesIter rns_ntt_tables, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (!input)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!rns_ntt_tables)
{
throw invalid_argument("rns_ntt_tables cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
CoeffIter last_input = input[base_q_size - 1];
inverse_ntt_negacyclic_harvey(last_input, rns_ntt_tables[base_q_size - 1]);
Modulus last_modulus = (*base_q_)[base_q_size - 1];
uint64_t half = last_modulus.value() >> 1;
add_poly_scalar_coeffmod(last_input, coeff_count_, half, last_modulus, last_input);
SEAL_ALLOCATE_GET_COEFF_ITER(temp, coeff_count_, pool);
SEAL_ITERATE(iter(input, inv_q_last_mod_q_, base_q_->base(), rns_ntt_tables), base_q_size - 1, [&](auto I) {
if (get<2>(I).value() < last_modulus.value())
{
modulo_poly_coeffs(last_input, coeff_count_, get<2>(I), temp);
}
else
{
set_uint(last_input, coeff_count_, temp);
}
uint64_t neg_half_mod = get<2>(I).value() - barrett_reduce_64(half, get<2>(I));
SEAL_ITERATE(temp, coeff_count_, [&](auto &J) { J += neg_half_mod; });
#if SEAL_USER_MOD_BIT_COUNT_MAX <= 60
uint64_t qi_lazy = get<2>(I).value() << 2;
ntt_negacyclic_harvey_lazy(temp, get<3>(I));
#else
uint64_t qi_lazy = get<2>(I).value() << 1;
ntt_negacyclic_harvey_lazy(temp, get<3>(I));
SEAL_ITERATE(temp, coeff_count_, [&](auto &J) {
J -= (qi_lazy & static_cast<uint64_t>(-static_cast<int64_t>(J >= qi_lazy)));
});
#endif
SEAL_ITERATE(iter(get<0>(I), temp), coeff_count_, [&](auto J) { get<0>(J) += qi_lazy - get<1>(J); });
multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<1>(I), get<2>(I), get<0>(I));
});
}
void RNSTool::fastbconv_sk(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (!input)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!destination)
{
throw invalid_argument("destination cannot be null");
}
if (destination.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("destination is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
size_t base_B_size = base_B_->size();
base_B_to_q_conv_->fast_convert_array(input, destination, pool);
SEAL_ALLOCATE_GET_COEFF_ITER(temp, coeff_count_, pool);
base_B_to_m_sk_conv_->fast_convert_array(input, RNSIter(temp, coeff_count_), pool);
SEAL_ALLOCATE_GET_COEFF_ITER(alpha_sk, coeff_count_, pool);
SEAL_ITERATE(iter(alpha_sk, temp, input[base_B_size]), coeff_count_, [&](auto I) {
get<0>(I) = multiply_uint_mod(get<1>(I) + (m_sk_.value() - get<2>(I)), inv_prod_B_mod_m_sk_, m_sk_);
});
const uint64_t m_sk_div_2 = m_sk_.value() >> 1;
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base(), destination), base_q_size, [&](auto I) {
MultiplyUIntModOperand prod_B_mod_q_elt;
prod_B_mod_q_elt.set(get<0>(I), get<1>(I));
MultiplyUIntModOperand neg_prod_B_mod_q_elt;
neg_prod_B_mod_q_elt.set(get<1>(I).value() - get<0>(I), get<1>(I));
SEAL_ITERATE(iter(alpha_sk, get<2>(I)), coeff_count_, [&](auto J) {
if (get<0>(J) > m_sk_div_2)
{
get<1>(J) = multiply_add_uint_mod(
negate_uint_mod(get<0>(J), m_sk_), prod_B_mod_q_elt, get<1>(J), get<1>(I));
}
else
{
get<1>(J) = multiply_add_uint_mod(get<0>(J), neg_prod_B_mod_q_elt, get<1>(J), get<1>(I));
}
});
});
}
void RNSTool::sm_mrq(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (input == nullptr)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!destination)
{
throw invalid_argument("destination cannot be null");
}
if (destination.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("destination is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_Bsk_size = base_Bsk_->size();
ConstCoeffIter input_m_tilde = input[base_Bsk_size];
const uint64_t m_tilde_div_2 = m_tilde_.value() >> 1;
SEAL_ALLOCATE_GET_COEFF_ITER(r_m_tilde, coeff_count_, pool);
multiply_poly_scalar_coeffmod(
input_m_tilde, coeff_count_, neg_inv_prod_q_mod_m_tilde_, m_tilde_, r_m_tilde);
SEAL_ITERATE(
iter(input, prod_q_mod_Bsk_, inv_m_tilde_mod_Bsk_, base_Bsk_->base(), destination), base_Bsk_size,
[&](auto I) {
MultiplyUIntModOperand prod_q_mod_Bsk_elt;
prod_q_mod_Bsk_elt.set(get<1>(I), get<3>(I));
SEAL_ITERATE(iter(get<0>(I), r_m_tilde, get<4>(I)), coeff_count_, [&](auto J) {
uint64_t temp = get<1>(J);
if (temp >= m_tilde_div_2)
{
temp += get<3>(I).value() - m_tilde_.value();
}
get<2>(J) = multiply_uint_mod(
multiply_add_uint_mod(temp, prod_q_mod_Bsk_elt, get<0>(J), get<3>(I)), get<2>(I),
get<3>(I));
});
});
}
void RNSTool::fast_floor(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (input == nullptr)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!destination)
{
throw invalid_argument("destination cannot be null");
}
if (destination.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("destination is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
size_t base_Bsk_size = base_Bsk_->size();
base_q_to_Bsk_conv_->fast_convert_array(input, destination, pool);
input += base_q_size;
SEAL_ITERATE(iter(input, inv_prod_q_mod_Bsk_, base_Bsk_->base(), destination), base_Bsk_size, [&](auto I) {
SEAL_ITERATE(iter(get<0>(I), get<3>(I)), coeff_count_, [&](auto J) {
get<1>(J) = multiply_uint_mod(get<0>(J) + (get<2>(I).value() - get<1>(J)), get<1>(I), get<2>(I));
});
});
}
void RNSTool::fastbconv_m_tilde(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (input == nullptr)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!destination)
{
throw invalid_argument("destination cannot be null");
}
if (destination.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("destination is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
size_t base_Bsk_size = base_Bsk_->size();
SEAL_ALLOCATE_GET_RNS_ITER(temp, coeff_count_, base_q_size, pool);
multiply_poly_scalar_coeffmod(input, base_q_size, m_tilde_.value(), base_q_->base(), temp);
base_q_to_Bsk_conv_->fast_convert_array(temp, destination, pool);
base_q_to_m_tilde_conv_->fast_convert_array(temp, destination + base_Bsk_size, pool);
}
void RNSTool::decrypt_scale_and_round(ConstRNSIter input, CoeffIter destination, MemoryPoolHandle pool) const
{
#ifdef SEAL_DEBUG
if (input == nullptr)
{
throw invalid_argument("input cannot be null");
}
if (input.poly_modulus_degree() != coeff_count_)
{
throw invalid_argument("input is not valid for encryption parameters");
}
if (!destination)
{
throw invalid_argument("destination cannot be null");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
#endif
size_t base_q_size = base_q_->size();
size_t base_t_gamma_size = base_t_gamma_->size();
SEAL_ALLOCATE_GET_RNS_ITER(temp, coeff_count_, base_q_size, pool);
SEAL_ITERATE(iter(input, prod_t_gamma_mod_q_, base_q_->base(), temp), base_q_size, [&](auto I) {
multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<1>(I), get<2>(I), get<3>(I));
});
SEAL_ALLOCATE_GET_RNS_ITER(temp_t_gamma, coeff_count_, base_t_gamma_size, pool);
base_q_to_t_gamma_conv_->fast_convert_array(temp, temp_t_gamma, pool);
SEAL_ITERATE(
iter(temp_t_gamma, neg_inv_q_mod_t_gamma_, base_t_gamma_->base(), temp_t_gamma), base_t_gamma_size,
[&](auto I) {
multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<1>(I), get<2>(I), get<3>(I));
});
uint64_t gamma_div_2 = (*base_t_gamma_)[1].value() >> 1;
SEAL_ITERATE(iter(temp_t_gamma[0], temp_t_gamma[1], destination), coeff_count_, [&](auto I) {
if (get<1>(I) > gamma_div_2)
{
get<2>(I) = add_uint_mod(get<0>(I), barrett_reduce_64(gamma_.value() - get<1>(I), t_), t_);
}
else
{
get<2>(I) = sub_uint_mod(get<0>(I), barrett_reduce_64(get<1>(I), t_), t_);
}
if (0 != get<2>(I))
{
get<2>(I) = multiply_uint_mod(get<2>(I), inv_gamma_mod_t_, t_);
}
});
}
void RNSTool::mod_t_and_divide_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const
{
size_t modulus_size = base_q_->size();
const Modulus *curr_modulus = base_q_->base();
const Modulus plain_modulus = t_;
uint64_t last_modulus_value = curr_modulus[modulus_size - 1].value();
SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(neg_c_last_mod_t, coeff_count_, pool);
modulo_poly_coeffs(CoeffIter(input[modulus_size - 1]), coeff_count_, plain_modulus, neg_c_last_mod_t);
negate_poly_coeffmod(neg_c_last_mod_t, coeff_count_, plain_modulus, neg_c_last_mod_t);
if (inv_q_last_mod_t_ != 1)
{
multiply_poly_scalar_coeffmod(
neg_c_last_mod_t, coeff_count_, inv_q_last_mod_t_, plain_modulus, neg_c_last_mod_t);
}
SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(delta_mod_q_i, coeff_count_, pool);
SEAL_ITERATE(iter(input, curr_modulus, inv_q_last_mod_q_), modulus_size - 1, [&](auto I) {
modulo_poly_coeffs(neg_c_last_mod_t, coeff_count_, get<1>(I), delta_mod_q_i);
multiply_poly_scalar_coeffmod(
delta_mod_q_i, coeff_count_, last_modulus_value, get<1>(I), delta_mod_q_i);
const uint64_t two_times_q_i = get<1>(I).value() << 1;
SEAL_ITERATE(iter(get<0>(I), delta_mod_q_i, input[modulus_size - 1]), coeff_count_, [&](auto J) {
get<0>(J) += two_times_q_i - barrett_reduce_64(get<2>(J), get<1>(I)) - get<1>(J);
});
multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<2>(I), get<1>(I), get<0>(I));
});
}
void RNSTool::decrypt_modt(RNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const
{
base_q_to_t_conv_->exact_convert_array(phase, destination, pool);
}
} }