#include "seal/batchencoder.h"
#include "seal/valcheck.h"
#include "seal/util/common.h"
#include <algorithm>
#include <limits>
#include <random>
#include <stdexcept>
using namespace std;
using namespace seal::util;
namespace seal
{
BatchEncoder::BatchEncoder(const SEALContext &context) : context_(context)
{
if (!context_.parameters_set())
{
throw invalid_argument("encryption parameters are not set correctly");
}
auto &context_data = *context_.first_context_data();
if (context_data.parms().scheme() != scheme_type::bfv && context_data.parms().scheme() != scheme_type::bgv)
{
throw invalid_argument("unsupported scheme");
}
if (!context_data.qualifiers().using_batching)
{
throw invalid_argument("encryption parameters are not valid for batching");
}
slots_ = context_data.parms().poly_modulus_degree();
roots_of_unity_ = allocate_uint(slots_, pool_);
populate_roots_of_unity_vector(context_data);
populate_matrix_reps_index_map();
}
void BatchEncoder::populate_roots_of_unity_vector(const SEALContext::ContextData &context_data)
{
uint64_t root = context_data.plain_ntt_tables()->get_root();
auto &modulus = context_data.parms().plain_modulus();
uint64_t generator_sq = multiply_uint_mod(root, root, modulus);
roots_of_unity_[0] = root;
for (size_t i = 1; i < slots_; i++)
{
roots_of_unity_[i] = multiply_uint_mod(roots_of_unity_[i - 1], generator_sq, modulus);
}
}
void BatchEncoder::populate_matrix_reps_index_map()
{
int logn = get_power_of_two(slots_);
matrix_reps_index_map_ = allocate<size_t>(slots_, pool_);
size_t row_size = slots_ >> 1;
size_t m = slots_ << 1;
uint64_t gen = 3;
uint64_t pos = 1;
for (size_t i = 0; i < row_size; i++)
{
uint64_t index1 = (pos - 1) >> 1;
uint64_t index2 = (m - pos - 1) >> 1;
matrix_reps_index_map_[i] = safe_cast<size_t>(util::reverse_bits(index1, logn));
matrix_reps_index_map_[row_size | i] = safe_cast<size_t>(util::reverse_bits(index2, logn));
pos *= gen;
pos &= (m - 1);
}
}
void BatchEncoder::reverse_bits(uint64_t *input)
{
#ifdef SEAL_DEBUG
if (input == nullptr)
{
throw invalid_argument("input cannot be null");
}
#endif
size_t coeff_count = context_.first_context_data()->parms().poly_modulus_degree();
int logn = get_power_of_two(coeff_count);
for (size_t i = 0; i < coeff_count; i++)
{
uint64_t reversed_i = util::reverse_bits(i, logn);
if (i < reversed_i)
{
swap(input[i], input[reversed_i]);
}
}
}
void BatchEncoder::encode(const vector<uint64_t> &values_matrix, Plaintext &destination) const
{
auto &context_data = *context_.first_context_data();
size_t values_matrix_size = values_matrix.size();
if (values_matrix_size > slots_)
{
throw invalid_argument("values_matrix size is too large");
}
#ifdef SEAL_DEBUG
uint64_t modulus = context_data.parms().plain_modulus().value();
for (auto v : values_matrix)
{
if (v >= modulus)
{
throw invalid_argument("input value is larger than plain_modulus");
}
}
#endif
destination.resize(slots_);
destination.parms_id() = parms_id_zero;
for (size_t i = 0; i < values_matrix_size; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = values_matrix[i];
}
for (size_t i = values_matrix_size; i < slots_; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = 0;
}
inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables());
}
void BatchEncoder::encode(const vector<int64_t> &values_matrix, Plaintext &destination) const
{
auto &context_data = *context_.first_context_data();
uint64_t modulus = context_data.parms().plain_modulus().value();
size_t values_matrix_size = values_matrix.size();
if (values_matrix_size > slots_)
{
throw invalid_argument("values_matrix size is too large");
}
#ifdef SEAL_DEBUG
uint64_t plain_modulus_div_two = modulus >> 1;
for (auto v : values_matrix)
{
if (unsigned_gt(llabs(v), plain_modulus_div_two))
{
throw invalid_argument("input value is larger than plain_modulus");
}
}
#endif
destination.resize(slots_);
destination.parms_id() = parms_id_zero;
for (size_t i = 0; i < values_matrix_size; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) =
(values_matrix[i] < 0) ? (modulus + static_cast<uint64_t>(values_matrix[i]))
: static_cast<uint64_t>(values_matrix[i]);
}
for (size_t i = values_matrix_size; i < slots_; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = 0;
}
inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables());
}
#ifdef SEAL_USE_MSGSL
void BatchEncoder::encode(gsl::span<const uint64_t> values_matrix, Plaintext &destination) const
{
auto &context_data = *context_.first_context_data();
size_t values_matrix_size = static_cast<size_t>(values_matrix.size());
if (values_matrix_size > slots_)
{
throw invalid_argument("values_matrix size is too large");
}
#ifdef SEAL_DEBUG
uint64_t modulus = context_data.parms().plain_modulus().value();
for (auto v : values_matrix)
{
if (v >= modulus)
{
throw invalid_argument("input value is larger than plain_modulus");
}
}
#endif
destination.resize(slots_);
destination.parms_id() = parms_id_zero;
for (size_t i = 0; i < values_matrix_size; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = values_matrix[i];
}
for (size_t i = values_matrix_size; i < slots_; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = 0;
}
inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables());
}
void BatchEncoder::encode(gsl::span<const int64_t> values_matrix, Plaintext &destination) const
{
auto &context_data = *context_.first_context_data();
uint64_t modulus = context_data.parms().plain_modulus().value();
size_t values_matrix_size = static_cast<size_t>(values_matrix.size());
if (values_matrix_size > slots_)
{
throw invalid_argument("values_matrix size is too large");
}
#ifdef SEAL_DEBUG
uint64_t plain_modulus_div_two = modulus >> 1;
for (auto v : values_matrix)
{
if (unsigned_gt(llabs(v), plain_modulus_div_two))
{
throw invalid_argument("input value is larger than plain_modulus");
}
}
#endif
destination.resize(slots_);
destination.parms_id() = parms_id_zero;
for (size_t i = 0; i < values_matrix_size; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) =
(values_matrix[i] < 0) ? (modulus + static_cast<uint64_t>(values_matrix[i]))
: static_cast<uint64_t>(values_matrix[i]);
}
for (size_t i = values_matrix_size; i < slots_; i++)
{
*(destination.data() + matrix_reps_index_map_[i]) = 0;
}
inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables());
}
#endif
void BatchEncoder::decode(const Plaintext &plain, vector<uint64_t> &destination, MemoryPoolHandle pool) const
{
if (!is_valid_for(plain, context_))
{
throw invalid_argument("plain is not valid for encryption parameters");
}
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto &context_data = *context_.first_context_data();
destination.resize(slots_);
size_t plain_coeff_count = min(plain.coeff_count(), slots_);
auto temp_dest(allocate_uint(slots_, pool));
set_uint(plain.data(), plain_coeff_count, temp_dest.get());
set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count);
ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables());
for (size_t i = 0; i < slots_; i++)
{
destination[i] = temp_dest[matrix_reps_index_map_[i]];
}
}
void BatchEncoder::decode(const Plaintext &plain, vector<int64_t> &destination, MemoryPoolHandle pool) const
{
if (!is_valid_for(plain, context_))
{
throw invalid_argument("plain is not valid for encryption parameters");
}
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto &context_data = *context_.first_context_data();
uint64_t modulus = context_data.parms().plain_modulus().value();
destination.resize(slots_);
size_t plain_coeff_count = min(plain.coeff_count(), slots_);
auto temp_dest(allocate_uint(slots_, pool));
set_uint(plain.data(), plain_coeff_count, temp_dest.get());
set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count);
ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables());
uint64_t plain_modulus_div_two = modulus >> 1;
for (size_t i = 0; i < slots_; i++)
{
uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]];
destination[i] = (curr_value > plain_modulus_div_two)
? (static_cast<int64_t>(curr_value) - static_cast<int64_t>(modulus))
: static_cast<int64_t>(curr_value);
}
}
#ifdef SEAL_USE_MSGSL
void BatchEncoder::decode(const Plaintext &plain, gsl::span<uint64_t> destination, MemoryPoolHandle pool) const
{
if (!is_valid_for(plain, context_))
{
throw invalid_argument("plain is not valid for encryption parameters");
}
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto &context_data = *context_.first_context_data();
if (unsigned_gt(destination.size(), numeric_limits<int>::max()) || unsigned_neq(destination.size(), slots_))
{
throw invalid_argument("destination has incorrect size");
}
size_t plain_coeff_count = min(plain.coeff_count(), slots_);
auto temp_dest(allocate_uint(slots_, pool));
set_uint(plain.data(), plain_coeff_count, temp_dest.get());
set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count);
ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables());
for (size_t i = 0; i < slots_; i++)
{
destination[i] = temp_dest[matrix_reps_index_map_[i]];
}
}
void BatchEncoder::decode(const Plaintext &plain, gsl::span<int64_t> destination, MemoryPoolHandle pool) const
{
if (!is_valid_for(plain, context_))
{
throw invalid_argument("plain is not valid for encryption parameters");
}
if (plain.is_ntt_form())
{
throw invalid_argument("plain cannot be in NTT form");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto &context_data = *context_.first_context_data();
uint64_t modulus = context_data.parms().plain_modulus().value();
if (unsigned_gt(destination.size(), numeric_limits<int>::max()) || unsigned_neq(destination.size(), slots_))
{
throw invalid_argument("destination has incorrect size");
}
size_t plain_coeff_count = min(plain.coeff_count(), slots_);
auto temp_dest(allocate_uint(slots_, pool));
set_uint(plain.data(), plain_coeff_count, temp_dest.get());
set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count);
ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables());
uint64_t plain_modulus_div_two = modulus >> 1;
for (size_t i = 0; i < slots_; i++)
{
uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]];
destination[i] = (curr_value > plain_modulus_div_two)
? (static_cast<int64_t>(curr_value) - static_cast<int64_t>(modulus))
: static_cast<int64_t>(curr_value);
}
}
#endif
}