prio/
vidpf.rs

1// SPDX-License-Identifier: MPL-2.0
2
3//! Verifiable Incremental Distributed Point Function (VIDPF).
4//!
5//! The VIDPF construction is specified in [[draft-mouris-cfrg-mastic]] and builds
6//! on techniques from [[MST23]] and [[CP22]] to lift an IDPF to a VIDPF.
7//!
8//! [CP22]: https://eprint.iacr.org/2021/580
9//! [MST23]: https://eprint.iacr.org/2023/080
10//! [draft-mouris-cfrg-mastic]: https://datatracker.ietf.org/doc/draft-mouris-cfrg-mastic/02/
11
12use core::{
13    iter::zip,
14    ops::{Add, AddAssign, BitXor, BitXorAssign, Index, Sub},
15};
16
17use bitvec::prelude::{BitVec, Lsb0};
18use rand::{rng, Rng, RngCore};
19use std::fmt::Debug;
20use std::io::{Cursor, Read};
21use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
22
23use crate::{
24    bt::{BinaryTree, Node},
25    codec::{CodecError, Decode, Encode, ParameterizedDecode},
26    field::FieldElement,
27    idpf::{conditional_swap_seed, conditional_xor_seeds, xor_seeds, IdpfInput, IdpfValue},
28    vdaf::{
29        mastic,
30        xof::{Seed, Xof, XofFixedKeyAes128, XofTurboShake128},
31    },
32};
33
34/// VIDPF errors.
35#[derive(Debug, thiserror::Error)]
36#[non_exhaustive]
37pub enum VidpfError {
38    /// Input is too long to be represented.
39    #[error("bit length too long")]
40    BitLengthTooLong,
41
42    /// Error when an input has an unexpected bit length.
43    #[error("invalid input length")]
44    InvalidInputLength,
45
46    /// Error when a weight has an unexpected length.
47    #[error("invalid weight length")]
48    InvalidWeightLength,
49}
50
51/// Represents the domain of an incremental point function.
52pub type VidpfInput = IdpfInput;
53
54/// Represents the codomain of an incremental point function.
55pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {}
56
57#[derive(Clone, Debug)]
58/// An instance of the VIDPF.
59pub struct Vidpf<W: VidpfValue> {
60    pub(crate) bits: u16,
61    pub(crate) weight_len: W::ValueParameter,
62}
63
64impl<W: VidpfValue> Vidpf<W> {
65    /// Creates a VIDPF instance.
66    ///
67    /// # Arguments
68    ///
69    /// * `bits`, the length of the input in bits.
70    /// * `weight_len`, the length of the weight in number of field elements.
71    pub fn new(bits: usize, weight_len: W::ValueParameter) -> Result<Self, VidpfError> {
72        let bits = u16::try_from(bits).map_err(|_| VidpfError::BitLengthTooLong)?;
73        Ok(Self { bits, weight_len })
74    }
75
76    /// Splits an incremental point function `F` into two private keys
77    /// used by the aggregation servers, and a common public share.
78    ///
79    /// The incremental point function is defined as `F`: [`VidpfInput`] --> [`VidpfValue`]
80    /// such that:
81    ///
82    /// ```txt
83    /// F(x) = weight, if x is a prefix of the input.
84    /// F(x) = 0,      if x is not a prefix of the input.
85    /// ```
86    ///
87    /// # Arguments
88    ///
89    /// * `input`, determines the input of the function.
90    /// * `weight`, determines the input's weight of the function.
91    /// * `nonce`, a nonce, typically the same value provided to the
92    ///   [`Client`](crate::vdaf::Client) and [`Aggregator`](crate::vdaf::Aggregator).
93    ///   APIs.
94    pub fn gen(
95        &self,
96        ctx: &[u8],
97        input: &VidpfInput,
98        weight: &W,
99        nonce: &[u8],
100    ) -> Result<(VidpfPublicShare<W>, [VidpfKey; 2]), VidpfError> {
101        let mut rng = rng();
102        let keys = rng.random();
103        let public = self.gen_with_keys(ctx, &keys, input, weight, nonce)?;
104        Ok((public, keys))
105    }
106
107    /// Produce the public share for the given keys, input, and weight.
108    pub(crate) fn gen_with_keys(
109        &self,
110        ctx: &[u8],
111        keys: &[VidpfKey; 2],
112        input: &VidpfInput,
113        weight: &W,
114        nonce: &[u8],
115    ) -> Result<VidpfPublicShare<W>, VidpfError> {
116        let mut seed = [keys[0].0, keys[1].0];
117        let mut ctrl = [
118            Choice::from(VidpfServerId::S0),
119            Choice::from(VidpfServerId::S1),
120        ];
121
122        let mut cw = Vec::with_capacity(input.len());
123        for idx in self.index_iter(input)? {
124            let bit = idx.bit;
125
126            // Extend.
127            let e = [
128                Self::extend(seed[0], ctx, nonce),
129                Self::extend(seed[1], ctx, nonce),
130            ];
131
132            // Select the seed and control bit.
133            let (seed_keep_0, seed_lose_0) = &mut (e[0].seed_right, e[0].seed_left);
134            conditional_swap_seed(seed_keep_0, seed_lose_0, !bit);
135            let (seed_keep_1, seed_lose_1) = &mut (e[1].seed_right, e[1].seed_left);
136            conditional_swap_seed(seed_keep_1, seed_lose_1, !bit);
137            let ctrl_keep_0 = Choice::conditional_select(&e[0].ctrl_left, &e[0].ctrl_right, bit);
138            let ctrl_keep_1 = Choice::conditional_select(&e[1].ctrl_left, &e[1].ctrl_right, bit);
139
140            // Compute the correction word seed and control bit.
141            let cw_seed = xor_seeds(seed_lose_0, seed_lose_1);
142            let cw_ctrl_left = e[0].ctrl_left ^ e[1].ctrl_left ^ bit ^ Choice::from(1);
143            let cw_ctrl_right = e[0].ctrl_right ^ e[1].ctrl_right ^ bit;
144
145            // Correct the seed and control bit.
146            let seed_keep_0 = conditional_xor_seeds(seed_keep_0, &cw_seed, ctrl[0]);
147            let seed_keep_1 = conditional_xor_seeds(seed_keep_1, &cw_seed, ctrl[1]);
148            let cw_ctrl_keep = Choice::conditional_select(&cw_ctrl_left, &cw_ctrl_right, bit);
149            let ctrl_keep_0 = ctrl_keep_0 ^ (ctrl[0] & cw_ctrl_keep);
150            let ctrl_keep_1 = ctrl_keep_1 ^ (ctrl[1] & cw_ctrl_keep);
151
152            // Convert.
153            let weight_0;
154            let weight_1;
155            (seed[0], weight_0) = self.convert(seed_keep_0, ctx, nonce);
156            (seed[1], weight_1) = self.convert(seed_keep_1, ctx, nonce);
157            ctrl[0] = ctrl_keep_0;
158            ctrl[1] = ctrl_keep_1;
159
160            // Compute the correction word payload.
161            let mut cw_weight = weight_1 - weight_0 + weight.clone();
162            cw_weight.conditional_negate(ctrl[1]);
163
164            // Compute the correction word node proof.
165            let cw_proof = xor_proof(
166                idx.node_proof(&seed[0], ctx),
167                &idx.node_proof(&seed[1], ctx),
168            );
169
170            cw.push(VidpfCorrectionWord {
171                seed: cw_seed,
172                ctrl_left: cw_ctrl_left,
173                ctrl_right: cw_ctrl_right,
174                weight: cw_weight,
175                proof: cw_proof,
176            });
177        }
178
179        Ok(VidpfPublicShare { cw })
180    }
181
182    /// Evaluate a given VIDPF (comprised of the key and public share) at a given prefix. Return
183    /// the weight for that prefix along with a hash of the node proofs along the path from the
184    /// root to the prefix.
185    pub fn eval(
186        &self,
187        ctx: &[u8],
188        id: VidpfServerId,
189        key: &VidpfKey,
190        public: &VidpfPublicShare<W>,
191        input: &VidpfInput,
192        nonce: &[u8],
193    ) -> Result<(W, VidpfProof), VidpfError> {
194        use sha3::{Digest, Sha3_256};
195
196        let mut r = VidpfEvalResult {
197            state: VidpfEvalState::init_from_key(id, key),
198            share: W::zero(&self.weight_len), // not used
199        };
200
201        if input.len() > public.cw.len() {
202            return Err(VidpfError::InvalidInputLength);
203        }
204
205        let mut hash = Sha3_256::new();
206        for (idx, cw) in self.index_iter(input)?.zip(public.cw.iter()) {
207            r = self.eval_next(ctx, cw, idx, &r.state, nonce);
208            hash.update(r.state.node_proof);
209        }
210
211        let mut weight = r.share;
212        weight.conditional_negate(Choice::from(id));
213        Ok((weight, hash.finalize().into()))
214    }
215
216    /// Evaluates the `input` at the given level using the provided initial
217    /// state, and returns a new state and a share of the input's weight at that level.
218    fn eval_next(
219        &self,
220        ctx: &[u8],
221        cw: &VidpfCorrectionWord<W>,
222        idx: VidpfEvalIndex<'_>,
223        state: &VidpfEvalState,
224        nonce: &[u8],
225    ) -> VidpfEvalResult<W> {
226        let bit = idx.bit;
227
228        // Extend.
229        let e = Self::extend(state.seed, ctx, nonce);
230
231        // Select the seed and control bit.
232        let (seed_keep, seed_lose) = &mut (e.seed_right, e.seed_left);
233        conditional_swap_seed(seed_keep, seed_lose, !bit);
234        let ctrl_keep = Choice::conditional_select(&e.ctrl_left, &e.ctrl_right, bit);
235
236        // Correct the seed and control bit.
237        let seed_keep = conditional_xor_seeds(seed_keep, &cw.seed, state.control_bit);
238        let cw_ctrl_keep = Choice::conditional_select(&cw.ctrl_left, &cw.ctrl_right, bit);
239        let next_ctrl = ctrl_keep ^ (state.control_bit & cw_ctrl_keep);
240
241        // Convert and correct the payload.
242        let (next_seed, w) = self.convert(seed_keep, ctx, nonce);
243        let mut weight = <W as IdpfValue>::conditional_select(
244            &<W as IdpfValue>::zero(&self.weight_len),
245            &cw.weight,
246            next_ctrl,
247        );
248        weight += w;
249
250        // Compute and correct the node proof.
251        let node_proof =
252            conditional_xor_proof(idx.node_proof(&next_seed, ctx), &cw.proof, next_ctrl);
253
254        let next_state = VidpfEvalState {
255            seed: next_seed,
256            control_bit: next_ctrl,
257            node_proof,
258        };
259
260        VidpfEvalResult {
261            state: next_state,
262            share: weight,
263        }
264    }
265
266    pub(crate) fn get_beta_share(
267        &self,
268        ctx: &[u8],
269        id: VidpfServerId,
270        public: &VidpfPublicShare<W>,
271        key: &VidpfKey,
272        nonce: &[u8],
273    ) -> Result<W, VidpfError> {
274        let cw = public.cw.first().ok_or(VidpfError::InvalidInputLength)?;
275
276        let state = VidpfEvalState::init_from_key(id, key);
277        let input_left = VidpfInput::from_bools(&[false]);
278        let idx_left = self.index(&input_left)?;
279
280        let VidpfEvalResult {
281            state: _,
282            share: weight_share_left,
283        } = self.eval_next(ctx, cw, idx_left, &state, nonce);
284
285        let VidpfEvalResult {
286            state: _,
287            share: weight_share_right,
288        } = self.eval_next(ctx, cw, idx_left.right_sibling(), &state, nonce);
289
290        let mut beta_share = weight_share_left + weight_share_right;
291        beta_share.conditional_negate(Choice::from(id));
292        Ok(beta_share)
293    }
294
295    /// Ensure `prefix_tree` contains the prefix tree for `prefixes`, as well as the sibling of
296    /// each node in the prefix tree. The return value is the weights for the prefixes
297    /// concatenated together.
298    #[allow(clippy::too_many_arguments)]
299    pub(crate) fn eval_prefix_tree_with_siblings(
300        &self,
301        ctx: &[u8],
302        id: VidpfServerId,
303        public: &VidpfPublicShare<W>,
304        key: &VidpfKey,
305        nonce: &[u8],
306        prefixes: &[VidpfInput],
307        prefix_tree: &mut BinaryTree<VidpfEvalResult<W>>,
308    ) -> Result<Vec<W>, VidpfError> {
309        let mut out_shares = Vec::with_capacity(prefixes.len());
310
311        for prefix in prefixes {
312            if prefix.len() > public.cw.len() {
313                return Err(VidpfError::InvalidInputLength);
314            }
315
316            let mut sub_tree = prefix_tree.root.get_or_insert_with(|| {
317                Box::new(Node::new(VidpfEvalResult {
318                    state: VidpfEvalState::init_from_key(id, key),
319                    share: W::zero(&self.weight_len), // not used
320                }))
321            });
322
323            for (idx, cw) in self.index_iter(prefix)?.zip(public.cw.iter()) {
324                let left = sub_tree.left.get_or_insert_with(|| {
325                    Box::new(Node::new(self.eval_next(
326                        ctx,
327                        cw,
328                        idx.left_sibling(),
329                        &sub_tree.value.state,
330                        nonce,
331                    )))
332                });
333                let right = sub_tree.right.get_or_insert_with(|| {
334                    Box::new(Node::new(self.eval_next(
335                        ctx,
336                        cw,
337                        idx.right_sibling(),
338                        &sub_tree.value.state,
339                        nonce,
340                    )))
341                });
342
343                sub_tree = if idx.bit.unwrap_u8() == 0 {
344                    left
345                } else {
346                    right
347                };
348            }
349
350            out_shares.push(sub_tree.value.share.clone());
351        }
352
353        for out_share in out_shares.iter_mut() {
354            out_share.conditional_negate(Choice::from(id));
355        }
356        Ok(out_shares)
357    }
358
359    fn extend(seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> ExtendedSeed {
360        let mut rng = XofFixedKeyAes128::seed_stream(
361            &seed,
362            &[&mastic::dst_usage(mastic::USAGE_EXTEND), ctx],
363            &[nonce],
364        );
365
366        let mut seed_left = VidpfSeed::default();
367        let mut seed_right = VidpfSeed::default();
368        rng.fill_bytes(&mut seed_left);
369        rng.fill_bytes(&mut seed_right);
370        // Use the LSB of seeds as control bits, and clears the bit,
371        // i.e., seeds produced by `prg` always have their LSB = 0.
372        // This ensures `prg` costs two AES calls only.
373        let ctrl_left = Choice::from(seed_left[0] & 0x01);
374        let ctrl_right = Choice::from(seed_right[0] & 0x01);
375        seed_left[0] &= 0xFE;
376        seed_right[0] &= 0xFE;
377
378        ExtendedSeed {
379            seed_left,
380            ctrl_left,
381            seed_right,
382            ctrl_right,
383        }
384    }
385
386    fn convert(&self, seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> (VidpfSeed, W) {
387        let mut seed_stream = XofFixedKeyAes128::seed_stream(
388            &seed,
389            &[&mastic::dst_usage(mastic::USAGE_CONVERT), ctx],
390            &[nonce],
391        );
392
393        let mut next_seed = VidpfSeed::default();
394        seed_stream.fill_bytes(&mut next_seed);
395        let weight = W::generate(&mut seed_stream, &self.weight_len);
396        (next_seed, weight)
397    }
398
399    fn index_iter<'a>(
400        &'a self,
401        input: &'a VidpfInput,
402    ) -> Result<impl Iterator<Item = VidpfEvalIndex<'a>>, VidpfError> {
403        let n = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)?;
404        if n > self.bits {
405            return Err(VidpfError::InvalidInputLength);
406        }
407        Ok(Box::new((0..n).zip(input.iter()).map(
408            move |(level, bit)| VidpfEvalIndex {
409                bit: Choice::from(u8::from(bit)),
410                input,
411                level,
412                bits: self.bits,
413            },
414        )))
415    }
416
417    fn index<'a>(&self, input: &'a VidpfInput) -> Result<VidpfEvalIndex<'a>, VidpfError> {
418        let level = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)? - 1;
419        if level >= self.bits {
420            return Err(VidpfError::InvalidInputLength);
421        }
422        let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap()));
423        Ok(VidpfEvalIndex {
424            bit,
425            input,
426            level,
427            bits: self.bits,
428        })
429    }
430}
431
432/// VIDPF key.
433///
434/// Private key of an aggregation server.
435pub type VidpfKey = Seed<VIDPF_SEED_SIZE>;
436
437/// VIDPF server ID.
438///
439/// Identifies the two aggregation servers.
440#[derive(Clone, Copy, Debug, PartialEq, Eq)]
441pub enum VidpfServerId {
442    /// S0 is the first server.
443    S0,
444    /// S1 is the second server.
445    S1,
446}
447
448impl From<VidpfServerId> for Choice {
449    fn from(value: VidpfServerId) -> Self {
450        match value {
451            VidpfServerId::S0 => Self::from(0),
452            VidpfServerId::S1 => Self::from(1),
453        }
454    }
455}
456
457/// VIDPF correction word.
458#[derive(Clone, Debug)]
459struct VidpfCorrectionWord<W: VidpfValue> {
460    seed: VidpfSeed,
461    ctrl_left: Choice,
462    ctrl_right: Choice,
463    weight: W,
464    proof: VidpfProof,
465}
466
467impl<W: VidpfValue> ConstantTimeEq for VidpfCorrectionWord<W> {
468    fn ct_eq(&self, other: &Self) -> Choice {
469        self.seed.ct_eq(&other.seed)
470            & self.ctrl_left.ct_eq(&other.ctrl_left)
471            & self.ctrl_right.ct_eq(&other.ctrl_right)
472            & self.weight.ct_eq(&other.weight)
473            & self.proof.ct_eq(&other.proof)
474    }
475}
476
477impl<W: VidpfValue> PartialEq for VidpfCorrectionWord<W>
478where
479    W: ConstantTimeEq,
480{
481    fn eq(&self, other: &Self) -> bool {
482        self.ct_eq(other).into()
483    }
484}
485
486/// VIDPF public share.
487#[derive(Clone, Debug, PartialEq)]
488pub struct VidpfPublicShare<W: VidpfValue> {
489    cw: Vec<VidpfCorrectionWord<W>>,
490}
491
492impl<W: VidpfValue> Encode for VidpfPublicShare<W> {
493    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
494        // Control bits
495        let mut control_bits: BitVec<u8, Lsb0> = BitVec::with_capacity(self.cw.len() * 2);
496        for cw in self.cw.iter() {
497            control_bits.push(bool::from(cw.ctrl_left));
498            control_bits.push(bool::from(cw.ctrl_right));
499        }
500        control_bits.set_uninitialized(false);
501        let mut packed_control = control_bits.into_vec();
502        bytes.append(&mut packed_control);
503
504        // Seeds
505        for cw in self.cw.iter() {
506            bytes.extend_from_slice(&cw.seed);
507        }
508
509        // Weights
510        for cw in self.cw.iter() {
511            cw.weight.encode(bytes)?;
512        }
513
514        // Node proofs
515        for cw in self.cw.iter() {
516            bytes.extend_from_slice(&cw.proof);
517        }
518
519        Ok(())
520    }
521
522    fn encoded_len(&self) -> Option<usize> {
523        // We assume the weight has the same length at each level of the tree.
524        let weight_len = self
525            .cw
526            .first()
527            .map_or(Some(0), |cw| cw.weight.encoded_len())?;
528
529        let mut len = 0;
530        len += (2 * self.cw.len()).div_ceil(8); // packed control bits
531        len += self.cw.len() * VIDPF_SEED_SIZE; // seeds
532        len += self.cw.len() * weight_len; // weights
533        len += self.cw.len() * VIDPF_PROOF_SIZE; // nod proofs
534        Some(len)
535    }
536}
537
538impl<W: VidpfValue> ParameterizedDecode<Vidpf<W>> for VidpfPublicShare<W> {
539    fn decode_with_param(vidpf: &Vidpf<W>, bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
540        let bits = usize::from(vidpf.bits);
541        let packed_control_len = bits.div_ceil(4);
542        let mut packed_control_bits = vec![0u8; packed_control_len];
543        bytes.read_exact(&mut packed_control_bits)?;
544        let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed_control_bits);
545
546        // Control bits
547        let mut control_bits = Vec::with_capacity(bits);
548        for chunk in unpacked_control_bits[0..bits * 2].chunks(2) {
549            control_bits.push([(chunk[0] as u8).into(), (chunk[1] as u8).into()]);
550        }
551
552        // Check that unused packed bits are zero.
553        if unpacked_control_bits[bits * 2..].any() {
554            return Err(CodecError::UnexpectedValue);
555        }
556
557        // Seeds
558        let seeds = std::iter::repeat_with(|| Seed::decode(bytes).map(|seed| seed.0))
559            .take(bits)
560            .collect::<Result<Vec<_>, _>>()?;
561
562        // Weights
563        let weights = std::iter::repeat_with(|| W::decode_with_param(&vidpf.weight_len, bytes))
564            .take(bits)
565            .collect::<Result<Vec<_>, _>>()?;
566
567        let proofs = std::iter::repeat_with(|| {
568            let mut proof = [0; VIDPF_PROOF_SIZE];
569            bytes.read_exact(&mut proof)?;
570            Ok::<_, CodecError>(proof)
571        })
572        .take(bits)
573        .collect::<Result<Vec<_>, _>>()?;
574
575        let cw = seeds
576            .into_iter()
577            .zip(
578                control_bits
579                    .into_iter()
580                    .zip(weights.into_iter().zip(proofs)),
581            )
582            .map(
583                |(seed, ([ctrl_left, ctrl_right], (weight, proof)))| VidpfCorrectionWord {
584                    seed,
585                    ctrl_left,
586                    ctrl_right,
587                    weight,
588                    proof,
589                },
590            )
591            .collect::<Vec<_>>();
592
593        Ok(Self { cw })
594    }
595}
596
597/// VIDPF evaluation state.
598#[derive(Debug)]
599pub(crate) struct VidpfEvalState {
600    seed: VidpfSeed,
601    control_bit: Choice,
602    pub(crate) node_proof: VidpfProof,
603}
604
605impl VidpfEvalState {
606    fn init_from_key(id: VidpfServerId, key: &VidpfKey) -> Self {
607        Self {
608            seed: key.0,
609            control_bit: Choice::from(id),
610            node_proof: VidpfProof::default(), // not used
611        }
612    }
613}
614
615/// Result of VIDPF evaluation.
616#[derive(Debug)]
617pub(crate) struct VidpfEvalResult<W: VidpfValue> {
618    pub(crate) state: VidpfEvalState,
619    pub(crate) share: W,
620}
621
622pub(crate) const VIDPF_PROOF_SIZE: usize = 32;
623const VIDPF_SEED_SIZE: usize = 16;
624
625/// Allows to validate user input and shares after evaluation.
626pub(crate) type VidpfProof = [u8; VIDPF_PROOF_SIZE];
627
628pub(crate) fn xor_proof(mut lhs: VidpfProof, rhs: &VidpfProof) -> VidpfProof {
629    zip(&mut lhs, rhs).for_each(|(a, b)| a.bitxor_assign(b));
630    lhs
631}
632
633fn conditional_xor_proof(mut lhs: VidpfProof, rhs: &VidpfProof, choice: Choice) -> VidpfProof {
634    zip(&mut lhs, rhs).for_each(|(a, b)| a.conditional_assign(&a.bitxor(b), choice));
635    lhs
636}
637
638/// Feeds a pseudorandom generator during evaluation.
639type VidpfSeed = [u8; VIDPF_SEED_SIZE];
640
641/// Output of [`extend()`](Vidpf::extend).
642struct ExtendedSeed {
643    seed_left: VidpfSeed,
644    ctrl_left: Choice,
645    seed_right: VidpfSeed,
646    ctrl_right: Choice,
647}
648
649/// Represents an array of field elements that implements the [`VidpfValue`] trait.
650#[derive(Debug, PartialEq, Eq, Clone)]
651pub struct VidpfWeight<F: FieldElement>(pub(crate) Vec<F>);
652
653impl<F: FieldElement> From<Vec<F>> for VidpfWeight<F> {
654    fn from(value: Vec<F>) -> Self {
655        Self(value)
656    }
657}
658
659impl<F: FieldElement> AsRef<[F]> for VidpfWeight<F> {
660    fn as_ref(&self) -> &[F] {
661        &self.0
662    }
663}
664
665impl<F: FieldElement> VidpfValue for VidpfWeight<F> {}
666
667impl<F: FieldElement> IdpfValue for VidpfWeight<F> {
668    /// The parameter determines the number of field elements in the vector.
669    type ValueParameter = usize;
670
671    fn generate<S: RngCore>(seed_stream: &mut S, length: &Self::ValueParameter) -> Self {
672        Self(
673            (0..*length)
674                .map(|_| <F as IdpfValue>::generate(seed_stream, &()))
675                .collect(),
676        )
677    }
678
679    fn zero(length: &Self::ValueParameter) -> Self {
680        Self((0..*length).map(|_| <F as IdpfValue>::zero(&())).collect())
681    }
682
683    /// Panics if weight lengths are different.
684    fn conditional_select(lhs: &Self, rhs: &Self, choice: Choice) -> Self {
685        assert_eq!(
686            lhs.0.len(),
687            rhs.0.len(),
688            "{}",
689            VidpfError::InvalidWeightLength
690        );
691
692        Self(
693            zip(&lhs.0, &rhs.0)
694                .map(|(a, b)| <F as IdpfValue>::conditional_select(a, b, choice))
695                .collect(),
696        )
697    }
698}
699
700impl<F: FieldElement> ConditionallyNegatable for VidpfWeight<F> {
701    fn conditional_negate(&mut self, choice: Choice) {
702        self.0.iter_mut().for_each(|a| a.conditional_negate(choice));
703    }
704}
705
706impl<F: FieldElement> Add for VidpfWeight<F> {
707    type Output = Self;
708
709    /// Panics if weight lengths are different.
710    fn add(self, rhs: Self) -> Self::Output {
711        assert_eq!(
712            self.0.len(),
713            rhs.0.len(),
714            "{}",
715            VidpfError::InvalidWeightLength
716        );
717
718        Self(zip(self.0, rhs.0).map(|(a, b)| a.add(b)).collect())
719    }
720}
721
722impl<F: FieldElement> AddAssign for VidpfWeight<F> {
723    /// Panics if weight lengths are different.
724    fn add_assign(&mut self, rhs: Self) {
725        assert_eq!(
726            self.0.len(),
727            rhs.0.len(),
728            "{}",
729            VidpfError::InvalidWeightLength
730        );
731
732        zip(&mut self.0, rhs.0).for_each(|(a, b)| a.add_assign(b));
733    }
734}
735
736impl<F: FieldElement> Sub for VidpfWeight<F> {
737    type Output = Self;
738
739    /// Panics if weight lengths are different.
740    fn sub(self, rhs: Self) -> Self::Output {
741        assert_eq!(
742            self.0.len(),
743            rhs.0.len(),
744            "{}",
745            VidpfError::InvalidWeightLength
746        );
747
748        Self(zip(self.0, rhs.0).map(|(a, b)| a.sub(b)).collect())
749    }
750}
751
752impl<F: FieldElement> ConstantTimeEq for VidpfWeight<F> {
753    fn ct_eq(&self, other: &Self) -> Choice {
754        self.0[..].ct_eq(&other.0[..])
755    }
756}
757
758impl<F: FieldElement> Encode for VidpfWeight<F> {
759    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
760        for e in &self.0 {
761            F::encode(e, bytes)?;
762        }
763        Ok(())
764    }
765
766    fn encoded_len(&self) -> Option<usize> {
767        Some(self.0.len() * F::ENCODED_SIZE)
768    }
769}
770
771impl<F: FieldElement> ParameterizedDecode<<Self as IdpfValue>::ValueParameter> for VidpfWeight<F> {
772    fn decode_with_param(
773        decoding_parameter: &<Self as IdpfValue>::ValueParameter,
774        bytes: &mut Cursor<&[u8]>,
775    ) -> Result<Self, CodecError> {
776        let mut v = Vec::with_capacity(*decoding_parameter);
777        for _ in 0..*decoding_parameter {
778            v.push(F::decode_with_param(&(), bytes)?);
779        }
780
781        Ok(Self(v))
782    }
783}
784
785#[derive(Copy, Clone)]
786struct VidpfEvalIndex<'a> {
787    bit: Choice,
788    input: &'a VidpfInput,
789    level: u16,
790    bits: u16,
791}
792
793impl VidpfEvalIndex<'_> {
794    fn left_sibling(&self) -> Self {
795        Self {
796            bit: Choice::from(0),
797            input: self.input,
798            level: self.level,
799            bits: self.bits,
800        }
801    }
802
803    fn right_sibling(&self) -> Self {
804        Self {
805            bit: Choice::from(1),
806            input: self.input,
807            level: self.level,
808            bits: self.bits,
809        }
810    }
811
812    fn node_proof(&self, seed: &VidpfSeed, ctx: &[u8]) -> VidpfProof {
813        let mut xof = XofTurboShake128::from_seed_slice(
814            &seed[..],
815            &[&mastic::dst_usage(mastic::USAGE_NODE_PROOF), ctx],
816        );
817        xof.update(&self.bits.to_le_bytes());
818        xof.update(&self.level.to_le_bytes());
819
820        for byte in self
821            .input
822            .index(..=usize::from(self.level))
823            .chunks(8)
824            .enumerate()
825            .map(|(byte_index, chunk)| {
826                let mut byte = 0;
827                for (bit_index, bit) in chunk.iter().enumerate() {
828                    byte |= u8::from(*bit) << (7 - bit_index);
829                }
830
831                // Typically `input[level] == bit` , but `bit` may be overwritten by either
832                // `left_sibling()` or `right_sibling()`. Adjust its value accordingly.
833                if byte_index == usize::from(self.level) / 8 {
834                    let bit_index = self.level % 8;
835                    let m = 1 << (7 - bit_index);
836                    byte = u8::conditional_select(&(byte & !m), &(byte | m), self.bit);
837                }
838                byte
839            })
840        {
841            xof.update(&[byte]);
842        }
843        xof.into_seed().0
844    }
845}
846
847#[cfg(test)]
848mod tests {
849
850    use crate::field::Field128;
851
852    use super::VidpfWeight;
853
854    type TestWeight = VidpfWeight<Field128>;
855    const TEST_WEIGHT_LEN: usize = 3;
856    const TEST_NONCE_SIZE: usize = 16;
857    const TEST_NONCE: &[u8; TEST_NONCE_SIZE] = b"Test Nonce VIDPF";
858
859    mod vidpf {
860        use crate::{
861            codec::{Encode, ParameterizedDecode},
862            idpf::IdpfValue,
863            vidpf::{
864                Vidpf, VidpfCorrectionWord, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare,
865                VidpfServerId,
866            },
867        };
868
869        use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN};
870
871        #[test]
872        fn roundtrip_codec() {
873            let ctx = b"appliction context";
874            let input = VidpfInput::from_bytes(&[0xFF]);
875            let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
876            let (vidpf, public, _, _) = vidpf_gen_setup(ctx, &input, &weight);
877
878            let bytes = public.get_encoded().unwrap();
879            assert_eq!(public.encoded_len().unwrap(), bytes.len());
880
881            let decoded = VidpfPublicShare::get_decoded_with_param(&vidpf, &bytes).unwrap();
882            assert_eq!(public, decoded);
883        }
884
885        fn vidpf_gen_setup(
886            ctx: &[u8],
887            input: &VidpfInput,
888            weight: &TestWeight,
889        ) -> (
890            Vidpf<TestWeight>,
891            VidpfPublicShare<TestWeight>,
892            [VidpfKey; 2],
893            [u8; TEST_NONCE_SIZE],
894        ) {
895            let vidpf = Vidpf::new(input.len(), TEST_WEIGHT_LEN).unwrap();
896            let (public, keys) = vidpf.gen(ctx, input, weight, TEST_NONCE).unwrap();
897            (vidpf, public, keys, *TEST_NONCE)
898        }
899
900        #[test]
901        fn correctness_at_last_level() {
902            let ctx = b"some application";
903            let input = VidpfInput::from_bytes(&[0xFF]);
904            let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
905            let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(ctx, &input, &weight);
906
907            let (value_share_0, onehot_proof_0) = vidpf
908                .eval(ctx, VidpfServerId::S0, &key_0, &public, &input, &nonce)
909                .unwrap();
910            let (value_share_1, onehot_proof_1) = vidpf
911                .eval(ctx, VidpfServerId::S1, &key_1, &public, &input, &nonce)
912                .unwrap();
913
914            assert_eq!(
915                value_share_0 + value_share_1,
916                weight,
917                "shares must add up to the expected weight",
918            );
919
920            assert_eq!(onehot_proof_0, onehot_proof_1, "proofs must be equal");
921
922            let bad_input = VidpfInput::from_bytes(&[0x00]);
923            let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
924            let (value_share_0, onehot_proof_0) = vidpf
925                .eval(ctx, VidpfServerId::S0, &key_0, &public, &bad_input, &nonce)
926                .unwrap();
927            let (value_share_1, onehot_proof_1) = vidpf
928                .eval(ctx, VidpfServerId::S1, &key_1, &public, &bad_input, &nonce)
929                .unwrap();
930
931            assert_eq!(
932                value_share_0 + value_share_1,
933                zero,
934                "shares must add up to zero",
935            );
936
937            assert_eq!(onehot_proof_0, onehot_proof_1, "proofs must be equal");
938        }
939
940        #[test]
941        fn correctness_at_each_level() {
942            let ctx = b"application context";
943            let input = VidpfInput::from_bytes(&[0xFF]);
944            let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
945            let (vidpf, public, keys, nonce) = vidpf_gen_setup(ctx, &input, &weight);
946
947            assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &input, &weight, &nonce);
948
949            let bad_input = VidpfInput::from_bytes(&[0x00]);
950            let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
951
952            assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &bad_input, &zero, &nonce);
953        }
954
955        fn assert_eval_at_each_level(
956            vidpf: &Vidpf<TestWeight>,
957            ctx: &[u8],
958            [key_0, key_1]: &[VidpfKey; 2],
959            public: &VidpfPublicShare<TestWeight>,
960            input: &VidpfInput,
961            weight: &TestWeight,
962            nonce: &[u8],
963        ) {
964            let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0);
965            let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1);
966
967            for (idx, cw) in vidpf.index_iter(input).unwrap().zip(public.cw.iter()) {
968                let r0 = vidpf.eval_next(ctx, cw, idx, &state_0, nonce);
969                let r1 = vidpf.eval_next(ctx, cw, idx, &state_1, nonce);
970
971                assert_eq!(
972                    r0.share - r1.share,
973                    *weight,
974                    "shares must add up to the expected weight at the current level: {:?}",
975                    idx.level
976                );
977
978                assert_eq!(
979                    r0.state.node_proof, r1.state.node_proof,
980                    "proofs must be equal at the current level: {:?}",
981                    idx.level
982                );
983
984                state_0 = r0.state;
985                state_1 = r1.state;
986            }
987        }
988
989        // Assert that the length of the weight is the same at each level of the tree. This
990        // assumption is made in `PublicShare::encoded_len()`.
991        #[test]
992        fn public_share_weight_len() {
993            let input = VidpfInput::from_bools(&vec![false; 237]);
994            let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
995            let (vidpf, public, _, _) = vidpf_gen_setup(b"some application", &input, &weight);
996            for VidpfCorrectionWord { weight, .. } in public.cw {
997                assert_eq!(weight.0.len(), vidpf.weight_len);
998            }
999        }
1000    }
1001
1002    mod weight {
1003        use std::io::Cursor;
1004        use subtle::{Choice, ConditionallyNegatable};
1005
1006        use crate::{
1007            codec::{Encode, ParameterizedDecode},
1008            idpf::IdpfValue,
1009            vdaf::xof::{Xof, XofTurboShake128},
1010        };
1011
1012        use super::{TestWeight, TEST_WEIGHT_LEN};
1013
1014        #[test]
1015        fn roundtrip_codec() {
1016            let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
1017
1018            let mut bytes = vec![];
1019            weight.encode(&mut bytes).unwrap();
1020
1021            let expected_bytes = [
1022                [vec![21], vec![0u8; 15]].concat(),
1023                [vec![22], vec![0u8; 15]].concat(),
1024                [vec![23], vec![0u8; 15]].concat(),
1025            ]
1026            .concat();
1027
1028            assert_eq!(weight.encoded_len().unwrap(), expected_bytes.len());
1029            // Check endianness of encoding
1030            assert_eq!(bytes, expected_bytes);
1031
1032            let decoded =
1033                TestWeight::decode_with_param(&TEST_WEIGHT_LEN, &mut Cursor::new(&bytes)).unwrap();
1034            assert_eq!(weight, decoded);
1035        }
1036
1037        #[test]
1038        fn add_sub() {
1039            let [a, b] = compatible_weights();
1040            let mut c = a.clone();
1041            c += a.clone();
1042
1043            assert_eq!(
1044                (a.clone() + b.clone()) + (a.clone() - b.clone()),
1045                c,
1046                "a: {:?} b:{:?}",
1047                a,
1048                b
1049            );
1050        }
1051
1052        #[test]
1053        fn conditional_negate() {
1054            let [a, _] = compatible_weights();
1055            let mut c = a.clone();
1056            c.conditional_negate(Choice::from(0));
1057            let mut d = a.clone();
1058            d.conditional_negate(Choice::from(1));
1059            let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
1060
1061            assert_eq!(c + d, zero, "a: {:?}", a);
1062        }
1063
1064        #[test]
1065        #[should_panic = "invalid weight length"]
1066        fn add_panics() {
1067            let [w0, w1] = incompatible_weights();
1068            let _ = w0 + w1;
1069        }
1070
1071        #[test]
1072        #[should_panic = "invalid weight length"]
1073        fn add_assign_panics() {
1074            let [mut w0, w1] = incompatible_weights();
1075            w0 += w1;
1076        }
1077
1078        #[test]
1079        #[should_panic = "invalid weight length"]
1080        fn sub_panics() {
1081            let [w0, w1] = incompatible_weights();
1082            let _ = w0 - w1;
1083        }
1084
1085        #[test]
1086        #[should_panic = "invalid weight length"]
1087        fn conditional_select_panics() {
1088            let [w0, w1] = incompatible_weights();
1089            TestWeight::conditional_select(&w0, &w1, Choice::from(0));
1090        }
1091
1092        fn compatible_weights() -> [TestWeight; 2] {
1093            let mut xof = XofTurboShake128::seed_stream(&[0; 32], &[], &[]);
1094            [
1095                TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1096                TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1097            ]
1098        }
1099
1100        fn incompatible_weights() -> [TestWeight; 2] {
1101            let mut xof = XofTurboShake128::seed_stream(&[0; 32], &[], &[]);
1102            [
1103                TestWeight::generate(&mut xof, &TEST_WEIGHT_LEN),
1104                TestWeight::generate(&mut xof, &(2 * TEST_WEIGHT_LEN)),
1105            ]
1106        }
1107    }
1108}