1use rustcrypto_ff::Field;
11use std::marker::PhantomData;
12use subtle::ConstantTimeEq;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15use crate::curve::DklsCurve;
16use crate::utilities::hashes::{scalar_to_bytes, tagged_hash, tagged_hash_as_scalar, HashOutput};
17use crate::utilities::oracle_tags::{
18 TAG_MUL_CHI_HAT, TAG_MUL_CHI_TILDE, TAG_MUL_GADGET, TAG_MUL_VERIFY,
19};
20use crate::utilities::proofs::{DLogProof, EncProof};
21use crate::utilities::rng;
22
23#[cfg(feature = "serde")]
24use super::ot::extension::{deserialize_vec_prg, serialize_vec_prg};
25use crate::utilities::ot::base::{OTReceiver, OTSender, Seed};
26use crate::utilities::ot::extension::{
27 OTEDataToSender, OTEReceiver, OTESender, PRGOutput, BATCH_SIZE,
28};
29use crate::utilities::ot::ErrorOT;
30use rand::RngExt;
31
32pub const L: u8 = 2;
34
35pub const OT_WIDTH: u8 = 2 * L;
38
39#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42#[cfg_attr(
43 feature = "serde",
44 serde(bound(
45 serialize = "C::Scalar: serde::Serialize",
46 deserialize = "C::Scalar: serde::Deserialize<'de>"
47 ))
48)]
49pub struct MulSender<C: DklsCurve> {
50 pub public_gadget: Vec<C::Scalar>,
51 pub ote_sender: OTESender,
52 #[zeroize(skip)]
53 #[cfg_attr(feature = "serde", serde(skip))]
54 pub(crate) _curve: PhantomData<C>,
55}
56
57#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[cfg_attr(
61 feature = "serde",
62 serde(bound(
63 serialize = "C::Scalar: serde::Serialize",
64 deserialize = "C::Scalar: serde::Deserialize<'de>"
65 ))
66)]
67pub struct MulReceiver<C: DklsCurve> {
68 pub public_gadget: Vec<C::Scalar>,
69 pub ote_receiver: OTEReceiver,
70 #[zeroize(skip)]
71 #[cfg_attr(feature = "serde", serde(skip))]
72 pub(crate) _curve: PhantomData<C>,
73}
74
75#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78#[cfg_attr(
79 feature = "serde",
80 serde(bound(
81 serialize = "C::Scalar: serde::Serialize",
82 deserialize = "C::Scalar: serde::Deserialize<'de>"
83 ))
84)]
85pub struct MulDataToReceiver<C: DklsCurve> {
86 pub vector_of_tau: Vec<Vec<C::Scalar>>,
87 pub verify_r: HashOutput,
88 pub verify_u: Vec<C::Scalar>,
89 pub gamma_sender: Vec<C::Scalar>,
90}
91
92#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95#[cfg_attr(
96 feature = "serde",
97 serde(bound(
98 serialize = "C::Scalar: serde::Serialize",
99 deserialize = "C::Scalar: serde::Deserialize<'de>"
100 ))
101)]
102pub struct MulDataToKeepReceiver<C: DklsCurve> {
103 pub b: C::Scalar,
104 pub choice_bits: Vec<bool>,
105 #[cfg_attr(
106 feature = "serde",
107 serde(
108 serialize_with = "serialize_vec_prg",
109 deserialize_with = "deserialize_vec_prg"
110 )
111 )]
112 pub extended_seeds: Vec<PRGOutput>,
113 pub chi_tilde: Vec<C::Scalar>,
114 pub chi_hat: Vec<C::Scalar>,
115 #[zeroize(skip)]
116 #[cfg_attr(feature = "serde", serde(skip))]
117 _curve: PhantomData<C>,
118}
119
120#[derive(Debug)]
122pub struct ErrorMul {
123 pub description: String,
124}
125
126impl ErrorMul {
127 #[must_use]
129 pub fn new(description: &str) -> ErrorMul {
130 ErrorMul {
131 description: String::from(description),
132 }
133 }
134}
135
136impl<C: DklsCurve> MulSender<C> {
138 #[must_use]
149 pub fn init_phase1(
150 session_id: &[u8],
151 ) -> (OTReceiver, Vec<bool>, Vec<C::Scalar>, Vec<EncProof<C>>) {
152 OTESender::init_phase1::<C>(session_id)
153 }
154
155 pub fn init_phase2(
165 ot_receiver: &OTReceiver,
166 session_id: &[u8],
167 correlation: Vec<bool>,
168 vec_r: &[C::Scalar],
169 dlog_proof: &DLogProof<C>,
170 nonce: &C::Scalar,
171 ) -> Result<MulSender<C>, ErrorOT> {
172 let ote_sender =
173 OTESender::init_phase2::<C>(ot_receiver, session_id, correlation, vec_r, dlog_proof)?;
174
175 let mut public_gadget: Vec<C::Scalar> = Vec::with_capacity(BATCH_SIZE as usize);
178 let mut counter = *nonce;
179 for _ in 0..BATCH_SIZE {
180 counter += <C::Scalar as Field>::ONE;
181 let counter_bytes = scalar_to_bytes::<C>(&counter);
182 public_gadget.push(tagged_hash_as_scalar::<C>(
183 TAG_MUL_GADGET,
184 &[session_id, &counter_bytes],
185 ));
186 }
187
188 let mul_sender = MulSender {
189 public_gadget,
190 ote_sender,
191 _curve: PhantomData,
192 };
193
194 Ok(mul_sender)
195 }
196
197 pub fn run(
215 &self,
216 session_id: &[u8],
217 input: &[C::Scalar],
218 data: &OTEDataToSender,
219 ) -> Result<(Vec<C::Scalar>, MulDataToReceiver<C>), ErrorMul> {
220 let mut a_tilde: Vec<C::Scalar> = Vec::with_capacity(L as usize);
229 let mut a_hat: Vec<C::Scalar> = Vec::with_capacity(L as usize);
230 for _ in 0..L {
231 a_tilde.push(<C::Scalar as Field>::random(&mut rng::get_rng()));
232 a_hat.push(<C::Scalar as Field>::random(&mut rng::get_rng()));
233 }
234
235 let mut correlation_tilde: Vec<Vec<C::Scalar>> = Vec::with_capacity(L as usize);
246 let mut correlation_hat: Vec<Vec<C::Scalar>> = Vec::with_capacity(L as usize);
247 for i in 0..L {
248 let correlation_tilde_i = vec![a_tilde[i as usize]; BATCH_SIZE as usize];
249 let correlation_hat_i = vec![a_hat[i as usize]; BATCH_SIZE as usize];
250
251 correlation_tilde.push(correlation_tilde_i);
252 correlation_hat.push(correlation_hat_i);
253 }
254
255 let correlations = [correlation_tilde, correlation_hat].concat();
257
258 let ote_sid = ["OT Extension protocol".as_bytes(), session_id].concat();
273
274 let result = self
275 .ote_sender
276 .run::<C>(&ote_sid, OT_WIDTH, &correlations, data);
277
278 let ot_outputs: Vec<Vec<C::Scalar>>;
279 let vector_of_tau: Vec<Vec<C::Scalar>>; match result {
281 Ok((out, tau)) => {
282 (ot_outputs, vector_of_tau) = (out, tau);
283 }
284 Err(error) => {
285 return Err(ErrorMul::new(&format!(
286 "OTE error during multiplication: {:?}",
287 error.description
288 )));
289 }
290 }
291
292 let (z_tilde, z_hat) = ot_outputs.split_at(L as usize);
294
295 let transcript = [
299 data.u.concat(),
300 data.verify_x.to_vec(),
301 data.verify_t.concat(),
302 ]
303 .concat();
304
305 let mut chi_tilde: Vec<C::Scalar> = Vec::with_capacity(L as usize);
308 let mut chi_hat: Vec<C::Scalar> = Vec::with_capacity(L as usize);
309 for i in 0..L {
310 chi_tilde.push(tagged_hash_as_scalar::<C>(
311 TAG_MUL_CHI_TILDE,
312 &[session_id, &i.to_be_bytes(), &transcript],
313 ));
314 chi_hat.push(tagged_hash_as_scalar::<C>(
315 TAG_MUL_CHI_HAT,
316 &[session_id, &i.to_be_bytes(), &transcript],
317 ));
318 }
319
320 let mut rows_r_as_bytes: Vec<Vec<u8>> = Vec::with_capacity(L as usize);
329 let mut verify_u: Vec<C::Scalar> = Vec::with_capacity(L as usize);
330 for i in 0..L {
331 let mut entries_as_bytes: Vec<Vec<u8>> = Vec::with_capacity(BATCH_SIZE as usize);
333 for j in 0..BATCH_SIZE {
334 let entry = (chi_tilde[i as usize] * z_tilde[i as usize][j as usize])
335 + (chi_hat[i as usize] * z_hat[i as usize][j as usize]);
336 let entry_as_bytes = scalar_to_bytes::<C>(&entry);
337 entries_as_bytes.push(entry_as_bytes);
338 }
339 let row_i_as_bytes = entries_as_bytes.concat();
340 rows_r_as_bytes.push(row_i_as_bytes);
341
342 let entry = (chi_tilde[i as usize] * a_tilde[i as usize])
344 + (chi_hat[i as usize] * a_hat[i as usize]);
345 verify_u.push(entry);
346 }
347 let r_as_bytes = rows_r_as_bytes.concat();
348
349 let verify_r: HashOutput = tagged_hash(TAG_MUL_VERIFY, &[session_id, &r_as_bytes]);
351
352 let mut gamma: Vec<C::Scalar> = Vec::with_capacity(L as usize);
359 for i in 0..L {
360 let difference = input[i as usize] - a_tilde[i as usize];
361 gamma.push(difference);
362 }
363
364 let mut output: Vec<C::Scalar> = Vec::with_capacity(L as usize);
368 for i in 0..L {
369 let mut summation = <C::Scalar as Field>::ZERO;
370 for j in 0..BATCH_SIZE {
371 summation += self.public_gadget[j as usize] * z_tilde[i as usize][j as usize];
372 }
373 output.push(summation);
374 }
375
376 let data_to_receiver = MulDataToReceiver {
379 vector_of_tau,
380 verify_r,
381 verify_u,
382 gamma_sender: gamma,
383 };
384
385 Ok((output, data_to_receiver))
386 }
387}
388
389impl<C: DklsCurve> MulReceiver<C> {
390 #[must_use]
405 pub fn init_phase1(session_id: &[u8]) -> (OTSender<C>, DLogProof<C>, C::Scalar) {
406 let (ot_sender, proof) = OTEReceiver::init_phase1::<C>(session_id);
407
408 let nonce = <C::Scalar as Field>::random(&mut rng::get_rng());
412
413 (ot_sender, proof, nonce)
414 }
415
416 pub fn init_phase2(
425 ot_sender: &OTSender<C>,
426 session_id: &[u8],
427 seed: &Seed,
428 enc_proofs: &[EncProof<C>],
429 nonce: &C::Scalar,
430 ) -> Result<MulReceiver<C>, ErrorOT> {
431 let ote_receiver = OTEReceiver::init_phase2::<C>(ot_sender, session_id, seed, enc_proofs)?;
432
433 let mut public_gadget: Vec<C::Scalar> = Vec::with_capacity(BATCH_SIZE as usize);
436 let mut counter = *nonce;
437 for _ in 0..BATCH_SIZE {
438 counter += <C::Scalar as Field>::ONE;
439 let counter_bytes = scalar_to_bytes::<C>(&counter);
440 public_gadget.push(tagged_hash_as_scalar::<C>(
441 TAG_MUL_GADGET,
442 &[session_id, &counter_bytes],
443 ));
444 }
445
446 let mul_receiver = MulReceiver {
447 public_gadget,
448 ote_receiver,
449 _curve: PhantomData,
450 };
451
452 Ok(mul_receiver)
453 }
454
455 pub fn run_phase1(
475 &self,
476 session_id: &[u8],
477 ) -> Result<(C::Scalar, MulDataToKeepReceiver<C>, OTEDataToSender), ErrorMul> {
478 let mut choice_bits: Vec<bool> = Vec::with_capacity(BATCH_SIZE as usize);
487 let mut b = <C::Scalar as Field>::ZERO;
488 for i in 0..BATCH_SIZE {
489 let current_bit: bool = rng::get_rng().random();
490 if current_bit {
491 b += &self.public_gadget[i as usize];
492 }
493 choice_bits.push(current_bit);
494 }
495
496 let ote_sid = ["OT Extension protocol".as_bytes(), session_id].concat();
505
506 let (extended_seeds, data_to_sender) =
507 match self.ote_receiver.run_phase1(&ote_sid, &choice_bits) {
508 Ok(values) => values,
509 Err(error) => {
510 return Err(ErrorMul::new(&format!(
511 "OTE error during multiplication: {:?}",
512 error.description
513 )));
514 }
515 };
516
517 let transcript = [
521 data_to_sender.u.concat(),
522 data_to_sender.verify_x.to_vec(),
523 data_to_sender.verify_t.concat(),
524 ]
525 .concat();
526
527 let mut chi_tilde: Vec<C::Scalar> = Vec::with_capacity(L as usize);
530 let mut chi_hat: Vec<C::Scalar> = Vec::with_capacity(L as usize);
531 for i in 0..L {
532 chi_tilde.push(tagged_hash_as_scalar::<C>(
533 TAG_MUL_CHI_TILDE,
534 &[session_id, &i.to_be_bytes(), &transcript],
535 ));
536 chi_hat.push(tagged_hash_as_scalar::<C>(
537 TAG_MUL_CHI_HAT,
538 &[session_id, &i.to_be_bytes(), &transcript],
539 ));
540 }
541
542 let data_to_keep = MulDataToKeepReceiver {
548 b,
549 choice_bits,
550 extended_seeds,
551 chi_tilde,
552 chi_hat,
553 _curve: PhantomData,
554 };
555
556 Ok((b, data_to_keep, data_to_sender))
557 }
558
559 pub fn run_phase2(
570 &self,
571 session_id: &[u8],
572 data_kept: &MulDataToKeepReceiver<C>,
573 data_received: &MulDataToReceiver<C>,
574 ) -> Result<Vec<C::Scalar>, ErrorMul> {
575 if data_received.verify_u.len() != L as usize
576 || data_received.gamma_sender.len() != L as usize
577 {
578 return Err(ErrorMul::new("Received data has incorrect dimensions"));
579 }
580
581 let ote_sid = ["OT Extension protocol".as_bytes(), session_id].concat();
588
589 let result = self.ote_receiver.run_phase2::<C>(
590 &ote_sid,
591 OT_WIDTH,
592 &data_kept.choice_bits,
593 &data_kept.extended_seeds,
594 &data_received.vector_of_tau,
595 );
596
597 let ot_outputs: Vec<Vec<C::Scalar>> = match result {
598 Ok(out) => out,
599 Err(error) => {
600 return Err(ErrorMul::new(&format!(
601 "OTE error during multiplication: {:?}",
602 error.description
603 )));
604 }
605 };
606
607 let (z_tilde, z_hat) = ot_outputs.split_at(L as usize);
609
610 let mut rows_r_as_bytes: Vec<Vec<u8>> = Vec::with_capacity(L as usize);
620 for i in 0..L {
621 let mut entries_as_bytes: Vec<Vec<u8>> = Vec::with_capacity(BATCH_SIZE as usize);
623 for j in 0..BATCH_SIZE {
624 let mut entry = (-(data_kept.chi_tilde[i as usize]
626 * z_tilde[i as usize][j as usize]))
627 - (data_kept.chi_hat[i as usize] * z_hat[i as usize][j as usize]);
628 if data_kept.choice_bits[j as usize] {
629 entry += &data_received.verify_u[i as usize];
630 }
631
632 let entry_as_bytes = scalar_to_bytes::<C>(&entry);
633 entries_as_bytes.push(entry_as_bytes);
634 }
635 let row_i_as_bytes = entries_as_bytes.concat();
636 rows_r_as_bytes.push(row_i_as_bytes);
637 }
638 let r_as_bytes = rows_r_as_bytes.concat();
639
640 let expected_verify_r: HashOutput = tagged_hash(TAG_MUL_VERIFY, &[session_id, &r_as_bytes]);
642
643 if !bool::from(data_received.verify_r.ct_eq(&expected_verify_r)) {
646 return Err(ErrorMul::new(
647 "Sender cheated in multiplication protocol: Consistency check failed!",
648 ));
649 }
650
651 let mut output: Vec<C::Scalar> = Vec::with_capacity(L as usize);
660 for i in 0..L {
661 let mut summation = <C::Scalar as Field>::ZERO;
662 for j in 0..BATCH_SIZE {
663 summation += self.public_gadget[j as usize] * z_tilde[i as usize][j as usize];
664 }
665 let final_sum = (data_kept.b * data_received.gamma_sender[i as usize]) + summation;
666 output.push(final_sum);
667 }
668
669 Ok(output)
670 }
671}
672
673#[cfg(test)]
674mod tests {
675 use super::*;
676 use elliptic_curve::CurveArithmetic;
677 use k256::Secp256k1;
678 use rand::RngExt;
679
680 type TestCurve = Secp256k1;
681 type Scalar = <TestCurve as CurveArithmetic>::Scalar;
682
683 fn prepare_mul_receiver_inputs(
684 session_id: &[u8; 32],
685 ) -> (
686 MulReceiver<TestCurve>,
687 MulDataToKeepReceiver<TestCurve>,
688 MulDataToReceiver<TestCurve>,
689 ) {
690 let (ot_sender, dlog_proof, nonce) = MulReceiver::<TestCurve>::init_phase1(session_id);
692 let (ot_receiver, correlation, vec_r, enc_proofs) =
693 MulSender::<TestCurve>::init_phase1(session_id);
694 let seed = ot_receiver.seed;
695
696 let mul_receiver =
697 match MulReceiver::init_phase2(&ot_sender, session_id, &seed, &enc_proofs, &nonce) {
698 Ok(r) => r,
699 Err(error) => {
700 panic!("Two-party multiplication error: {:?}", error.description);
701 }
702 };
703 let mul_sender = match MulSender::init_phase2(
704 &ot_receiver,
705 session_id,
706 correlation,
707 &vec_r,
708 &dlog_proof,
709 &nonce,
710 ) {
711 Ok(s) => s,
712 Err(error) => {
713 panic!("Two-party multiplication error: {:?}", error.description);
714 }
715 };
716
717 let mut sender_input: Vec<Scalar> = Vec::with_capacity(L as usize);
719 for _ in 0..L {
720 sender_input.push(Scalar::random(&mut rng::get_rng()));
721 }
722
723 let (_, data_to_keep, data_to_sender) = mul_receiver
724 .run_phase1(session_id)
725 .expect("mul receiver phase1 should succeed");
726 let (_, data_to_receiver) = match mul_sender.run(session_id, &sender_input, &data_to_sender)
727 {
728 Ok(result) => result,
729 Err(error) => {
730 panic!("Two-party multiplication error: {:?}", error.description);
731 }
732 };
733
734 (mul_receiver, data_to_keep, data_to_receiver)
735 }
736
737 #[test]
740 fn test_multiplication() {
741 let session_id = rng::get_rng().random::<[u8; 32]>();
742
743 let (ot_sender, dlog_proof, nonce) = MulReceiver::<TestCurve>::init_phase1(&session_id);
747
748 let (ot_receiver, correlation, vec_r, enc_proofs) =
750 MulSender::<TestCurve>::init_phase1(&session_id);
751
752 let seed = ot_receiver.seed;
756
757 let result_receiver =
759 MulReceiver::init_phase2(&ot_sender, &session_id, &seed, &enc_proofs, &nonce);
760 let mul_receiver = match result_receiver {
761 Ok(r) => r,
762 Err(error) => {
763 panic!("Two-party multiplication error: {:?}", error.description);
764 }
765 };
766
767 let result_sender = MulSender::init_phase2(
769 &ot_receiver,
770 &session_id,
771 correlation,
772 &vec_r,
773 &dlog_proof,
774 &nonce,
775 );
776 let mul_sender = match result_sender {
777 Ok(s) => s,
778 Err(error) => {
779 panic!("Two-party multiplication error: {:?}", error.description);
780 }
781 };
782
783 let mut sender_input: Vec<Scalar> = Vec::with_capacity(L as usize);
787 for _ in 0..L {
788 sender_input.push(Scalar::random(&mut rng::get_rng()));
789 }
790
791 let (receiver_random, data_to_keep, data_to_sender) = mul_receiver
793 .run_phase1(&session_id)
794 .expect("mul receiver phase1 should succeed");
795
796 let sender_result = mul_sender.run(&session_id, &sender_input, &data_to_sender);
802
803 let sender_output: Vec<Scalar>;
804 let data_to_receiver: MulDataToReceiver<TestCurve>;
805 match sender_result {
806 Ok((output, data)) => {
807 sender_output = output;
808 data_to_receiver = data;
809 }
810 Err(error) => {
811 panic!("Two-party multiplication error: {:?}", error.description);
812 }
813 }
814
815 let receiver_result =
820 mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver);
821
822 let receiver_output = match receiver_result {
823 Ok(output) => output,
824 Err(error) => {
825 panic!("Two-party multiplication error: {:?}", error.description);
826 }
827 };
828
829 for i in 0..L {
831 let sum = sender_output[i as usize] + receiver_output[i as usize];
834 assert_eq!(sum, sender_input[i as usize] * receiver_random);
835 }
836 }
837
838 #[test]
840 fn test_multiplication_receiver_rejects_wrong_verify_vector_lengths() {
841 let session_id = rng::get_rng().random::<[u8; 32]>();
842
843 let (ot_sender, dlog_proof, nonce) = MulReceiver::<TestCurve>::init_phase1(&session_id);
845 let (ot_receiver, correlation, vec_r, enc_proofs) =
846 MulSender::<TestCurve>::init_phase1(&session_id);
847 let seed = ot_receiver.seed;
848
849 let mul_receiver =
850 MulReceiver::init_phase2(&ot_sender, &session_id, &seed, &enc_proofs, &nonce)
851 .expect("mul receiver init should succeed");
852 let mul_sender = MulSender::init_phase2(
853 &ot_receiver,
854 &session_id,
855 correlation,
856 &vec_r,
857 &dlog_proof,
858 &nonce,
859 )
860 .expect("mul sender init should succeed");
861
862 let mut sender_input: Vec<Scalar> = Vec::with_capacity(L as usize);
864 for _ in 0..L {
865 sender_input.push(Scalar::random(&mut rng::get_rng()));
866 }
867
868 let (_, data_to_keep, data_to_sender) = mul_receiver
869 .run_phase1(&session_id)
870 .expect("mul receiver phase1 should succeed");
871 let (_, mut data_to_receiver) =
872 match mul_sender.run(&session_id, &sender_input, &data_to_sender) {
873 Ok(result) => result,
874 Err(error) => {
875 panic!("Two-party multiplication error: {:?}", error.description);
876 }
877 };
878
879 data_to_receiver.verify_u.pop();
880 let result = mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver);
881 let error = result.expect_err("wrong verify_u length should fail");
882 assert!(error.description.contains("incorrect dimensions"));
883 }
884
885 #[test]
887 fn test_multiplication_rejects_tampered_verify_r() {
888 let session_id = rng::get_rng().random::<[u8; 32]>();
889 let (mul_receiver, data_to_keep, mut data_to_receiver) =
890 prepare_mul_receiver_inputs(&session_id);
891
892 data_to_receiver.verify_r[0] ^= 1;
893
894 let result = mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver);
895 let error = result.expect_err("tampered verify_r should fail");
896 assert!(error.description.contains("Consistency check failed"));
897 }
898
899 #[test]
901 fn test_multiplication_rejects_tampered_verify_u() {
902 let session_id = rng::get_rng().random::<[u8; 32]>();
903 let (mul_receiver, data_to_keep, mut data_to_receiver) =
904 prepare_mul_receiver_inputs(&session_id);
905
906 data_to_receiver.verify_u[0] += Scalar::ONE;
907
908 let result = mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver);
909 let error = result.expect_err("tampered verify_u should fail");
910 assert!(error.description.contains("Consistency check failed"));
911 }
912
913 #[test]
918 fn test_multiplication_rejects_tampered_tau_vectors() {
919 let session_id = rng::get_rng().random::<[u8; 32]>();
920 let (mul_receiver, data_to_keep, mut data_to_receiver) =
921 prepare_mul_receiver_inputs(&session_id);
922
923 let honest_output =
924 match mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver) {
925 Ok(output) => output,
926 Err(error) => {
927 panic!("Two-party multiplication error: {:?}", error.description);
928 }
929 };
930
931 for tau_row in &mut data_to_receiver.vector_of_tau {
932 for value in tau_row {
933 *value += Scalar::ONE;
934 }
935 }
936
937 let result = mul_receiver.run_phase2(&session_id, &data_to_keep, &data_to_receiver);
938 match result {
939 Err(error) => {
940 assert!(
941 error
942 .description
943 .contains("OTE error during multiplication")
944 || error.description.contains("Consistency check failed")
945 );
946 }
947 Ok(tampered_output) => {
948 assert_ne!(
949 tampered_output, honest_output,
950 "tampered tau vectors should not preserve honest receiver output"
951 );
952 }
953 }
954 }
955}