use super::packet::{AuthHeader, AuthResponse, AuthTag, Nonce, Packet, Tag, MAGIC_LENGTH};
use crate::Discv5Error;
use enr::{CombinedKey, Enr, NodeId};
use log::debug;
use sha2::{Digest, Sha256};
use std::net::SocketAddr;
use zeroize::Zeroize;
mod crypto;
mod ecdh_ident;
const WHOAREYOU_STRING: &str = "WHOAREYOU";
pub(crate) struct Session {
state: SessionState,
trusted: TrustedState,
remote_enr: Option<Enr<CombinedKey>>,
last_seen_socket: SocketAddr,
}
#[derive(Zeroize, PartialEq)]
pub(crate) struct Keys {
auth_resp_key: [u8; 16],
encryption_key: [u8; 16],
decryption_key: [u8; 16],
}
enum TrustedState {
Trusted,
Untrusted,
}
#[derive(PartialEq)]
pub(crate) enum SessionState {
WhoAreYouSent,
RandomSent,
AwaitingResponse(Keys),
EstablishedAwaitingResponse {
current_keys: Keys,
new_keys: Keys,
},
Established(Keys),
Poisoned,
}
impl Session {
pub(crate) fn new_random(tag: Tag, remote_enr: Enr<CombinedKey>) -> (Self, Packet) {
let random_packet = Packet::random(tag);
let session = Session {
state: SessionState::RandomSent,
trusted: TrustedState::Untrusted,
remote_enr: Some(remote_enr),
last_seen_socket: "0.0.0.0:0".parse::<SocketAddr>().expect("Valid Socket"),
};
(session, random_packet)
}
pub(crate) fn new_whoareyou(
node_id: &NodeId,
enr_seq: u64,
remote_enr: Option<Enr<CombinedKey>>,
auth_tag: AuthTag,
) -> (Self, Packet) {
let whoareyou_packet = {
let magic = {
let mut hasher = Sha256::new();
hasher.input(node_id.raw());
hasher.input(WHOAREYOU_STRING.as_bytes());
let mut magic = [0u8; MAGIC_LENGTH];
magic.copy_from_slice(&hasher.result());
magic
};
let id_nonce: Nonce = rand::random();
Packet::WhoAreYou {
magic,
token: auth_tag,
id_nonce,
enr_seq,
}
};
let session = Session {
state: SessionState::WhoAreYouSent,
trusted: TrustedState::Untrusted,
remote_enr,
last_seen_socket: "0.0.0.0:0".parse::<SocketAddr>().expect("Valid Socket"),
};
(session, whoareyou_packet)
}
pub(crate) fn establish_from_header(
&mut self,
local_key: &CombinedKey,
local_id: &NodeId,
remote_id: &NodeId,
id_nonce: Nonce,
auth_header: &AuthHeader,
) -> Result<bool, Discv5Error> {
let (decryption_key, encryption_key, auth_resp_key) = crypto::derive_keys_from_pubkey(
local_key,
local_id,
remote_id,
&id_nonce,
&auth_header.ephemeral_pubkey,
)?;
let auth_response = crypto::decrypt_authentication_header(&auth_resp_key, auth_header)?;
if let Some(enr) = auth_response.node_record {
if let Some(remote_enr) = &self.remote_enr {
if remote_enr.seq() < enr.seq() {
self.remote_enr = Some(enr);
}
} else {
self.remote_enr = Some(enr);
}
} else if self.remote_enr.is_none() {
return Err(Discv5Error::InvalidEnr);
}
let remote_public_key = self
.remote_enr
.as_ref()
.expect("ENR Must exist")
.public_key();
if !crypto::verify_authentication_nonce(
&remote_public_key,
&auth_header.ephemeral_pubkey,
&id_nonce,
&auth_response.signature,
) {
return Err(Discv5Error::InvalidSignature);
}
let keys = Keys {
auth_resp_key,
encryption_key,
decryption_key,
};
self.state = SessionState::Established(keys);
Ok(self.update_trusted())
}
pub(crate) fn encrypt_with_header(
&mut self,
tag: Tag,
local_key: &CombinedKey,
updated_enr: Option<Enr<CombinedKey>>,
local_node_id: &NodeId,
id_nonce: &Nonce,
message: &[u8],
) -> Result<Packet, Discv5Error> {
let (encryption_key, decryption_key, auth_resp_key, ephem_pubkey) =
crypto::generate_session_keys(
local_node_id,
self.remote_enr
.as_ref()
.expect("Should never be None at this point"),
id_nonce,
)?;
let keys = Keys {
auth_resp_key,
encryption_key,
decryption_key,
};
let sig = crypto::sign_nonce(local_key, id_nonce, &ephem_pubkey)
.map_err(|_| Discv5Error::Custom("Could not sign WHOAREYOU nonce"))?;
let auth_pt = AuthResponse::new(&sig, updated_enr).encode();
let auth_response_ciphertext =
crypto::encrypt_message(&auth_resp_key, [0u8; 12], &auth_pt, &[])?;
let auth_tag: [u8; 12] = rand::random();
let auth_header = AuthHeader::new(
auth_tag,
*id_nonce,
ephem_pubkey.to_vec(),
auth_response_ciphertext,
);
let message_ciphertext =
crypto::encrypt_message(&encryption_key, auth_tag, message, &tag[..])?;
match std::mem::replace(&mut self.state, SessionState::Poisoned) {
SessionState::Established(current_keys) => {
self.state = SessionState::EstablishedAwaitingResponse {
current_keys,
new_keys: keys,
}
}
SessionState::Poisoned => unreachable!("Coding error if this is possible"),
_ => self.state = SessionState::AwaitingResponse(keys),
}
Ok(Packet::AuthMessage {
tag,
auth_header,
message: message_ciphertext,
})
}
pub(crate) fn encrypt_message(&self, tag: Tag, message: &[u8]) -> Result<Packet, Discv5Error> {
let auth_tag: AuthTag = rand::random();
let cipher = match &self.state {
SessionState::Established(keys) => {
crypto::encrypt_message(&keys.encryption_key, auth_tag, message, &tag)?
}
SessionState::EstablishedAwaitingResponse { current_keys, .. } => {
crypto::encrypt_message(¤t_keys.encryption_key, auth_tag, message, &tag)?
}
_ => return Err(Discv5Error::SessionNotEstablished),
};
Ok(Packet::Message {
tag,
auth_tag,
message: cipher,
})
}
pub(crate) fn decrypt_message(
&mut self,
nonce: AuthTag,
message: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Discv5Error> {
let node_id = self.remote_enr.as_ref().expect("ENR must exist").node_id();
match std::mem::replace(&mut self.state, SessionState::Poisoned) {
SessionState::Established(keys) => {
let result = crypto::decrypt_message(&keys.decryption_key, nonce, message, aad);
self.state = SessionState::Established(keys);
result
}
SessionState::EstablishedAwaitingResponse {
current_keys,
new_keys,
} => {
match crypto::decrypt_message(¤t_keys.decryption_key, nonce, message, aad) {
Ok(message) => {
self.state = SessionState::Established(current_keys);
Ok(message)
}
Err(_) => {
debug!("Old session key failed to decrypt message");
match crypto::decrypt_message(&new_keys.decryption_key, nonce, message, aad)
{
Ok(msg) => {
debug!("Session keys have been updated for node: {}", node_id);
self.state = SessionState::Established(new_keys);
Ok(msg)
}
Err(e) => {
self.state = SessionState::EstablishedAwaitingResponse {
current_keys,
new_keys,
};
Err(e)
}
}
}
}
}
SessionState::AwaitingResponse(keys) => {
match crypto::decrypt_message(&keys.decryption_key, nonce, message, aad) {
Ok(message) => {
self.state = SessionState::Established(keys);
Ok(message)
}
Err(e) => {
self.state = SessionState::AwaitingResponse(keys);
Err(e)
}
}
}
SessionState::Poisoned => unreachable!(),
message_sent_state => {
self.state = message_sent_state;
Err(Discv5Error::SessionNotEstablished)
}
}
}
pub(crate) fn update_enr(&mut self, enr: Enr<CombinedKey>) -> bool {
if let Some(remote_enr) = &self.remote_enr {
if remote_enr.seq() < enr.seq() {
self.remote_enr = Some(enr);
return self.update_trusted();
}
}
false
}
pub(crate) fn update_trusted(&mut self) -> bool {
if let TrustedState::Untrusted = self.trusted {
if let Some(remote_enr) = &self.remote_enr {
if Some(self.last_seen_socket) == remote_enr.udp_socket() {
self.trusted = TrustedState::Trusted;
return true;
}
}
} else if let TrustedState::Trusted = self.trusted {
if let Some(remote_enr) = &self.remote_enr {
if Some(self.last_seen_socket) != remote_enr.udp_socket() {
self.trusted = TrustedState::Untrusted;
}
}
}
false
}
pub(crate) fn set_last_seen_socket(&mut self, socket: SocketAddr) {
self.last_seen_socket = socket;
}
pub(crate) fn is_whoareyou_sent(&self) -> bool {
SessionState::WhoAreYouSent == self.state
}
pub(crate) fn is_random_sent(&self) -> bool {
SessionState::RandomSent == self.state
}
pub(crate) fn is_awaiting_response(&self) -> bool {
if let SessionState::AwaitingResponse(_) = self.state {
true
} else {
false
}
}
pub(crate) fn remote_enr(&self) -> &Option<Enr<CombinedKey>> {
&self.remote_enr
}
pub(crate) fn is_trusted(&self) -> bool {
if let TrustedState::Trusted = self.trusted {
true
} else {
false
}
}
pub(crate) fn trusted_established(&self) -> bool {
let established = match &self.state {
SessionState::WhoAreYouSent => false,
SessionState::RandomSent => false,
SessionState::AwaitingResponse(_) => false,
SessionState::Established(_) => true,
SessionState::EstablishedAwaitingResponse { .. } => true,
SessionState::Poisoned => unreachable!(),
};
self.is_trusted() && established
}
}