#include "seal/encryptor.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/util/scalingvariant.h"
#include "seal/util/uintarith.h"
using namespace std;
namespace seal
{
namespace util
{
void add_plain_without_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
const size_t plain_coeff_count = plain.coeff_count();
const size_t coeff_modulus_size = coeff_modulus.size();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw std::invalid_argument("destination is not valid for encryption parameters");
}
#endif
SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) {
std::transform(
plain.data(), plain.data() + plain_coeff_count, get<0>(I), get<0>(I),
[&](uint64_t m, uint64_t c) -> uint64_t {
m = barrett_reduce_64(m, get<1>(I));
return add_uint_mod(c, m, get<1>(I));
});
});
}
void sub_plain_without_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
const size_t plain_coeff_count = plain.coeff_count();
const size_t coeff_modulus_size = coeff_modulus.size();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw std::invalid_argument("destination is not valid for encryption parameters");
}
#endif
SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) {
std::transform(
plain.data(), plain.data() + plain_coeff_count, get<0>(I), get<0>(I),
[&](uint64_t m, uint64_t c) -> uint64_t {
m = barrett_reduce_64(m, get<1>(I));
return sub_uint_mod(c, m, get<1>(I));
});
});
}
void multiply_add_plain_with_scaling_variant(
const Plaintext &plain,
const SEALContext::ContextData &context_data,
RNSIter destination,
bool export_remainder,
Plaintext &remainder_destination
)
{
auto &parms = context_data.parms();
size_t plain_coeff_count = plain.coeff_count();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
auto plain_modulus = context_data.parms().plain_modulus();
auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus();
uint64_t plain_upper_half_threshold = context_data.plain_upper_half_threshold();
uint64_t q_mod_t = context_data.coeff_modulus_mod_plain_modulus();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw invalid_argument("destination is not valid for encryption parameters");
}
#endif
if (export_remainder) {
remainder_destination.resize(plain_coeff_count);
}
SEAL_ITERATE(iter(plain.data(), size_t(0), remainder_destination.data()), plain_coeff_count, [&](auto I) {
unsigned long long prod[2]{ 0, 0 };
uint64_t numerator[2]{ 0, 0 };
multiply_uint64(get<0>(I), q_mod_t, prod);
unsigned char carry = add_uint64(*prod, plain_upper_half_threshold, numerator);
numerator[1] = static_cast<uint64_t>(prod[1]) + static_cast<uint64_t>(carry);
uint64_t fix[2] = { 0, 0 };
divide_uint128_inplace(numerator, plain_modulus.value(), fix);
if (export_remainder) {
get<2>(I) = fix[0];
}
size_t coeff_index = get<1>(I);
SEAL_ITERATE(
iter(destination, coeff_modulus, coeff_div_plain_modulus), coeff_modulus_size, [&](auto J) {
uint64_t scaled_rounded_coeff = multiply_add_uint_mod(get<0>(I), get<2>(J), fix[0], get<1>(J));
get<0>(J)[coeff_index] = add_uint_mod(get<0>(J)[coeff_index], scaled_rounded_coeff, get<1>(J));
});
});
}
void multiply_add_plain_with_scaling_variant(
const Plaintext &plain,
const SEALContext::ContextData &context_data,
RNSIter destination
) {
auto remainder_destination = Plaintext();
multiply_add_plain_with_scaling_variant(
plain, context_data, destination, false, remainder_destination
);
}
void multiply_sub_plain_with_scaling_variant(
const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination)
{
auto &parms = context_data.parms();
size_t plain_coeff_count = plain.coeff_count();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
auto plain_modulus = context_data.parms().plain_modulus();
auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus();
uint64_t plain_upper_half_threshold = context_data.plain_upper_half_threshold();
uint64_t q_mod_t = context_data.coeff_modulus_mod_plain_modulus();
#ifdef SEAL_DEBUG
if (plain_coeff_count > parms.poly_modulus_degree())
{
throw std::invalid_argument("invalid plaintext");
}
if (destination.poly_modulus_degree() != parms.poly_modulus_degree())
{
throw invalid_argument("destination is not valid for encryption parameters");
}
#endif
SEAL_ITERATE(iter(plain.data(), size_t(0)), plain_coeff_count, [&](auto I) {
unsigned long long prod[2]{ 0, 0 };
uint64_t numerator[2]{ 0, 0 };
multiply_uint64(get<0>(I), q_mod_t, prod);
unsigned char carry = add_uint64(*prod, plain_upper_half_threshold, numerator);
numerator[1] = static_cast<uint64_t>(prod[1]) + static_cast<uint64_t>(carry);
uint64_t fix[2] = { 0, 0 };
divide_uint128_inplace(numerator, plain_modulus.value(), fix);
size_t coeff_index = get<1>(I);
SEAL_ITERATE(
iter(destination, coeff_modulus, coeff_div_plain_modulus), coeff_modulus_size, [&](auto J) {
uint64_t scaled_rounded_coeff = multiply_add_uint_mod(get<0>(I), get<2>(J), fix[0], get<1>(J));
get<0>(J)[coeff_index] = sub_uint_mod(get<0>(J)[coeff_index], scaled_rounded_coeff, get<1>(J));
});
});
}
} }