Skip to main content

pakery_spake2plus/
prover.rs

1//! SPAKE2+ Prover (client) state machine.
2//!
3//! The Prover knows the password and derives `(w0, w1)` from it.
4
5use alloc::vec::Vec;
6use rand_core::CryptoRng;
7use subtle::ConstantTimeEq;
8use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
9
10use pakery_core::crypto::CpaceGroup;
11use pakery_core::SharedSecret;
12
13use crate::ciphersuite::Spake2PlusCiphersuite;
14use crate::encoding::build_transcript;
15use crate::error::Spake2PlusError;
16use crate::transcript::derive_key_schedule;
17
18/// State held by the Prover between sending shareP and receiving (shareV, confirmV).
19#[derive(Zeroize, ZeroizeOnDrop)]
20pub struct ProverState<C: Spake2PlusCiphersuite> {
21    x: <C::Group as CpaceGroup>::Scalar,
22    w0: <C::Group as CpaceGroup>::Scalar,
23    w1: <C::Group as CpaceGroup>::Scalar,
24    share_p_bytes: Vec<u8>,
25    context: Vec<u8>,
26    id_prover: Vec<u8>,
27    id_verifier: Vec<u8>,
28    #[zeroize(skip)]
29    _marker: core::marker::PhantomData<C>,
30}
31
32/// Output returned by the Prover after verifying confirmV.
33#[derive(Zeroize, ZeroizeOnDrop)]
34pub struct ProverOutput {
35    /// The shared session key.
36    #[zeroize(skip)]
37    pub session_key: SharedSecret,
38    /// The Prover's confirmation MAC to send to the Verifier.
39    pub confirm_p: Vec<u8>,
40}
41
42impl ProverOutput {
43    /// Consume the output and yield the session key.
44    ///
45    /// Because [`ProverOutput`] derives `ZeroizeOnDrop`, it cannot be
46    /// pattern-destructured by the caller. This consumer extracts the
47    /// session key cleanly without the boilerplate `mem::replace` shim
48    /// users would otherwise have to write themselves.
49    ///
50    /// Both fields remain `pub`, so to keep `confirm_p` while consuming
51    /// `session_key`, clone it first:
52    ///
53    /// ```ignore
54    /// let confirm_p = output.confirm_p.clone();
55    /// let session_key = output.into_session_key();
56    /// ```
57    #[must_use]
58    pub fn into_session_key(mut self) -> SharedSecret {
59        core::mem::replace(&mut self.session_key, SharedSecret::new(Vec::new()))
60    }
61
62    /// Consume the output and yield the `confirmP` MAC.
63    ///
64    /// Mirror of [`Self::into_session_key`]. To also keep `session_key`,
65    /// clone it first:
66    ///
67    /// ```ignore
68    /// let session_key = output.session_key.clone();
69    /// let confirm_p = output.into_confirm_p();
70    /// ```
71    #[must_use]
72    pub fn into_confirm_p(mut self) -> Vec<u8> {
73        core::mem::take(&mut self.confirm_p)
74    }
75}
76
77/// SPAKE2+ Prover: generates the first message and processes the Verifier's response.
78pub struct Prover<C: Spake2PlusCiphersuite>(core::marker::PhantomData<C>);
79
80impl<C: Spake2PlusCiphersuite> Prover<C> {
81    /// Start the SPAKE2+ protocol as the Prover.
82    ///
83    /// `w0` and `w1` are the password-derived scalars. The caller is responsible
84    /// for password stretching.
85    ///
86    /// Returns `(shareP_bytes, state)` where `shareP_bytes` is sent to the Verifier.
87    pub fn start(
88        w0: &<C::Group as CpaceGroup>::Scalar,
89        w1: &<C::Group as CpaceGroup>::Scalar,
90        context: &[u8],
91        id_prover: &[u8],
92        id_verifier: &[u8],
93        rng: &mut impl CryptoRng,
94    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
95        let x = C::Group::random_scalar(rng);
96        Self::start_inner(w0.clone(), w1.clone(), x, context, id_prover, id_verifier)
97    }
98
99    /// Start with a deterministic scalar (for testing).
100    ///
101    /// # Security
102    ///
103    /// Using a non-random scalar completely breaks security.
104    /// This method is gated behind the `test-utils` feature and must
105    /// only be used for RFC test vector validation.
106    #[cfg(feature = "test-utils")]
107    pub fn start_with_scalar(
108        w0: &<C::Group as CpaceGroup>::Scalar,
109        w1: &<C::Group as CpaceGroup>::Scalar,
110        x: &<C::Group as CpaceGroup>::Scalar,
111        context: &[u8],
112        id_prover: &[u8],
113        id_verifier: &[u8],
114    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
115        Self::start_inner(
116            w0.clone(),
117            w1.clone(),
118            x.clone(),
119            context,
120            id_prover,
121            id_verifier,
122        )
123    }
124
125    fn start_inner(
126        w0: <C::Group as CpaceGroup>::Scalar,
127        w1: <C::Group as CpaceGroup>::Scalar,
128        x: <C::Group as CpaceGroup>::Scalar,
129        context: &[u8],
130        id_prover: &[u8],
131        id_verifier: &[u8],
132    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
133        // Decode M from ciphersuite constants
134        let m = C::Group::from_bytes(C::M_BYTES)?;
135
136        // shareP = x*G + w0*M
137        let x_g = C::Group::basepoint_mul(&x);
138        let w0_m = m.scalar_mul(&w0);
139        let share_p = x_g.add(&w0_m);
140
141        let share_p_bytes = share_p.to_bytes();
142
143        let state = ProverState {
144            x,
145            w0,
146            w1,
147            share_p_bytes: share_p_bytes.clone(),
148            context: context.to_vec(),
149            id_prover: id_prover.to_vec(),
150            id_verifier: id_verifier.to_vec(),
151            _marker: core::marker::PhantomData,
152        };
153
154        Ok((share_p_bytes, state))
155    }
156}
157
158impl<C: Spake2PlusCiphersuite> ProverState<C> {
159    /// Finish the SPAKE2+ protocol by processing the Verifier's response.
160    ///
161    /// The Prover receives `(shareV_bytes, confirm_v)` from the Verifier,
162    /// verifies `confirm_v`, and returns `ProverOutput` containing the session
163    /// key and `confirm_p` to send back.
164    pub fn finish(
165        self,
166        share_v_bytes: &[u8],
167        confirm_v: &[u8],
168    ) -> Result<ProverOutput, Spake2PlusError> {
169        // Decode shareV and reject identity (defense-in-depth)
170        let share_v = C::Group::from_bytes(share_v_bytes)?;
171        if share_v.is_identity() {
172            return Err(Spake2PlusError::IdentityPoint);
173        }
174
175        // Decode N from ciphersuite constants
176        let n = C::Group::from_bytes(C::N_BYTES)?;
177
178        // tmp = shareV - w0*N (= y*G)
179        let w0_n = n.scalar_mul(&self.w0);
180        let tmp = share_v.add(&w0_n.negate());
181
182        // Z = x * tmp (= x*y*G, since cofactor h=1 for ristretto255)
183        let z = tmp.scalar_mul(&self.x);
184
185        // V = w1 * tmp (= w1*y*G)
186        let v = tmp.scalar_mul(&self.w1);
187
188        // Check Z != identity, V != identity
189        if z.is_identity() {
190            return Err(Spake2PlusError::IdentityPoint);
191        }
192        if v.is_identity() {
193            return Err(Spake2PlusError::IdentityPoint);
194        }
195
196        let z_bytes = Zeroizing::new(z.to_bytes());
197        let v_bytes = Zeroizing::new(v.to_bytes());
198        let w0_bytes = Zeroizing::new(C::Group::scalar_to_bytes(&self.w0));
199
200        // Decode M and N to get canonical group element encoding for transcript.
201        // This ensures M/N use the same encoding as other group elements (e.g.
202        // uncompressed SEC1 for P-256), regardless of how they are stored in the
203        // ciphersuite constants.
204        let m = C::Group::from_bytes(C::M_BYTES)?;
205        let n_point = C::Group::from_bytes(C::N_BYTES)?;
206        let m_bytes = m.to_bytes();
207        let n_bytes = n_point.to_bytes();
208
209        // Build transcript TT (10 fields)
210        let tt = build_transcript(
211            &self.context,
212            &self.id_prover,
213            &self.id_verifier,
214            &m_bytes,
215            &n_bytes,
216            &self.share_p_bytes,
217            share_v_bytes,
218            &z_bytes,
219            &v_bytes,
220            &w0_bytes,
221        );
222
223        // Derive key schedule
224        let mut ks = derive_key_schedule::<C>(&tt, &self.share_p_bytes, share_v_bytes)?;
225
226        // Verify confirmV: MAC(K_confirmV, shareP)
227        if !bool::from(ks.confirm_v.ct_eq(confirm_v)) {
228            return Err(Spake2PlusError::ConfirmationFailed);
229        }
230
231        Ok(ProverOutput {
232            session_key: core::mem::replace(&mut ks.session_key, SharedSecret::new(Vec::new())),
233            confirm_p: core::mem::take(&mut ks.confirm_p),
234        })
235    }
236}