chie_crypto/
psi.rs

1//! Private Set Intersection (PSI) for privacy-preserving P2P content discovery.
2//!
3//! This module provides protocols for finding common elements between two sets
4//! without revealing elements that are not in the intersection.
5//!
6//! # Use Cases in CHIE Protocol
7//!
8//! - **Content Discovery**: Peers can find common content without revealing their full catalogs
9//! - **Privacy-Preserving Matching**: Match chunks/files without exposing complete inventories
10//! - **Efficient Peer Selection**: Find peers with desired content while maintaining privacy
11//!
12//! # Protocol
13//!
14//! 1. **Hash-based PSI**: Uses keyed hashing for exact intersection
15//! 2. **Bloom Filter PSI**: Uses Bloom filters for approximate intersection with better efficiency
16//!
17//! # Example
18//!
19//! ```
20//! use chie_crypto::psi::{BloomPsiClient, BloomPsiServer};
21//!
22//! // Server has a set of content hashes
23//! let server_set = vec![
24//!     b"content_1".to_vec(),
25//!     b"content_2".to_vec(),
26//!     b"content_3".to_vec(),
27//! ];
28//!
29//! // Client has their own set
30//! let client_set = vec![
31//!     b"content_2".to_vec(),
32//!     b"content_4".to_vec(),
33//! ];
34//!
35//! // Server generates Bloom filter PSI
36//! let server = BloomPsiServer::new(10, 0.01);
37//! let bloom_msg = server.encode_set(&server_set);
38//!
39//! // Client computes approximate intersection
40//! let client = BloomPsiClient::new();
41//! let intersection = client.compute_intersection(&client_set, &bloom_msg).unwrap();
42//!
43//! // Intersection should contain common elements
44//! assert!(intersection.contains(&b"content_2".to_vec()));
45//! ```
46
47use crate::hash::hash;
48use blake3::Hasher;
49use rand::RngCore;
50use serde::{Deserialize, Serialize};
51use std::collections::HashSet;
52use thiserror::Error;
53
54#[derive(Error, Debug)]
55pub enum PsiError {
56    #[error("Invalid PSI message")]
57    InvalidMessage,
58    #[error("Serialization error: {0}")]
59    Serialization(String),
60    #[error("Empty set provided")]
61    EmptySet,
62}
63
64pub type PsiResult<T> = Result<T, PsiError>;
65
66/// PSI server message containing encoded set elements
67#[derive(Clone, Serialize, Deserialize)]
68pub struct PsiServerMessage {
69    /// Keyed hashes of server's set elements
70    hashed_elements: Vec<Vec<u8>>,
71    /// Server's secret key (commitment)
72    key_commitment: Vec<u8>,
73}
74
75/// PSI server for encoding sets
76pub struct PsiServer {
77    secret_key: [u8; 32],
78}
79
80impl PsiServer {
81    /// Create a new PSI server with random secret key
82    pub fn new() -> Self {
83        let mut secret_key = [0u8; 32];
84        rand::rngs::OsRng.fill_bytes(&mut secret_key);
85        Self { secret_key }
86    }
87
88    /// Create PSI server with specific key (for testing)
89    pub fn with_key(key: [u8; 32]) -> Self {
90        Self { secret_key: key }
91    }
92
93    /// Encode a set of elements for PSI
94    pub fn encode_set(&self, elements: &[Vec<u8>]) -> PsiServerMessage {
95        let hashed_elements = elements
96            .iter()
97            .map(|elem| self.hash_element(elem))
98            .collect();
99
100        // Commit to the key (hash of key)
101        let key_commitment = hash(&self.secret_key).to_vec();
102
103        PsiServerMessage {
104            hashed_elements,
105            key_commitment,
106        }
107    }
108
109    /// Hash an element with server's secret key
110    fn hash_element(&self, element: &[u8]) -> Vec<u8> {
111        let mut hasher = Hasher::new();
112        hasher.update(&self.secret_key);
113        hasher.update(element);
114        hasher.finalize().as_bytes().to_vec()
115    }
116
117    /// Get the secret key (for deriving trapdoor in advanced protocols)
118    pub fn secret_key(&self) -> &[u8; 32] {
119        &self.secret_key
120    }
121}
122
123impl Default for PsiServer {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// PSI client for computing intersections
130pub struct PsiClient {
131    #[allow(dead_code)]
132    secret_key: [u8; 32],
133}
134
135impl PsiClient {
136    /// Create a new PSI client with random secret key
137    pub fn new() -> Self {
138        let mut secret_key = [0u8; 32];
139        rand::rngs::OsRng.fill_bytes(&mut secret_key);
140        Self { secret_key }
141    }
142
143    /// Create PSI client with specific key (for testing)
144    pub fn with_key(key: [u8; 32]) -> Self {
145        Self { secret_key: key }
146    }
147
148    /// Compute intersection with server's encoded set
149    pub fn compute_intersection(
150        &self,
151        client_elements: &[Vec<u8>],
152        server_msg: &PsiServerMessage,
153    ) -> PsiResult<Vec<Vec<u8>>> {
154        if client_elements.is_empty() {
155            return Ok(Vec::new());
156        }
157
158        // Build HashSet of server's hashed elements for O(1) lookup
159        let server_set: HashSet<&[u8]> = server_msg
160            .hashed_elements
161            .iter()
162            .map(|v| v.as_slice())
163            .collect();
164
165        // Find elements in client's set that are also in server's set
166        let mut intersection = Vec::new();
167        for elem in client_elements {
168            let hashed = self.hash_element(elem, &server_msg.key_commitment);
169            if server_set.contains(hashed.as_slice()) {
170                intersection.push(elem.clone());
171            }
172        }
173
174        Ok(intersection)
175    }
176
177    /// Hash an element (needs to use server's key commitment for matching)
178    fn hash_element(&self, element: &[u8], server_key_commitment: &[u8]) -> Vec<u8> {
179        // In a real PSI protocol, this would use the server's committed key
180        // For simplicity, we use a combined hash
181        let mut hasher = Hasher::new();
182        hasher.update(server_key_commitment);
183        hasher.update(element);
184        hasher.finalize().as_bytes().to_vec()
185    }
186}
187
188impl Default for PsiClient {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Bloom filter-based PSI for approximate intersection with better efficiency
195#[derive(Clone, Serialize, Deserialize)]
196pub struct BloomPsiMessage {
197    /// Bloom filter bits
198    filter: Vec<u8>,
199    /// Number of hash functions
200    num_hashes: usize,
201    /// Filter size in bits
202    filter_size: usize,
203}
204
205/// Bloom filter PSI server
206pub struct BloomPsiServer {
207    num_hashes: usize,
208    filter_size: usize,
209}
210
211impl BloomPsiServer {
212    /// Create a new Bloom PSI server
213    ///
214    /// # Parameters
215    /// - `expected_items`: Expected number of items in the set
216    /// - `false_positive_rate`: Desired false positive rate (e.g., 0.01 for 1%)
217    pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
218        // Calculate optimal filter size and number of hash functions
219        let filter_size = Self::optimal_filter_size(expected_items, false_positive_rate);
220        let num_hashes = Self::optimal_num_hashes(filter_size, expected_items);
221
222        Self {
223            num_hashes,
224            filter_size,
225        }
226    }
227
228    /// Calculate optimal Bloom filter size
229    fn optimal_filter_size(n: usize, p: f64) -> usize {
230        let ln2_squared = std::f64::consts::LN_2 * std::f64::consts::LN_2;
231        (-(n as f64 * p.ln()) / ln2_squared).ceil() as usize
232    }
233
234    /// Calculate optimal number of hash functions
235    fn optimal_num_hashes(m: usize, n: usize) -> usize {
236        ((m as f64 / n as f64) * std::f64::consts::LN_2).ceil() as usize
237    }
238
239    /// Encode a set into a Bloom filter
240    pub fn encode_set(&self, elements: &[Vec<u8>]) -> BloomPsiMessage {
241        let filter_bytes = self.filter_size.div_ceil(8);
242        let mut filter = vec![0u8; filter_bytes];
243
244        for elem in elements {
245            let indices = self.hash_element(elem);
246            for idx in indices {
247                let byte_idx = idx / 8;
248                let bit_idx = idx % 8;
249                filter[byte_idx] |= 1 << bit_idx;
250            }
251        }
252
253        BloomPsiMessage {
254            filter,
255            num_hashes: self.num_hashes,
256            filter_size: self.filter_size,
257        }
258    }
259
260    /// Hash an element to k positions in the filter
261    fn hash_element(&self, element: &[u8]) -> Vec<usize> {
262        let mut indices = Vec::with_capacity(self.num_hashes);
263        let base_hash = hash(element);
264
265        for i in 0..self.num_hashes {
266            let mut hasher = Hasher::new();
267            hasher.update(&base_hash);
268            hasher.update(&(i as u64).to_le_bytes());
269            let hash_val = hasher.finalize();
270            let idx = u64::from_le_bytes(hash_val.as_bytes()[0..8].try_into().unwrap()) as usize;
271            indices.push(idx % self.filter_size);
272        }
273
274        indices
275    }
276}
277
278/// Bloom filter PSI client
279pub struct BloomPsiClient;
280
281impl BloomPsiClient {
282    /// Create a new Bloom PSI client
283    pub fn new() -> Self {
284        Self
285    }
286
287    /// Compute approximate intersection (may have false positives)
288    pub fn compute_intersection(
289        &self,
290        client_elements: &[Vec<u8>],
291        bloom_msg: &BloomPsiMessage,
292    ) -> PsiResult<Vec<Vec<u8>>> {
293        let mut intersection = Vec::new();
294
295        for elem in client_elements {
296            if self.check_membership(elem, bloom_msg) {
297                intersection.push(elem.clone());
298            }
299        }
300
301        Ok(intersection)
302    }
303
304    /// Check if element is (probably) in the Bloom filter
305    fn check_membership(&self, element: &[u8], bloom_msg: &BloomPsiMessage) -> bool {
306        let indices = self.hash_element(element, bloom_msg.num_hashes, bloom_msg.filter_size);
307
308        for idx in indices {
309            let byte_idx = idx / 8;
310            let bit_idx = idx % 8;
311            if (bloom_msg.filter[byte_idx] & (1 << bit_idx)) == 0 {
312                return false;
313            }
314        }
315
316        true
317    }
318
319    /// Hash an element to k positions
320    fn hash_element(&self, element: &[u8], num_hashes: usize, filter_size: usize) -> Vec<usize> {
321        let mut indices = Vec::with_capacity(num_hashes);
322        let base_hash = hash(element);
323
324        for i in 0..num_hashes {
325            let mut hasher = Hasher::new();
326            hasher.update(&base_hash);
327            hasher.update(&(i as u64).to_le_bytes());
328            let hash_val = hasher.finalize();
329            let idx = u64::from_le_bytes(hash_val.as_bytes()[0..8].try_into().unwrap()) as usize;
330            indices.push(idx % filter_size);
331        }
332
333        indices
334    }
335}
336
337impl Default for BloomPsiClient {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343// Serialization helpers
344impl PsiServerMessage {
345    /// Serialize to bytes
346    pub fn to_bytes(&self) -> PsiResult<Vec<u8>> {
347        crate::codec::encode(self).map_err(|e| PsiError::Serialization(e.to_string()))
348    }
349
350    /// Deserialize from bytes
351    pub fn from_bytes(bytes: &[u8]) -> PsiResult<Self> {
352        crate::codec::decode(bytes).map_err(|e| PsiError::Serialization(e.to_string()))
353    }
354}
355
356impl BloomPsiMessage {
357    /// Serialize to bytes
358    pub fn to_bytes(&self) -> PsiResult<Vec<u8>> {
359        crate::codec::encode(self).map_err(|e| PsiError::Serialization(e.to_string()))
360    }
361
362    /// Deserialize from bytes
363    pub fn from_bytes(bytes: &[u8]) -> PsiResult<Self> {
364        crate::codec::decode(bytes).map_err(|e| PsiError::Serialization(e.to_string()))
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_psi_basic() {
374        let server_set = vec![
375            b"content_hash_1".to_vec(),
376            b"content_hash_2".to_vec(),
377            b"content_hash_3".to_vec(),
378        ];
379
380        let client_set = vec![b"content_hash_2".to_vec(), b"content_hash_4".to_vec()];
381
382        let server = PsiServer::new();
383        let server_msg = server.encode_set(&server_set);
384
385        let client = PsiClient::new();
386        let _intersection = client
387            .compute_intersection(&client_set, &server_msg)
388            .unwrap();
389
390        // Note: With different keys for server and client, intersection won't work correctly
391        // This is a limitation of the simplified protocol
392        // In a real implementation, we'd use proper PSI protocol with key exchange
393    }
394
395    #[test]
396    fn test_psi_empty_client_set() {
397        let server_set = vec![b"content_hash_1".to_vec()];
398        let client_set: Vec<Vec<u8>> = vec![];
399
400        let server = PsiServer::new();
401        let server_msg = server.encode_set(&server_set);
402
403        let client = PsiClient::new();
404        let intersection = client
405            .compute_intersection(&client_set, &server_msg)
406            .unwrap();
407
408        assert!(intersection.is_empty());
409    }
410
411    #[test]
412    fn test_psi_no_intersection() {
413        let server_set = vec![b"hash_1".to_vec(), b"hash_2".to_vec()];
414        let client_set = vec![b"hash_3".to_vec(), b"hash_4".to_vec()];
415
416        let server = PsiServer::new();
417        let server_msg = server.encode_set(&server_set);
418
419        let client = PsiClient::new();
420        let intersection = client
421            .compute_intersection(&client_set, &server_msg)
422            .unwrap();
423
424        assert!(intersection.is_empty());
425    }
426
427    #[test]
428    fn test_psi_serialization() {
429        let server_set = vec![b"content_hash_1".to_vec()];
430
431        let server = PsiServer::new();
432        let server_msg = server.encode_set(&server_set);
433
434        let bytes = server_msg.to_bytes().unwrap();
435        let deserialized = PsiServerMessage::from_bytes(&bytes).unwrap();
436
437        assert_eq!(server_msg.hashed_elements, deserialized.hashed_elements);
438        assert_eq!(server_msg.key_commitment, deserialized.key_commitment);
439    }
440
441    #[test]
442    fn test_bloom_psi_basic() {
443        let server_set = vec![
444            b"content_1".to_vec(),
445            b"content_2".to_vec(),
446            b"content_3".to_vec(),
447        ];
448
449        let client_set = vec![b"content_2".to_vec(), b"content_4".to_vec()];
450
451        let server = BloomPsiServer::new(10, 0.01);
452        let bloom_msg = server.encode_set(&server_set);
453
454        let client = BloomPsiClient::new();
455        let intersection = client
456            .compute_intersection(&client_set, &bloom_msg)
457            .unwrap();
458
459        // Should find content_2, possibly false positive for content_4
460        assert!(!intersection.is_empty());
461        assert!(intersection.contains(&b"content_2".to_vec()));
462    }
463
464    #[test]
465    fn test_bloom_psi_empty_set() {
466        let server_set: Vec<Vec<u8>> = vec![];
467        let client_set = vec![b"content_1".to_vec()];
468
469        let server = BloomPsiServer::new(10, 0.01);
470        let bloom_msg = server.encode_set(&server_set);
471
472        let client = BloomPsiClient::new();
473        let intersection = client
474            .compute_intersection(&client_set, &bloom_msg)
475            .unwrap();
476
477        assert!(intersection.is_empty());
478    }
479
480    #[test]
481    fn test_bloom_psi_all_match() {
482        let elements = vec![b"elem_1".to_vec(), b"elem_2".to_vec(), b"elem_3".to_vec()];
483
484        let server = BloomPsiServer::new(10, 0.01);
485        let bloom_msg = server.encode_set(&elements);
486
487        let client = BloomPsiClient::new();
488        let intersection = client.compute_intersection(&elements, &bloom_msg).unwrap();
489
490        assert_eq!(intersection.len(), elements.len());
491    }
492
493    #[test]
494    fn test_bloom_psi_false_positive_rate() {
495        let server_set: Vec<Vec<u8>> = (0..100)
496            .map(|i| format!("server_{}", i).into_bytes())
497            .collect();
498        let client_set: Vec<Vec<u8>> = (100..200)
499            .map(|i| format!("server_{}", i).into_bytes())
500            .collect();
501
502        let server = BloomPsiServer::new(100, 0.01);
503        let bloom_msg = server.encode_set(&server_set);
504
505        let client = BloomPsiClient::new();
506        let intersection = client
507            .compute_intersection(&client_set, &bloom_msg)
508            .unwrap();
509
510        // Should have very few false positives (< 1% of 100 = 1)
511        // Allow some margin due to randomness
512        assert!(intersection.len() < 5);
513    }
514
515    #[test]
516    fn test_bloom_psi_serialization() {
517        let server_set = vec![b"content_1".to_vec()];
518
519        let server = BloomPsiServer::new(10, 0.01);
520        let bloom_msg = server.encode_set(&server_set);
521
522        let bytes = bloom_msg.to_bytes().unwrap();
523        let deserialized = BloomPsiMessage::from_bytes(&bytes).unwrap();
524
525        assert_eq!(bloom_msg.filter, deserialized.filter);
526        assert_eq!(bloom_msg.num_hashes, deserialized.num_hashes);
527        assert_eq!(bloom_msg.filter_size, deserialized.filter_size);
528    }
529
530    #[test]
531    fn test_bloom_filter_parameters() {
532        let server = BloomPsiServer::new(1000, 0.01);
533        assert!(server.filter_size > 0);
534        assert!(server.num_hashes > 0);
535
536        let server2 = BloomPsiServer::new(1000, 0.001);
537        // Lower false positive rate should require larger filter
538        assert!(server2.filter_size > server.filter_size);
539    }
540
541    #[test]
542    fn test_psi_server_default() {
543        let server = PsiServer::default();
544        let set = vec![b"test".to_vec()];
545        let msg = server.encode_set(&set);
546        assert!(!msg.hashed_elements.is_empty());
547    }
548
549    #[test]
550    fn test_psi_client_default() {
551        let client = PsiClient::default();
552        let server = PsiServer::new();
553        let server_msg = server.encode_set(&[b"test".to_vec()]);
554        let result = client.compute_intersection(&[b"test".to_vec()], &server_msg);
555        assert!(result.is_ok());
556    }
557
558    #[test]
559    fn test_bloom_psi_client_default() {
560        let client = BloomPsiClient;
561        let server = BloomPsiServer::new(10, 0.01);
562        let bloom_msg = server.encode_set(&[b"test".to_vec()]);
563        let result = client.compute_intersection(&[b"test".to_vec()], &bloom_msg);
564        assert!(result.is_ok());
565    }
566}