prio/topology/
ping_pong.rs

1// SPDX-License-Identifier: MPL-2.0
2
3//! Implements the Ping-Pong Topology described in [VDAF]. This topology assumes there are exactly
4//! two aggregators, designated "Leader" and "Helper". This topology is required for implementing
5//! the [Distributed Aggregation Protocol][DAP].
6//!
7//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
8//! [DAP]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap
9
10use crate::{
11    codec::{decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode},
12    vdaf::{Aggregator, PrepareTransition, VdafError},
13};
14use std::fmt::Debug;
15
16/// Errors emitted by this module.
17#[derive(Debug, thiserror::Error)]
18#[non_exhaustive]
19pub enum PingPongError {
20    /// Error running prepare_init
21    #[error("vdaf.prepare_init: {0}")]
22    VdafPrepareInit(VdafError),
23
24    /// Error running prepare_shares_to_prepare_message
25    #[error("vdaf.prepare_shares_to_prepare_message {0}")]
26    VdafPrepareSharesToPrepareMessage(VdafError),
27
28    /// Error running prepare_next
29    #[error("vdaf.prepare_next {0}")]
30    VdafPrepareNext(VdafError),
31
32    /// Error encoding or decoding a prepare share
33    #[error("encode/decode prep share {0}")]
34    CodecPrepShare(CodecError),
35
36    /// Error encoding or decoding a prepare message
37    #[error("encode/decode prep message {0}")]
38    CodecPrepMessage(CodecError),
39
40    /// Host is in an unexpected state
41    #[error("host state mismatch: in {found} expected {expected}")]
42    HostStateMismatch {
43        /// The state the host is in.
44        found: &'static str,
45        /// The state the host expected to be in.
46        expected: &'static str,
47    },
48
49    /// Message from peer indicates it is in an unexpected state
50    #[error("peer message mismatch: message is {found} expected {expected}")]
51    PeerMessageMismatch {
52        /// The state in the message from the peer.
53        found: &'static str,
54        /// The message expected from the peer.
55        expected: &'static str,
56    },
57
58    /// Internal error
59    #[error("internal error: {0}")]
60    InternalError(&'static str),
61}
62
63/// Corresponds to `struct Message` in [VDAF's Ping-Pong Topology][VDAF]. All of the fields of the
64/// variants are opaque byte buffers. This is because the ping-pong routines take responsibility for
65/// decoding preparation shares and messages, which usually requires having the preparation state.
66///
67/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
68#[derive(Clone, PartialEq, Eq)]
69pub enum PingPongMessage {
70    /// Corresponds to MessageType.initialize.
71    Initialize {
72        /// The leader's initial preparation share.
73        prep_share: Vec<u8>,
74    },
75    /// Corresponds to MessageType.continue.
76    Continue {
77        /// The current round's preparation message.
78        prep_msg: Vec<u8>,
79        /// The next round's preparation share.
80        prep_share: Vec<u8>,
81    },
82    /// Corresponds to MessageType.finish.
83    Finish {
84        /// The current round's preparation message.
85        prep_msg: Vec<u8>,
86    },
87}
88
89impl PingPongMessage {
90    fn variant(&self) -> &'static str {
91        match self {
92            Self::Initialize { .. } => "Initialize",
93            Self::Continue { .. } => "Continue",
94            Self::Finish { .. } => "Finish",
95        }
96    }
97}
98
99impl Debug for PingPongMessage {
100    // We want `PingPongMessage` to implement `Debug`, but we don't want that impl to print out
101    // prepare shares or messages, because (1) their contents are sensitive and (2) their contents
102    // are long and not intelligible to humans. For both reasons they generally shouldn't get
103    // logged. Normally, we'd use the `derivative` crate to customize a derived `Debug`, but that
104    // crate has not been audited (in the `cargo vet` sense) so we can't use it here unless we audit
105    // 8,000+ lines of proc macros.
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_tuple(self.variant()).finish()
108    }
109}
110
111impl Encode for PingPongMessage {
112    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
113        // The encoding includes an implicit discriminator byte, called MessageType in the VDAF
114        // spec.
115        match self {
116            Self::Initialize { prep_share } => {
117                0u8.encode(bytes)?;
118                encode_u32_items(bytes, &(), prep_share)?;
119            }
120            Self::Continue {
121                prep_msg,
122                prep_share,
123            } => {
124                1u8.encode(bytes)?;
125                encode_u32_items(bytes, &(), prep_msg)?;
126                encode_u32_items(bytes, &(), prep_share)?;
127            }
128            Self::Finish { prep_msg } => {
129                2u8.encode(bytes)?;
130                encode_u32_items(bytes, &(), prep_msg)?;
131            }
132        }
133        Ok(())
134    }
135
136    fn encoded_len(&self) -> Option<usize> {
137        match self {
138            Self::Initialize { prep_share } => Some(1 + 4 + prep_share.len()),
139            Self::Continue {
140                prep_msg,
141                prep_share,
142            } => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()),
143            Self::Finish { prep_msg } => Some(1 + 4 + prep_msg.len()),
144        }
145    }
146}
147
148impl Decode for PingPongMessage {
149    fn decode(bytes: &mut std::io::Cursor<&[u8]>) -> Result<Self, CodecError> {
150        let message_type = u8::decode(bytes)?;
151        Ok(match message_type {
152            0 => {
153                let prep_share = decode_u32_items(&(), bytes)?;
154                Self::Initialize { prep_share }
155            }
156            1 => {
157                let prep_msg = decode_u32_items(&(), bytes)?;
158                let prep_share = decode_u32_items(&(), bytes)?;
159                Self::Continue {
160                    prep_msg,
161                    prep_share,
162                }
163            }
164            2 => {
165                let prep_msg = decode_u32_items(&(), bytes)?;
166                Self::Finish { prep_msg }
167            }
168            _ => return Err(CodecError::UnexpectedValue),
169        })
170    }
171}
172
173/// A transition in the pong-pong topology. This represents the `ping_pong_transition` function
174/// defined in [VDAF].
175///
176/// # Discussion
177///
178/// The obvious implementation of `ping_pong_transition` would be a method on trait
179/// [`PingPongTopology`] that returns `(State, Message)`, and then `ContinuedValue::WithMessage`
180/// would contain those values. But then DAP implementations would have to store relatively large
181/// VDAF prepare shares between rounds of input preparation.
182///
183/// Instead, this structure stores just the previous round's prepare state and the current round's
184/// preprocessed prepare message. Their encoding is much smaller than the `(State, Message)` tuple,
185/// which can always be recomputed with [`Self::evaluate`].
186///
187/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
188#[derive(Clone, Debug, Eq)]
189pub struct PingPongTransition<
190    const VERIFY_KEY_SIZE: usize,
191    const NONCE_SIZE: usize,
192    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
193> {
194    previous_prepare_state: A::PrepareState,
195    current_prepare_message: A::PrepareMessage,
196}
197
198impl<
199        const VERIFY_KEY_SIZE: usize,
200        const NONCE_SIZE: usize,
201        A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
202    > PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
203{
204    /// Evaluate this transition to obtain a new [`PingPongState`] and a [`PingPongMessage`] which
205    /// should be transmitted to the peer.
206    #[allow(clippy::type_complexity)]
207    pub fn evaluate(
208        &self,
209        vdaf: &A,
210    ) -> Result<
211        (
212            PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
213            PingPongMessage,
214        ),
215        PingPongError,
216    > {
217        let prep_msg = self
218            .current_prepare_message
219            .get_encoded()
220            .map_err(PingPongError::CodecPrepMessage)?;
221
222        vdaf.prepare_next(
223            self.previous_prepare_state.clone(),
224            self.current_prepare_message.clone(),
225        )
226        .map_err(PingPongError::VdafPrepareNext)
227        .and_then(|transition| match transition {
228            PrepareTransition::Continue(prep_state, prep_share) => Ok((
229                PingPongState::Continued(prep_state),
230                PingPongMessage::Continue {
231                    prep_msg,
232                    prep_share: prep_share
233                        .get_encoded()
234                        .map_err(PingPongError::CodecPrepShare)?,
235                },
236            )),
237            PrepareTransition::Finish(output_share) => Ok((
238                PingPongState::Finished(output_share),
239                PingPongMessage::Finish { prep_msg },
240            )),
241        })
242    }
243}
244
245impl<
246        const VERIFY_KEY_SIZE: usize,
247        const NONCE_SIZE: usize,
248        A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
249    > PartialEq for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
250{
251    fn eq(&self, other: &Self) -> bool {
252        self.previous_prepare_state == other.previous_prepare_state
253            && self.current_prepare_message == other.current_prepare_message
254    }
255}
256
257impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> Encode
258    for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
259where
260    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
261    A::PrepareState: Encode,
262{
263    fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
264        self.previous_prepare_state.encode(bytes)?;
265        self.current_prepare_message.encode(bytes)
266    }
267
268    fn encoded_len(&self) -> Option<usize> {
269        Some(
270            self.previous_prepare_state.encoded_len()?
271                + self.current_prepare_message.encoded_len()?,
272        )
273    }
274}
275
276impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A, PrepareStateDecode>
277    ParameterizedDecode<PrepareStateDecode> for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
278where
279    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
280    A::PrepareState: ParameterizedDecode<PrepareStateDecode> + PartialEq,
281    A::PrepareMessage: PartialEq,
282{
283    fn decode_with_param(
284        decoding_param: &PrepareStateDecode,
285        bytes: &mut std::io::Cursor<&[u8]>,
286    ) -> Result<Self, CodecError> {
287        let previous_prepare_state = A::PrepareState::decode_with_param(decoding_param, bytes)?;
288        let current_prepare_message =
289            A::PrepareMessage::decode_with_param(&previous_prepare_state, bytes)?;
290
291        Ok(Self {
292            previous_prepare_state,
293            current_prepare_message,
294        })
295    }
296}
297
298/// Corresponds to the `State` enumeration implicitly defined in [VDAF's Ping-Pong Topology][VDAF].
299/// VDAF describes `Start` and `Rejected` states, but the `Start` state is never instantiated in
300/// code, and the `Rejected` state is represented as `std::result::Result::Err`, so this enum does
301/// not include those variants.
302///
303/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
304#[derive(Clone, Debug, PartialEq, Eq)]
305pub enum PingPongState<
306    const VERIFY_KEY_SIZE: usize,
307    const NONCE_SIZE: usize,
308    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
309> {
310    /// Preparation of the report will continue with the enclosed state.
311    Continued(A::PrepareState),
312    /// Preparation of the report is finished and has yielded the enclosed output share.
313    Finished(A::OutputShare),
314}
315
316/// Values returned by [`PingPongTopology::leader_continued`] or
317/// [`PingPongTopology::helper_continued`].
318#[derive(Clone, Debug)]
319pub enum PingPongContinuedValue<
320    const VERIFY_KEY_SIZE: usize,
321    const NONCE_SIZE: usize,
322    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
323> {
324    /// The operation resulted in a new state and a message to transmit to the peer.
325    WithMessage {
326        /// The transition that will be executed. Call `PingPongTransition::evaluate` to obtain the
327        /// next
328        /// [`PingPongState`] and a [`PingPongMessage`] to transmit to the peer.
329        transition: PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
330    },
331    /// The operation caused the host to finish preparation of the input share, yielding an output
332    /// share and no message for the peer.
333    FinishedNoMessage {
334        /// The output share which may now be accumulated.
335        output_share: A::OutputShare,
336    },
337}
338
339/// Extension trait on [`crate::vdaf::Aggregator`] which adds the [VDAF Ping-Pong Topology][VDAF].
340///
341/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
342pub trait PingPongTopology<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
343    Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>
344{
345    /// Specialization of [`PingPongState`] for this VDAF.
346    type State;
347    /// Specialization of [`PingPongContinuedValue`] for this VDAF.
348    type ContinuedValue;
349    /// Specializaton of [`PingPongTransition`] for this VDAF.
350    type Transition;
351
352    /// Initialize leader state using the leader's input share. Corresponds to
353    /// `ping_pong_leader_init` in [VDAF].
354    ///
355    /// If successful, the returned [`PingPongMessage`] (which will always be
356    /// `PingPongMessage::Initialize`) should be transmitted to the helper. The returned
357    /// [`PingPongState`] (which will always be `PingPongState::Continued`) should be used by the
358    /// leader along with the next [`PingPongMessage`] received from the helper as input to
359    /// [`Self::leader_continued`] to advance to the next round.
360    ///
361    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
362    fn leader_initialized(
363        &self,
364        verify_key: &[u8; VERIFY_KEY_SIZE],
365        agg_param: &Self::AggregationParam,
366        nonce: &[u8; NONCE_SIZE],
367        public_share: &Self::PublicShare,
368        input_share: &Self::InputShare,
369    ) -> Result<(Self::State, PingPongMessage), PingPongError>;
370
371    /// Initialize helper state using the helper's input share and the leader's first prepare share.
372    /// Corresponds to `ping_pong_helper_init` in [VDAF].
373    ///
374    /// If successful, the returned [`PingPongTransition`] should be evaluated, yielding a
375    /// [`PingPongMessage`], which should be transmitted to the leader, and a [`PingPongState`].
376    ///
377    /// If the state is `PingPongState::Continued`, then it should be used by the helper along with
378    /// the next `PingPongMessage` received from the leader as input to [`Self::helper_continued`]
379    /// to advance to the next round. The helper may store the `PingPongTransition` between rounds
380    /// of preparation instead of the `PingPongState` and `PingPongMessage`.
381    ///
382    /// If the state is `PingPongState::Finished`, then preparation is finished and the output share
383    /// may be accumulated.
384    ///
385    /// # Errors
386    ///
387    /// `inbound` must be `PingPongMessage::Initialize` or the function will fail.
388    ///
389    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
390    fn helper_initialized(
391        &self,
392        verify_key: &[u8; VERIFY_KEY_SIZE],
393        agg_param: &Self::AggregationParam,
394        nonce: &[u8; NONCE_SIZE],
395        public_share: &Self::PublicShare,
396        input_share: &Self::InputShare,
397        inbound: &PingPongMessage,
398    ) -> Result<PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>, PingPongError>;
399
400    /// Continue preparation based on the leader's current state and an incoming [`PingPongMessage`]
401    /// from the helper. Corresponds to `ping_pong_leader_continued` in [VDAF].
402    ///
403    /// If successful, the returned [`PingPongContinuedValue`] will either be:
404    ///
405    /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated,
406    ///   yielding a [`PingPongMessage`], which should be transmitted to the helper, and a
407    ///   [`PingPongState`].
408    ///
409    ///   If the state is `PingPongState::Continued`, then it should be used by the leader along
410    ///   with the next `PingPongMessage` received from the helper as input to
411    ///   [`Self::leader_continued`] to advance to the next round. The leader may store the
412    ///   `PingPongTransition` between rounds of preparation instead of of the `PingPongState` and
413    ///   `PingPongMessage`.
414    ///
415    ///   If the state is `PingPongState::Finished`, then preparation is finished and the output
416    ///   share may be accumulated.
417    ///
418    /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share
419    ///   may be accumulated. No message needs to be sent to the helper.
420    ///
421    /// # Errors
422    ///
423    /// `leader_state` must be `PingPongState::Continued` or the function will fail.
424    ///
425    /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail.
426    ///
427    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
428    fn leader_continued(
429        &self,
430        leader_state: Self::State,
431        agg_param: &Self::AggregationParam,
432        inbound: &PingPongMessage,
433    ) -> Result<Self::ContinuedValue, PingPongError>;
434
435    /// PingPongContinue preparation based on the helper's current state and an incoming
436    /// [`PingPongMessage`] from the leader. Corresponds to `ping_pong_helper_contnued` in [VDAF].
437    ///
438    /// If successful, the returned [`PingPongContinuedValue`] will either be:
439    ///
440    /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated,
441    ///   yielding a [`PingPongMessage`], which should be transmitted to the leader, and a
442    ///   [`PingPongState`].
443    ///
444    ///   If the state is `PingPongState::Continued`, then it should be used by the helper along
445    ///   with the next `PingPongMessage` received from the leader as input to
446    ///   [`Self::helper_continued`] to advance to the next round. The helper may store the
447    ///   `PingPongTransition` between rounds of preparation instead of the `PingPongState` and
448    ///   `PingPongMessage`.
449    ///
450    ///   If the state is `PingPongState::Finished`, then preparation is finished and the output
451    ///   share may be accumulated.
452    ///
453    /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share
454    ///   may be accumulated. No message needs to be sent to the leader.
455    ///
456    /// # Errors
457    ///
458    /// `helper_state` must be `PingPongState::Continued` or the function will fail.
459    ///
460    /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail.
461    ///
462    /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.8
463    fn helper_continued(
464        &self,
465        helper_state: Self::State,
466        agg_param: &Self::AggregationParam,
467        inbound: &PingPongMessage,
468    ) -> Result<Self::ContinuedValue, PingPongError>;
469}
470
471/// Private interfaces for implementing ping-pong
472trait PingPongTopologyPrivate<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
473    PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE>
474{
475    fn continued(
476        &self,
477        is_leader: bool,
478        host_state: Self::State,
479        agg_param: &Self::AggregationParam,
480        inbound: &PingPongMessage,
481    ) -> Result<Self::ContinuedValue, PingPongError>;
482}
483
484impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
485    PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE> for A
486where
487    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
488{
489    type State = PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
490    type ContinuedValue = PingPongContinuedValue<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
491    type Transition = PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
492
493    fn leader_initialized(
494        &self,
495        verify_key: &[u8; VERIFY_KEY_SIZE],
496        agg_param: &Self::AggregationParam,
497        nonce: &[u8; NONCE_SIZE],
498        public_share: &Self::PublicShare,
499        input_share: &Self::InputShare,
500    ) -> Result<(Self::State, PingPongMessage), PingPongError> {
501        self.prepare_init(
502            verify_key,
503            /* Leader */ 0,
504            agg_param,
505            nonce,
506            public_share,
507            input_share,
508        )
509        .map_err(PingPongError::VdafPrepareInit)
510        .and_then(|(prep_state, prep_share)| {
511            Ok((
512                PingPongState::Continued(prep_state),
513                PingPongMessage::Initialize {
514                    prep_share: prep_share
515                        .get_encoded()
516                        .map_err(PingPongError::CodecPrepShare)?,
517                },
518            ))
519        })
520    }
521
522    fn helper_initialized(
523        &self,
524        verify_key: &[u8; VERIFY_KEY_SIZE],
525        agg_param: &Self::AggregationParam,
526        nonce: &[u8; NONCE_SIZE],
527        public_share: &Self::PublicShare,
528        input_share: &Self::InputShare,
529        inbound: &PingPongMessage,
530    ) -> Result<Self::Transition, PingPongError> {
531        let (prep_state, prep_share) = self
532            .prepare_init(
533                verify_key,
534                /* Helper */ 1,
535                agg_param,
536                nonce,
537                public_share,
538                input_share,
539            )
540            .map_err(PingPongError::VdafPrepareInit)?;
541
542        let inbound_prep_share = if let PingPongMessage::Initialize { prep_share } = inbound {
543            Self::PrepareShare::get_decoded_with_param(&prep_state, prep_share)
544                .map_err(PingPongError::CodecPrepShare)?
545        } else {
546            return Err(PingPongError::PeerMessageMismatch {
547                found: inbound.variant(),
548                expected: "initialize",
549            });
550        };
551
552        let current_prepare_message = self
553            .prepare_shares_to_prepare_message(agg_param, [inbound_prep_share, prep_share])
554            .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
555
556        Ok(PingPongTransition {
557            previous_prepare_state: prep_state,
558            current_prepare_message,
559        })
560    }
561
562    fn leader_continued(
563        &self,
564        leader_state: Self::State,
565        agg_param: &Self::AggregationParam,
566        inbound: &PingPongMessage,
567    ) -> Result<Self::ContinuedValue, PingPongError> {
568        self.continued(true, leader_state, agg_param, inbound)
569    }
570
571    fn helper_continued(
572        &self,
573        helper_state: Self::State,
574        agg_param: &Self::AggregationParam,
575        inbound: &PingPongMessage,
576    ) -> Result<Self::ContinuedValue, PingPongError> {
577        self.continued(false, helper_state, agg_param, inbound)
578    }
579}
580
581impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
582    PingPongTopologyPrivate<VERIFY_KEY_SIZE, NONCE_SIZE> for A
583where
584    A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
585{
586    fn continued(
587        &self,
588        is_leader: bool,
589        host_state: Self::State,
590        agg_param: &Self::AggregationParam,
591        inbound: &PingPongMessage,
592    ) -> Result<Self::ContinuedValue, PingPongError> {
593        let host_prep_state = if let PingPongState::Continued(state) = host_state {
594            state
595        } else {
596            return Err(PingPongError::HostStateMismatch {
597                found: "finished",
598                expected: "continue",
599            });
600        };
601
602        let (prep_msg, next_peer_prep_share) = match inbound {
603            PingPongMessage::Initialize { .. } => {
604                return Err(PingPongError::PeerMessageMismatch {
605                    found: inbound.variant(),
606                    expected: "continue",
607                });
608            }
609            PingPongMessage::Continue {
610                prep_msg,
611                prep_share,
612            } => (prep_msg, Some(prep_share)),
613            PingPongMessage::Finish { prep_msg } => (prep_msg, None),
614        };
615
616        let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg)
617            .map_err(PingPongError::CodecPrepMessage)?;
618        let host_prep_transition = self
619            .prepare_next(host_prep_state, prep_msg)
620            .map_err(PingPongError::VdafPrepareNext)?;
621
622        match (host_prep_transition, next_peer_prep_share) {
623            (
624                PrepareTransition::Continue(next_prep_state, next_host_prep_share),
625                Some(next_peer_prep_share),
626            ) => {
627                let next_peer_prep_share = Self::PrepareShare::get_decoded_with_param(
628                    &next_prep_state,
629                    next_peer_prep_share,
630                )
631                .map_err(PingPongError::CodecPrepShare)?;
632                let mut prep_shares = [next_peer_prep_share, next_host_prep_share];
633                if is_leader {
634                    prep_shares.reverse();
635                }
636                let current_prepare_message = self
637                    .prepare_shares_to_prepare_message(agg_param, prep_shares)
638                    .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
639
640                Ok(PingPongContinuedValue::WithMessage {
641                    transition: PingPongTransition {
642                        previous_prepare_state: next_prep_state,
643                        current_prepare_message,
644                    },
645                })
646            }
647            (PrepareTransition::Finish(output_share), None) => {
648                Ok(PingPongContinuedValue::FinishedNoMessage { output_share })
649            }
650            (PrepareTransition::Continue(_, _), None) => Err(PingPongError::PeerMessageMismatch {
651                found: inbound.variant(),
652                expected: "continue",
653            }),
654            (PrepareTransition::Finish(_), Some(_)) => Err(PingPongError::PeerMessageMismatch {
655                found: inbound.variant(),
656                expected: "finish",
657            }),
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use std::io::Cursor;
665
666    use super::*;
667    use crate::vdaf::dummy;
668    use assert_matches::assert_matches;
669
670    #[test]
671    fn ping_pong_one_round() {
672        let verify_key = [];
673        let aggregation_param = dummy::AggregationParam(0);
674        let nonce = [0; 16];
675        #[allow(clippy::let_unit_value)]
676        let public_share = ();
677        let input_share = dummy::InputShare(0);
678
679        let leader = dummy::Vdaf::new(1);
680        let helper = dummy::Vdaf::new(1);
681
682        // Leader inits into round 0
683        let (leader_state, leader_message) = leader
684            .leader_initialized(
685                &verify_key,
686                &aggregation_param,
687                &nonce,
688                &public_share,
689                &input_share,
690            )
691            .unwrap();
692
693        // Helper inits into round 1
694        let (helper_state, helper_message) = helper
695            .helper_initialized(
696                &verify_key,
697                &aggregation_param,
698                &nonce,
699                &public_share,
700                &input_share,
701                &leader_message,
702            )
703            .unwrap()
704            .evaluate(&helper)
705            .unwrap();
706
707        // 1 round VDAF: helper should finish immediately.
708        assert_matches!(helper_state, PingPongState::Finished(_));
709
710        let leader_state = leader
711            .leader_continued(leader_state, &aggregation_param, &helper_message)
712            .unwrap();
713        // 1 round VDAF: leader should finish when it gets helper message and emit no message.
714        assert_matches!(
715            leader_state,
716            PingPongContinuedValue::FinishedNoMessage { .. }
717        );
718    }
719
720    #[test]
721    fn ping_pong_two_rounds() {
722        let verify_key = [];
723        let aggregation_param = dummy::AggregationParam(0);
724        let nonce = [0; 16];
725        #[allow(clippy::let_unit_value)]
726        let public_share = ();
727        let input_share = dummy::InputShare(0);
728
729        let leader = dummy::Vdaf::new(2);
730        let helper = dummy::Vdaf::new(2);
731
732        // Leader inits into round 0
733        let (leader_state, leader_message) = leader
734            .leader_initialized(
735                &verify_key,
736                &aggregation_param,
737                &nonce,
738                &public_share,
739                &input_share,
740            )
741            .unwrap();
742
743        // Helper inits into round 1
744        let (helper_state, helper_message) = helper
745            .helper_initialized(
746                &verify_key,
747                &aggregation_param,
748                &nonce,
749                &public_share,
750                &input_share,
751                &leader_message,
752            )
753            .unwrap()
754            .evaluate(&helper)
755            .unwrap();
756
757        // 2 round VDAF, round 1: helper should continue.
758        assert_matches!(helper_state, PingPongState::Continued(_));
759
760        let leader_state = leader
761            .leader_continued(leader_state, &aggregation_param, &helper_message)
762            .unwrap();
763        // 2 round VDAF, round 1: leader should finish and emit a finish message.
764        let leader_message = assert_matches!(
765            leader_state, PingPongContinuedValue::WithMessage { transition } => {
766                let (state, message) = transition.evaluate(&leader).unwrap();
767                assert_matches!(state, PingPongState::Finished(_));
768                message
769            }
770        );
771
772        let helper_state = helper
773            .helper_continued(helper_state, &aggregation_param, &leader_message)
774            .unwrap();
775        // 2 round vdaf, round 1: helper should finish and emit no message.
776        assert_matches!(
777            helper_state,
778            PingPongContinuedValue::FinishedNoMessage { .. }
779        );
780    }
781
782    #[test]
783    fn ping_pong_three_rounds() {
784        let verify_key = [];
785        let aggregation_param = dummy::AggregationParam(0);
786        let nonce = [0; 16];
787        #[allow(clippy::let_unit_value)]
788        let public_share = ();
789        let input_share = dummy::InputShare(0);
790
791        let leader = dummy::Vdaf::new(3);
792        let helper = dummy::Vdaf::new(3);
793
794        // Leader inits into round 0
795        let (leader_state, leader_message) = leader
796            .leader_initialized(
797                &verify_key,
798                &aggregation_param,
799                &nonce,
800                &public_share,
801                &input_share,
802            )
803            .unwrap();
804
805        // Helper inits into round 1
806        let (helper_state, helper_message) = helper
807            .helper_initialized(
808                &verify_key,
809                &aggregation_param,
810                &nonce,
811                &public_share,
812                &input_share,
813                &leader_message,
814            )
815            .unwrap()
816            .evaluate(&helper)
817            .unwrap();
818
819        // 3 round VDAF, round 1: helper should continue.
820        assert_matches!(helper_state, PingPongState::Continued(_));
821
822        let leader_state = leader
823            .leader_continued(leader_state, &aggregation_param, &helper_message)
824            .unwrap();
825        // 3 round VDAF, round 1: leader should continue and emit a continue message.
826        let (leader_state, leader_message) = assert_matches!(
827            leader_state, PingPongContinuedValue::WithMessage { transition } => {
828                let (state, message) = transition.evaluate(&leader).unwrap();
829                assert_matches!(state, PingPongState::Continued(_));
830                (state, message)
831            }
832        );
833
834        let helper_state = helper
835            .helper_continued(helper_state, &aggregation_param, &leader_message)
836            .unwrap();
837        // 3 round vdaf, round 2: helper should finish and emit a finish message.
838        let helper_message = assert_matches!(
839            helper_state, PingPongContinuedValue::WithMessage { transition } => {
840                let (state, message) = transition.evaluate(&helper).unwrap();
841                assert_matches!(state, PingPongState::Finished(_));
842                message
843            }
844        );
845
846        let leader_state = leader
847            .leader_continued(leader_state, &aggregation_param, &helper_message)
848            .unwrap();
849        // 3 round VDAF, round 2: leader should finish and emit no message.
850        assert_matches!(
851            leader_state,
852            PingPongContinuedValue::FinishedNoMessage { .. }
853        );
854    }
855
856    #[test]
857    fn roundtrip_message() {
858        let messages = [
859            (
860                PingPongMessage::Initialize {
861                    prep_share: Vec::from("prepare share"),
862                },
863                concat!(
864                    "00", // enum discriminant
865                    concat!(
866                        // prep_share
867                        "0000000d",                   // length
868                        "70726570617265207368617265", // contents
869                    ),
870                ),
871            ),
872            (
873                PingPongMessage::Continue {
874                    prep_msg: Vec::from("prepare message"),
875                    prep_share: Vec::from("prepare share"),
876                },
877                concat!(
878                    "01", // enum discriminant
879                    concat!(
880                        // prep_msg
881                        "0000000f",                       // length
882                        "70726570617265206d657373616765", // contents
883                    ),
884                    concat!(
885                        // prep_share
886                        "0000000d",                   // length
887                        "70726570617265207368617265", // contents
888                    ),
889                ),
890            ),
891            (
892                PingPongMessage::Finish {
893                    prep_msg: Vec::from("prepare message"),
894                },
895                concat!(
896                    "02", // enum discriminant
897                    concat!(
898                        // prep_msg
899                        "0000000f",                       // length
900                        "70726570617265206d657373616765", // contents
901                    ),
902                ),
903            ),
904        ];
905
906        for (message, expected_hex) in messages {
907            let mut encoded_val = Vec::new();
908            message.encode(&mut encoded_val).unwrap();
909            let got_hex = hex::encode(&encoded_val);
910            assert_eq!(
911                &got_hex, expected_hex,
912                "Couldn't roundtrip (encoded value differs): {message:?}",
913            );
914            let decoded_val = PingPongMessage::decode(&mut Cursor::new(&encoded_val)).unwrap();
915            assert_eq!(
916                decoded_val, message,
917                "Couldn't roundtrip (decoded value differs): {message:?}"
918            );
919            assert_eq!(
920                encoded_val.len(),
921                message.encoded_len().expect("No encoded length hint"),
922                "Encoded length hint is incorrect: {message:?}"
923            )
924        }
925    }
926
927    #[test]
928    fn roundtrip_transition() {
929        // VDAF implementations have tests for encoding/decoding their respective PrepareShare and
930        // PrepareMessage types, so we test here using the dummy VDAF.
931        let transition = PingPongTransition::<0, 16, dummy::Vdaf> {
932            previous_prepare_state: dummy::PrepareState::default(),
933            current_prepare_message: (),
934        };
935
936        let encoded = transition.get_encoded().unwrap();
937        let hex_encoded = hex::encode(&encoded);
938
939        assert_eq!(
940            hex_encoded,
941            concat!(
942                concat!(
943                    // previous_prepare_state
944                    "00",       // input_share
945                    "00000000", // current_round
946                ),
947                // current_prepare_message (0 length encoding)
948            )
949        );
950
951        let decoded = PingPongTransition::get_decoded_with_param(&(), &encoded).unwrap();
952        assert_eq!(transition, decoded);
953
954        assert_eq!(
955            encoded.len(),
956            transition.encoded_len().expect("No encoded length hint"),
957        );
958    }
959}