#include "examples.h"
using namespace std;
using namespace seal;
void bfv_performance_test(SEALContext context)
{
chrono::high_resolution_clock::time_point time_start, time_end;
print_parameters(context);
cout << endl;
auto &parms = context.first_context_data()->parms();
auto &plain_modulus = parms.plain_modulus();
size_t poly_modulus_degree = parms.poly_modulus_degree();
cout << "Generating secret/public keys: ";
KeyGenerator keygen(context);
cout << "Done" << endl;
auto secret_key = keygen.secret_key();
PublicKey public_key;
keygen.create_public_key(public_key);
RelinKeys relin_keys;
GaloisKeys gal_keys;
chrono::microseconds time_diff;
if (context.using_keyswitching())
{
cout << "Generating relinearization keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_relin_keys(relin_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
if (!context.key_context_data()->qualifiers().using_batching)
{
cout << "Given encryption parameters do not support batching." << endl;
return;
}
cout << "Generating Galois keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_galois_keys(gal_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
}
Encryptor encryptor(context, public_key);
Decryptor decryptor(context, secret_key);
Evaluator evaluator(context);
BatchEncoder batch_encoder(context);
chrono::microseconds time_batch_sum(0);
chrono::microseconds time_unbatch_sum(0);
chrono::microseconds time_encrypt_sum(0);
chrono::microseconds time_decrypt_sum(0);
chrono::microseconds time_add_sum(0);
chrono::microseconds time_multiply_sum(0);
chrono::microseconds time_multiply_plain_sum(0);
chrono::microseconds time_square_sum(0);
chrono::microseconds time_relinearize_sum(0);
chrono::microseconds time_rotate_rows_one_step_sum(0);
chrono::microseconds time_rotate_rows_random_sum(0);
chrono::microseconds time_rotate_columns_sum(0);
chrono::microseconds time_serialize_sum(0);
#ifdef SEAL_USE_ZLIB
chrono::microseconds time_serialize_zlib_sum(0);
#endif
#ifdef SEAL_USE_ZSTD
chrono::microseconds time_serialize_zstd_sum(0);
#endif
long long count = 10;
size_t slot_count = batch_encoder.slot_count();
vector<uint64_t> pod_vector;
random_device rd;
for (size_t i = 0; i < slot_count; i++)
{
pod_vector.push_back(plain_modulus.reduce(rd()));
}
cout << "Running tests ";
for (size_t i = 0; i < static_cast<size_t>(count); i++)
{
Plaintext plain(poly_modulus_degree, 0);
Plaintext plain1(poly_modulus_degree, 0);
Plaintext plain2(poly_modulus_degree, 0);
time_start = chrono::high_resolution_clock::now();
batch_encoder.encode(pod_vector, plain);
time_end = chrono::high_resolution_clock::now();
time_batch_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
vector<uint64_t> pod_vector2(slot_count);
time_start = chrono::high_resolution_clock::now();
batch_encoder.decode(plain, pod_vector2);
time_end = chrono::high_resolution_clock::now();
time_unbatch_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (pod_vector2 != pod_vector)
{
throw runtime_error("Batch/unbatch failed. Something is wrong.");
}
Ciphertext encrypted(context);
time_start = chrono::high_resolution_clock::now();
encryptor.encrypt(plain, encrypted);
time_end = chrono::high_resolution_clock::now();
time_encrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
decryptor.decrypt(encrypted, plain2);
time_end = chrono::high_resolution_clock::now();
time_decrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (plain2 != plain)
{
throw runtime_error("Encrypt/decrypt failed. Something is wrong.");
}
Ciphertext encrypted1(context);
batch_encoder.encode(vector<uint64_t>(slot_count, i), plain1);
encryptor.encrypt(plain1, encrypted1);
Ciphertext encrypted2(context);
batch_encoder.encode(vector<uint64_t>(slot_count, i + 1), plain2);
encryptor.encrypt(plain2, encrypted2);
time_start = chrono::high_resolution_clock::now();
evaluator.add_inplace(encrypted1, encrypted1);
evaluator.add_inplace(encrypted2, encrypted2);
evaluator.add_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_add_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
encrypted1.reserve(3);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_multiply_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_plain_inplace(encrypted2, plain);
time_end = chrono::high_resolution_clock::now();
time_multiply_plain_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.square_inplace(encrypted2);
time_end = chrono::high_resolution_clock::now();
time_square_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (context.using_keyswitching())
{
time_start = chrono::high_resolution_clock::now();
evaluator.relinearize_inplace(encrypted1, relin_keys);
time_end = chrono::high_resolution_clock::now();
time_relinearize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_rows_inplace(encrypted, 1, gal_keys);
evaluator.rotate_rows_inplace(encrypted, -1, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_rows_one_step_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
;
size_t row_size = batch_encoder.slot_count() / 2;
int random_rotation = static_cast<int>(rd() & (row_size - 1));
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_rows_inplace(encrypted, random_rotation, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_rows_random_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_columns_inplace(encrypted, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_columns_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
}
size_t buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::none));
vector<seal_byte> buf(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::none);
time_end = chrono::high_resolution_clock::now();
time_serialize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#ifdef SEAL_USE_ZLIB
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zlib));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zlib);
time_end = chrono::high_resolution_clock::now();
time_serialize_zlib_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
#ifdef SEAL_USE_ZSTD
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zstd));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zstd);
time_end = chrono::high_resolution_clock::now();
time_serialize_zstd_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
cout << ".";
cout.flush();
}
cout << " Done" << endl << endl;
cout.flush();
auto avg_batch = time_batch_sum.count() / count;
auto avg_unbatch = time_unbatch_sum.count() / count;
auto avg_encrypt = time_encrypt_sum.count() / count;
auto avg_decrypt = time_decrypt_sum.count() / count;
auto avg_add = time_add_sum.count() / (3 * count);
auto avg_multiply = time_multiply_sum.count() / count;
auto avg_multiply_plain = time_multiply_plain_sum.count() / count;
auto avg_square = time_square_sum.count() / count;
auto avg_relinearize = time_relinearize_sum.count() / count;
auto avg_rotate_rows_one_step = time_rotate_rows_one_step_sum.count() / (2 * count);
auto avg_rotate_rows_random = time_rotate_rows_random_sum.count() / count;
auto avg_rotate_columns = time_rotate_columns_sum.count() / count;
auto avg_serialize = time_serialize_sum.count() / count;
#ifdef SEAL_USE_ZLIB
auto avg_serialize_zlib = time_serialize_zlib_sum.count() / count;
#endif
#ifdef SEAL_USE_ZSTD
auto avg_serialize_zstd = time_serialize_zstd_sum.count() / count;
#endif
cout << "Average batch: " << avg_batch << " microseconds" << endl;
cout << "Average unbatch: " << avg_unbatch << " microseconds" << endl;
cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl;
cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl;
cout << "Average add: " << avg_add << " microseconds" << endl;
cout << "Average multiply: " << avg_multiply << " microseconds" << endl;
cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl;
cout << "Average square: " << avg_square << " microseconds" << endl;
if (context.using_keyswitching())
{
cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl;
cout << "Average rotate rows one step: " << avg_rotate_rows_one_step << " microseconds" << endl;
cout << "Average rotate rows random: " << avg_rotate_rows_random << " microseconds" << endl;
cout << "Average rotate columns: " << avg_rotate_columns << " microseconds" << endl;
}
cout << "Average serialize ciphertext: " << avg_serialize << " microseconds" << endl;
#ifdef SEAL_USE_ZLIB
cout << "Average compressed (ZLIB) serialize ciphertext: " << avg_serialize_zlib << " microseconds" << endl;
#endif
#ifdef SEAL_USE_ZSTD
cout << "Average compressed (Zstandard) serialize ciphertext: " << avg_serialize_zstd << " microseconds" << endl;
#endif
cout.flush();
}
void ckks_performance_test(SEALContext context)
{
chrono::high_resolution_clock::time_point time_start, time_end;
print_parameters(context);
cout << endl;
auto &parms = context.first_context_data()->parms();
size_t poly_modulus_degree = parms.poly_modulus_degree();
cout << "Generating secret/public keys: ";
KeyGenerator keygen(context);
cout << "Done" << endl;
auto secret_key = keygen.secret_key();
PublicKey public_key;
keygen.create_public_key(public_key);
RelinKeys relin_keys;
GaloisKeys gal_keys;
chrono::microseconds time_diff;
if (context.using_keyswitching())
{
cout << "Generating relinearization keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_relin_keys(relin_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
if (!context.first_context_data()->qualifiers().using_batching)
{
cout << "Given encryption parameters do not support batching." << endl;
return;
}
cout << "Generating Galois keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_galois_keys(gal_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
}
Encryptor encryptor(context, public_key);
Decryptor decryptor(context, secret_key);
Evaluator evaluator(context);
CKKSEncoder ckks_encoder(context);
chrono::microseconds time_encode_sum(0);
chrono::microseconds time_decode_sum(0);
chrono::microseconds time_encrypt_sum(0);
chrono::microseconds time_decrypt_sum(0);
chrono::microseconds time_add_sum(0);
chrono::microseconds time_multiply_sum(0);
chrono::microseconds time_multiply_plain_sum(0);
chrono::microseconds time_square_sum(0);
chrono::microseconds time_relinearize_sum(0);
chrono::microseconds time_rescale_sum(0);
chrono::microseconds time_rotate_one_step_sum(0);
chrono::microseconds time_rotate_random_sum(0);
chrono::microseconds time_conjugate_sum(0);
chrono::microseconds time_serialize_sum(0);
#ifdef SEAL_USE_ZLIB
chrono::microseconds time_serialize_zlib_sum(0);
#endif
#ifdef SEAL_USE_ZSTD
chrono::microseconds time_serialize_zstd_sum(0);
#endif
long long count = 10;
vector<double> pod_vector;
random_device rd;
for (size_t i = 0; i < ckks_encoder.slot_count(); i++)
{
pod_vector.push_back(1.001 * static_cast<double>(i));
}
cout << "Running tests ";
for (long long i = 0; i < count; i++)
{
Plaintext plain(parms.poly_modulus_degree() * parms.coeff_modulus().size(), 0);
double scale = sqrt(static_cast<double>(parms.coeff_modulus().back().value()));
time_start = chrono::high_resolution_clock::now();
ckks_encoder.encode(pod_vector, scale, plain);
time_end = chrono::high_resolution_clock::now();
time_encode_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
vector<double> pod_vector2(ckks_encoder.slot_count());
time_start = chrono::high_resolution_clock::now();
ckks_encoder.decode(plain, pod_vector2);
time_end = chrono::high_resolution_clock::now();
time_decode_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
Ciphertext encrypted(context);
time_start = chrono::high_resolution_clock::now();
encryptor.encrypt(plain, encrypted);
time_end = chrono::high_resolution_clock::now();
time_encrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
Plaintext plain2(poly_modulus_degree, 0);
time_start = chrono::high_resolution_clock::now();
decryptor.decrypt(encrypted, plain2);
time_end = chrono::high_resolution_clock::now();
time_decrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
Ciphertext encrypted1(context);
ckks_encoder.encode(i + 1, plain);
encryptor.encrypt(plain, encrypted1);
Ciphertext encrypted2(context);
ckks_encoder.encode(i + 1, plain2);
encryptor.encrypt(plain2, encrypted2);
time_start = chrono::high_resolution_clock::now();
evaluator.add_inplace(encrypted1, encrypted1);
evaluator.add_inplace(encrypted2, encrypted2);
evaluator.add_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_add_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
encrypted1.reserve(3);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_multiply_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_plain_inplace(encrypted2, plain);
time_end = chrono::high_resolution_clock::now();
time_multiply_plain_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.square_inplace(encrypted2);
time_end = chrono::high_resolution_clock::now();
time_square_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (context.using_keyswitching())
{
time_start = chrono::high_resolution_clock::now();
evaluator.relinearize_inplace(encrypted1, relin_keys);
time_end = chrono::high_resolution_clock::now();
time_relinearize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rescale_to_next_inplace(encrypted1);
time_end = chrono::high_resolution_clock::now();
time_rescale_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_vector_inplace(encrypted, 1, gal_keys);
evaluator.rotate_vector_inplace(encrypted, -1, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_one_step_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
int random_rotation = static_cast<int>(rd() & (ckks_encoder.slot_count() - 1));
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_vector_inplace(encrypted, random_rotation, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_random_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.complex_conjugate_inplace(encrypted, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_conjugate_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
}
size_t buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::none));
vector<seal_byte> buf(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::none);
time_end = chrono::high_resolution_clock::now();
time_serialize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#ifdef SEAL_USE_ZLIB
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zlib));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zlib);
time_end = chrono::high_resolution_clock::now();
time_serialize_zlib_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
#ifdef SEAL_USE_ZSTD
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zstd));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zstd);
time_end = chrono::high_resolution_clock::now();
time_serialize_zstd_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
cout << ".";
cout.flush();
}
cout << " Done" << endl << endl;
cout.flush();
auto avg_encode = time_encode_sum.count() / count;
auto avg_decode = time_decode_sum.count() / count;
auto avg_encrypt = time_encrypt_sum.count() / count;
auto avg_decrypt = time_decrypt_sum.count() / count;
auto avg_add = time_add_sum.count() / (3 * count);
auto avg_multiply = time_multiply_sum.count() / count;
auto avg_multiply_plain = time_multiply_plain_sum.count() / count;
auto avg_square = time_square_sum.count() / count;
auto avg_relinearize = time_relinearize_sum.count() / count;
auto avg_rescale = time_rescale_sum.count() / count;
auto avg_rotate_one_step = time_rotate_one_step_sum.count() / (2 * count);
auto avg_rotate_random = time_rotate_random_sum.count() / count;
auto avg_conjugate = time_conjugate_sum.count() / count;
auto avg_serialize = time_serialize_sum.count() / count;
#ifdef SEAL_USE_ZLIB
auto avg_serialize_zlib = time_serialize_zlib_sum.count() / count;
#endif
#ifdef SEAL_USE_ZSTD
auto avg_serialize_zstd = time_serialize_zstd_sum.count() / count;
#endif
cout << "Average encode: " << avg_encode << " microseconds" << endl;
cout << "Average decode: " << avg_decode << " microseconds" << endl;
cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl;
cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl;
cout << "Average add: " << avg_add << " microseconds" << endl;
cout << "Average multiply: " << avg_multiply << " microseconds" << endl;
cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl;
cout << "Average square: " << avg_square << " microseconds" << endl;
if (context.using_keyswitching())
{
cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl;
cout << "Average rescale: " << avg_rescale << " microseconds" << endl;
cout << "Average rotate vector one step: " << avg_rotate_one_step << " microseconds" << endl;
cout << "Average rotate vector random: " << avg_rotate_random << " microseconds" << endl;
cout << "Average complex conjugate: " << avg_conjugate << " microseconds" << endl;
}
cout << "Average serialize ciphertext: " << avg_serialize << " microseconds" << endl;
#ifdef SEAL_USE_ZLIB
cout << "Average compressed (ZLIB) serialize ciphertext: " << avg_serialize_zlib << " microseconds" << endl;
#endif
#ifdef SEAL_USE_ZSTD
cout << "Average compressed (Zstandard) serialize ciphertext: " << avg_serialize_zstd << " microseconds" << endl;
#endif
cout.flush();
}
void bgv_performance_test(SEALContext context)
{
chrono::high_resolution_clock::time_point time_start, time_end;
print_parameters(context);
cout << endl;
auto &parms = context.first_context_data()->parms();
auto &plain_modulus = parms.plain_modulus();
size_t poly_modulus_degree = parms.poly_modulus_degree();
cout << "Generating secret/public keys: ";
KeyGenerator keygen(context);
cout << "Done" << endl;
auto secret_key = keygen.secret_key();
PublicKey public_key;
keygen.create_public_key(public_key);
RelinKeys relin_keys;
GaloisKeys gal_keys;
chrono::microseconds time_diff;
if (context.using_keyswitching())
{
cout << "Generating relinearization keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_relin_keys(relin_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
if (!context.key_context_data()->qualifiers().using_batching)
{
cout << "Given encryption parameters do not support batching." << endl;
return;
}
cout << "Generating Galois keys: ";
time_start = chrono::high_resolution_clock::now();
keygen.create_galois_keys(gal_keys);
time_end = chrono::high_resolution_clock::now();
time_diff = chrono::duration_cast<chrono::microseconds>(time_end - time_start);
cout << "Done [" << time_diff.count() << " microseconds]" << endl;
}
Encryptor encryptor(context, public_key);
Decryptor decryptor(context, secret_key);
Evaluator evaluator(context);
BatchEncoder batch_encoder(context);
chrono::microseconds time_batch_sum(0);
chrono::microseconds time_unbatch_sum(0);
chrono::microseconds time_encrypt_sum(0);
chrono::microseconds time_decrypt_sum(0);
chrono::microseconds time_add_sum(0);
chrono::microseconds time_multiply_sum(0);
chrono::microseconds time_multiply_plain_sum(0);
chrono::microseconds time_square_sum(0);
chrono::microseconds time_relinearize_sum(0);
chrono::microseconds time_rotate_rows_one_step_sum(0);
chrono::microseconds time_rotate_rows_random_sum(0);
chrono::microseconds time_rotate_columns_sum(0);
chrono::microseconds time_serialize_sum(0);
#ifdef SEAL_USE_ZLIB
chrono::microseconds time_serialize_zlib_sum(0);
#endif
#ifdef SEAL_USE_ZSTD
chrono::microseconds time_serialize_zstd_sum(0);
#endif
long long count = 10;
size_t slot_count = batch_encoder.slot_count();
vector<uint64_t> pod_vector;
random_device rd;
for (size_t i = 0; i < slot_count; i++)
{
pod_vector.push_back(plain_modulus.reduce(rd()));
}
cout << "Running tests ";
for (size_t i = 0; i < static_cast<size_t>(count); i++)
{
Plaintext plain(poly_modulus_degree, 0);
Plaintext plain1(poly_modulus_degree, 0);
Plaintext plain2(poly_modulus_degree, 0);
time_start = chrono::high_resolution_clock::now();
batch_encoder.encode(pod_vector, plain);
time_end = chrono::high_resolution_clock::now();
time_batch_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
vector<uint64_t> pod_vector2(slot_count);
time_start = chrono::high_resolution_clock::now();
batch_encoder.decode(plain, pod_vector2);
time_end = chrono::high_resolution_clock::now();
time_unbatch_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (pod_vector2 != pod_vector)
{
throw runtime_error("Batch/unbatch failed. Something is wrong.");
}
Ciphertext encrypted(context);
time_start = chrono::high_resolution_clock::now();
encryptor.encrypt(plain, encrypted);
time_end = chrono::high_resolution_clock::now();
time_encrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
decryptor.decrypt(encrypted, plain2);
time_end = chrono::high_resolution_clock::now();
time_decrypt_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (plain2 != plain)
{
throw runtime_error("Encrypt/decrypt failed. Something is wrong.");
}
Ciphertext encrypted1(context);
batch_encoder.encode(vector<uint64_t>(slot_count, i), plain1);
encryptor.encrypt(plain1, encrypted1);
Ciphertext encrypted2(context);
batch_encoder.encode(vector<uint64_t>(slot_count, i + 1), plain2);
encryptor.encrypt(plain2, encrypted2);
time_start = chrono::high_resolution_clock::now();
evaluator.add_inplace(encrypted1, encrypted1);
evaluator.add_inplace(encrypted2, encrypted2);
evaluator.add_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_add_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
encrypted1.reserve(3);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_inplace(encrypted1, encrypted2);
time_end = chrono::high_resolution_clock::now();
time_multiply_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.multiply_plain_inplace(encrypted2, plain);
time_end = chrono::high_resolution_clock::now();
time_multiply_plain_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.square_inplace(encrypted2);
time_end = chrono::high_resolution_clock::now();
time_square_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
if (context.using_keyswitching())
{
time_start = chrono::high_resolution_clock::now();
evaluator.relinearize_inplace(encrypted1, relin_keys);
time_end = chrono::high_resolution_clock::now();
time_relinearize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_rows_inplace(encrypted, 1, gal_keys);
evaluator.rotate_rows_inplace(encrypted, -1, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_rows_one_step_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
;
size_t row_size = batch_encoder.slot_count() / 2;
int random_rotation = static_cast<int>(rd() & (row_size - 1));
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_rows_inplace(encrypted, random_rotation, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_rows_random_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
time_start = chrono::high_resolution_clock::now();
evaluator.rotate_columns_inplace(encrypted, gal_keys);
time_end = chrono::high_resolution_clock::now();
time_rotate_columns_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
}
size_t buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::none));
vector<seal_byte> buf(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::none);
time_end = chrono::high_resolution_clock::now();
time_serialize_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#ifdef SEAL_USE_ZLIB
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zlib));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zlib);
time_end = chrono::high_resolution_clock::now();
time_serialize_zlib_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
#ifdef SEAL_USE_ZSTD
buf_size = static_cast<size_t>(encrypted.save_size(compr_mode_type::zstd));
buf.resize(buf_size);
time_start = chrono::high_resolution_clock::now();
encrypted.save(buf.data(), buf_size, compr_mode_type::zstd);
time_end = chrono::high_resolution_clock::now();
time_serialize_zstd_sum += chrono::duration_cast<chrono::microseconds>(time_end - time_start);
#endif
cout << ".";
cout.flush();
}
cout << " Done" << endl << endl;
cout.flush();
auto avg_batch = time_batch_sum.count() / count;
auto avg_unbatch = time_unbatch_sum.count() / count;
auto avg_encrypt = time_encrypt_sum.count() / count;
auto avg_decrypt = time_decrypt_sum.count() / count;
auto avg_add = time_add_sum.count() / (3 * count);
auto avg_multiply = time_multiply_sum.count() / count;
auto avg_multiply_plain = time_multiply_plain_sum.count() / count;
auto avg_square = time_square_sum.count() / count;
auto avg_relinearize = time_relinearize_sum.count() / count;
auto avg_rotate_rows_one_step = time_rotate_rows_one_step_sum.count() / (2 * count);
auto avg_rotate_rows_random = time_rotate_rows_random_sum.count() / count;
auto avg_rotate_columns = time_rotate_columns_sum.count() / count;
auto avg_serialize = time_serialize_sum.count() / count;
#ifdef SEAL_USE_ZLIB
auto avg_serialize_zlib = time_serialize_zlib_sum.count() / count;
#endif
#ifdef SEAL_USE_ZSTD
auto avg_serialize_zstd = time_serialize_zstd_sum.count() / count;
#endif
cout << "Average batch: " << avg_batch << " microseconds" << endl;
cout << "Average unbatch: " << avg_unbatch << " microseconds" << endl;
cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl;
cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl;
cout << "Average add: " << avg_add << " microseconds" << endl;
cout << "Average multiply: " << avg_multiply << " microseconds" << endl;
cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl;
cout << "Average square: " << avg_square << " microseconds" << endl;
if (context.using_keyswitching())
{
cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl;
cout << "Average rotate rows one step: " << avg_rotate_rows_one_step << " microseconds" << endl;
cout << "Average rotate rows random: " << avg_rotate_rows_random << " microseconds" << endl;
cout << "Average rotate columns: " << avg_rotate_columns << " microseconds" << endl;
}
cout << "Average serialize ciphertext: " << avg_serialize << " microseconds" << endl;
#ifdef SEAL_USE_ZLIB
cout << "Average compressed (ZLIB) serialize ciphertext: " << avg_serialize_zlib << " microseconds" << endl;
#endif
#ifdef SEAL_USE_ZSTD
cout << "Average compressed (Zstandard) serialize ciphertext: " << avg_serialize_zstd << " microseconds" << endl;
#endif
cout.flush();
}
void example_bfv_performance_default()
{
print_example_banner("BFV Performance Test with Degrees: 4096, 8192, and 16384");
EncryptionParameters parms(scheme_type::bfv);
size_t poly_modulus_degree = 4096;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bfv_performance_test(parms);
cout << endl;
poly_modulus_degree = 8192;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bfv_performance_test(parms);
cout << endl;
poly_modulus_degree = 16384;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bfv_performance_test(parms);
}
void example_bfv_performance_custom()
{
size_t poly_modulus_degree = 0;
cout << endl << "Set poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): ";
if (!(cin >> poly_modulus_degree))
{
cout << "Invalid option." << endl;
cin.clear();
cin.ignore(numeric_limits<streamsize>::max(), '\n');
return;
}
if (poly_modulus_degree < 1024 || poly_modulus_degree > 32768 ||
(poly_modulus_degree & (poly_modulus_degree - 1)) != 0)
{
cout << "Invalid option." << endl;
return;
}
string banner = "BFV Performance Test with Degree: ";
print_example_banner(banner + to_string(poly_modulus_degree));
EncryptionParameters parms(scheme_type::bfv);
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
if (poly_modulus_degree == 1024)
{
parms.set_plain_modulus(12289);
}
else
{
parms.set_plain_modulus(786433);
}
bfv_performance_test(parms);
}
void example_ckks_performance_default()
{
print_example_banner("CKKS Performance Test with Degrees: 4096, 8192, and 16384");
EncryptionParameters parms(scheme_type::ckks);
size_t poly_modulus_degree = 4096;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
ckks_performance_test(parms);
cout << endl;
poly_modulus_degree = 8192;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
ckks_performance_test(parms);
cout << endl;
poly_modulus_degree = 16384;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
ckks_performance_test(parms);
}
void example_ckks_performance_custom()
{
size_t poly_modulus_degree = 0;
cout << endl << "Set poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): ";
if (!(cin >> poly_modulus_degree))
{
cout << "Invalid option." << endl;
cin.clear();
cin.ignore(numeric_limits<streamsize>::max(), '\n');
return;
}
if (poly_modulus_degree < 1024 || poly_modulus_degree > 32768 ||
(poly_modulus_degree & (poly_modulus_degree - 1)) != 0)
{
cout << "Invalid option." << endl;
return;
}
string banner = "CKKS Performance Test with Degree: ";
print_example_banner(banner + to_string(poly_modulus_degree));
EncryptionParameters parms(scheme_type::ckks);
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
ckks_performance_test(parms);
}
void example_bgv_performance_default()
{
print_example_banner("BGV Performance Test with Degrees: 4096, 8192, and 16384");
EncryptionParameters parms(scheme_type::bgv);
size_t poly_modulus_degree = 4096;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bgv_performance_test(parms);
cout << endl;
poly_modulus_degree = 8192;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bgv_performance_test(parms);
cout << endl;
poly_modulus_degree = 16384;
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
parms.set_plain_modulus(786433);
bgv_performance_test(parms);
}
void example_bgv_performance_custom()
{
size_t poly_modulus_degree = 0;
cout << endl << "Set poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): ";
if (!(cin >> poly_modulus_degree))
{
cout << "Invalid option." << endl;
cin.clear();
cin.ignore(numeric_limits<streamsize>::max(), '\n');
return;
}
if (poly_modulus_degree < 1024 || poly_modulus_degree > 32768 ||
(poly_modulus_degree & (poly_modulus_degree - 1)) != 0)
{
cout << "Invalid option." << endl;
return;
}
string banner = "BGV Performance Test with Degree: ";
print_example_banner(banner + to_string(poly_modulus_degree));
EncryptionParameters parms(scheme_type::bgv);
parms.set_poly_modulus_degree(poly_modulus_degree);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree));
if (poly_modulus_degree == 1024)
{
parms.set_plain_modulus(12289);
}
else
{
parms.set_plain_modulus(786433);
}
bgv_performance_test(parms);
}
void example_performance_test()
{
print_example_banner("Example: Performance Test");
while (true)
{
cout << endl;
cout << "Select a scheme (and optionally poly_modulus_degree):" << endl;
cout << " 1. BFV with default degrees" << endl;
cout << " 2. BFV with a custom degree" << endl;
cout << " 3. CKKS with default degrees" << endl;
cout << " 4. CKKS with a custom degree" << endl;
cout << " 5. BGV with default degrees" << endl;
cout << " 6. BGV with a custom degree" << endl;
cout << " 0. Back to main menu" << endl;
int selection = 0;
cout << endl << "> Run performance test (1 ~ 6) or go back (0): ";
if (!(cin >> selection))
{
cout << "Invalid option." << endl;
cin.clear();
cin.ignore(numeric_limits<streamsize>::max(), '\n');
continue;
}
switch (selection)
{
case 1:
example_bfv_performance_default();
break;
case 2:
example_bfv_performance_custom();
break;
case 3:
example_ckks_performance_default();
break;
case 4:
example_ckks_performance_custom();
break;
case 5:
example_bgv_performance_default();
break;
case 6:
example_bgv_performance_custom();
break;
case 0:
cout << endl;
return;
default:
cout << "Invalid option." << endl;
}
}
}