chie_crypto/
ot.rs

1//! Oblivious Transfer for private information retrieval.
2//!
3//! This module implements a 1-out-of-N oblivious transfer protocol where:
4//! - A sender has N items (messages)
5//! - A receiver wants to retrieve one of the N items by index
6//! - The sender doesn't learn which item was chosen
7//! - The receiver doesn't learn anything about the other items
8//!
9//! # Use Cases for CHIE Protocol
10//! - Private P2P content discovery (receiver queries without revealing interest)
11//! - Privacy-preserving content catalog browsing
12//! - Anonymous chunk retrieval from peers
13//! - Private database queries in distributed systems
14//!
15//! # Protocol Overview
16//! 1. Receiver generates keypairs for each possible choice
17//! 2. Receiver encrypts the chosen index's public key, randomizes others
18//! 3. Sender encrypts each message with corresponding receiver public key
19//! 4. Receiver can only decrypt the chosen message
20//!
21//! # Example
22//! ```
23//! use chie_crypto::ot::*;
24//!
25//! // Sender has 3 items
26//! let items = vec![
27//!     b"Item 0".to_vec(),
28//!     b"Item 1".to_vec(),
29//!     b"Item 2".to_vec(),
30//! ];
31//!
32//! // Receiver wants item at index 1
33//! let receiver = OTReceiver::new(items.len(), 1).unwrap();
34//! let request = receiver.create_request();
35//!
36//! // Sender responds
37//! let sender = OTSender::new();
38//! let response = sender.respond(&request, &items).unwrap();
39//!
40//! // Receiver retrieves only the chosen item
41//! let retrieved = receiver.retrieve(&response).unwrap();
42//! assert_eq!(retrieved, b"Item 1");
43//! ```
44
45use blake3::Hasher;
46use chacha20poly1305::{
47    ChaCha20Poly1305, Nonce,
48    aead::{Aead, KeyInit},
49};
50use curve25519_dalek::{
51    constants::RISTRETTO_BASEPOINT_TABLE,
52    ristretto::{CompressedRistretto, RistrettoPoint},
53    scalar::Scalar,
54};
55use rand::Rng;
56use serde::{Deserialize, Serialize};
57
58/// Errors that can occur during oblivious transfer operations.
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub enum OTError {
61    /// Invalid choice index
62    InvalidChoice,
63    /// Invalid number of items
64    InvalidItemCount,
65    /// Invalid request format
66    InvalidRequest,
67    /// Invalid response format
68    InvalidResponse,
69    /// Decryption failed
70    DecryptionFailed,
71    /// Encryption failed
72    EncryptionFailed,
73    /// Invalid public key
74    InvalidPublicKey,
75    /// Mismatched item count
76    MismatchedItemCount,
77}
78
79impl std::fmt::Display for OTError {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            OTError::InvalidChoice => write!(f, "Invalid choice index"),
83            OTError::InvalidItemCount => write!(f, "Invalid number of items"),
84            OTError::InvalidRequest => write!(f, "Invalid request format"),
85            OTError::InvalidResponse => write!(f, "Invalid response format"),
86            OTError::DecryptionFailed => write!(f, "Decryption failed"),
87            OTError::EncryptionFailed => write!(f, "Encryption failed"),
88            OTError::InvalidPublicKey => write!(f, "Invalid public key"),
89            OTError::MismatchedItemCount => write!(f, "Mismatched item count"),
90        }
91    }
92}
93
94impl std::error::Error for OTError {}
95
96/// Result type for oblivious transfer operations.
97pub type OTResult<T> = Result<T, OTError>;
98
99/// Oblivious transfer request from receiver.
100#[derive(Clone, Serialize, Deserialize)]
101pub struct OTRequest {
102    /// Public keys for each possible choice (one real, others random)
103    pub_keys: Vec<CompressedRistretto>,
104}
105
106/// Oblivious transfer response from sender.
107#[derive(Clone, Serialize, Deserialize)]
108pub struct OTResponse {
109    /// Encrypted items (one for each public key)
110    encrypted_items: Vec<EncryptedItem>,
111}
112
113/// Encrypted item in oblivious transfer.
114#[derive(Clone, Serialize, Deserialize)]
115struct EncryptedItem {
116    /// Ephemeral public key for this encryption
117    ephemeral_pk: CompressedRistretto,
118    /// Encrypted data
119    ciphertext: Vec<u8>,
120    /// Nonce for encryption
121    nonce: [u8; 12],
122}
123
124/// Receiver in oblivious transfer protocol.
125pub struct OTReceiver {
126    /// Number of items to choose from
127    n_items: usize,
128    /// Index of chosen item
129    choice: usize,
130    /// Secret key for the chosen item
131    chosen_sk: Scalar,
132    /// Public keys sent to sender
133    pub_keys: Vec<CompressedRistretto>,
134}
135
136impl OTReceiver {
137    /// Create a new receiver choosing item at given index.
138    ///
139    /// # Arguments
140    /// * `n_items` - Total number of items sender has
141    /// * `choice` - Index of item to retrieve (0-based)
142    pub fn new(n_items: usize, choice: usize) -> OTResult<Self> {
143        if n_items == 0 {
144            return Err(OTError::InvalidItemCount);
145        }
146        if choice >= n_items {
147            return Err(OTError::InvalidChoice);
148        }
149
150        let mut rng = rand::thread_rng();
151        let mut pub_keys = Vec::with_capacity(n_items);
152
153        // Generate secret key for chosen item
154        let mut sk_bytes = [0u8; 32];
155        rng.fill(&mut sk_bytes);
156        let chosen_sk = Scalar::from_bytes_mod_order(sk_bytes);
157        let chosen_pk = &chosen_sk * RISTRETTO_BASEPOINT_TABLE;
158
159        // Generate public keys for all items
160        for i in 0..n_items {
161            if i == choice {
162                // Use the real public key for chosen item
163                pub_keys.push(chosen_pk.compress());
164            } else {
165                // Generate random points for other items
166                let mut random_bytes = [0u8; 32];
167                rng.fill(&mut random_bytes);
168                let random_sk = Scalar::from_bytes_mod_order(random_bytes);
169                let random_pk = &random_sk * RISTRETTO_BASEPOINT_TABLE;
170                pub_keys.push(random_pk.compress());
171            }
172        }
173
174        Ok(Self {
175            n_items,
176            choice,
177            chosen_sk,
178            pub_keys,
179        })
180    }
181
182    /// Create the oblivious transfer request to send to the sender.
183    pub fn create_request(&self) -> OTRequest {
184        OTRequest {
185            pub_keys: self.pub_keys.clone(),
186        }
187    }
188
189    /// Retrieve the chosen item from the sender's response.
190    pub fn retrieve(&self, response: &OTResponse) -> OTResult<Vec<u8>> {
191        if response.encrypted_items.len() != self.n_items {
192            return Err(OTError::MismatchedItemCount);
193        }
194
195        let item = &response.encrypted_items[self.choice];
196
197        // Decompress ephemeral public key
198        let ephemeral_pk = item
199            .ephemeral_pk
200            .decompress()
201            .ok_or(OTError::InvalidPublicKey)?;
202
203        // Compute shared secret: chosen_sk * ephemeral_pk
204        let shared_point = ephemeral_pk * self.chosen_sk;
205
206        // Derive symmetric key
207        let sym_key = derive_ot_key(&shared_point);
208
209        // Decrypt
210        let cipher = ChaCha20Poly1305::new(&sym_key.into());
211        let nonce = Nonce::from_slice(&item.nonce);
212
213        cipher
214            .decrypt(nonce, item.ciphertext.as_ref())
215            .map_err(|_| OTError::DecryptionFailed)
216    }
217
218    /// Get the choice index.
219    pub fn choice(&self) -> usize {
220        self.choice
221    }
222
223    /// Get the number of items.
224    pub fn n_items(&self) -> usize {
225        self.n_items
226    }
227}
228
229/// Sender in oblivious transfer protocol.
230pub struct OTSender;
231
232impl OTSender {
233    /// Create a new sender.
234    pub fn new() -> Self {
235        Self
236    }
237
238    /// Respond to a receiver's request by encrypting all items.
239    ///
240    /// # Arguments
241    /// * `request` - The receiver's OT request
242    /// * `items` - All items (must match the number of public keys in request)
243    pub fn respond(&self, request: &OTRequest, items: &[Vec<u8>]) -> OTResult<OTResponse> {
244        if items.len() != request.pub_keys.len() {
245            return Err(OTError::MismatchedItemCount);
246        }
247        if items.is_empty() {
248            return Err(OTError::InvalidItemCount);
249        }
250
251        let mut rng = rand::thread_rng();
252        let mut encrypted_items = Vec::with_capacity(items.len());
253
254        // Encrypt each item with corresponding public key
255        for (item, pk_compressed) in items.iter().zip(&request.pub_keys) {
256            // Decompress public key
257            let pk = pk_compressed
258                .decompress()
259                .ok_or(OTError::InvalidPublicKey)?;
260
261            // Generate ephemeral keypair
262            let mut ephemeral_sk_bytes = [0u8; 32];
263            rng.fill(&mut ephemeral_sk_bytes);
264            let ephemeral_sk = Scalar::from_bytes_mod_order(ephemeral_sk_bytes);
265            let ephemeral_pk = &ephemeral_sk * RISTRETTO_BASEPOINT_TABLE;
266
267            // Compute shared secret: ephemeral_sk * receiver_pk
268            let shared_point = pk * ephemeral_sk;
269
270            // Derive symmetric key
271            let sym_key = derive_ot_key(&shared_point);
272
273            // Generate nonce
274            let mut nonce_bytes = [0u8; 12];
275            rng.fill(&mut nonce_bytes);
276            let nonce = Nonce::from_slice(&nonce_bytes);
277
278            // Encrypt item
279            let cipher = ChaCha20Poly1305::new(&sym_key.into());
280            let ciphertext = cipher
281                .encrypt(nonce, item.as_ref())
282                .map_err(|_| OTError::EncryptionFailed)?;
283
284            encrypted_items.push(EncryptedItem {
285                ephemeral_pk: ephemeral_pk.compress(),
286                ciphertext,
287                nonce: nonce_bytes,
288            });
289        }
290
291        Ok(OTResponse { encrypted_items })
292    }
293}
294
295impl Default for OTSender {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301/// Derive a symmetric key from a shared point for OT encryption.
302fn derive_ot_key(point: &RistrettoPoint) -> [u8; 32] {
303    let mut hasher = Hasher::new();
304    hasher.update(b"chie-ot-v1");
305    hasher.update(&point.compress().to_bytes());
306    let hash = hasher.finalize();
307    *hash.as_bytes()
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_basic_ot_1_of_2() {
316        let items = vec![b"First item".to_vec(), b"Second item".to_vec()];
317
318        // Receiver chooses index 0
319        let receiver = OTReceiver::new(2, 0).unwrap();
320        let request = receiver.create_request();
321
322        // Sender responds
323        let sender = OTSender::new();
324        let response = sender.respond(&request, &items).unwrap();
325
326        // Receiver retrieves
327        let retrieved = receiver.retrieve(&response).unwrap();
328        assert_eq!(retrieved, items[0]);
329    }
330
331    #[test]
332    fn test_basic_ot_1_of_3() {
333        let items = vec![b"Item 0".to_vec(), b"Item 1".to_vec(), b"Item 2".to_vec()];
334
335        // Receiver chooses index 1
336        let receiver = OTReceiver::new(3, 1).unwrap();
337        let request = receiver.create_request();
338
339        // Sender responds
340        let sender = OTSender::new();
341        let response = sender.respond(&request, &items).unwrap();
342
343        // Receiver retrieves
344        let retrieved = receiver.retrieve(&response).unwrap();
345        assert_eq!(retrieved, items[1]);
346    }
347
348    #[test]
349    fn test_ot_all_choices() {
350        let items = vec![
351            b"Alpha".to_vec(),
352            b"Beta".to_vec(),
353            b"Gamma".to_vec(),
354            b"Delta".to_vec(),
355        ];
356
357        // Test retrieving each item
358        for choice in 0..items.len() {
359            let receiver = OTReceiver::new(items.len(), choice).unwrap();
360            let request = receiver.create_request();
361
362            let sender = OTSender::new();
363            let response = sender.respond(&request, &items).unwrap();
364
365            let retrieved = receiver.retrieve(&response).unwrap();
366            assert_eq!(retrieved, items[choice]);
367        }
368    }
369
370    #[test]
371    fn test_invalid_choice() {
372        assert!(OTReceiver::new(3, 3).is_err());
373        assert!(OTReceiver::new(3, 100).is_err());
374    }
375
376    #[test]
377    fn test_invalid_item_count() {
378        assert!(OTReceiver::new(0, 0).is_err());
379    }
380
381    #[test]
382    fn test_mismatched_item_count() {
383        let items = vec![b"Item 1".to_vec(), b"Item 2".to_vec()];
384        let receiver = OTReceiver::new(3, 0).unwrap();
385        let request = receiver.create_request();
386
387        let sender = OTSender::new();
388        assert!(sender.respond(&request, &items).is_err());
389    }
390
391    #[test]
392    fn test_empty_items() {
393        let items: Vec<Vec<u8>> = vec![];
394        let receiver = OTReceiver::new(1, 0).unwrap();
395        let request = receiver.create_request();
396
397        let sender = OTSender::new();
398        assert!(sender.respond(&request, &items).is_err());
399    }
400
401    #[test]
402    fn test_large_items() {
403        let items = vec![vec![1u8; 10_000], vec![2u8; 10_000]];
404
405        let receiver = OTReceiver::new(2, 1).unwrap();
406        let request = receiver.create_request();
407
408        let sender = OTSender::new();
409        let response = sender.respond(&request, &items).unwrap();
410
411        let retrieved = receiver.retrieve(&response).unwrap();
412        assert_eq!(retrieved, items[1]);
413    }
414
415    #[test]
416    fn test_empty_item_content() {
417        let items = vec![b"".to_vec(), b"Non-empty".to_vec()];
418
419        let receiver = OTReceiver::new(2, 0).unwrap();
420        let request = receiver.create_request();
421
422        let sender = OTSender::new();
423        let response = sender.respond(&request, &items).unwrap();
424
425        let retrieved = receiver.retrieve(&response).unwrap();
426        assert_eq!(retrieved, items[0]);
427    }
428
429    #[test]
430    fn test_request_serialization() {
431        let receiver = OTReceiver::new(3, 1).unwrap();
432        let request = receiver.create_request();
433
434        let serialized = crate::codec::encode(&request).unwrap();
435        let deserialized: OTRequest = crate::codec::decode(&serialized).unwrap();
436
437        assert_eq!(request.pub_keys.len(), deserialized.pub_keys.len());
438        for (a, b) in request.pub_keys.iter().zip(&deserialized.pub_keys) {
439            assert_eq!(a.to_bytes(), b.to_bytes());
440        }
441    }
442
443    #[test]
444    fn test_response_serialization() {
445        let items = vec![b"Item 1".to_vec(), b"Item 2".to_vec()];
446        let receiver = OTReceiver::new(2, 0).unwrap();
447        let request = receiver.create_request();
448
449        let sender = OTSender::new();
450        let response = sender.respond(&request, &items).unwrap();
451
452        let serialized = crate::codec::encode(&response).unwrap();
453        let deserialized: OTResponse = crate::codec::decode(&serialized).unwrap();
454
455        let retrieved = receiver.retrieve(&deserialized).unwrap();
456        assert_eq!(retrieved, items[0]);
457    }
458
459    #[test]
460    fn test_receiver_properties() {
461        let receiver = OTReceiver::new(5, 2).unwrap();
462        assert_eq!(receiver.choice(), 2);
463        assert_eq!(receiver.n_items(), 5);
464    }
465
466    #[test]
467    fn test_multiple_receivers_same_items() {
468        let items = vec![
469            b"Content A".to_vec(),
470            b"Content B".to_vec(),
471            b"Content C".to_vec(),
472        ];
473
474        // Multiple receivers with different choices
475        let receiver1 = OTReceiver::new(3, 0).unwrap();
476        let receiver2 = OTReceiver::new(3, 2).unwrap();
477
478        let request1 = receiver1.create_request();
479        let request2 = receiver2.create_request();
480
481        let sender = OTSender::new();
482        let response1 = sender.respond(&request1, &items).unwrap();
483        let response2 = sender.respond(&request2, &items).unwrap();
484
485        let retrieved1 = receiver1.retrieve(&response1).unwrap();
486        let retrieved2 = receiver2.retrieve(&response2).unwrap();
487
488        assert_eq!(retrieved1, items[0]);
489        assert_eq!(retrieved2, items[2]);
490    }
491
492    #[test]
493    fn test_wrong_response_to_receiver() {
494        let items1 = vec![b"Set 1 - Item A".to_vec(), b"Set 1 - Item B".to_vec()];
495        let items2 = vec![b"Set 2 - Item X".to_vec(), b"Set 2 - Item Y".to_vec()];
496
497        let receiver = OTReceiver::new(2, 0).unwrap();
498        let request = receiver.create_request();
499
500        let sender = OTSender::new();
501        let response1 = sender.respond(&request, &items1).unwrap();
502        let _response2 = sender.respond(&request, &items2).unwrap();
503
504        // Correct response should work
505        let retrieved = receiver.retrieve(&response1).unwrap();
506        assert_eq!(retrieved, items1[0]);
507    }
508
509    #[test]
510    fn test_1_of_10() {
511        let items: Vec<Vec<u8>> = (0..10)
512            .map(|i| format!("Item {}", i).into_bytes())
513            .collect();
514
515        let receiver = OTReceiver::new(10, 7).unwrap();
516        let request = receiver.create_request();
517
518        let sender = OTSender::new();
519        let response = sender.respond(&request, &items).unwrap();
520
521        let retrieved = receiver.retrieve(&response).unwrap();
522        assert_eq!(retrieved, items[7]);
523    }
524}