#include "seal/encryptor.h"
#include "seal/modulus.h"
#include "seal/randomtostd.h"
#include "seal/util/common.h"
#include "seal/util/iterator.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/util/rlwe.h"
#include "seal/util/scalingvariant.h"
#include <algorithm>
#include <stdexcept>
using namespace std;
using namespace seal::util;
namespace seal
{
Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key) : context_(context)
{
if (!context_.parameters_set())
{
throw invalid_argument("encryption parameters are not set correctly");
}
set_public_key(public_key);
auto &parms = context_.key_context_data()->parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = coeff_modulus.size();
if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
{
throw logic_error("invalid parameters");
}
}
Encryptor::Encryptor(const SEALContext &context, const SecretKey &secret_key) : context_(context)
{
if (!context_.parameters_set())
{
throw invalid_argument("encryption parameters are not set correctly");
}
set_secret_key(secret_key);
auto &parms = context_.key_context_data()->parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = coeff_modulus.size();
if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
{
throw logic_error("invalid parameters");
}
}
Encryptor::Encryptor(const SEALContext &context, const PublicKey &public_key, const SecretKey &secret_key)
: context_(context)
{
if (!context_.parameters_set())
{
throw invalid_argument("encryption parameters are not set correctly");
}
set_public_key(public_key);
set_secret_key(secret_key);
auto &parms = context_.key_context_data()->parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = coeff_modulus.size();
if (!product_fits_in(coeff_count, coeff_modulus_size, size_t(2)))
{
throw logic_error("invalid parameters");
}
}
void Encryptor::encrypt_zero_internal(
parms_id_type parms_id,
bool is_asymmetric,
bool save_seed,
Ciphertext &destination,
MemoryPoolHandle pool
) const
{
PolynomialArray u_destination;
PolynomialArray e_destination;
encrypt_zero_internal(
parms_id,
is_asymmetric,
save_seed,
false, false, destination,
u_destination,
e_destination,
{}, pool
);
}
void Encryptor::encrypt_zero_internal(
parms_id_type parms_id,
bool is_asymmetric,
bool save_seed,
bool disable_special_modulus,
bool export_components,
Ciphertext &destination,
PolynomialArray &u_destination,
PolynomialArray &e_destination,
optional<prng_seed_type> seed,
MemoryPoolHandle pool
) const
{
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto context_data_ptr = context_.get_context_data(parms_id);
if (!context_data_ptr)
{
throw invalid_argument("parms_id is not valid for encryption parameters");
}
auto &context_data = *context_.get_context_data(parms_id);
auto &parms = context_data.parms();
size_t coeff_modulus_size = parms.coeff_modulus().size();
size_t coeff_count = parms.poly_modulus_degree();
bool is_ntt_form = false;
if (parms.scheme() == scheme_type::ckks)
{
is_ntt_form = true;
}
else if (parms.scheme() != scheme_type::bfv && parms.scheme() != scheme_type::bgv)
{
throw invalid_argument("unsupported scheme");
}
destination.resize(context_, parms_id, 2);
if (is_asymmetric)
{
auto prev_context_data_ptr = context_data.prev_context_data();
auto disable_special = disable_special_modulus && parms.scheme() == scheme_type::bfv;
if (prev_context_data_ptr && !disable_special)
{
auto &prev_context_data = *prev_context_data_ptr;
auto &prev_parms_id = prev_context_data.parms_id();
auto rns_tool = prev_context_data.rns_tool();
Ciphertext temp(pool);
util::encrypt_zero_asymmetric(
public_key_,
context_,
prev_parms_id,
is_ntt_form,
export_components,
temp,
u_destination,
e_destination,
seed
);
SEAL_ITERATE(iter(temp, destination), temp.size(), [&](auto I) {
if (is_ntt_form)
{
rns_tool->divide_and_round_q_last_ntt_inplace(
get<0>(I), prev_context_data.small_ntt_tables(), pool);
}
else if (parms.scheme() != scheme_type::bgv)
{
rns_tool->divide_and_round_q_last_inplace(get<0>(I), pool);
}
else
{
rns_tool->mod_t_and_divide_q_last_inplace(get<0>(I), pool);
}
set_poly(get<0>(I), coeff_count, coeff_modulus_size, get<1>(I));
});
destination.parms_id() = parms_id;
destination.is_ntt_form() = is_ntt_form;
destination.scale() = temp.scale();
destination.correction_factor() = temp.correction_factor();
}
else
{
util::encrypt_zero_asymmetric(
public_key_,
context_,
parms_id,
is_ntt_form,
export_components,
destination,
u_destination,
e_destination,
seed
);
}
}
else
{
util::encrypt_zero_symmetric(
secret_key_,
context_,
parms_id,
is_ntt_form,
save_seed,
destination
);
}
}
void Encryptor::encrypt_internal(
const Plaintext &plain,
bool is_asymmetric,
bool save_seed,
Ciphertext &destination,
MemoryPoolHandle pool
) const
{
PolynomialArray u_destination;
PolynomialArray e_destination;
Plaintext remainder_destination;
encrypt_internal(plain, is_asymmetric, save_seed, false, false, destination, u_destination, e_destination, remainder_destination);
}
void Encryptor::encrypt_internal(
const Plaintext &plain,
bool is_asymmetric,
bool save_seed,
bool disable_special_modulus,
bool export_components,
Ciphertext &destination,
PolynomialArray &u_destination,
PolynomialArray &e_destination,
Plaintext &remainder_destination,
optional<prng_seed_type> seed,
MemoryPoolHandle pool
) const
{
if (is_asymmetric)
{
if (!is_metadata_valid_for(public_key_, context_))
{
throw logic_error("public key is not set");
}
}
else
{
if (!is_metadata_valid_for(secret_key_, context_))
{
throw logic_error("secret key is not set");
}
}
if (!is_valid_for(plain, context_))
{
throw invalid_argument("plain is not valid for encryption parameters");
}
auto scheme = context_.key_context_data()->parms().scheme();
if (scheme == scheme_type::bfv)
{
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
encrypt_zero_internal(
context_.first_parms_id(),
is_asymmetric,
save_seed,
disable_special_modulus,
export_components,
destination,
u_destination,
e_destination,
seed,
pool
);
if (export_components) {
multiply_add_plain_with_scaling_variant(plain, *context_.first_context_data(), *iter(destination), true, remainder_destination);
} else {
multiply_add_plain_with_scaling_variant(plain, *context_.first_context_data(), *iter(destination));
}
}
else if (scheme == scheme_type::ckks)
{
if (!plain.is_ntt_form())
{
throw invalid_argument("plain must be in NTT form");
}
auto context_data_ptr = context_.get_context_data(plain.parms_id());
if (!context_data_ptr)
{
throw invalid_argument("plain is not valid for encryption parameters");
}
encrypt_zero_internal(
plain.parms_id(),
is_asymmetric,
save_seed,
disable_special_modulus,
export_components,
destination,
u_destination,
e_destination,
seed,
pool
);
auto &parms = context_.get_context_data(plain.parms_id())->parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
size_t coeff_count = parms.poly_modulus_degree();
ConstRNSIter plain_iter(plain.data(), coeff_count);
RNSIter destination_iter = *iter(destination);
add_poly_coeffmod(destination_iter, plain_iter, coeff_modulus_size, coeff_modulus, destination_iter);
destination.scale() = plain.scale();
}
else if (scheme == scheme_type::bgv)
{
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
encrypt_zero_internal(
context_.first_parms_id(),
is_asymmetric,
save_seed,
disable_special_modulus,
export_components,
destination,
u_destination,
e_destination,
seed,
pool
);
auto context_data_ptr = context_.first_context_data();
auto &parms = context_data_ptr->parms();
size_t coeff_count = parms.poly_modulus_degree();
add_plain_without_scaling_variant(plain, *context_data_ptr, RNSIter(destination.data(0), coeff_count));
}
else
{
throw invalid_argument("unsupported scheme");
}
}
}