1use blake3::Hasher;
46use chacha20poly1305::{
47 ChaCha20Poly1305, Nonce,
48 aead::{Aead, KeyInit},
49};
50use curve25519_dalek::{
51 constants::RISTRETTO_BASEPOINT_TABLE,
52 ristretto::{CompressedRistretto, RistrettoPoint},
53 scalar::Scalar,
54};
55use rand::Rng;
56use serde::{Deserialize, Serialize};
57
58#[derive(Debug, Clone, PartialEq, Eq)]
60pub enum OTError {
61 InvalidChoice,
63 InvalidItemCount,
65 InvalidRequest,
67 InvalidResponse,
69 DecryptionFailed,
71 EncryptionFailed,
73 InvalidPublicKey,
75 MismatchedItemCount,
77}
78
79impl std::fmt::Display for OTError {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 OTError::InvalidChoice => write!(f, "Invalid choice index"),
83 OTError::InvalidItemCount => write!(f, "Invalid number of items"),
84 OTError::InvalidRequest => write!(f, "Invalid request format"),
85 OTError::InvalidResponse => write!(f, "Invalid response format"),
86 OTError::DecryptionFailed => write!(f, "Decryption failed"),
87 OTError::EncryptionFailed => write!(f, "Encryption failed"),
88 OTError::InvalidPublicKey => write!(f, "Invalid public key"),
89 OTError::MismatchedItemCount => write!(f, "Mismatched item count"),
90 }
91 }
92}
93
94impl std::error::Error for OTError {}
95
96pub type OTResult<T> = Result<T, OTError>;
98
99#[derive(Clone, Serialize, Deserialize)]
101pub struct OTRequest {
102 pub_keys: Vec<CompressedRistretto>,
104}
105
106#[derive(Clone, Serialize, Deserialize)]
108pub struct OTResponse {
109 encrypted_items: Vec<EncryptedItem>,
111}
112
113#[derive(Clone, Serialize, Deserialize)]
115struct EncryptedItem {
116 ephemeral_pk: CompressedRistretto,
118 ciphertext: Vec<u8>,
120 nonce: [u8; 12],
122}
123
124pub struct OTReceiver {
126 n_items: usize,
128 choice: usize,
130 chosen_sk: Scalar,
132 pub_keys: Vec<CompressedRistretto>,
134}
135
136impl OTReceiver {
137 pub fn new(n_items: usize, choice: usize) -> OTResult<Self> {
143 if n_items == 0 {
144 return Err(OTError::InvalidItemCount);
145 }
146 if choice >= n_items {
147 return Err(OTError::InvalidChoice);
148 }
149
150 let mut rng = rand::thread_rng();
151 let mut pub_keys = Vec::with_capacity(n_items);
152
153 let mut sk_bytes = [0u8; 32];
155 rng.fill(&mut sk_bytes);
156 let chosen_sk = Scalar::from_bytes_mod_order(sk_bytes);
157 let chosen_pk = &chosen_sk * RISTRETTO_BASEPOINT_TABLE;
158
159 for i in 0..n_items {
161 if i == choice {
162 pub_keys.push(chosen_pk.compress());
164 } else {
165 let mut random_bytes = [0u8; 32];
167 rng.fill(&mut random_bytes);
168 let random_sk = Scalar::from_bytes_mod_order(random_bytes);
169 let random_pk = &random_sk * RISTRETTO_BASEPOINT_TABLE;
170 pub_keys.push(random_pk.compress());
171 }
172 }
173
174 Ok(Self {
175 n_items,
176 choice,
177 chosen_sk,
178 pub_keys,
179 })
180 }
181
182 pub fn create_request(&self) -> OTRequest {
184 OTRequest {
185 pub_keys: self.pub_keys.clone(),
186 }
187 }
188
189 pub fn retrieve(&self, response: &OTResponse) -> OTResult<Vec<u8>> {
191 if response.encrypted_items.len() != self.n_items {
192 return Err(OTError::MismatchedItemCount);
193 }
194
195 let item = &response.encrypted_items[self.choice];
196
197 let ephemeral_pk = item
199 .ephemeral_pk
200 .decompress()
201 .ok_or(OTError::InvalidPublicKey)?;
202
203 let shared_point = ephemeral_pk * self.chosen_sk;
205
206 let sym_key = derive_ot_key(&shared_point);
208
209 let cipher = ChaCha20Poly1305::new(&sym_key.into());
211 let nonce = Nonce::from_slice(&item.nonce);
212
213 cipher
214 .decrypt(nonce, item.ciphertext.as_ref())
215 .map_err(|_| OTError::DecryptionFailed)
216 }
217
218 pub fn choice(&self) -> usize {
220 self.choice
221 }
222
223 pub fn n_items(&self) -> usize {
225 self.n_items
226 }
227}
228
229pub struct OTSender;
231
232impl OTSender {
233 pub fn new() -> Self {
235 Self
236 }
237
238 pub fn respond(&self, request: &OTRequest, items: &[Vec<u8>]) -> OTResult<OTResponse> {
244 if items.len() != request.pub_keys.len() {
245 return Err(OTError::MismatchedItemCount);
246 }
247 if items.is_empty() {
248 return Err(OTError::InvalidItemCount);
249 }
250
251 let mut rng = rand::thread_rng();
252 let mut encrypted_items = Vec::with_capacity(items.len());
253
254 for (item, pk_compressed) in items.iter().zip(&request.pub_keys) {
256 let pk = pk_compressed
258 .decompress()
259 .ok_or(OTError::InvalidPublicKey)?;
260
261 let mut ephemeral_sk_bytes = [0u8; 32];
263 rng.fill(&mut ephemeral_sk_bytes);
264 let ephemeral_sk = Scalar::from_bytes_mod_order(ephemeral_sk_bytes);
265 let ephemeral_pk = &ephemeral_sk * RISTRETTO_BASEPOINT_TABLE;
266
267 let shared_point = pk * ephemeral_sk;
269
270 let sym_key = derive_ot_key(&shared_point);
272
273 let mut nonce_bytes = [0u8; 12];
275 rng.fill(&mut nonce_bytes);
276 let nonce = Nonce::from_slice(&nonce_bytes);
277
278 let cipher = ChaCha20Poly1305::new(&sym_key.into());
280 let ciphertext = cipher
281 .encrypt(nonce, item.as_ref())
282 .map_err(|_| OTError::EncryptionFailed)?;
283
284 encrypted_items.push(EncryptedItem {
285 ephemeral_pk: ephemeral_pk.compress(),
286 ciphertext,
287 nonce: nonce_bytes,
288 });
289 }
290
291 Ok(OTResponse { encrypted_items })
292 }
293}
294
295impl Default for OTSender {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301fn derive_ot_key(point: &RistrettoPoint) -> [u8; 32] {
303 let mut hasher = Hasher::new();
304 hasher.update(b"chie-ot-v1");
305 hasher.update(&point.compress().to_bytes());
306 let hash = hasher.finalize();
307 *hash.as_bytes()
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_basic_ot_1_of_2() {
316 let items = vec![b"First item".to_vec(), b"Second item".to_vec()];
317
318 let receiver = OTReceiver::new(2, 0).unwrap();
320 let request = receiver.create_request();
321
322 let sender = OTSender::new();
324 let response = sender.respond(&request, &items).unwrap();
325
326 let retrieved = receiver.retrieve(&response).unwrap();
328 assert_eq!(retrieved, items[0]);
329 }
330
331 #[test]
332 fn test_basic_ot_1_of_3() {
333 let items = vec![b"Item 0".to_vec(), b"Item 1".to_vec(), b"Item 2".to_vec()];
334
335 let receiver = OTReceiver::new(3, 1).unwrap();
337 let request = receiver.create_request();
338
339 let sender = OTSender::new();
341 let response = sender.respond(&request, &items).unwrap();
342
343 let retrieved = receiver.retrieve(&response).unwrap();
345 assert_eq!(retrieved, items[1]);
346 }
347
348 #[test]
349 fn test_ot_all_choices() {
350 let items = vec![
351 b"Alpha".to_vec(),
352 b"Beta".to_vec(),
353 b"Gamma".to_vec(),
354 b"Delta".to_vec(),
355 ];
356
357 for choice in 0..items.len() {
359 let receiver = OTReceiver::new(items.len(), choice).unwrap();
360 let request = receiver.create_request();
361
362 let sender = OTSender::new();
363 let response = sender.respond(&request, &items).unwrap();
364
365 let retrieved = receiver.retrieve(&response).unwrap();
366 assert_eq!(retrieved, items[choice]);
367 }
368 }
369
370 #[test]
371 fn test_invalid_choice() {
372 assert!(OTReceiver::new(3, 3).is_err());
373 assert!(OTReceiver::new(3, 100).is_err());
374 }
375
376 #[test]
377 fn test_invalid_item_count() {
378 assert!(OTReceiver::new(0, 0).is_err());
379 }
380
381 #[test]
382 fn test_mismatched_item_count() {
383 let items = vec![b"Item 1".to_vec(), b"Item 2".to_vec()];
384 let receiver = OTReceiver::new(3, 0).unwrap();
385 let request = receiver.create_request();
386
387 let sender = OTSender::new();
388 assert!(sender.respond(&request, &items).is_err());
389 }
390
391 #[test]
392 fn test_empty_items() {
393 let items: Vec<Vec<u8>> = vec![];
394 let receiver = OTReceiver::new(1, 0).unwrap();
395 let request = receiver.create_request();
396
397 let sender = OTSender::new();
398 assert!(sender.respond(&request, &items).is_err());
399 }
400
401 #[test]
402 fn test_large_items() {
403 let items = vec![vec![1u8; 10_000], vec![2u8; 10_000]];
404
405 let receiver = OTReceiver::new(2, 1).unwrap();
406 let request = receiver.create_request();
407
408 let sender = OTSender::new();
409 let response = sender.respond(&request, &items).unwrap();
410
411 let retrieved = receiver.retrieve(&response).unwrap();
412 assert_eq!(retrieved, items[1]);
413 }
414
415 #[test]
416 fn test_empty_item_content() {
417 let items = vec![b"".to_vec(), b"Non-empty".to_vec()];
418
419 let receiver = OTReceiver::new(2, 0).unwrap();
420 let request = receiver.create_request();
421
422 let sender = OTSender::new();
423 let response = sender.respond(&request, &items).unwrap();
424
425 let retrieved = receiver.retrieve(&response).unwrap();
426 assert_eq!(retrieved, items[0]);
427 }
428
429 #[test]
430 fn test_request_serialization() {
431 let receiver = OTReceiver::new(3, 1).unwrap();
432 let request = receiver.create_request();
433
434 let serialized = crate::codec::encode(&request).unwrap();
435 let deserialized: OTRequest = crate::codec::decode(&serialized).unwrap();
436
437 assert_eq!(request.pub_keys.len(), deserialized.pub_keys.len());
438 for (a, b) in request.pub_keys.iter().zip(&deserialized.pub_keys) {
439 assert_eq!(a.to_bytes(), b.to_bytes());
440 }
441 }
442
443 #[test]
444 fn test_response_serialization() {
445 let items = vec![b"Item 1".to_vec(), b"Item 2".to_vec()];
446 let receiver = OTReceiver::new(2, 0).unwrap();
447 let request = receiver.create_request();
448
449 let sender = OTSender::new();
450 let response = sender.respond(&request, &items).unwrap();
451
452 let serialized = crate::codec::encode(&response).unwrap();
453 let deserialized: OTResponse = crate::codec::decode(&serialized).unwrap();
454
455 let retrieved = receiver.retrieve(&deserialized).unwrap();
456 assert_eq!(retrieved, items[0]);
457 }
458
459 #[test]
460 fn test_receiver_properties() {
461 let receiver = OTReceiver::new(5, 2).unwrap();
462 assert_eq!(receiver.choice(), 2);
463 assert_eq!(receiver.n_items(), 5);
464 }
465
466 #[test]
467 fn test_multiple_receivers_same_items() {
468 let items = vec![
469 b"Content A".to_vec(),
470 b"Content B".to_vec(),
471 b"Content C".to_vec(),
472 ];
473
474 let receiver1 = OTReceiver::new(3, 0).unwrap();
476 let receiver2 = OTReceiver::new(3, 2).unwrap();
477
478 let request1 = receiver1.create_request();
479 let request2 = receiver2.create_request();
480
481 let sender = OTSender::new();
482 let response1 = sender.respond(&request1, &items).unwrap();
483 let response2 = sender.respond(&request2, &items).unwrap();
484
485 let retrieved1 = receiver1.retrieve(&response1).unwrap();
486 let retrieved2 = receiver2.retrieve(&response2).unwrap();
487
488 assert_eq!(retrieved1, items[0]);
489 assert_eq!(retrieved2, items[2]);
490 }
491
492 #[test]
493 fn test_wrong_response_to_receiver() {
494 let items1 = vec![b"Set 1 - Item A".to_vec(), b"Set 1 - Item B".to_vec()];
495 let items2 = vec![b"Set 2 - Item X".to_vec(), b"Set 2 - Item Y".to_vec()];
496
497 let receiver = OTReceiver::new(2, 0).unwrap();
498 let request = receiver.create_request();
499
500 let sender = OTSender::new();
501 let response1 = sender.respond(&request, &items1).unwrap();
502 let _response2 = sender.respond(&request, &items2).unwrap();
503
504 let retrieved = receiver.retrieve(&response1).unwrap();
506 assert_eq!(retrieved, items1[0]);
507 }
508
509 #[test]
510 fn test_1_of_10() {
511 let items: Vec<Vec<u8>> = (0..10)
512 .map(|i| format!("Item {}", i).into_bytes())
513 .collect();
514
515 let receiver = OTReceiver::new(10, 7).unwrap();
516 let request = receiver.create_request();
517
518 let sender = OTSender::new();
519 let response = sender.respond(&request, &items).unwrap();
520
521 let retrieved = receiver.retrieve(&response).unwrap();
522 assert_eq!(retrieved, items[7]);
523 }
524}