use zeroize::Zeroizing;
use crate::crypto::{Cipher, TAG_SIZE};
use crate::handshake_pattern::MessageToken::*;
use crate::TransportState;
use crate::{
crypto::{DHKeypair, HashFunction},
handshake_pattern::{HandshakePattern, MessageToken},
symmetric_state::SymmetricState,
};
pub struct HandshakeState<
const DHLEN: usize,
const HASHLEN: usize,
K: DHKeypair<DHLEN>,
H: HashFunction<HASHLEN>,
> {
symmetric_state: SymmetricState<HASHLEN, H>,
s: Option<K>,
e: Option<K>,
rs: Option<[u8; DHLEN]>,
re: Option<[u8; DHLEN]>,
next_psk: Option<Zeroizing<[u8; 32]>>,
pub initiator: bool,
pub turn: u8,
pub pattern: &'static HandshakePattern,
}
pub type ProtocolName = [&'static str; 8];
impl<const DHLEN: usize, const HASHLEN: usize, K: DHKeypair<DHLEN>, H: HashFunction<HASHLEN>>
HandshakeState<DHLEN, HASHLEN, K, H>
{
pub fn protocol_name(
pattern: &HandshakePattern,
cipher: &'static dyn Cipher,
) -> [&'static str; 8] {
[
"Noise_",
pattern.name,
"_",
K::NAME,
"_",
cipher.name(),
"_",
H::NAME,
]
}
#[inline(always)]
pub fn new(
cipher: &'static dyn Cipher,
pattern: &'static HandshakePattern,
initiator: bool,
prologue: &[u8],
s: Option<&K>,
e: Option<&K>,
rs: Option<[u8; DHLEN]>,
re: Option<[u8; DHLEN]>,
) -> Self {
pattern.validate().expect("pattern must be valid");
let mut state = Self {
symmetric_state: SymmetricState::new(&Self::protocol_name(&pattern, cipher), cipher),
s: s.cloned(),
e: e.cloned(),
rs,
re,
next_psk: None,
initiator,
turn: 0,
pattern,
};
state.symmetric_state.mix_hash(prologue);
for (i, pre_msg) in pattern.patterns[..2].iter().enumerate() {
let local = i == !initiator as usize;
for tok in pre_msg.iter() {
let key = match tok {
E => {
if local {
state
.e
.as_ref()
.expect("requires local ephemeral key in premessage")
.public_key()
} else {
state.re.expect("remote emphemeral key")
}
}
S => {
if local {
state
.s
.as_ref()
.expect("requires local static key in premessage")
.public_key()
} else {
state.rs.expect("remote static key")
}
}
_ => unreachable!(),
};
state.symmetric_state.mix_hash(&key);
}
}
state
}
fn dh(&mut self, tok: MessageToken) {
let (ini, resp) = match tok {
EE => (E, E),
ES => (E, S),
SE => (S, E),
SS => (S, S),
_ => unreachable!(),
};
let (local, remote) = if self.initiator {
(ini, resp)
} else {
(resp, ini)
};
let local = match local {
E => self.e.as_ref().expect("dh requires e"),
S => self.s.as_ref().expect("dh requires s"),
_ => unreachable!(),
};
let remote = match remote {
E => self.re.as_ref().expect("dh requires re"),
S => self.rs.as_ref().expect("dh requires rs"),
_ => unreachable!(),
};
let mut tmp = Zeroizing::new([0; DHLEN]);
local.dh(&remote, &mut tmp);
self.symmetric_state.mix_key(&tmp[..]);
}
pub fn is_our_turn(&self) -> bool {
self.turn % 2 == (!self.initiator as u8)
}
pub fn insert_psk(&mut self, psk: &[u8]) {
self.next_psk.insert(Zeroizing::default()).copy_from_slice(psk);
}
pub fn write_msg(&mut self, buf: &mut [u8], payload_size: usize) -> usize {
assert!(self.is_our_turn());
let msg = &self.pattern.patterns[self.turn as usize + 2];
self.turn += 1;
let buf_size = buf.len();
let mut buf_remaining = &mut buf[payload_size..];
for tok in msg.iter() {
match tok {
E => {
if self.e.is_none() {
K::generate_keypair(&mut self.e);
}
let e_pub = self.e.as_ref().unwrap().public_key();
self.symmetric_state.mix_hash(&e_pub);
if self.pattern.has_psk() { self.symmetric_state.mix_key(&e_pub); }
buf_remaining[..e_pub.len()].copy_from_slice(&e_pub);
buf_remaining = &mut buf_remaining[e_pub.len()..];
}
S => {
let s_pub = self.s.as_ref().unwrap().public_key();
buf_remaining[..s_pub.len()].copy_from_slice(&s_pub);
let len = self
.symmetric_state
.encrypt_and_hash(&mut buf_remaining[..s_pub.len() + TAG_SIZE]);
buf_remaining = &mut buf_remaining[len..];
}
EE | ES | SE | SS => self.dh(*tok),
PSK => {
let psk = self.next_psk.take().expect("pattern requires psk");
self.symmetric_state.mix_key_and_hash(psk.as_ref());
},
}
}
let payload_headers_size = buf_size - buf_remaining.len();
let headers_size = payload_headers_size - payload_size;
drop(buf_remaining);
buf[..payload_headers_size].rotate_left(payload_size);
let ciphertext_size = self
.symmetric_state
.encrypt_and_hash(&mut buf[headers_size..][..payload_size + TAG_SIZE]);
headers_size + ciphertext_size
}
pub fn read_msg<'b>(&mut self, mut buf: &'b mut [u8]) -> Option<&'b [u8]> {
assert!(!self.is_our_turn());
let msg = &self.pattern.patterns[self.turn as usize + 2];
let snapshot = self.symmetric_state.snapshot();
let mut fail = false;
for tok in msg.iter() {
match tok {
E => {
assert!(self.re.is_none());
let (re_buf, rest) = buf.split_at_mut(DHLEN);
buf = rest;
self.re = Some(re_buf.try_into().unwrap());
self.symmetric_state.mix_hash(&re_buf);
if self.pattern.has_psk() { self.symmetric_state.mix_key(&re_buf); }
}
S => {
assert!(self.rs.is_none());
let s_len = if self.symmetric_state.has_key() {
DHLEN + TAG_SIZE
} else {
DHLEN
};
let (tmp, rest) = buf.split_at_mut(s_len);
buf = rest;
match self.symmetric_state.decrypt_and_hash(tmp) {
Some(tmp) => {
self.rs = Some([0; DHLEN]);
self.rs.as_mut().unwrap().copy_from_slice(tmp);
}
None => {
fail = true;
break;
}
}
}
EE | ES | SE | SS => self.dh(*tok),
PSK => {
let psk = self.next_psk.take().expect("pattern requires psk");
self.symmetric_state.mix_key_and_hash(psk.as_ref());
},
}
}
if !fail {
if let Some(payload) = self.symmetric_state.decrypt_and_hash(buf) {
self.turn += 1;
return Some(payload);
};
}
if msg.contains(&E) {
self.re = None;
}
if msg.contains(&S) {
self.rs = None;
}
self.symmetric_state.restore(snapshot);
None
}
pub fn is_finished(&self) -> bool {
self.turn as usize + 2 >= self.pattern.patterns.len()
}
#[inline(always)]
pub fn split(self) -> TransportState<HASHLEN> {
assert!(self.is_finished());
let (ini, res, hash) = self.symmetric_state.split();
let (send, recv) = match self.initiator {
true => (ini, res),
false => (res, ini),
};
TransportState { send, recv, hash }
}
}