1use core::{
13 iter::zip,
14 ops::{Add, AddAssign, BitXor, BitXorAssign, Index, Sub},
15};
16
17use bitvec::prelude::{BitVec, Lsb0};
18use rand::{rng, Rng, RngCore};
19use std::fmt::Debug;
20use std::io::{Cursor, Read};
21use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
22
23use crate::{
24 bt::{BinaryTree, Node},
25 codec::{CodecError, Decode, Encode, ParameterizedDecode},
26 field::FieldElement,
27 idpf::{conditional_swap_seed, conditional_xor_seeds, xor_seeds, IdpfInput, IdpfValue},
28 vdaf::{
29 mastic,
30 xof::{Seed, Xof, XofFixedKeyAes128, XofTurboShake128},
31 },
32};
33
34#[derive(Debug, thiserror::Error)]
36#[non_exhaustive]
37pub enum VidpfError {
38 #[error("bit length too long")]
40 BitLengthTooLong,
41
42 #[error("invalid input length")]
44 InvalidInputLength,
45
46 #[error("invalid weight length")]
48 InvalidWeightLength,
49}
50
51pub type VidpfInput = IdpfInput;
53
54pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {}
56
57#[derive(Clone, Debug)]
58pub struct Vidpf<W: VidpfValue> {
60 pub(crate) bits: u16,
61 pub(crate) weight_len: W::ValueParameter,
62}
63
64impl<W: VidpfValue> Vidpf<W> {
65 pub fn new(bits: usize, weight_len: W::ValueParameter) -> Result<Self, VidpfError> {
72 let bits = u16::try_from(bits).map_err(|_| VidpfError::BitLengthTooLong)?;
73 Ok(Self { bits, weight_len })
74 }
75
76 pub fn gen(
95 &self,
96 ctx: &[u8],
97 input: &VidpfInput,
98 weight: &W,
99 nonce: &[u8],
100 ) -> Result<(VidpfPublicShare<W>, [VidpfKey; 2]), VidpfError> {
101 let mut rng = rng();
102 let keys = rng.random();
103 let public = self.gen_with_keys(ctx, &keys, input, weight, nonce)?;
104 Ok((public, keys))
105 }
106
107 pub(crate) fn gen_with_keys(
109 &self,
110 ctx: &[u8],
111 keys: &[VidpfKey; 2],
112 input: &VidpfInput,
113 weight: &W,
114 nonce: &[u8],
115 ) -> Result<VidpfPublicShare<W>, VidpfError> {
116 let mut seed = [keys[0].0, keys[1].0];
117 let mut ctrl = [
118 Choice::from(VidpfServerId::S0),
119 Choice::from(VidpfServerId::S1),
120 ];
121
122 let mut cw = Vec::with_capacity(input.len());
123 for idx in self.index_iter(input)? {
124 let bit = idx.bit;
125
126 let e = [
128 Self::extend(seed[0], ctx, nonce),
129 Self::extend(seed[1], ctx, nonce),
130 ];
131
132 let (seed_keep_0, seed_lose_0) = &mut (e[0].seed_right, e[0].seed_left);
134 conditional_swap_seed(seed_keep_0, seed_lose_0, !bit);
135 let (seed_keep_1, seed_lose_1) = &mut (e[1].seed_right, e[1].seed_left);
136 conditional_swap_seed(seed_keep_1, seed_lose_1, !bit);
137 let ctrl_keep_0 = Choice::conditional_select(&e[0].ctrl_left, &e[0].ctrl_right, bit);
138 let ctrl_keep_1 = Choice::conditional_select(&e[1].ctrl_left, &e[1].ctrl_right, bit);
139
140 let cw_seed = xor_seeds(seed_lose_0, seed_lose_1);
142 let cw_ctrl_left = e[0].ctrl_left ^ e[1].ctrl_left ^ bit ^ Choice::from(1);
143 let cw_ctrl_right = e[0].ctrl_right ^ e[1].ctrl_right ^ bit;
144
145 let seed_keep_0 = conditional_xor_seeds(seed_keep_0, &cw_seed, ctrl[0]);
147 let seed_keep_1 = conditional_xor_seeds(seed_keep_1, &cw_seed, ctrl[1]);
148 let cw_ctrl_keep = Choice::conditional_select(&cw_ctrl_left, &cw_ctrl_right, bit);
149 let ctrl_keep_0 = ctrl_keep_0 ^ (ctrl[0] & cw_ctrl_keep);
150 let ctrl_keep_1 = ctrl_keep_1 ^ (ctrl[1] & cw_ctrl_keep);
151
152 let weight_0;
154 let weight_1;
155 (seed[0], weight_0) = self.convert(seed_keep_0, ctx, nonce);
156 (seed[1], weight_1) = self.convert(seed_keep_1, ctx, nonce);
157 ctrl[0] = ctrl_keep_0;
158 ctrl[1] = ctrl_keep_1;
159
160 let mut cw_weight = weight_1 - weight_0 + weight.clone();
162 cw_weight.conditional_negate(ctrl[1]);
163
164 let cw_proof = xor_proof(
166 idx.node_proof(&seed[0], ctx),
167 &idx.node_proof(&seed[1], ctx),
168 );
169
170 cw.push(VidpfCorrectionWord {
171 seed: cw_seed,
172 ctrl_left: cw_ctrl_left,
173 ctrl_right: cw_ctrl_right,
174 weight: cw_weight,
175 proof: cw_proof,
176 });
177 }
178
179 Ok(VidpfPublicShare { cw })
180 }
181
182 pub fn eval(
186 &self,
187 ctx: &[u8],
188 id: VidpfServerId,
189 key: &VidpfKey,
190 public: &VidpfPublicShare<W>,
191 input: &VidpfInput,
192 nonce: &[u8],
193 ) -> Result<(W, VidpfProof), VidpfError> {
194 use sha3::{Digest, Sha3_256};
195
196 let mut r = VidpfEvalResult {
197 state: VidpfEvalState::init_from_key(id, key),
198 share: W::zero(&self.weight_len), };
200
201 if input.len() > public.cw.len() {
202 return Err(VidpfError::InvalidInputLength);
203 }
204
205 let mut hash = Sha3_256::new();
206 for (idx, cw) in self.index_iter(input)?.zip(public.cw.iter()) {
207 r = self.eval_next(ctx, cw, idx, &r.state, nonce);
208 hash.update(r.state.node_proof);
209 }
210
211 let mut weight = r.share;
212 weight.conditional_negate(Choice::from(id));
213 Ok((weight, hash.finalize().into()))
214 }
215
216 fn eval_next(
219 &self,
220 ctx: &[u8],
221 cw: &VidpfCorrectionWord<W>,
222 idx: VidpfEvalIndex<'_>,
223 state: &VidpfEvalState,
224 nonce: &[u8],
225 ) -> VidpfEvalResult<W> {
226 let bit = idx.bit;
227
228 let e = Self::extend(state.seed, ctx, nonce);
230
231 let (seed_keep, seed_lose) = &mut (e.seed_right, e.seed_left);
233 conditional_swap_seed(seed_keep, seed_lose, !bit);
234 let ctrl_keep = Choice::conditional_select(&e.ctrl_left, &e.ctrl_right, bit);
235
236 let seed_keep = conditional_xor_seeds(seed_keep, &cw.seed, state.control_bit);
238 let cw_ctrl_keep = Choice::conditional_select(&cw.ctrl_left, &cw.ctrl_right, bit);
239 let next_ctrl = ctrl_keep ^ (state.control_bit & cw_ctrl_keep);
240
241 let (next_seed, w) = self.convert(seed_keep, ctx, nonce);
243 let mut weight = <W as IdpfValue>::conditional_select(
244 &<W as IdpfValue>::zero(&self.weight_len),
245 &cw.weight,
246 next_ctrl,
247 );
248 weight += w;
249
250 let node_proof =
252 conditional_xor_proof(idx.node_proof(&next_seed, ctx), &cw.proof, next_ctrl);
253
254 let next_state = VidpfEvalState {
255 seed: next_seed,
256 control_bit: next_ctrl,
257 node_proof,
258 };
259
260 VidpfEvalResult {
261 state: next_state,
262 share: weight,
263 }
264 }
265
266 pub(crate) fn get_beta_share(
267 &self,
268 ctx: &[u8],
269 id: VidpfServerId,
270 public: &VidpfPublicShare<W>,
271 key: &VidpfKey,
272 nonce: &[u8],
273 ) -> Result<W, VidpfError> {
274 let cw = public.cw.first().ok_or(VidpfError::InvalidInputLength)?;
275
276 let state = VidpfEvalState::init_from_key(id, key);
277 let input_left = VidpfInput::from_bools(&[false]);
278 let idx_left = self.index(&input_left)?;
279
280 let VidpfEvalResult {
281 state: _,
282 share: weight_share_left,
283 } = self.eval_next(ctx, cw, idx_left, &state, nonce);
284
285 let VidpfEvalResult {
286 state: _,
287 share: weight_share_right,
288 } = self.eval_next(ctx, cw, idx_left.right_sibling(), &state, nonce);
289
290 let mut beta_share = weight_share_left + weight_share_right;
291 beta_share.conditional_negate(Choice::from(id));
292 Ok(beta_share)
293 }
294
295 #[allow(clippy::too_many_arguments)]
299 pub(crate) fn eval_prefix_tree_with_siblings(
300 &self,
301 ctx: &[u8],
302 id: VidpfServerId,
303 public: &VidpfPublicShare<W>,
304 key: &VidpfKey,
305 nonce: &[u8],
306 prefixes: &[VidpfInput],
307 prefix_tree: &mut BinaryTree<VidpfEvalResult<W>>,
308 ) -> Result<Vec<W>, VidpfError> {
309 let mut out_shares = Vec::with_capacity(prefixes.len());
310
311 for prefix in prefixes {
312 if prefix.len() > public.cw.len() {
313 return Err(VidpfError::InvalidInputLength);
314 }
315
316 let mut sub_tree = prefix_tree.root.get_or_insert_with(|| {
317 Box::new(Node::new(VidpfEvalResult {
318 state: VidpfEvalState::init_from_key(id, key),
319 share: W::zero(&self.weight_len), }))
321 });
322
323 for (idx, cw) in self.index_iter(prefix)?.zip(public.cw.iter()) {
324 let left = sub_tree.left.get_or_insert_with(|| {
325 Box::new(Node::new(self.eval_next(
326 ctx,
327 cw,
328 idx.left_sibling(),
329 &sub_tree.value.state,
330 nonce,
331 )))
332 });
333 let right = sub_tree.right.get_or_insert_with(|| {
334 Box::new(Node::new(self.eval_next(
335 ctx,
336 cw,
337 idx.right_sibling(),
338 &sub_tree.value.state,
339 nonce,
340 )))
341 });
342
343 sub_tree = if idx.bit.unwrap_u8() == 0 {
344 left
345 } else {
346 right
347 };
348 }
349
350 out_shares.push(sub_tree.value.share.clone());
351 }
352
353 for out_share in out_shares.iter_mut() {
354 out_share.conditional_negate(Choice::from(id));
355 }
356 Ok(out_shares)
357 }
358
359 fn extend(seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> ExtendedSeed {
360 let mut rng = XofFixedKeyAes128::seed_stream(
361 &seed,
362 &[&mastic::dst_usage(mastic::USAGE_EXTEND), ctx],
363 &[nonce],
364 );
365
366 let mut seed_left = VidpfSeed::default();
367 let mut seed_right = VidpfSeed::default();
368 rng.fill_bytes(&mut seed_left);
369 rng.fill_bytes(&mut seed_right);
370 let ctrl_left = Choice::from(seed_left[0] & 0x01);
374 let ctrl_right = Choice::from(seed_right[0] & 0x01);
375 seed_left[0] &= 0xFE;
376 seed_right[0] &= 0xFE;
377
378 ExtendedSeed {
379 seed_left,
380 ctrl_left,
381 seed_right,
382 ctrl_right,
383 }
384 }
385
386 fn convert(&self, seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> (VidpfSeed, W) {
387 let mut seed_stream = XofFixedKeyAes128::seed_stream(
388 &seed,
389 &[&mastic::dst_usage(mastic::USAGE_CONVERT), ctx],
390 &[nonce],
391 );
392
393 let mut next_seed = VidpfSeed::default();
394 seed_stream.fill_bytes(&mut next_seed);
395 let weight = W::generate(&mut seed_stream, &self.weight_len);
396 (next_seed, weight)
397 }
398
399 fn index_iter<'a>(
400 &'a self,
401 input: &'a VidpfInput,
402 ) -> Result<impl Iterator<Item = VidpfEvalIndex<'a>>, VidpfError> {
403 let n = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)?;
404 if n > self.bits {
405 return Err(VidpfError::InvalidInputLength);
406 }
407 Ok(Box::new((0..n).zip(input.iter()).map(
408 move |(level, bit)| VidpfEvalIndex {
409 bit: Choice::from(u8::from(bit)),
410 input,
411 level,
412 bits: self.bits,
413 },
414 )))
415 }
416
417 fn index<'a>(&self, input: &'a VidpfInput) -> Result<VidpfEvalIndex<'a>, VidpfError> {
418 let level = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)? - 1;
419 if level >= self.bits {
420 return Err(VidpfError::InvalidInputLength);
421 }
422 let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap()));
423 Ok(VidpfEvalIndex {
424 bit,
425 input,
426 level,
427 bits: self.bits,
428 })
429 }
430}
431
432pub type VidpfKey = Seed<VIDPF_SEED_SIZE>;
436
437#[derive(Clone, Copy, Debug, PartialEq, Eq)]
441pub enum VidpfServerId {
442 S0,
444 S1,
446}
447
448impl From<VidpfServerId> for Choice {
449 fn from(value: VidpfServerId) -> Self {
450 match value {
451 VidpfServerId::S0 => Self::from(0),
452 VidpfServerId::S1 => Self::from(1),
453 }
454 }
455}
456
457#[derive(Clone, Debug)]
459struct VidpfCorrectionWord<W: VidpfValue> {
460 seed: VidpfSeed,
461 ctrl_left: Choice,
462 ctrl_right: Choice,
463 weight: W,
464 proof: VidpfProof,
465}
466
467impl<W: VidpfValue> ConstantTimeEq for VidpfCorrectionWord<W> {
468 fn ct_eq(&self, other: &Self) -> Choice {
469 self.seed.ct_eq(&other.seed)
470 & self.ctrl_left.ct_eq(&other.ctrl_left)
471 & self.ctrl_right.ct_eq(&other.ctrl_right)
472 & self.weight.ct_eq(&other.weight)
473 & self.proof.ct_eq(&other.proof)
474 }
475}
476
477impl<W: VidpfValue> PartialEq for VidpfCorrectionWord<W>
478where
479 W: ConstantTimeEq,
480{
481 fn eq(&self, other: &Self) -> bool {
482 self.ct_eq(other).into()
483 }
484}
485
486#[derive(Clone, Debug, PartialEq)]
488pub struct VidpfPublicShare<W: VidpfValue> {
489 cw: Vec<VidpfCorrectionWord<W>>,
490}
491
492impl<W: VidpfValue> Encode for VidpfPublicShare<W> {
493 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
494 let mut control_bits: BitVec<u8, Lsb0> = BitVec::with_capacity(self.cw.len() * 2);
496 for cw in self.cw.iter() {
497 control_bits.push(bool::from(cw.ctrl_left));
498 control_bits.push(bool::from(cw.ctrl_right));
499 }
500 control_bits.set_uninitialized(false);
501 let mut packed_control = control_bits.into_vec();
502 bytes.append(&mut packed_control);
503
504 for cw in self.cw.iter() {
506 bytes.extend_from_slice(&cw.seed);
507 }
508
509 for cw in self.cw.iter() {
511 cw.weight.encode(bytes)?;
512 }
513
514 for cw in self.cw.iter() {
516 bytes.extend_from_slice(&cw.proof);
517 }
518
519 Ok(())
520 }
521
522 fn encoded_len(&self) -> Option<usize> {
523 let weight_len = self
525 .cw
526 .first()
527 .map_or(Some(0), |cw| cw.weight.encoded_len())?;
528
529 let mut len = 0;
530 len += (2 * self.cw.len()).div_ceil(8); len += self.cw.len() * VIDPF_SEED_SIZE; len += self.cw.len() * weight_len; len += self.cw.len() * VIDPF_PROOF_SIZE; Some(len)
535 }
536}
537
538impl<W: VidpfValue> ParameterizedDecode<Vidpf<W>> for VidpfPublicShare<W> {
539 fn decode_with_param(vidpf: &Vidpf<W>, bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
540 let bits = usize::from(vidpf.bits);
541 let packed_control_len = bits.div_ceil(4);
542 let mut packed_control_bits = vec![0u8; packed_control_len];
543 bytes.read_exact(&mut packed_control_bits)?;
544 let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed_control_bits);
545
546 let mut control_bits = Vec::with_capacity(bits);
548 for chunk in unpacked_control_bits[0..bits * 2].chunks(2) {
549 control_bits.push([(chunk[0] as u8).into(), (chunk[1] as u8).into()]);
550 }
551
552 if unpacked_control_bits[bits * 2..].any() {
554 return Err(CodecError::UnexpectedValue);
555 }
556
557 let seeds = std::iter::repeat_with(|| Seed::decode(bytes).map(|seed| seed.0))
559 .take(bits)
560 .collect::<Result<Vec<_>, _>>()?;
561
562 let weights = std::iter::repeat_with(|| W::decode_with_param(&vidpf.weight_len, bytes))
564 .take(bits)
565 .collect::<Result<Vec<_>, _>>()?;
566
567 let proofs = std::iter::repeat_with(|| {
568 let mut proof = [0; VIDPF_PROOF_SIZE];
569 bytes.read_exact(&mut proof)?;
570 Ok::<_, CodecError>(proof)
571 })
572 .take(bits)
573 .collect::<Result<Vec<_>, _>>()?;
574
575 let cw = seeds
576 .into_iter()
577 .zip(
578 control_bits
579 .into_iter()
580 .zip(weights.into_iter().zip(proofs)),
581 )
582 .map(
583 |(seed, ([ctrl_left, ctrl_right], (weight, proof)))| VidpfCorrectionWord {
584 seed,
585 ctrl_left,
586 ctrl_right,
587 weight,
588 proof,
589 },
590 )
591 .collect::<Vec<_>>();
592
593 Ok(Self { cw })
594 }
595}
596
597#[derive(Debug)]
599pub(crate) struct VidpfEvalState {
600 seed: VidpfSeed,
601 control_bit: Choice,
602 pub(crate) node_proof: VidpfProof,
603}
604
605impl VidpfEvalState {
606 fn init_from_key(id: VidpfServerId, key: &VidpfKey) -> Self {
607 Self {
608 seed: key.0,
609 control_bit: Choice::from(id),
610 node_proof: VidpfProof::default(), }
612 }
613}
614
615#[derive(Debug)]
617pub(crate) struct VidpfEvalResult<W: VidpfValue> {
618 pub(crate) state: VidpfEvalState,
619 pub(crate) share: W,
620}
621
622pub(crate) const VIDPF_PROOF_SIZE: usize = 32;
623const VIDPF_SEED_SIZE: usize = 16;
624
625pub(crate) type VidpfProof = [u8; VIDPF_PROOF_SIZE];
627
628pub(crate) fn xor_proof(mut lhs: VidpfProof, rhs: &VidpfProof) -> VidpfProof {
629 zip(&mut lhs, rhs).for_each(|(a, b)| a.bitxor_assign(b));
630 lhs
631}
632
633fn conditional_xor_proof(mut lhs: VidpfProof, rhs: &VidpfProof, choice: Choice) -> VidpfProof {
634 zip(&mut lhs, rhs).for_each(|(a, b)| a.conditional_assign(&a.bitxor(b), choice));
635 lhs
636}
637
638type VidpfSeed = [u8; VIDPF_SEED_SIZE];
640
641struct ExtendedSeed {
643 seed_left: VidpfSeed,
644 ctrl_left: Choice,
645 seed_right: VidpfSeed,
646 ctrl_right: Choice,
647}
648
649#[derive(Debug, PartialEq, Eq, Clone)]
651pub struct VidpfWeight<F: FieldElement>(pub(crate) Vec<F>);
652
653impl<F: FieldElement> From<Vec<F>> for VidpfWeight<F> {
654 fn from(value: Vec<F>) -> Self {
655 Self(value)
656 }
657}
658
659impl<F: FieldElement> AsRef<[F]> for VidpfWeight<F> {
660 fn as_ref(&self) -> &[F] {
661 &self.0
662 }
663}
664
665impl<F: FieldElement> VidpfValue for VidpfWeight<F> {}
666
667impl<F: FieldElement> IdpfValue for VidpfWeight<F> {
668 type ValueParameter = usize;
670
671 fn generate<S: RngCore>(seed_stream: &mut S, length: &Self::ValueParameter) -> Self {
672 Self(
673 (0..*length)
674 .map(|_| <F as IdpfValue>::generate(seed_stream, &()))
675 .collect(),
676 )
677 }
678
679 fn zero(length: &Self::ValueParameter) -> Self {
680 Self((0..*length).map(|_| <F as IdpfValue>::zero(&())).collect())
681 }
682
683 fn conditional_select(lhs: &Self, rhs: &Self, choice: Choice) -> Self {
685 assert_eq!(
686 lhs.0.len(),
687 rhs.0.len(),
688 "{}",
689 VidpfError::InvalidWeightLength
690 );
691
692 Self(
693 zip(&lhs.0, &rhs.0)
694 .map(|(a, b)| <F as IdpfValue>::conditional_select(a, b, choice))
695 .collect(),
696 )
697 }
698}
699
700impl<F: FieldElement> ConditionallyNegatable for VidpfWeight<F> {
701 fn conditional_negate(&mut self, choice: Choice) {
702 self.0.iter_mut().for_each(|a| a.conditional_negate(choice));
703 }
704}
705
706impl<F: FieldElement> Add for VidpfWeight<F> {
707 type Output = Self;
708
709 fn add(self, rhs: Self) -> Self::Output {
711 assert_eq!(
712 self.0.len(),
713 rhs.0.len(),
714 "{}",
715 VidpfError::InvalidWeightLength
716 );
717
718 Self(zip(self.0, rhs.0).map(|(a, b)| a.add(b)).collect())
719 }
720}
721
722impl<F: FieldElement> AddAssign for VidpfWeight<F> {
723 fn add_assign(&mut self, rhs: Self) {
725 assert_eq!(
726 self.0.len(),
727 rhs.0.len(),
728 "{}",
729 VidpfError::InvalidWeightLength
730 );
731
732 zip(&mut self.0, rhs.0).for_each(|(a, b)| a.add_assign(b));
733 }
734}
735
736impl<F: FieldElement> Sub for VidpfWeight<F> {
737 type Output = Self;
738
739 fn sub(self, rhs: Self) -> Self::Output {
741 assert_eq!(
742 self.0.len(),
743 rhs.0.len(),
744 "{}",
745 VidpfError::InvalidWeightLength
746 );
747
748 Self(zip(self.0, rhs.0).map(|(a, b)| a.sub(b)).collect())
749 }
750}
751
752impl<F: FieldElement> ConstantTimeEq for VidpfWeight<F> {
753 fn ct_eq(&self, other: &Self) -> Choice {
754 self.0[..].ct_eq(&other.0[..])
755 }
756}
757
758impl<F: FieldElement> Encode for VidpfWeight<F> {
759 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
760 for e in &self.0 {
761 F::encode(e, bytes)?;
762 }
763 Ok(())
764 }
765
766 fn encoded_len(&self) -> Option<usize> {
767 Some(self.0.len() * F::ENCODED_SIZE)
768 }
769}
770
771impl<F: FieldElement> ParameterizedDecode<<Self as IdpfValue>::ValueParameter> for VidpfWeight<F> {
772 fn decode_with_param(
773 decoding_parameter: &<Self as IdpfValue>::ValueParameter,
774 bytes: &mut Cursor<&[u8]>,
775 ) -> Result<Self, CodecError> {
776 let mut v = Vec::with_capacity(*decoding_parameter);
777 for _ in 0..*decoding_parameter {
778 v.push(F::decode_with_param(&(), bytes)?);
779 }
780
781 Ok(Self(v))
782 }
783}
784
785#[derive(Copy, Clone)]
786struct VidpfEvalIndex<'a> {
787 bit: Choice,
788 input: &'a VidpfInput,
789 level: u16,
790 bits: u16,
791}
792
793impl VidpfEvalIndex<'_> {
794 fn left_sibling(&self) -> Self {
795 Self {
796 bit: Choice::from(0),
797 input: self.input,
798 level: self.level,
799 bits: self.bits,
800 }
801 }
802
803 fn right_sibling(&self) -> Self {
804 Self {
805 bit: Choice::from(1),
806 input: self.input,
807 level: self.level,
808 bits: self.bits,
809 }
810 }
811
812 fn node_proof(&self, seed: &VidpfSeed, ctx: &[u8]) -> VidpfProof {
813 let mut xof = XofTurboShake128::from_seed_slice(
814 &seed[..],
815 &[&mastic::dst_usage(mastic::USAGE_NODE_PROOF), ctx],
816 );
817 xof.update(&self.bits.to_le_bytes());
818 xof.update(&self.level.to_le_bytes());
819
820 for byte in self
821 .input
822 .index(..=usize::from(self.level))
823 .chunks(8)
824 .enumerate()
825 .map(|(byte_index, chunk)| {
826 let mut byte = 0;
827 for (bit_index, bit) in chunk.iter().enumerate() {
828 byte |= u8::from(*bit) << (7 - bit_index);
829 }
830
831 if byte_index == usize::from(self.level) / 8 {
834 let bit_index = self.level % 8;
835 let m = 1 << (7 - bit_index);
836 byte = u8::conditional_select(&(byte & !m), &(byte | m), self.bit);
837 }
838 byte
839 })
840 {
841 xof.update(&[byte]);
842 }
843 xof.into_seed().0
844 }
845}
846
847#[cfg(test)]
848mod tests {
849
850 use crate::field::Field128;
851
852 use super::VidpfWeight;
853
854 type TestWeight = VidpfWeight<Field128>;
855 const TEST_WEIGHT_LEN: usize = 3;
856 const TEST_NONCE_SIZE: usize = 16;
857 const TEST_NONCE: &[u8; TEST_NONCE_SIZE] = b"Test Nonce VIDPF";
858
859 mod vidpf {
860 use crate::{
861 codec::{Encode, ParameterizedDecode},
862 idpf::IdpfValue,
863 vidpf::{
864 Vidpf, VidpfCorrectionWord, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare,
865 VidpfServerId,
866 },
867 };
868
869 use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN};
870
871 #[test]
872 fn roundtrip_codec() {
873 let ctx = b"appliction context";
874 let input = VidpfInput::from_bytes(&[0xFF]);
875 let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
876 let (vidpf, public, _, _) = vidpf_gen_setup(ctx, &input, &weight);
877
878 let bytes = public.get_encoded().unwrap();
879 assert_eq!(public.encoded_len().unwrap(), bytes.len());
880
881 let decoded = VidpfPublicShare::get_decoded_with_param(&vidpf, &bytes).unwrap();
882 assert_eq!(public, decoded);
883 }
884
885 fn vidpf_gen_setup(
886 ctx: &[u8],
887 input: &VidpfInput,
888 weight: &TestWeight,
889 ) -> (
890 Vidpf<TestWeight>,
891 VidpfPublicShare<TestWeight>,
892 [VidpfKey; 2],
893 [u8; TEST_NONCE_SIZE],
894 ) {
895 let vidpf = Vidpf::new(input.len(), TEST_WEIGHT_LEN).unwrap();
896 let (public, keys) = vidpf.gen(ctx, input, weight, TEST_NONCE).unwrap();
897 (vidpf, public, keys, *TEST_NONCE)
898 }
899
900 #[test]
901 fn correctness_at_last_level() {
902 let ctx = b"some application";
903 let input = VidpfInput::from_bytes(&[0xFF]);
904 let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
905 let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(ctx, &input, &weight);
906
907 let (value_share_0, onehot_proof_0) = vidpf
908 .eval(ctx, VidpfServerId::S0, &key_0, &public, &input, &nonce)
909 .unwrap();
910 let (value_share_1, onehot_proof_1) = vidpf
911 .eval(ctx, VidpfServerId::S1, &key_1, &public, &input, &nonce)
912 .unwrap();
913
914 assert_eq!(
915 value_share_0 + value_share_1,
916 weight,
917 "shares must add up to the expected weight",
918 );
919
920 assert_eq!(onehot_proof_0, onehot_proof_1, "proofs must be equal");
921
922 let bad_input = VidpfInput::from_bytes(&[0x00]);
923 let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
924 let (value_share_0, onehot_proof_0) = vidpf
925 .eval(ctx, VidpfServerId::S0, &key_0, &public, &bad_input, &nonce)
926 .unwrap();
927 let (value_share_1, onehot_proof_1) = vidpf
928 .eval(ctx, VidpfServerId::S1, &key_1, &public, &bad_input, &nonce)
929 .unwrap();
930
931 assert_eq!(
932 value_share_0 + value_share_1,
933 zero,
934 "shares must add up to zero",
935 );
936
937 assert_eq!(onehot_proof_0, onehot_proof_1, "proofs must be equal");
938 }
939
940 #[test]
941 fn correctness_at_each_level() {
942 let ctx = b"application context";
943 let input = VidpfInput::from_bytes(&[0xFF]);
944 let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
945 let (vidpf, public, keys, nonce) = vidpf_gen_setup(ctx, &input, &weight);
946
947 assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &input, &weight, &nonce);
948
949 let bad_input = VidpfInput::from_bytes(&[0x00]);
950 let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
951
952 assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &bad_input, &zero, &nonce);
953 }
954
955 fn assert_eval_at_each_level(
956 vidpf: &Vidpf<TestWeight>,
957 ctx: &[u8],
958 [key_0, key_1]: &[VidpfKey; 2],
959 public: &VidpfPublicShare<TestWeight>,
960 input: &VidpfInput,
961 weight: &TestWeight,
962 nonce: &[u8],
963 ) {
964 let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0);
965 let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1);
966
967 for (idx, cw) in vidpf.index_iter(input).unwrap().zip(public.cw.iter()) {
968 let r0 = vidpf.eval_next(ctx, cw, idx, &state_0, nonce);
969 let r1 = vidpf.eval_next(ctx, cw, idx, &state_1, nonce);
970
971 assert_eq!(
972 r0.share - r1.share,
973 *weight,
974 "shares must add up to the expected weight at the current level: {:?}",
975 idx.level
976 );
977
978 assert_eq!(
979 r0.state.node_proof, r1.state.node_proof,
980 "proofs must be equal at the current level: {:?}",
981 idx.level
982 );
983
984 state_0 = r0.state;
985 state_1 = r1.state;
986 }
987 }
988
989 #[test]
992 fn public_share_weight_len() {
993 let input = VidpfInput::from_bools(&vec![false; 237]);
994 let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
995 let (vidpf, public, _, _) = vidpf_gen_setup(b"some application", &input, &weight);
996 for VidpfCorrectionWord { weight, .. } in public.cw {
997 assert_eq!(weight.0.len(), vidpf.weight_len);
998 }
999 }
1000 }
1001
1002 mod weight {
1003 use std::io::Cursor;
1004 use subtle::{Choice, ConditionallyNegatable};
1005
1006 use crate::{
1007 codec::{Encode, ParameterizedDecode},
1008 idpf::IdpfValue,
1009 vdaf::xof::{Xof, XofTurboShake128},
1010 };
1011
1012 use super::{TestWeight, TEST_WEIGHT_LEN};
1013
1014 #[test]
1015 fn roundtrip_codec() {
1016 let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
1017
1018 let mut bytes = vec![];
1019 weight.encode(&mut bytes).unwrap();
1020
1021 let expected_bytes = [
1022 [vec![21], vec![0u8; 15]].concat(),
1023 [vec![22], vec![0u8; 15]].concat(),
1024 [vec![23], vec![0u8; 15]].concat(),
1025 ]
1026 .concat();
1027
1028 assert_eq!(weight.encoded_len().unwrap(), expected_bytes.len());
1029 assert_eq!(bytes, expected_bytes);
1031
1032 let decoded =
1033 TestWeight::decode_with_param(&TEST_WEIGHT_LEN, &mut Cursor::new(&bytes)).unwrap();
1034 assert_eq!(weight, decoded);
1035 }
1036
1037 #[test]
1038 fn add_sub() {
1039 let [a, b] = compatible_weights();
1040 let mut c = a.clone();
1041 c += a.clone();
1042
1043 assert_eq!(
1044 (a.clone() + b.clone()) + (a.clone() - b.clone()),
1045 c,
1046 "a: {:?} b:{:?}",
1047 a,
1048 b
1049 );
1050 }
1051
1052 #[test]
1053 fn conditional_negate() {
1054 let [a, _] = compatible_weights();
1055 let mut c = a.clone();
1056 c.conditional_negate(Choice::from(0));
1057 let mut d = a.clone();
1058 d.conditional_negate(Choice::from(1));
1059 let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
1060
1061 assert_eq!(c + d, zero, "a: {:?}", a);
1062 }
1063
1064 #[test]
1065 #[should_panic = "invalid weight length"]
1066 fn add_panics() {
1067 let [w0, w1] = incompatible_weights();
1068 let _ = w0 + w1;
1069 }
1070
1071 #[test]
1072 #[should_panic = "invalid weight length"]
1073 fn add_assign_panics() {
1074 let [mut w0, w1] = incompatible_weights();
1075 w0 += w1;
1076 }
1077
1078 #[test]
1079 #[should_panic = "invalid weight length"]
1080 fn sub_panics() {
1081 let [w0, w1] = incompatible_weights();
1082 let _ = w0 - w1;
1083 }
1084
1085 #[test]
1086 #[should_panic = "invalid weight length"]
1087 fn conditional_select_panics() {
1088 let [w0, w1] = incompatible_weights();
1089 TestWeight::conditional_select(&w0, &w1, Choice::from(0));
1090 }
1091
1092 fn compatible_weights() -> [TestWeight; 2] {
1093 let mut xof = XofTurboShake128::seed_stream(&[0; 32], &[], &[]);
1094 [
1095 TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1096 TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1097 ]
1098 }
1099
1100 fn incompatible_weights() -> [TestWeight; 2] {
1101 let mut xof = XofTurboShake128::seed_stream(&[0; 32], &[], &[]);
1102 [
1103 TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1104 TestWeight::generate(&mut xof, &(2 * TEST_WEIGHT_LEN)),
1105 ]
1106 }
1107 }
1108}