use crate::registration::handshake::error::HandshakeError;
use crate::registration::handshake::messages::{
HandshakeMessage, Initialisation, MaterialExchange,
};
use crate::registration::handshake::{
HandshakeResult, SharedSymmetricKey, WsItem, KDF_SALT_LENGTH,
};
use crate::shared_key::SharedKeySize;
use crate::{types, GatewayProtocolVersion};
use futures::{Sink, SinkExt, Stream, StreamExt};
use nym_crypto::asymmetric::{ed25519, x25519};
use nym_crypto::symmetric::aead::random_nonce;
use nym_crypto::{generic_array::typenum::Unsigned, hkdf};
use nym_sphinx::params::{GatewayEncryptionAlgorithm, GatewaySharedKeyHkdfAlgorithm};
use rand::{thread_rng, CryptoRng, RngCore};
use std::any::{type_name, Any};
use std::str::FromStr;
use std::time::Duration;
use tracing::log::*;
use tungstenite::Message as WsMessage;
#[cfg(not(target_arch = "wasm32"))]
use nym_task::ShutdownToken;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::timeout;
#[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::timeout;
pub(crate) struct State<'a, S, R> {
ws_stream: &'a mut S,
rng: &'a mut R,
identity: &'a ed25519::KeyPair,
ephemeral_keypair: x25519::KeyPair,
derived_shared_keys: Option<SharedSymmetricKey>,
remote_pubkey: Option<ed25519::PublicKey>,
protocol_version: GatewayProtocolVersion,
#[cfg(not(target_arch = "wasm32"))]
shutdown_token: ShutdownToken,
}
impl<'a, S, R> State<'a, S, R> {
pub(crate) fn new(
rng: &'a mut R,
ws_stream: &'a mut S,
identity: &'a ed25519::KeyPair,
remote_pubkey: Option<ed25519::PublicKey>,
protocol_version: GatewayProtocolVersion,
#[cfg(not(target_arch = "wasm32"))] shutdown_token: ShutdownToken,
) -> Self
where
R: CryptoRng + RngCore,
{
let ephemeral_keypair = x25519::KeyPair::new(rng);
State {
ws_stream,
rng,
ephemeral_keypair,
identity,
remote_pubkey,
protocol_version,
derived_shared_keys: None,
#[cfg(not(target_arch = "wasm32"))]
shutdown_token,
}
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn local_ephemeral_key(&self) -> &x25519::PublicKey {
self.ephemeral_keypair.public_key()
}
pub(crate) fn proposed_protocol_version(&self) -> GatewayProtocolVersion {
self.protocol_version
}
pub(crate) fn set_protocol_version(&mut self, protocol_version: GatewayProtocolVersion) {
self.protocol_version = protocol_version;
}
pub(crate) fn generate_initiator_salt(&mut self) -> Vec<u8>
where
R: CryptoRng + RngCore,
{
let mut salt = vec![0u8; KDF_SALT_LENGTH];
self.rng.fill_bytes(&mut salt);
salt
}
pub(crate) fn init_message(&self, initiator_salt: Vec<u8>) -> Initialisation {
Initialisation {
identity: *self.identity.public_key(),
ephemeral_dh: *self.ephemeral_keypair.public_key(),
initiator_salt,
}
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn finalization_message(
&self,
) -> crate::registration::handshake::messages::Finalization {
crate::registration::handshake::messages::Finalization { success: true }
}
pub(crate) fn derive_shared_key(
&mut self,
remote_ephemeral_key: &x25519::PublicKey,
initiator_salt: &[u8],
) {
let dh_result = self
.ephemeral_keypair
.private_key()
.diffie_hellman(remote_ephemeral_key);
let key_size = SharedKeySize::to_usize();
#[allow(clippy::expect_used)]
let okm = hkdf::extract_then_expand::<GatewaySharedKeyHkdfAlgorithm>(
Some(initiator_salt),
&dh_result,
None,
key_size,
)
.expect("somehow too long okm was provided");
#[allow(clippy::expect_used)]
let shared_key = SharedSymmetricKey::try_from_bytes(&okm)
.expect("okm was expanded to incorrect length!");
self.derived_shared_keys = Some(shared_key)
}
pub(crate) fn prepare_key_material_sig(
&self,
remote_ephemeral_key: &x25519::PublicKey,
) -> Result<MaterialExchange, HandshakeError> {
let plaintext: Vec<_> = self
.ephemeral_keypair
.public_key()
.to_bytes()
.into_iter()
.chain(remote_ephemeral_key.to_bytes())
.collect();
let signature = self.identity.private_key().sign(plaintext);
let mut rng = thread_rng();
let nonce = random_nonce::<GatewayEncryptionAlgorithm, _>(&mut rng);
#[allow(clippy::expect_used)]
let signature_ciphertext = self
.derived_shared_keys
.as_ref()
.expect("shared key was not derived!")
.encrypt(&signature.to_bytes(), &nonce)?;
Ok(MaterialExchange {
signature_ciphertext,
nonce,
})
}
pub(crate) fn verify_remote_key_material(
&self,
remote_response: &MaterialExchange,
remote_ephemeral_key: &x25519::PublicKey,
) -> Result<(), HandshakeError> {
#[allow(clippy::expect_used)]
let derived_shared_key = self
.derived_shared_keys
.as_ref()
.expect("shared key was not derived!");
let decrypted_signature = derived_shared_key.decrypt(
&remote_response.signature_ciphertext,
&remote_response.nonce,
)?;
let signature = ed25519::Signature::from_bytes(&decrypted_signature)
.map_err(|_| HandshakeError::InvalidSignature)?;
let signed_payload: Vec<_> = remote_ephemeral_key
.to_bytes()
.into_iter()
.chain(self.ephemeral_keypair.public_key().to_bytes())
.collect();
#[allow(clippy::unwrap_used)]
self.remote_pubkey
.as_ref()
.unwrap()
.verify(signed_payload, &signature)
.map_err(|_| HandshakeError::InvalidSignature)
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn update_remote_identity(&mut self, remote_pubkey: ed25519::PublicKey) {
self.remote_pubkey = Some(remote_pubkey)
}
#[allow(clippy::complexity)]
fn on_wg_msg(
msg: Option<WsItem>,
) -> Result<Option<(Vec<u8>, GatewayProtocolVersion)>, HandshakeError> {
let Some(msg) = msg else {
return Err(HandshakeError::ClosedStream);
};
let Ok(msg) = msg else {
return Err(HandshakeError::NetworkError);
};
match msg {
WsMessage::Text(ref ws_msg) => {
match types::RegistrationHandshake::from_str(ws_msg) {
Ok(reg_handshake_msg) => {
match reg_handshake_msg {
types::RegistrationHandshake::HandshakePayload {
protocol_version,
data,
} => Ok(Some((data, protocol_version))),
types::RegistrationHandshake::HandshakeError { message } => {
Err(HandshakeError::RemoteError(message))
}
}
}
Err(_) => {
error!("Received a non-handshake message during the registration handshake! It's getting dropped. The received content was: '{msg}'");
Ok(None)
}
}
}
_ => {
error!("Received non-text message during registration handshake");
Ok(None)
}
}
}
#[cfg(not(target_arch = "wasm32"))]
async fn _receive_handshake_message_bytes(
&mut self,
) -> Result<(Vec<u8>, GatewayProtocolVersion), HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
{
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => return Err(HandshakeError::ReceivedShutdown),
msg = self.ws_stream.next() => {
let Some(ret) = Self::on_wg_msg(msg)? else {
continue;
};
return Ok(ret);
}
}
}
}
#[cfg(target_arch = "wasm32")]
async fn _receive_handshake_message_bytes(
&mut self,
) -> Result<(Vec<u8>, GatewayProtocolVersion), HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
{
loop {
let msg = self.ws_stream.next().await;
let Some(ret) = Self::on_wg_msg(msg)? else {
continue;
};
return Ok(ret);
}
}
pub(crate) async fn receive_handshake_message<M>(
&mut self,
) -> Result<(M, GatewayProtocolVersion), HandshakeError>
where
S: Stream<Item = WsItem> + Unpin,
M: HandshakeMessage,
{
let (bytes, protocol) = timeout(
Duration::from_secs(5),
self._receive_handshake_message_bytes(),
)
.await
.map_err(|_| HandshakeError::Timeout)??;
M::try_from_bytes(&bytes).map(|msg| (msg, protocol))
}
pub(crate) async fn send_handshake_error<M: Into<String>>(
&mut self,
message: M,
) -> Result<(), HandshakeError>
where
S: Sink<WsMessage> + Unpin,
{
let handshake_message = types::RegistrationHandshake::new_error(message);
self.ws_stream
.send(WsMessage::Text(handshake_message.into()))
.await
.map_err(|_| HandshakeError::ClosedStream)
}
pub(crate) async fn send_handshake_data<M>(
&mut self,
inner_message: M,
) -> Result<(), HandshakeError>
where
S: Sink<WsMessage> + Unpin,
M: HandshakeMessage + Any,
{
trace!("sending handshake message: {}", type_name::<M>());
let handshake_message = types::RegistrationHandshake::new_payload(
inner_message.into_bytes(),
self.protocol_version,
);
self.ws_stream
.send(WsMessage::Text(handshake_message.into()))
.await
.map_err(|_| HandshakeError::ClosedStream)
}
pub(crate) fn finalize_handshake(self) -> HandshakeResult {
#[allow(clippy::unwrap_used)]
HandshakeResult {
negotiated_protocol: self.proposed_protocol_version(),
derived_key: self.derived_shared_keys.unwrap(),
}
}
pub(crate) async fn check_for_handshake_processing_error<T>(
&mut self,
result: Result<T, HandshakeError>,
) -> Result<T, HandshakeError>
where
S: Sink<WsMessage> + Unpin,
{
match result {
Ok(ok) => Ok(ok),
Err(err) => {
self.send_handshake_error(err.to_string()).await?;
Err(err)
}
}
}
}