use alloc::vec::Vec;
use rand_core::CryptoRng;
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, Zeroizing};
use pakery_core::crypto::CpaceGroup;
use pakery_core::SharedSecret;
use crate::ciphersuite::Spake2PlusCiphersuite;
use crate::encoding::build_transcript;
use crate::error::Spake2PlusError;
use crate::transcript::{derive_key_schedule, Spake2PlusOutput};
pub struct VerifierState {
expected_confirm_p: Vec<u8>,
session_key: SharedSecret,
}
impl Drop for VerifierState {
fn drop(&mut self) {
self.expected_confirm_p.zeroize();
}
}
impl VerifierState {
pub fn finish(mut self, confirm_p: &[u8]) -> Result<Spake2PlusOutput, Spake2PlusError> {
if !bool::from(self.expected_confirm_p.ct_eq(confirm_p)) {
return Err(Spake2PlusError::ConfirmationFailed);
}
let session_key =
core::mem::replace(&mut self.session_key, SharedSecret::new(alloc::vec![]));
Ok(Spake2PlusOutput { session_key })
}
}
pub struct Verifier<C: Spake2PlusCiphersuite>(core::marker::PhantomData<C>);
impl<C: Spake2PlusCiphersuite> Verifier<C> {
pub fn start(
share_p_bytes: &[u8],
w0: &<C::Group as CpaceGroup>::Scalar,
l_bytes: &[u8],
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
rng: &mut impl CryptoRng,
) -> Result<(Vec<u8>, Vec<u8>, VerifierState), Spake2PlusError> {
let y = C::Group::random_scalar(rng);
Self::start_inner(
share_p_bytes,
w0,
l_bytes,
&y,
context,
id_prover,
id_verifier,
)
}
#[cfg(feature = "test-utils")]
pub fn start_with_scalar(
share_p_bytes: &[u8],
w0: &<C::Group as CpaceGroup>::Scalar,
l_bytes: &[u8],
y: &<C::Group as CpaceGroup>::Scalar,
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
) -> Result<(Vec<u8>, Vec<u8>, VerifierState), Spake2PlusError> {
Self::start_inner(
share_p_bytes,
w0,
l_bytes,
y,
context,
id_prover,
id_verifier,
)
}
fn start_inner(
share_p_bytes: &[u8],
w0: &<C::Group as CpaceGroup>::Scalar,
l_bytes: &[u8],
y: &<C::Group as CpaceGroup>::Scalar,
context: &[u8],
id_prover: &[u8],
id_verifier: &[u8],
) -> Result<(Vec<u8>, Vec<u8>, VerifierState), Spake2PlusError> {
let share_p = C::Group::from_bytes(share_p_bytes)?;
if share_p.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
let m = C::Group::from_bytes(C::M_BYTES)?;
let l = C::Group::from_bytes(l_bytes)?;
let n = C::Group::from_bytes(C::N_BYTES)?;
let y_g = C::Group::basepoint_mul(y);
let w0_n = n.scalar_mul(w0);
let share_v = y_g.add(&w0_n);
let share_v_bytes = share_v.to_bytes();
let w0_m = m.scalar_mul(w0);
let tmp = share_p.add(&w0_m.negate());
let z = tmp.scalar_mul(y);
let v = l.scalar_mul(y);
if z.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
if v.is_identity() {
return Err(Spake2PlusError::IdentityPoint);
}
let z_bytes = Zeroizing::new(z.to_bytes());
let v_bytes = Zeroizing::new(v.to_bytes());
let w0_bytes = Zeroizing::new(C::Group::scalar_to_bytes(w0));
let m_bytes = m.to_bytes();
let n_bytes = n.to_bytes();
let tt = build_transcript(
context,
id_prover,
id_verifier,
&m_bytes,
&n_bytes,
share_p_bytes,
&share_v_bytes,
&z_bytes,
&v_bytes,
&w0_bytes,
);
let mut ks = derive_key_schedule::<C>(&tt, share_p_bytes, &share_v_bytes)?;
let state = VerifierState {
expected_confirm_p: core::mem::take(&mut ks.confirm_p),
session_key: core::mem::replace(&mut ks.session_key, SharedSecret::new(Vec::new())),
};
Ok((share_v_bytes, core::mem::take(&mut ks.confirm_v), state))
}
}