1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
// Implementation of MulPIR using the `fhe` crate.
//
// SealPIR is a Private Information Retrieval scheme that enables a client to
// retrieve a row from a database without revealing the index to the server.
// SealPIR is described in <https://eprint.iacr.org/2019/1483>.
// We use the same parameters as in the paper to enable an apple-to-apple
// comparison.
mod pir;
mod util;
use clap::Parser;
use fhe::bfv;
use fhe_traits::{
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize,
};
use fhe_util::{inverse, transcode_to_bytes};
use indicatif::HumanBytes;
use rand::{rng, RngCore};
use std::{error::Error, time::Instant};
use util::{
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
};
fn main() -> Result<(), Box<dyn Error>> {
// We use the parameters reported in Table 1 of https://eprint.iacr.org/2019/1483.pdf.
let degree = 8192;
let plaintext_modulus: u64 = (1 << 20) + (1 << 19) + (1 << 17) + (1 << 16) + (1 << 14) + 1;
let moduli_sizes = [50, 55, 55];
let args = pir::Cli::parse();
let database_size = args.database_size;
let elements_size = args.element_size;
let mut rng = rng();
// Compute what is the maximum byte-length of an element to fit within one
// ciphertext. Each coefficient of the ciphertext polynomial can contain
// floor(log2(plaintext_modulus)) bits.
let max_element_size = ((plaintext_modulus.ilog2() as usize) * degree) / 8;
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
log::error!("Invalid parameters: database_size = {database_size}, elements_size = {elements_size}. The maximum element size if {max_element_size}.");
clap::Error::new(clap::error::ErrorKind::InvalidValue).exit();
}
// The parameters are within bound, let's go! Let's first display some
// information about the database.
println!("# MulPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
// Generation of a random database.
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
// Let's generate the BFV parameters structure.
let params = timeit!(
"Parameters generation",
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build_arc()?
);
// Proprocess the database on the server side: the database will be reshaped
// so as to pack as many values as possible in every row so that it fits in one
// ciphertext, and each element will be encoded as a polynomial in Ntt
// representation.
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
// Client setup: the client generates a secret key, an evaluation key for
// the server will which enable to obliviously expand a ciphertext up to (dim1 +
// dim2) values, i.e. with expansion level ceil(log2(dim1 + dim2)), and a
// relinearization key.
let (sk, ek_expansion_serialized, rk_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(¶ms, &mut rng);
let level = (dim1 + dim2).next_power_of_two().ilog2() as usize;
println!("level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut rng)?;
let rk = bfv::RelinearizationKey::new_leveled(&sk, 1, 1, &mut rng)?;
let ek_expansion_serialized = ek_expansion.to_bytes();
let rk_serialized = rk.to_bytes();
(sk, ek_expansion_serialized, rk_serialized)
});
println!(
"📄 Evaluation key (expansion): {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
println!(
"📄 Relinearization key: {}",
HumanBytes(rk_serialized.len() as u64)
);
// Server setup: the server receives the evaluation and relinearization keys and
// deserializes them.
let (ek_expansion, rk) = timeit!("Server setup", {
(
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?,
bfv::RelinearizationKey::from_bytes(&rk_serialized, ¶ms)?,
)
});
// Client query: when the client wants to retrieve the `index`-th row of the
// original database, it first computes to which row it corresponds in the
// original database, and then encrypt a selection vector with 0 everywhere,
// except at two indices i and (dim1 + j) such that `query_index = i * dim 2 +
// j` where it sets the value (2^level)^(-1) modulo the plaintext space.
// It then encodes this vector as a `polynomial` and encrypt the plaintext.
// The ciphertext is set at level `1`, which means that one of the three moduli
// has been dropped already; the reason is that the expansion will happen at
// level 0 (with all three moduli) and then one of the moduli will be dropped
// to reduce the noise.
let index = (rng.next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = (dim1 + dim2).next_power_of_two().ilog2();
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
plaintext_modulus.ilog2() as usize,
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).ok_or("No inverse")?;
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut rng)?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
// Server response: The server receives the query, and after deserializing it,
// performs the following steps:
// 1- It expands the query ciphertext into `dim1 + dim2` ciphertexts.
// If the client created the query correctly, the server will have obtained
// `dim1 + dim2` ciphertexts all encrypting `0`, expect the `i`th and
// `dim1 + j`th ones encrypting `1`.
// 2- It computes the inner product of the first `dim1` ciphertexts with the
// columns if the database viewed as a dim1 * dim2 matrix.
// 3- It then multiplies the column of ciphertexts with the next `dim2`
// ciphertexts obtained after expansion of the query, then relinearize and
// modulus switch to the latest modulus to optimize communication.
// The operation is done `5` times to compute an average response time.
let response = timeit_n!("Server response", 5, {
let start = Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?;
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {:?}", start.elapsed());
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch =
move |i, database: &[bfv::Plaintext]| -> fhe::Result<bfv::Ciphertext> {
let column = database.iter().skip(i).step_by(dim2);
bfv::dot_product_scalar(query_vec.iter(), column)
};
let mut out = bfv::Ciphertext::zero(¶ms);
for (i, ci) in expanded_query[dim1..].iter().enumerate() {
out += &(&dot_product_mod_switch(i, &preprocessed_database)? * ci)
}
rk.relinearizes(&mut out)?;
out.switch_to_level(out.max_switchable_level())?;
out.to_bytes()
});
println!("📄 Response: {}", HumanBytes(response.len() as u64));
// Client processing: Upon reception of the response, the client decrypts.
// Finally, it outputs the plaintext bytes, offset by the correct value
// (remember the database was reshaped to maximize how many elements) were
// embedded in a single ciphertext.
let answer = timeit!("Client answer", {
let response = bfv::Ciphertext::from_bytes(&response, ¶ms)?;
let pt = sk.try_decrypt(&response)?;
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2))?;
let plaintext = transcode_to_bytes(&pt, plaintext_modulus.ilog2() as usize);
let offset = index
% number_elements_per_plaintext(
params.degree(),
plaintext_modulus.ilog2() as usize,
elements_size,
);
println!("Noise in response: {:?}", unsafe {
sk.measure_noise(&response)
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
assert_eq!(&database[index], &answer);
Ok(())
}