chalametpir_client/
client.rs1use chalametpir_common::{
2 binary_fuse_filter::{self, BinaryFuseFilter},
3 branch_opt_util,
4 error::ChalametPIRError,
5 matrix::Matrix,
6 params::{HASHED_KEY_BYTE_LEN, LWE_DIMENSION, SEED_BYTE_LEN},
7 serialization,
8};
9use std::collections::HashMap;
10
11#[derive(Clone)]
13pub struct Query {
14 vec_c: Matrix,
15}
16
17#[derive(Clone)]
21pub struct Client {
22 pub_mat_a: Matrix,
23 hint_mat_m: Matrix,
24 filter: BinaryFuseFilter,
25 pending_queries: HashMap<Vec<u8>, Query>,
26}
27
28impl Client {
29 pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], hint_bytes: &[u8], filter_param_bytes: &[u8]) -> Result<Client, ChalametPIRError> {
40 let filter = BinaryFuseFilter::from_bytes(filter_param_bytes)?;
41
42 let pub_mat_a_num_rows = LWE_DIMENSION;
43 let pub_mat_a_num_cols = filter.num_fingerprints as u32;
44
45 let pub_mat_a = Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ)?;
46 let hint_mat_m = Matrix::from_bytes(hint_bytes)?;
47 if branch_opt_util::unlikely(hint_mat_m.num_rows() != LWE_DIMENSION) {
48 return Err(ChalametPIRError::InvalidHintMatrix);
49 }
50
51 Ok(Client {
52 pub_mat_a,
53 hint_mat_m,
54 filter,
55 pending_queries: HashMap::new(),
56 })
57 }
58
59 #[cfg(feature = "mutate_internal_client_state")]
61 #[inline(always)]
62 pub fn discard_query(&mut self, key: &[u8]) -> Option<Query> {
63 self.pending_queries.remove(key)
64 }
65
66 #[cfg(feature = "mutate_internal_client_state")]
68 #[inline(always)]
69 pub fn insert_query(&mut self, key: &[u8], query: Query) {
70 self.pending_queries.insert(key.to_vec(), query);
71 }
72
73 pub fn query(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
85 match self.filter.arity {
86 3 => self.query_for_3_wise_xor_filter(key),
87 4 => self.query_for_4_wise_xor_filter(key),
88 _ => {
89 branch_opt_util::cold();
90 Err(ChalametPIRError::UnsupportedArityForBinaryFuseFilter)
91 }
92 }
93 }
94
95 fn query_for_3_wise_xor_filter(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
96 if branch_opt_util::unlikely(self.pending_queries.contains_key(key)) {
97 return Err(ChalametPIRError::PendingQueryExistsForKey);
98 }
99
100 let secret_vec_num_cols = LWE_DIMENSION;
101 let secret_vec_s = unsafe { Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols).unwrap_unchecked() };
102
103 let error_vector_num_cols = self.pub_mat_a.num_cols();
104 let error_vec_e = unsafe { Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols).unwrap_unchecked() };
105
106 let mut query_vec_b = unsafe { ((&secret_vec_s * &self.pub_mat_a).unwrap_unchecked() + error_vec_e).unwrap_unchecked() };
107 let secret_vec_c = unsafe { (&secret_vec_s * &self.hint_mat_m).unwrap_unchecked() };
108
109 let hashed_key = binary_fuse_filter::hash_of_key(key);
110 let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
111 let (h0, h1, h2) = binary_fuse_filter::hash_batch_for_3_wise_xor_filter(hash, self.filter.segment_length, self.filter.segment_count_length);
112
113 let query_indicator = self.calculate_query_indicator();
114
115 let (added_val, flag) = query_vec_b[(0, h0 as usize)].overflowing_add(query_indicator);
116 if branch_opt_util::unlikely(flag) {
117 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
118 } else {
119 query_vec_b[(0, h0 as usize)] = added_val;
120 }
121
122 let (added_val, flag) = query_vec_b[(0, h1 as usize)].overflowing_add(query_indicator);
123 if branch_opt_util::unlikely(flag) {
124 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
125 } else {
126 query_vec_b[(0, h1 as usize)] = added_val;
127 }
128
129 let (added_val, flag) = query_vec_b[(0, h2 as usize)].overflowing_add(query_indicator);
130 if branch_opt_util::unlikely(flag) {
131 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
132 } else {
133 query_vec_b[(0, h2 as usize)] = added_val;
134 }
135
136 let query_bytes = query_vec_b.to_bytes();
137 self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });
138
139 Ok(query_bytes)
140 }
141
142 fn query_for_4_wise_xor_filter(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
143 if branch_opt_util::unlikely(self.pending_queries.contains_key(key)) {
144 return Err(ChalametPIRError::PendingQueryExistsForKey);
145 }
146
147 let secret_vec_num_cols = LWE_DIMENSION;
148 let secret_vec_s = unsafe { Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols).unwrap_unchecked() };
149
150 let error_vector_num_cols = self.pub_mat_a.num_cols();
151 let error_vec_e = unsafe { Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols).unwrap_unchecked() };
152
153 let mut query_vec_b = unsafe { ((&secret_vec_s * &self.pub_mat_a).unwrap_unchecked() + error_vec_e).unwrap_unchecked() };
154 let secret_vec_c = unsafe { (&secret_vec_s * &self.hint_mat_m).unwrap_unchecked() };
155
156 let hashed_key = binary_fuse_filter::hash_of_key(key);
157 let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
158 let (h0, h1, h2, h3) = binary_fuse_filter::hash_batch_for_4_wise_xor_filter(hash, self.filter.segment_length, self.filter.segment_count_length);
159
160 let query_indicator = self.calculate_query_indicator();
161
162 let (added_val, flag) = query_vec_b[(0, h0 as usize)].overflowing_add(query_indicator);
163 if branch_opt_util::unlikely(flag) {
164 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
165 } else {
166 query_vec_b[(0, h0 as usize)] = added_val;
167 }
168
169 let (added_val, flag) = query_vec_b[(0, h1 as usize)].overflowing_add(query_indicator);
170 if branch_opt_util::unlikely(flag) {
171 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
172 } else {
173 query_vec_b[(0, h1 as usize)] = added_val;
174 }
175
176 let (added_val, flag) = query_vec_b[(0, h2 as usize)].overflowing_add(query_indicator);
177 if branch_opt_util::unlikely(flag) {
178 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
179 } else {
180 query_vec_b[(0, h2 as usize)] = added_val;
181 }
182
183 let (added_val, flag) = query_vec_b[(0, h3 as usize)].overflowing_add(query_indicator);
184 if branch_opt_util::unlikely(flag) {
185 return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
186 } else {
187 query_vec_b[(0, h3 as usize)] = added_val;
188 }
189
190 let query_bytes = query_vec_b.to_bytes();
191 self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });
192
193 Ok(query_bytes)
194 }
195
196 pub fn process_response(&mut self, key: &[u8], response_bytes: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
210 match self.pending_queries.get(key) {
211 Some(query) => {
212 let secret_vec_c = &query.vec_c;
213
214 let response_vector = Matrix::from_bytes(response_bytes)?;
215 if branch_opt_util::unlikely(!(response_vector.num_rows() == 1 && response_vector.num_cols() == secret_vec_c.num_cols())) {
216 return Err(ChalametPIRError::InvalidResponseVector);
217 }
218
219 let rounding_factor = self.calculate_query_indicator();
220 let rounding_floor = rounding_factor / 2;
221 let mat_elem_mask = (1u32 << self.filter.mat_elem_bit_len) - 1;
222
223 let hashed_key = binary_fuse_filter::hash_of_key(key);
224 let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
225
226 let recovered_row = (0..response_vector.num_cols() as usize)
227 .map(|idx| {
228 let unscaled_res = response_vector[(0, idx)].wrapping_sub(secret_vec_c[(0, idx)]);
229
230 let scaled_res = unscaled_res / rounding_factor;
231 let scaled_rem = unscaled_res % rounding_factor;
232
233 let mut rounded_res = scaled_res;
234 if scaled_rem > rounding_floor {
235 rounded_res += 1;
236 }
237
238 let masked = rounded_res & mat_elem_mask;
239 masked.wrapping_add(binary_fuse_filter::mix(hash, idx as u64) as u32) & mat_elem_mask
240 })
241 .collect::<Vec<u32>>();
242
243 let value = match serialization::decode_kv_from_row(&recovered_row, self.filter.mat_elem_bit_len) {
244 Ok(mut decoded_kv) => {
245 let mut hashed_key_as_bytes = [0u8; HASHED_KEY_BYTE_LEN];
246
247 hashed_key_as_bytes[..8].copy_from_slice(&hashed_key[0].to_le_bytes());
248 hashed_key_as_bytes[8..16].copy_from_slice(&hashed_key[1].to_le_bytes());
249 hashed_key_as_bytes[16..24].copy_from_slice(&hashed_key[2].to_le_bytes());
250 hashed_key_as_bytes[24..].copy_from_slice(&hashed_key[3].to_le_bytes());
251
252 let is_key_matching = (0..hashed_key_as_bytes.len()).fold(0u8, |acc, idx| acc ^ (decoded_kv[idx] ^ hashed_key_as_bytes[idx])) == 0;
253
254 if branch_opt_util::likely(is_key_matching) {
255 decoded_kv.drain(..hashed_key_as_bytes.len());
256 Ok(decoded_kv)
257 } else {
258 Err(ChalametPIRError::DecodedRowNotPrependedWithDigestOfKey)
259 }
260 }
261 Err(e) => {
262 branch_opt_util::cold();
263 Err(e)
264 }
265 };
266
267 self.pending_queries.remove(key);
268 value
269 }
270 None => {
271 branch_opt_util::cold();
272 Err(ChalametPIRError::PendingQueryDoesNotExistForKey)
273 }
274 }
275 }
276
277 const fn calculate_query_indicator(&self) -> u32 {
278 const MODULUS: u64 = u32::MAX as u64 + 1;
279 let plaintext_modulo = 1u64 << self.filter.mat_elem_bit_len;
280
281 (MODULUS / plaintext_modulo) as u32
282 }
283}