Skip to main content

pakery_spake2plus/
prover.rs

1//! SPAKE2+ Prover (client) state machine.
2//!
3//! The Prover knows the password and derives `(w0, w1)` from it.
4
5use alloc::vec::Vec;
6use rand_core::CryptoRngCore;
7use subtle::ConstantTimeEq;
8use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
9
10use pakery_core::crypto::CpaceGroup;
11use pakery_core::SharedSecret;
12
13use crate::ciphersuite::Spake2PlusCiphersuite;
14use crate::encoding::build_transcript;
15use crate::error::Spake2PlusError;
16use crate::transcript::derive_key_schedule;
17
18/// State held by the Prover between sending shareP and receiving (shareV, confirmV).
19#[derive(Zeroize, ZeroizeOnDrop)]
20pub struct ProverState<C: Spake2PlusCiphersuite> {
21    x: <C::Group as CpaceGroup>::Scalar,
22    w0: <C::Group as CpaceGroup>::Scalar,
23    w1: <C::Group as CpaceGroup>::Scalar,
24    share_p_bytes: Vec<u8>,
25    context: Vec<u8>,
26    id_prover: Vec<u8>,
27    id_verifier: Vec<u8>,
28    #[zeroize(skip)]
29    _marker: core::marker::PhantomData<C>,
30}
31
32/// Output returned by the Prover after verifying confirmV.
33#[derive(Zeroize, ZeroizeOnDrop)]
34pub struct ProverOutput {
35    /// The shared session key.
36    #[zeroize(skip)]
37    pub session_key: SharedSecret,
38    /// The Prover's confirmation MAC to send to the Verifier.
39    pub confirm_p: Vec<u8>,
40}
41
42/// SPAKE2+ Prover: generates the first message and processes the Verifier's response.
43pub struct Prover<C: Spake2PlusCiphersuite>(core::marker::PhantomData<C>);
44
45impl<C: Spake2PlusCiphersuite> Prover<C> {
46    /// Start the SPAKE2+ protocol as the Prover.
47    ///
48    /// `w0` and `w1` are the password-derived scalars. The caller is responsible
49    /// for password stretching.
50    ///
51    /// Returns `(shareP_bytes, state)` where `shareP_bytes` is sent to the Verifier.
52    pub fn start(
53        w0: &<C::Group as CpaceGroup>::Scalar,
54        w1: &<C::Group as CpaceGroup>::Scalar,
55        context: &[u8],
56        id_prover: &[u8],
57        id_verifier: &[u8],
58        rng: &mut impl CryptoRngCore,
59    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
60        let x = C::Group::random_scalar(rng);
61        Self::start_inner(w0.clone(), w1.clone(), x, context, id_prover, id_verifier)
62    }
63
64    /// Start with a deterministic scalar (for testing).
65    ///
66    /// # Security
67    ///
68    /// Using a non-random scalar completely breaks security.
69    /// This method is gated behind the `test-utils` feature and must
70    /// only be used for RFC test vector validation.
71    #[cfg(feature = "test-utils")]
72    pub fn start_with_scalar(
73        w0: &<C::Group as CpaceGroup>::Scalar,
74        w1: &<C::Group as CpaceGroup>::Scalar,
75        x: &<C::Group as CpaceGroup>::Scalar,
76        context: &[u8],
77        id_prover: &[u8],
78        id_verifier: &[u8],
79    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
80        Self::start_inner(
81            w0.clone(),
82            w1.clone(),
83            x.clone(),
84            context,
85            id_prover,
86            id_verifier,
87        )
88    }
89
90    fn start_inner(
91        w0: <C::Group as CpaceGroup>::Scalar,
92        w1: <C::Group as CpaceGroup>::Scalar,
93        x: <C::Group as CpaceGroup>::Scalar,
94        context: &[u8],
95        id_prover: &[u8],
96        id_verifier: &[u8],
97    ) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
98        // Decode M from ciphersuite constants
99        let m = C::Group::from_bytes(C::M_BYTES)?;
100
101        // shareP = x*G + w0*M
102        let x_g = C::Group::basepoint_mul(&x);
103        let w0_m = m.scalar_mul(&w0);
104        let share_p = x_g.add(&w0_m);
105
106        let share_p_bytes = share_p.to_bytes();
107
108        let state = ProverState {
109            x,
110            w0,
111            w1,
112            share_p_bytes: share_p_bytes.clone(),
113            context: context.to_vec(),
114            id_prover: id_prover.to_vec(),
115            id_verifier: id_verifier.to_vec(),
116            _marker: core::marker::PhantomData,
117        };
118
119        Ok((share_p_bytes, state))
120    }
121}
122
123impl<C: Spake2PlusCiphersuite> ProverState<C> {
124    /// Finish the SPAKE2+ protocol by processing the Verifier's response.
125    ///
126    /// The Prover receives `(shareV_bytes, confirm_v)` from the Verifier,
127    /// verifies `confirm_v`, and returns `ProverOutput` containing the session
128    /// key and `confirm_p` to send back.
129    pub fn finish(
130        self,
131        share_v_bytes: &[u8],
132        confirm_v: &[u8],
133    ) -> Result<ProverOutput, Spake2PlusError> {
134        // Decode shareV and reject identity (defense-in-depth)
135        let share_v = C::Group::from_bytes(share_v_bytes)?;
136        if share_v.is_identity() {
137            return Err(Spake2PlusError::IdentityPoint);
138        }
139
140        // Decode N from ciphersuite constants
141        let n = C::Group::from_bytes(C::N_BYTES)?;
142
143        // tmp = shareV - w0*N (= y*G)
144        let w0_n = n.scalar_mul(&self.w0);
145        let tmp = share_v.add(&w0_n.negate());
146
147        // Z = x * tmp (= x*y*G, since cofactor h=1 for ristretto255)
148        let z = tmp.scalar_mul(&self.x);
149
150        // V = w1 * tmp (= w1*y*G)
151        let v = tmp.scalar_mul(&self.w1);
152
153        // Check Z != identity, V != identity
154        if z.is_identity() {
155            return Err(Spake2PlusError::IdentityPoint);
156        }
157        if v.is_identity() {
158            return Err(Spake2PlusError::IdentityPoint);
159        }
160
161        let z_bytes = Zeroizing::new(z.to_bytes());
162        let v_bytes = Zeroizing::new(v.to_bytes());
163        let w0_bytes = Zeroizing::new(C::Group::scalar_to_bytes(&self.w0));
164
165        // Decode M and N to get canonical group element encoding for transcript.
166        // This ensures M/N use the same encoding as other group elements (e.g.
167        // uncompressed SEC1 for P-256), regardless of how they are stored in the
168        // ciphersuite constants.
169        let m = C::Group::from_bytes(C::M_BYTES)?;
170        let n_point = C::Group::from_bytes(C::N_BYTES)?;
171        let m_bytes = m.to_bytes();
172        let n_bytes = n_point.to_bytes();
173
174        // Build transcript TT (10 fields)
175        let tt = build_transcript(
176            &self.context,
177            &self.id_prover,
178            &self.id_verifier,
179            &m_bytes,
180            &n_bytes,
181            &self.share_p_bytes,
182            share_v_bytes,
183            &z_bytes,
184            &v_bytes,
185            &w0_bytes,
186        );
187
188        // Derive key schedule
189        let mut ks = derive_key_schedule::<C>(&tt, &self.share_p_bytes, share_v_bytes)?;
190
191        // Verify confirmV: MAC(K_confirmV, shareP)
192        if !bool::from(ks.confirm_v.ct_eq(confirm_v)) {
193            return Err(Spake2PlusError::ConfirmationFailed);
194        }
195
196        Ok(ProverOutput {
197            session_key: core::mem::replace(&mut ks.session_key, SharedSecret::new(Vec::new())),
198            confirm_p: core::mem::take(&mut ks.confirm_p),
199        })
200    }
201}