1use crate::hash::hash;
48use blake3::Hasher;
49use rand::RngCore;
50use serde::{Deserialize, Serialize};
51use std::collections::HashSet;
52use thiserror::Error;
53
54#[derive(Error, Debug)]
55pub enum PsiError {
56 #[error("Invalid PSI message")]
57 InvalidMessage,
58 #[error("Serialization error: {0}")]
59 Serialization(String),
60 #[error("Empty set provided")]
61 EmptySet,
62}
63
64pub type PsiResult<T> = Result<T, PsiError>;
65
66#[derive(Clone, Serialize, Deserialize)]
68pub struct PsiServerMessage {
69 hashed_elements: Vec<Vec<u8>>,
71 key_commitment: Vec<u8>,
73}
74
75pub struct PsiServer {
77 secret_key: [u8; 32],
78}
79
80impl PsiServer {
81 pub fn new() -> Self {
83 let mut secret_key = [0u8; 32];
84 rand::rngs::OsRng.fill_bytes(&mut secret_key);
85 Self { secret_key }
86 }
87
88 pub fn with_key(key: [u8; 32]) -> Self {
90 Self { secret_key: key }
91 }
92
93 pub fn encode_set(&self, elements: &[Vec<u8>]) -> PsiServerMessage {
95 let hashed_elements = elements
96 .iter()
97 .map(|elem| self.hash_element(elem))
98 .collect();
99
100 let key_commitment = hash(&self.secret_key).to_vec();
102
103 PsiServerMessage {
104 hashed_elements,
105 key_commitment,
106 }
107 }
108
109 fn hash_element(&self, element: &[u8]) -> Vec<u8> {
111 let mut hasher = Hasher::new();
112 hasher.update(&self.secret_key);
113 hasher.update(element);
114 hasher.finalize().as_bytes().to_vec()
115 }
116
117 pub fn secret_key(&self) -> &[u8; 32] {
119 &self.secret_key
120 }
121}
122
123impl Default for PsiServer {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129pub struct PsiClient {
131 #[allow(dead_code)]
132 secret_key: [u8; 32],
133}
134
135impl PsiClient {
136 pub fn new() -> Self {
138 let mut secret_key = [0u8; 32];
139 rand::rngs::OsRng.fill_bytes(&mut secret_key);
140 Self { secret_key }
141 }
142
143 pub fn with_key(key: [u8; 32]) -> Self {
145 Self { secret_key: key }
146 }
147
148 pub fn compute_intersection(
150 &self,
151 client_elements: &[Vec<u8>],
152 server_msg: &PsiServerMessage,
153 ) -> PsiResult<Vec<Vec<u8>>> {
154 if client_elements.is_empty() {
155 return Ok(Vec::new());
156 }
157
158 let server_set: HashSet<&[u8]> = server_msg
160 .hashed_elements
161 .iter()
162 .map(|v| v.as_slice())
163 .collect();
164
165 let mut intersection = Vec::new();
167 for elem in client_elements {
168 let hashed = self.hash_element(elem, &server_msg.key_commitment);
169 if server_set.contains(hashed.as_slice()) {
170 intersection.push(elem.clone());
171 }
172 }
173
174 Ok(intersection)
175 }
176
177 fn hash_element(&self, element: &[u8], server_key_commitment: &[u8]) -> Vec<u8> {
179 let mut hasher = Hasher::new();
182 hasher.update(server_key_commitment);
183 hasher.update(element);
184 hasher.finalize().as_bytes().to_vec()
185 }
186}
187
188impl Default for PsiClient {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[derive(Clone, Serialize, Deserialize)]
196pub struct BloomPsiMessage {
197 filter: Vec<u8>,
199 num_hashes: usize,
201 filter_size: usize,
203}
204
205pub struct BloomPsiServer {
207 num_hashes: usize,
208 filter_size: usize,
209}
210
211impl BloomPsiServer {
212 pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
218 let filter_size = Self::optimal_filter_size(expected_items, false_positive_rate);
220 let num_hashes = Self::optimal_num_hashes(filter_size, expected_items);
221
222 Self {
223 num_hashes,
224 filter_size,
225 }
226 }
227
228 fn optimal_filter_size(n: usize, p: f64) -> usize {
230 let ln2_squared = std::f64::consts::LN_2 * std::f64::consts::LN_2;
231 (-(n as f64 * p.ln()) / ln2_squared).ceil() as usize
232 }
233
234 fn optimal_num_hashes(m: usize, n: usize) -> usize {
236 ((m as f64 / n as f64) * std::f64::consts::LN_2).ceil() as usize
237 }
238
239 pub fn encode_set(&self, elements: &[Vec<u8>]) -> BloomPsiMessage {
241 let filter_bytes = self.filter_size.div_ceil(8);
242 let mut filter = vec![0u8; filter_bytes];
243
244 for elem in elements {
245 let indices = self.hash_element(elem);
246 for idx in indices {
247 let byte_idx = idx / 8;
248 let bit_idx = idx % 8;
249 filter[byte_idx] |= 1 << bit_idx;
250 }
251 }
252
253 BloomPsiMessage {
254 filter,
255 num_hashes: self.num_hashes,
256 filter_size: self.filter_size,
257 }
258 }
259
260 fn hash_element(&self, element: &[u8]) -> Vec<usize> {
262 let mut indices = Vec::with_capacity(self.num_hashes);
263 let base_hash = hash(element);
264
265 for i in 0..self.num_hashes {
266 let mut hasher = Hasher::new();
267 hasher.update(&base_hash);
268 hasher.update(&(i as u64).to_le_bytes());
269 let hash_val = hasher.finalize();
270 let idx = u64::from_le_bytes(hash_val.as_bytes()[0..8].try_into().unwrap()) as usize;
271 indices.push(idx % self.filter_size);
272 }
273
274 indices
275 }
276}
277
278pub struct BloomPsiClient;
280
281impl BloomPsiClient {
282 pub fn new() -> Self {
284 Self
285 }
286
287 pub fn compute_intersection(
289 &self,
290 client_elements: &[Vec<u8>],
291 bloom_msg: &BloomPsiMessage,
292 ) -> PsiResult<Vec<Vec<u8>>> {
293 let mut intersection = Vec::new();
294
295 for elem in client_elements {
296 if self.check_membership(elem, bloom_msg) {
297 intersection.push(elem.clone());
298 }
299 }
300
301 Ok(intersection)
302 }
303
304 fn check_membership(&self, element: &[u8], bloom_msg: &BloomPsiMessage) -> bool {
306 let indices = self.hash_element(element, bloom_msg.num_hashes, bloom_msg.filter_size);
307
308 for idx in indices {
309 let byte_idx = idx / 8;
310 let bit_idx = idx % 8;
311 if (bloom_msg.filter[byte_idx] & (1 << bit_idx)) == 0 {
312 return false;
313 }
314 }
315
316 true
317 }
318
319 fn hash_element(&self, element: &[u8], num_hashes: usize, filter_size: usize) -> Vec<usize> {
321 let mut indices = Vec::with_capacity(num_hashes);
322 let base_hash = hash(element);
323
324 for i in 0..num_hashes {
325 let mut hasher = Hasher::new();
326 hasher.update(&base_hash);
327 hasher.update(&(i as u64).to_le_bytes());
328 let hash_val = hasher.finalize();
329 let idx = u64::from_le_bytes(hash_val.as_bytes()[0..8].try_into().unwrap()) as usize;
330 indices.push(idx % filter_size);
331 }
332
333 indices
334 }
335}
336
337impl Default for BloomPsiClient {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl PsiServerMessage {
345 pub fn to_bytes(&self) -> PsiResult<Vec<u8>> {
347 crate::codec::encode(self).map_err(|e| PsiError::Serialization(e.to_string()))
348 }
349
350 pub fn from_bytes(bytes: &[u8]) -> PsiResult<Self> {
352 crate::codec::decode(bytes).map_err(|e| PsiError::Serialization(e.to_string()))
353 }
354}
355
356impl BloomPsiMessage {
357 pub fn to_bytes(&self) -> PsiResult<Vec<u8>> {
359 crate::codec::encode(self).map_err(|e| PsiError::Serialization(e.to_string()))
360 }
361
362 pub fn from_bytes(bytes: &[u8]) -> PsiResult<Self> {
364 crate::codec::decode(bytes).map_err(|e| PsiError::Serialization(e.to_string()))
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_psi_basic() {
374 let server_set = vec![
375 b"content_hash_1".to_vec(),
376 b"content_hash_2".to_vec(),
377 b"content_hash_3".to_vec(),
378 ];
379
380 let client_set = vec![b"content_hash_2".to_vec(), b"content_hash_4".to_vec()];
381
382 let server = PsiServer::new();
383 let server_msg = server.encode_set(&server_set);
384
385 let client = PsiClient::new();
386 let _intersection = client
387 .compute_intersection(&client_set, &server_msg)
388 .unwrap();
389
390 }
394
395 #[test]
396 fn test_psi_empty_client_set() {
397 let server_set = vec![b"content_hash_1".to_vec()];
398 let client_set: Vec<Vec<u8>> = vec![];
399
400 let server = PsiServer::new();
401 let server_msg = server.encode_set(&server_set);
402
403 let client = PsiClient::new();
404 let intersection = client
405 .compute_intersection(&client_set, &server_msg)
406 .unwrap();
407
408 assert!(intersection.is_empty());
409 }
410
411 #[test]
412 fn test_psi_no_intersection() {
413 let server_set = vec![b"hash_1".to_vec(), b"hash_2".to_vec()];
414 let client_set = vec![b"hash_3".to_vec(), b"hash_4".to_vec()];
415
416 let server = PsiServer::new();
417 let server_msg = server.encode_set(&server_set);
418
419 let client = PsiClient::new();
420 let intersection = client
421 .compute_intersection(&client_set, &server_msg)
422 .unwrap();
423
424 assert!(intersection.is_empty());
425 }
426
427 #[test]
428 fn test_psi_serialization() {
429 let server_set = vec![b"content_hash_1".to_vec()];
430
431 let server = PsiServer::new();
432 let server_msg = server.encode_set(&server_set);
433
434 let bytes = server_msg.to_bytes().unwrap();
435 let deserialized = PsiServerMessage::from_bytes(&bytes).unwrap();
436
437 assert_eq!(server_msg.hashed_elements, deserialized.hashed_elements);
438 assert_eq!(server_msg.key_commitment, deserialized.key_commitment);
439 }
440
441 #[test]
442 fn test_bloom_psi_basic() {
443 let server_set = vec![
444 b"content_1".to_vec(),
445 b"content_2".to_vec(),
446 b"content_3".to_vec(),
447 ];
448
449 let client_set = vec![b"content_2".to_vec(), b"content_4".to_vec()];
450
451 let server = BloomPsiServer::new(10, 0.01);
452 let bloom_msg = server.encode_set(&server_set);
453
454 let client = BloomPsiClient::new();
455 let intersection = client
456 .compute_intersection(&client_set, &bloom_msg)
457 .unwrap();
458
459 assert!(!intersection.is_empty());
461 assert!(intersection.contains(&b"content_2".to_vec()));
462 }
463
464 #[test]
465 fn test_bloom_psi_empty_set() {
466 let server_set: Vec<Vec<u8>> = vec![];
467 let client_set = vec![b"content_1".to_vec()];
468
469 let server = BloomPsiServer::new(10, 0.01);
470 let bloom_msg = server.encode_set(&server_set);
471
472 let client = BloomPsiClient::new();
473 let intersection = client
474 .compute_intersection(&client_set, &bloom_msg)
475 .unwrap();
476
477 assert!(intersection.is_empty());
478 }
479
480 #[test]
481 fn test_bloom_psi_all_match() {
482 let elements = vec![b"elem_1".to_vec(), b"elem_2".to_vec(), b"elem_3".to_vec()];
483
484 let server = BloomPsiServer::new(10, 0.01);
485 let bloom_msg = server.encode_set(&elements);
486
487 let client = BloomPsiClient::new();
488 let intersection = client.compute_intersection(&elements, &bloom_msg).unwrap();
489
490 assert_eq!(intersection.len(), elements.len());
491 }
492
493 #[test]
494 fn test_bloom_psi_false_positive_rate() {
495 let server_set: Vec<Vec<u8>> = (0..100)
496 .map(|i| format!("server_{}", i).into_bytes())
497 .collect();
498 let client_set: Vec<Vec<u8>> = (100..200)
499 .map(|i| format!("server_{}", i).into_bytes())
500 .collect();
501
502 let server = BloomPsiServer::new(100, 0.01);
503 let bloom_msg = server.encode_set(&server_set);
504
505 let client = BloomPsiClient::new();
506 let intersection = client
507 .compute_intersection(&client_set, &bloom_msg)
508 .unwrap();
509
510 assert!(intersection.len() < 5);
513 }
514
515 #[test]
516 fn test_bloom_psi_serialization() {
517 let server_set = vec![b"content_1".to_vec()];
518
519 let server = BloomPsiServer::new(10, 0.01);
520 let bloom_msg = server.encode_set(&server_set);
521
522 let bytes = bloom_msg.to_bytes().unwrap();
523 let deserialized = BloomPsiMessage::from_bytes(&bytes).unwrap();
524
525 assert_eq!(bloom_msg.filter, deserialized.filter);
526 assert_eq!(bloom_msg.num_hashes, deserialized.num_hashes);
527 assert_eq!(bloom_msg.filter_size, deserialized.filter_size);
528 }
529
530 #[test]
531 fn test_bloom_filter_parameters() {
532 let server = BloomPsiServer::new(1000, 0.01);
533 assert!(server.filter_size > 0);
534 assert!(server.num_hashes > 0);
535
536 let server2 = BloomPsiServer::new(1000, 0.001);
537 assert!(server2.filter_size > server.filter_size);
539 }
540
541 #[test]
542 fn test_psi_server_default() {
543 let server = PsiServer::default();
544 let set = vec![b"test".to_vec()];
545 let msg = server.encode_set(&set);
546 assert!(!msg.hashed_elements.is_empty());
547 }
548
549 #[test]
550 fn test_psi_client_default() {
551 let client = PsiClient::default();
552 let server = PsiServer::new();
553 let server_msg = server.encode_set(&[b"test".to_vec()]);
554 let result = client.compute_intersection(&[b"test".to_vec()], &server_msg);
555 assert!(result.is_ok());
556 }
557
558 #[test]
559 fn test_bloom_psi_client_default() {
560 let client = BloomPsiClient;
561 let server = BloomPsiServer::new(10, 0.01);
562 let bloom_msg = server.encode_set(&[b"test".to_vec()]);
563 let result = client.compute_intersection(&[b"test".to_vec()], &bloom_msg);
564 assert!(result.is_ok());
565 }
566}