use alloc::vec::Vec;
use rand_core::CryptoRng;
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use pakery_core::crypto::CpaceGroup;
use pakery_core::SharedSecret;
use crate::ciphersuite::Spake2PlusCiphersuite;
use crate::encoding::build_transcript;
use crate::error::Spake2PlusError;
use crate::transcript::derive_key_schedule;
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct ProverState<C: Spake2PlusCiphersuite> {
x: <C::Group as CpaceGroup>::Scalar,
w0: <C::Group as CpaceGroup>::Scalar,
w1: <C::Group as CpaceGroup>::Scalar,
share_p_bytes: Vec<u8>,
context: Vec<u8>,
id_prover: Vec<u8>,
id_verifier: Vec<u8>,
#[zeroize(skip)]
_marker: core::marker::PhantomData<C>,
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct ProverOutput {
#[zeroize(skip)]
pub session_key: SharedSecret,
pub confirm_p: Vec<u8>,
}
impl ProverOutput {
#[must_use]
pub fn into_session_key(mut self) -> SharedSecret {
core::mem::replace(&mut self.session_key, SharedSecret::new(Vec::new()))
}
#[must_use]
pub fn into_confirm_p(mut self) -> Vec<u8> {
core::mem::take(&mut self.confirm_p)
}
}
pub struct Prover<C: Spake2PlusCiphersuite>(core::marker::PhantomData<C>);
impl<C: Spake2PlusCiphersuite> Prover<C> {
pub fn start(
w0: &<C::Group as CpaceGroup>::Scalar,
w1: &<C::Group as CpaceGroup>::Scalar,
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
rng: &mut impl CryptoRng,
) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
let x = C::Group::random_scalar(rng);
Self::start_inner(w0.clone(), w1.clone(), x, context, id_prover, id_verifier)
}
#[cfg(feature = "test-utils")]
pub fn start_with_scalar(
w0: &<C::Group as CpaceGroup>::Scalar,
w1: &<C::Group as CpaceGroup>::Scalar,
x: &<C::Group as CpaceGroup>::Scalar,
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
Self::start_inner(
w0.clone(),
w1.clone(),
x.clone(),
context,
id_prover,
id_verifier,
)
}
fn start_inner(
w0: <C::Group as CpaceGroup>::Scalar,
w1: <C::Group as CpaceGroup>::Scalar,
x: <C::Group as CpaceGroup>::Scalar,
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
) -> Result<(Vec<u8>, ProverState<C>), Spake2PlusError> {
let m = C::Group::from_bytes(C::M_BYTES)?;
let x_g = C::Group::basepoint_mul(&x);
let w0_m = m.scalar_mul(&w0);
let share_p = x_g.add(&w0_m);
let share_p_bytes = share_p.to_bytes();
let state = ProverState {
x,
w0,
w1,
share_p_bytes: share_p_bytes.clone(),
context: context.to_vec(),
id_prover: id_prover.to_vec(),
id_verifier: id_verifier.to_vec(),
_marker: core::marker::PhantomData,
};
Ok((share_p_bytes, state))
}
}
impl<C: Spake2PlusCiphersuite> ProverState<C> {
pub fn finish(
self,
share_v_bytes: &[u8],
confirm_v: &[u8],
) -> Result<ProverOutput, Spake2PlusError> {
let share_v = C::Group::from_bytes(share_v_bytes)?;
if share_v.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
let n = C::Group::from_bytes(C::N_BYTES)?;
let w0_n = n.scalar_mul(&self.w0);
let tmp = share_v.add(&w0_n.negate());
let z = tmp.scalar_mul(&self.x);
let v = tmp.scalar_mul(&self.w1);
if z.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
if v.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
let z_bytes = Zeroizing::new(z.to_bytes());
let v_bytes = Zeroizing::new(v.to_bytes());
let w0_bytes = Zeroizing::new(C::Group::scalar_to_bytes(&self.w0));
let m = C::Group::from_bytes(C::M_BYTES)?;
let n_point = C::Group::from_bytes(C::N_BYTES)?;
let m_bytes = m.to_bytes();
let n_bytes = n_point.to_bytes();
let tt = build_transcript(
&self.context,
&self.id_prover,
&self.id_verifier,
&m_bytes,
&n_bytes,
&self.share_p_bytes,
share_v_bytes,
&z_bytes,
&v_bytes,
&w0_bytes,
);
let mut ks = derive_key_schedule::<C>(&tt, &self.share_p_bytes, share_v_bytes)?;
if !bool::from(ks.confirm_v.ct_eq(confirm_v)) {
return Err(Spake2PlusError::ConfirmationFailed);
}
Ok(ProverOutput {
session_key: core::mem::replace(&mut ks.session_key, SharedSecret::new(Vec::new())),
confirm_p: core::mem::take(&mut ks.confirm_p),
})
}
}