chalametpir_client/
client.rs

1use chalametpir_common::{
2    binary_fuse_filter::{self, BinaryFuseFilter},
3    branch_opt_util,
4    error::ChalametPIRError,
5    matrix::Matrix,
6    params::{HASHED_KEY_BYTE_LEN, LWE_DIMENSION, SEED_BYTE_LEN},
7    serialization,
8};
9use std::collections::HashMap;
10
11/// Represents a PIR query. This struct is used to store secret vector `c`, which is used to recover value from PIR server response.
12#[derive(Clone)]
13pub struct Query {
14    vec_c: Matrix,
15}
16
17/// Represents a client, performing Chalamet Private Information Retrieval (PIR) queries.
18///
19/// This struct holds the necessary data and methods for setting up a PIR client, generating PIR queries, processing PIR server response, privately fetching value associated with queried key.
20#[derive(Clone)]
21pub struct Client {
22    pub_mat_a: Matrix,
23    hint_mat_m: Matrix,
24    filter: BinaryFuseFilter,
25    pending_queries: HashMap<Vec<u8>, Query>,
26}
27
28impl Client {
29    /// Sets up a new keyword **P**rivate **I**nformation **R**etrieval client instance.
30    ///
31    /// This function initializes a client object with the necessary parameters for performing private information retrieval (PIR) queries.
32    /// It takes as input:
33    ///
34    /// * `seed_μ`: A byte array representing the seed for generating the public matrix A.  The length is determined by `SEED_BYTE_LEN`.
35    /// * `hint_bytes`: A byte array representing the hint matrix M. This matrix is used to help reconstruct the result of the PIR query.
36    /// * `filter_param_bytes`: A byte array containing the parameters for the underlying binary fuse filter in-use.
37    ///
38    /// Errors can occur if the `BinaryFuseFilter` cannot be constructed from the provided bytes, or if matrix generation fails.  These errors will result in a `ChalametPIRError` being returned.
39    pub fn setup(seed_μ: &[u8; SEED_BYTE_LEN], hint_bytes: &[u8], filter_param_bytes: &[u8]) -> Result<Client, ChalametPIRError> {
40        let filter = BinaryFuseFilter::from_bytes(filter_param_bytes)?;
41
42        let pub_mat_a_num_rows = LWE_DIMENSION;
43        let pub_mat_a_num_cols = filter.num_fingerprints as u32;
44
45        let pub_mat_a = Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ)?;
46        let hint_mat_m = Matrix::from_bytes(hint_bytes)?;
47        if branch_opt_util::unlikely(hint_mat_m.num_rows() != LWE_DIMENSION) {
48            return Err(ChalametPIRError::InvalidHintMatrix);
49        }
50
51        Ok(Client {
52            pub_mat_a,
53            hint_mat_m,
54            filter,
55            pending_queries: HashMap::new(),
56        })
57    }
58
59    /// Used only for benchmarking. You are not supposed to use this.
60    #[cfg(feature = "mutate_internal_client_state")]
61    #[inline(always)]
62    pub fn discard_query(&mut self, key: &[u8]) -> Option<Query> {
63        self.pending_queries.remove(key)
64    }
65
66    /// Used only for benchmarking. You are not supposed to use this.
67    #[cfg(feature = "mutate_internal_client_state")]
68    #[inline(always)]
69    pub fn insert_query(&mut self, key: &[u8], query: Query) {
70        self.pending_queries.insert(key.to_vec(), query);
71    }
72
73    /// Generates a PIR query for the specified key.
74    ///
75    /// The query is added to the client's pending queries, awaiting a response. If a query for the same key already exists, this function returns an error.
76    ///
77    /// # Arguments
78    ///
79    /// * `key`: The key to query.
80    ///
81    /// # Returns
82    ///
83    /// `Result<Vec<u8>, ChalametPIRError>` containing the query bytes if successful, or an error if a query for the same key already exists or if arithmetic overflow occurs during query generation.
84    pub fn query(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
85        match self.filter.arity {
86            3 => self.query_for_3_wise_xor_filter(key),
87            4 => self.query_for_4_wise_xor_filter(key),
88            _ => {
89                branch_opt_util::cold();
90                Err(ChalametPIRError::UnsupportedArityForBinaryFuseFilter)
91            }
92        }
93    }
94
95    fn query_for_3_wise_xor_filter(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
96        if branch_opt_util::unlikely(self.pending_queries.contains_key(key)) {
97            return Err(ChalametPIRError::PendingQueryExistsForKey);
98        }
99
100        let secret_vec_num_cols = LWE_DIMENSION;
101        let secret_vec_s = unsafe { Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols).unwrap_unchecked() };
102
103        let error_vector_num_cols = self.pub_mat_a.num_cols();
104        let error_vec_e = unsafe { Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols).unwrap_unchecked() };
105
106        let mut query_vec_b = unsafe { ((&secret_vec_s * &self.pub_mat_a).unwrap_unchecked() + error_vec_e).unwrap_unchecked() };
107        let secret_vec_c = unsafe { (&secret_vec_s * &self.hint_mat_m).unwrap_unchecked() };
108
109        let hashed_key = binary_fuse_filter::hash_of_key(key);
110        let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
111        let (h0, h1, h2) = binary_fuse_filter::hash_batch_for_3_wise_xor_filter(hash, self.filter.segment_length, self.filter.segment_count_length);
112
113        let query_indicator = self.calculate_query_indicator();
114
115        let (added_val, flag) = query_vec_b[(0, h0 as usize)].overflowing_add(query_indicator);
116        if branch_opt_util::unlikely(flag) {
117            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
118        } else {
119            query_vec_b[(0, h0 as usize)] = added_val;
120        }
121
122        let (added_val, flag) = query_vec_b[(0, h1 as usize)].overflowing_add(query_indicator);
123        if branch_opt_util::unlikely(flag) {
124            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
125        } else {
126            query_vec_b[(0, h1 as usize)] = added_val;
127        }
128
129        let (added_val, flag) = query_vec_b[(0, h2 as usize)].overflowing_add(query_indicator);
130        if branch_opt_util::unlikely(flag) {
131            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
132        } else {
133            query_vec_b[(0, h2 as usize)] = added_val;
134        }
135
136        let query_bytes = query_vec_b.to_bytes();
137        self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });
138
139        Ok(query_bytes)
140    }
141
142    fn query_for_4_wise_xor_filter(&mut self, key: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
143        if branch_opt_util::unlikely(self.pending_queries.contains_key(key)) {
144            return Err(ChalametPIRError::PendingQueryExistsForKey);
145        }
146
147        let secret_vec_num_cols = LWE_DIMENSION;
148        let secret_vec_s = unsafe { Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols).unwrap_unchecked() };
149
150        let error_vector_num_cols = self.pub_mat_a.num_cols();
151        let error_vec_e = unsafe { Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols).unwrap_unchecked() };
152
153        let mut query_vec_b = unsafe { ((&secret_vec_s * &self.pub_mat_a).unwrap_unchecked() + error_vec_e).unwrap_unchecked() };
154        let secret_vec_c = unsafe { (&secret_vec_s * &self.hint_mat_m).unwrap_unchecked() };
155
156        let hashed_key = binary_fuse_filter::hash_of_key(key);
157        let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
158        let (h0, h1, h2, h3) = binary_fuse_filter::hash_batch_for_4_wise_xor_filter(hash, self.filter.segment_length, self.filter.segment_count_length);
159
160        let query_indicator = self.calculate_query_indicator();
161
162        let (added_val, flag) = query_vec_b[(0, h0 as usize)].overflowing_add(query_indicator);
163        if branch_opt_util::unlikely(flag) {
164            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
165        } else {
166            query_vec_b[(0, h0 as usize)] = added_val;
167        }
168
169        let (added_val, flag) = query_vec_b[(0, h1 as usize)].overflowing_add(query_indicator);
170        if branch_opt_util::unlikely(flag) {
171            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
172        } else {
173            query_vec_b[(0, h1 as usize)] = added_val;
174        }
175
176        let (added_val, flag) = query_vec_b[(0, h2 as usize)].overflowing_add(query_indicator);
177        if branch_opt_util::unlikely(flag) {
178            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
179        } else {
180            query_vec_b[(0, h2 as usize)] = added_val;
181        }
182
183        let (added_val, flag) = query_vec_b[(0, h3 as usize)].overflowing_add(query_indicator);
184        if branch_opt_util::unlikely(flag) {
185            return Err(ChalametPIRError::ArithmeticOverflowAddingQueryIndicator);
186        } else {
187            query_vec_b[(0, h3 as usize)] = added_val;
188        }
189
190        let query_bytes = query_vec_b.to_bytes();
191        self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });
192
193        Ok(query_bytes)
194    }
195
196    /// Processes a response to a PIR query.
197    ///
198    /// This function takes the key associated with a pending query and the received response bytes as input.
199    /// It reconstructs the original data from the response, removes the query from the pending queries, and returns the result.
200    ///
201    /// # Arguments
202    ///
203    /// * `key`: The key associated with the query.
204    /// * `response_bytes`: The bytes received as a response to the query.
205    ///
206    /// # Returns
207    ///
208    /// `Result<Vec<u8>, ChalametPIRError>` containing the retrieved data if successful, or an error if the response vector has an unexpected dimension, if decoding fails, or if the query is not found in `pending_queries`.
209    pub fn process_response(&mut self, key: &[u8], response_bytes: &[u8]) -> Result<Vec<u8>, ChalametPIRError> {
210        match self.pending_queries.get(key) {
211            Some(query) => {
212                let secret_vec_c = &query.vec_c;
213
214                let response_vector = Matrix::from_bytes(response_bytes)?;
215                if branch_opt_util::unlikely(!(response_vector.num_rows() == 1 && response_vector.num_cols() == secret_vec_c.num_cols())) {
216                    return Err(ChalametPIRError::InvalidResponseVector);
217                }
218
219                let rounding_factor = self.calculate_query_indicator();
220                let rounding_floor = rounding_factor / 2;
221                let mat_elem_mask = (1u32 << self.filter.mat_elem_bit_len) - 1;
222
223                let hashed_key = binary_fuse_filter::hash_of_key(key);
224                let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);
225
226                let recovered_row = (0..response_vector.num_cols() as usize)
227                    .map(|idx| {
228                        let unscaled_res = response_vector[(0, idx)].wrapping_sub(secret_vec_c[(0, idx)]);
229
230                        let scaled_res = unscaled_res / rounding_factor;
231                        let scaled_rem = unscaled_res % rounding_factor;
232
233                        let mut rounded_res = scaled_res;
234                        if scaled_rem > rounding_floor {
235                            rounded_res += 1;
236                        }
237
238                        let masked = rounded_res & mat_elem_mask;
239                        masked.wrapping_add(binary_fuse_filter::mix(hash, idx as u64) as u32) & mat_elem_mask
240                    })
241                    .collect::<Vec<u32>>();
242
243                let value = match serialization::decode_kv_from_row(&recovered_row, self.filter.mat_elem_bit_len) {
244                    Ok(mut decoded_kv) => {
245                        let mut hashed_key_as_bytes = [0u8; HASHED_KEY_BYTE_LEN];
246
247                        hashed_key_as_bytes[..8].copy_from_slice(&hashed_key[0].to_le_bytes());
248                        hashed_key_as_bytes[8..16].copy_from_slice(&hashed_key[1].to_le_bytes());
249                        hashed_key_as_bytes[16..24].copy_from_slice(&hashed_key[2].to_le_bytes());
250                        hashed_key_as_bytes[24..].copy_from_slice(&hashed_key[3].to_le_bytes());
251
252                        let is_key_matching = (0..hashed_key_as_bytes.len()).fold(0u8, |acc, idx| acc ^ (decoded_kv[idx] ^ hashed_key_as_bytes[idx])) == 0;
253
254                        if branch_opt_util::likely(is_key_matching) {
255                            decoded_kv.drain(..hashed_key_as_bytes.len());
256                            Ok(decoded_kv)
257                        } else {
258                            Err(ChalametPIRError::DecodedRowNotPrependedWithDigestOfKey)
259                        }
260                    }
261                    Err(e) => {
262                        branch_opt_util::cold();
263                        Err(e)
264                    }
265                };
266
267                self.pending_queries.remove(key);
268                value
269            }
270            None => {
271                branch_opt_util::cold();
272                Err(ChalametPIRError::PendingQueryDoesNotExistForKey)
273            }
274        }
275    }
276
277    const fn calculate_query_indicator(&self) -> u32 {
278        const MODULUS: u64 = u32::MAX as u64 + 1;
279        let plaintext_modulo = 1u64 << self.filter.mat_elem_bit_len;
280
281        (MODULUS / plaintext_modulo) as u32
282    }
283}