1use crate::{configs::OTConfig, error::OTError, util::xor, Message};
5use ark_ec::{AffineRepr, CurveGroup, Group};
6use ark_ff::PrimeField;
7use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
8use ark_std::{cfg_into_iter, cfg_iter, log2, rand::RngCore, vec, vec::Vec, UniformRand};
9use digest::{ExtendableOutput, Update};
10use dock_crypto_utils::msm::WindowTable;
11use itertools::Itertools;
12
13#[cfg(feature = "parallel")]
14use rayon::prelude::*;
15
16#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
18pub struct OTSenderSetup<G: AffineRepr> {
19 pub ot_config: OTConfig,
20 pub r: G::ScalarField,
21 pub C_r: Vec<G>,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
26pub struct OTReceiver<G: AffineRepr> {
27 pub ot_config: OTConfig,
28 pub choices: Vec<u16>,
29 pub k: Vec<G::ScalarField>,
30 pub pk: Vec<G>,
31 pub dk: Vec<G>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
35pub struct SenderPubKey<G: AffineRepr>(
36 pub G,
38 pub Vec<G>,
40);
41
42#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
43pub struct ReceiverPubKey<G: AffineRepr>(pub Vec<G>);
44
45#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
46pub struct SenderEncryptions(Vec<(Vec<Message>, Vec<u8>)>);
47
48impl<G: AffineRepr> OTSenderSetup<G> {
49 pub fn new<R: RngCore>(rng: &mut R, ot_config: OTConfig, g: &G) -> (Self, SenderPubKey<G>) {
51 let r = G::ScalarField::rand(rng);
52 let r_repr = r.into_bigint();
53 let g_r = g.mul_bigint(&r_repr).into_affine();
54 let C_proj = (0..ot_config.num_messages - 1)
55 .map(|_| G::Group::rand(rng))
56 .collect::<Vec<_>>();
57 let C_r = G::Group::normalize_batch(
58 &cfg_iter!(C_proj)
59 .map(|c| c.mul_bigint(&r_repr))
60 .collect::<Vec<_>>(),
61 );
62 (
63 Self { ot_config, r, C_r },
64 SenderPubKey(g_r, G::Group::normalize_batch(&C_proj)),
65 )
66 }
67
68 pub fn encrypt<R: RngCore, D: Default + Update + ExtendableOutput>(
69 &self,
70 rng: &mut R,
71 pk_not: ReceiverPubKey<G>,
72 messages: Vec<Vec<Message>>,
73 ) -> Result<SenderEncryptions, OTError> {
74 let m = self.ot_config.num_ot as usize;
75 let n = self.ot_config.num_messages as usize;
76 if pk_not.0.len() != m {
77 return Err(OTError::IncorrectReceiverPubKeySize(
78 self.ot_config.num_ot,
79 pk_not.0.len() as u16,
80 ));
81 }
82 if messages.len() != m {
83 return Err(OTError::IncorrectMessageBatchSize(
84 self.ot_config.num_ot,
85 messages.len() as u16,
86 ));
87 }
88 if !messages.iter().all(|m| m.len() == n) {
89 return Err(OTError::IncorrectNoOfMessages(self.ot_config.num_messages));
90 }
91
92 let R = (0..m)
93 .map(|_| {
94 let mut bytes = vec![0u8; log2(m) as usize];
95 rng.fill_bytes(&mut bytes);
96 bytes
97 })
98 .collect::<Vec<_>>();
99
100 let C_r = cfg_iter!(self.C_r)
101 .map(|c| c.into_group())
102 .collect::<Vec<_>>();
103 let r_repr = self.r.into_bigint();
104 let enc: Vec<_> = cfg_into_iter!(R)
105 .enumerate()
106 .map(|(i, R)| {
107 let pk_not_i = pk_not.0[i].mul_bigint(r_repr);
108 let mut pk_i = cfg_into_iter!(0..n - 1)
109 .map(|j| C_r[j] - pk_not_i)
110 .collect::<Vec<_>>();
111 pk_i.insert(0, pk_not_i);
112 let pk_i = G::Group::normalize_batch(&pk_i);
113 let enc: Vec<_> = cfg_iter!(messages[i])
114 .enumerate()
115 .map(|(j, m)| {
116 let pad = hash_to_otp::<G, D>(
117 j as u16,
118 &pk_i[j],
119 &R,
120 m.len()
121 .try_into()
122 .map_err(|_| OTError::MessageIsTooLong(m.len()))?,
123 );
124
125 Ok(xor(&pad, m))
126 })
127 .collect::<Result<_, OTError>>()?;
128
129 Ok((enc, R))
130 })
131 .collect::<Result<Vec<_>, OTError>>()?;
132 Ok(SenderEncryptions(enc))
133 }
134}
135
136impl<G: AffineRepr> OTReceiver<G> {
137 pub fn new<R: RngCore>(
138 rng: &mut R,
139 ot_config: OTConfig,
140 choices: Vec<u16>,
141 pub_key: SenderPubKey<G>,
142 g: &G,
143 ) -> Result<(Self, ReceiverPubKey<G>), OTError> {
144 ot_config.verify_receiver_choices(&choices)?;
145 if pub_key.1.len() != ot_config.num_messages as usize - 1 {
146 return Err(OTError::IncorrectSenderPubKeySize(
147 pub_key.1.len() as u16,
148 ot_config.num_messages,
149 ));
150 }
151
152 let k = (0..ot_config.num_ot)
153 .map(|_| G::ScalarField::rand(rng))
154 .collect::<Vec<_>>();
155 let g_table = WindowTable::new(ot_config.num_ot as usize, g.into_group());
156 let g_r_table = WindowTable::new(ot_config.num_ot as usize, pub_key.0.into_group());
157 let (pk, pk_not, dk) = cfg_into_iter!(0..ot_config.num_ot as usize)
158 .map(|i| {
159 let pk = g_table.multiply(&k[i]);
160 let dk = g_r_table.multiply(&k[i]);
161
162 let pk_times_2 = pk.double();
164 let choice = choices[i];
165 let pk_not = if choice == 0 {
166 pk_times_2 - pk
168 } else {
169 pub_key.1[choice as usize - 1].into_group() - pk
170 };
171
172 (pk, pk_not, dk)
173 })
174 .collect::<Vec<_>>()
175 .into_iter()
176 .multiunzip::<(Vec<_>, Vec<_>, Vec<_>)>();
177 Ok((
178 Self {
179 ot_config,
180 choices,
181 k,
182 pk: G::Group::normalize_batch(&pk),
183 dk: G::Group::normalize_batch(&dk),
184 },
185 ReceiverPubKey(G::Group::normalize_batch(&pk_not)),
186 ))
187 }
188
189 pub fn decrypt<D: Default + Update + ExtendableOutput>(
190 &self,
191 sender_encryptions: SenderEncryptions,
192 message_size: u32,
193 ) -> Result<Vec<Message>, OTError> {
194 if sender_encryptions.0.len() != self.ot_config.num_ot as usize {
195 return Err(OTError::IncorrectMessageBatchSize(
196 self.ot_config.num_ot,
197 sender_encryptions.0.len() as u16,
198 ));
199 }
200 if !sender_encryptions
201 .0
202 .iter()
203 .all(|(m, _)| m.len() == self.ot_config.num_messages as usize)
204 {
205 return Err(OTError::IncorrectNoOfMessages(self.ot_config.num_messages));
206 }
207 Ok(cfg_into_iter!(sender_encryptions.0)
208 .enumerate()
209 .map(|(i, (m, r))| {
210 let pad = hash_to_otp::<G, D>(self.choices[i], &self.dk[i], &r, message_size);
211 let m = &m[self.choices[i] as usize];
212 xor(&pad, m)
213 })
214 .collect())
215 }
216}
217
218fn hash_to_otp<G: CanonicalSerialize, D: Default + Update + ExtendableOutput>(
220 index: u16,
221 pk: &G,
222 R: &[u8],
223 pad_size: u32,
224) -> Vec<u8> {
225 let mut bytes = index.to_be_bytes().to_vec();
226 pk.serialize_compressed(&mut bytes).unwrap();
227 bytes.extend_from_slice(R);
228 let mut pad = vec![0; pad_size as usize];
229 let mut hasher = D::default();
230 hasher.update(&bytes);
231 hasher.finalize_xof_into(&mut pad);
232 pad
233}
234
235#[cfg(test)]
236pub mod tests {
237 use super::*;
238 use ark_bls12_381::Bls12_381;
239 use ark_ec::pairing::Pairing;
240 use ark_std::{
241 rand::{rngs::StdRng, SeedableRng},
242 UniformRand,
243 };
244 use sha3::Shake256;
245 use std::time::Instant;
246
247 #[test]
248 fn naor_pinkas_ot() {
249 let mut rng = StdRng::seed_from_u64(0u64);
250 let g = <Bls12_381 as Pairing>::G1Affine::rand(&mut rng);
251
252 fn check(
253 rng: &mut StdRng,
254 m: u16,
255 n: u16,
256 choices: Vec<u16>,
257 g: &<Bls12_381 as Pairing>::G1Affine,
258 ) {
259 let ot_config = OTConfig {
260 num_ot: m,
261 num_messages: n,
262 };
263 let start = Instant::now();
264 let (sender_setup, sender_pk) = OTSenderSetup::new(rng, ot_config, g);
265 println!(
266 "Sender setup for {} 1-of-{} OTs in {:?}",
267 m,
268 n,
269 start.elapsed()
270 );
271
272 let start = Instant::now();
273 let (receiver, pk_not) =
274 OTReceiver::new(rng, ot_config, choices, sender_pk, g).unwrap();
275 println!(
276 "Receiver inits {} 1-of-{} OTs in {:?}",
277 m,
278 n,
279 start.elapsed()
280 );
281
282 let message_size = 200;
283 let messages = (0..m)
284 .map(|_| {
285 (0..n)
286 .map(|_| {
287 let mut bytes = vec![0u8; message_size];
288 rng.fill_bytes(&mut bytes);
289 bytes
290 })
291 .collect::<Vec<_>>()
292 })
293 .collect::<Vec<_>>();
294
295 let start = Instant::now();
296 let encryptions = sender_setup
297 .encrypt::<_, Shake256>(rng, pk_not, messages.clone())
298 .unwrap();
299 println!(
300 "Sender encrypts messages for {} 1-of-{} OTs in {:?}",
301 m,
302 n,
303 start.elapsed()
304 );
305
306 let start = Instant::now();
307 let decryptions = receiver
308 .decrypt::<Shake256>(encryptions, message_size as u32)
309 .unwrap();
310 println!(
311 "Receiver decrypts messages for {} 1-of-{} OTs in {:?}",
312 m,
313 n,
314 start.elapsed()
315 );
316 for i in 0..m as usize {
317 assert_eq!(messages[i][receiver.choices[i] as usize], decryptions[i]);
318 }
319 }
320
321 check(&mut rng, 1, 2, vec![0], &g);
322 check(&mut rng, 1, 2, vec![1], &g);
323 check(&mut rng, 1, 3, vec![0], &g);
324 check(&mut rng, 1, 3, vec![1], &g);
325 check(&mut rng, 1, 3, vec![2], &g);
326 check(&mut rng, 2, 2, vec![0, 0], &g);
327 check(&mut rng, 2, 2, vec![0, 1], &g);
328 check(&mut rng, 2, 2, vec![1, 0], &g);
329 check(&mut rng, 2, 2, vec![1, 1], &g);
330 check(&mut rng, 3, 2, vec![1, 1, 1], &g);
331 check(&mut rng, 3, 2, vec![0, 0, 0], &g);
332 check(&mut rng, 3, 3, vec![0, 1, 2], &g);
333 check(&mut rng, 3, 3, vec![1, 2, 2], &g);
334 check(&mut rng, 3, 3, vec![1, 0, 2], &g);
335 check(&mut rng, 3, 5, vec![4, 0, 1], &g);
336 check(&mut rng, 4, 2, vec![1, 0, 1, 1], &g);
337 check(&mut rng, 4, 3, vec![2, 1, 0, 1], &g);
338 check(&mut rng, 4, 4, vec![3, 2, 1, 0], &g);
339 check(&mut rng, 4, 8, vec![7, 6, 5, 4], &g);
340
341 let choices = (0..32).map(|_| u16::rand(&mut rng) % 2).collect();
342 check(&mut rng, 32, 2, choices, &g);
343
344 let choices = (0..64).map(|_| u16::rand(&mut rng) % 2).collect();
345 check(&mut rng, 64, 2, choices, &g);
346
347 let choices = (0..128).map(|_| u16::rand(&mut rng) % 2).collect();
348 check(&mut rng, 128, 2, choices, &g);
349
350 let choices = (0..192).map(|_| u16::rand(&mut rng) % 2).collect();
351 check(&mut rng, 192, 2, choices, &g);
352 }
353}