use super::*;
use crate::{
common::{
x25519_elligator2::{PublicRepresentative, REPRESENTATIVE_LENGTH},
HmacSha256,
},
framing::{build_and_marshall, ClientHandshakeMessage, MessageTypes, ServerHandshakeMessage},
};
use ptrs::{debug, trace};
use rand::thread_rng;
use tokio_util::codec::Encoder;
use std::time::Instant;
#[derive(Clone)]
pub(crate) struct HandshakeMaterials {
pub(crate) identity_keys: Obfs4NtorSecretKey,
pub(crate) session_id: String,
pub(crate) len_seed: [u8; SEED_LENGTH],
}
impl<'a> HandshakeMaterials {
pub fn get_hmac(&self) -> HmacSha256 {
let mut key = self.identity_keys.pk.pk.as_bytes().to_vec();
key.append(&mut self.identity_keys.pk.id.as_bytes().to_vec());
HmacSha256::new_from_slice(&key[..]).unwrap()
}
pub fn new<'b>(
identity_keys: &'b Obfs4NtorSecretKey,
session_id: String,
len_seed: [u8; SEED_LENGTH],
) -> Self
where
'b: 'a,
{
HandshakeMaterials {
identity_keys: identity_keys.clone(),
session_id,
len_seed,
}
}
}
impl Server {
pub(super) fn server_handshake_obfs4<T>(
&self,
msg: T,
materials: HandshakeMaterials,
) -> RelayHandshakeResult<(NtorHkdfKeyGenerator, Vec<u8>)>
where
T: AsRef<[u8]>,
{
let rng = thread_rng();
let session_sk = Keys::ephemeral_from_rng(rng);
self.server_handshake_obfs4_no_keygen(session_sk, msg, materials)
}
pub(crate) fn server_handshake_obfs4_no_keygen<T>(
&self,
session_sk: EphemeralSecret,
msg: T,
mut materials: HandshakeMaterials,
) -> RelayHandshakeResult<(NtorHkdfKeyGenerator, Vec<u8>)>
where
T: AsRef<[u8]>,
{
if CLIENT_MIN_HANDSHAKE_LENGTH > msg.as_ref().len() {
Err(RelayHandshakeError::EAgain)?;
}
let mut client_hs = match self.try_parse_client_handshake(msg, &mut materials) {
Ok(chs) => chs,
Err(Error::HandshakeErr(RelayHandshakeError::EAgain)) => {
return Err(RelayHandshakeError::EAgain);
}
Err(_e) => {
debug!(
"{} failed to parse client handshake: {_e}",
materials.session_id
);
return Err(RelayHandshakeError::BadClientHandshake);
}
};
debug!(
"{} successfully parsed client handshake",
materials.session_id
);
let their_pk = client_hs.get_public();
let ephem_pub = (&session_sk).into();
let session_repres = PublicRepresentative::from(&session_sk);
let xy = session_sk.diffie_hellman(&their_pk);
let xb = materials.identity_keys.sk.diffie_hellman(&their_pk);
let okay =
ct::bool_to_choice(xy.was_contributory()) & ct::bool_to_choice(xb.was_contributory());
trace!("x {} y {}", hex::encode(their_pk), hex::encode(ephem_pub));
let (key_seed, authcode) =
ntor_derive(&xy, &xb, &materials.identity_keys.pk, &their_pk, &ephem_pub)
.map_err(into_internal!("Error deriving keys"))?;
trace!(
"seed: {} auth: {}",
hex::encode(key_seed.as_slice()),
hex::encode(authcode)
);
let mut keygen = NtorHkdfKeyGenerator::new(key_seed, false);
let reply =
self.complete_server_hs(&client_hs, materials, session_repres, &mut keygen, authcode)?;
if okay.into() {
Ok((keygen, reply))
} else {
Err(RelayHandshakeError::BadClientHandshake)
}
}
pub(crate) fn complete_server_hs(
&self,
client_hs: &ClientHandshakeMessage,
materials: HandshakeMaterials,
session_repres: PublicRepresentative,
keygen: &mut NtorHkdfKeyGenerator,
authcode: Authcode,
) -> RelayHandshakeResult<Vec<u8>> {
let epoch_hr = client_hs.get_epoch_hr();
let mut sh_msg = ServerHandshakeMessage::new(session_repres, authcode, epoch_hr);
let h = materials.get_hmac();
let mut buf = BytesMut::with_capacity(MAX_HANDSHAKE_LENGTH);
sh_msg
.marshall(&mut buf, h)
.map_err(|e| RelayHandshakeError::FrameError(format!("{e}")))?;
trace!("adding encoded prng seed");
let mut prng_pkt_buf = BytesMut::new();
build_and_marshall(
&mut prng_pkt_buf,
MessageTypes::PrngSeed.into(),
materials.len_seed,
0,
)
.map_err(|e| RelayHandshakeError::FrameError(format!("{e}")))?;
let codec = &mut keygen.codec;
codec
.encode(prng_pkt_buf.clone(), &mut buf)
.map_err(|e| RelayHandshakeError::FrameError(format!("{e}")))?;
debug!(
"{} writing server handshake {}B ...{}",
materials.session_id,
buf.len(),
hex::encode(&buf[buf.len() - 10..]),
);
Ok(buf.to_vec())
}
fn try_parse_client_handshake(
&self,
buf: impl AsRef<[u8]>,
materials: &mut HandshakeMaterials,
) -> Result<ClientHandshakeMessage> {
let buf = buf.as_ref();
let mut h = materials.get_hmac();
if CLIENT_MIN_HANDSHAKE_LENGTH > buf.len() {
Err(Error::HandshakeErr(RelayHandshakeError::EAgain))?;
}
let r_bytes: [u8; 32] = buf[0..REPRESENTATIVE_LENGTH].try_into().unwrap();
h.update(&r_bytes[..]);
let repres = PublicRepresentative::from(&r_bytes);
let m = h.finalize_reset().into_bytes();
let mark: [u8; MARK_LENGTH] = m[..MARK_LENGTH].try_into()?;
trace!("{} mark?:{}", materials.session_id, hex::encode(mark));
let pos = match find_mac_mark(
mark,
buf,
REPRESENTATIVE_LENGTH + CLIENT_MIN_PAD_LENGTH,
MAX_HANDSHAKE_LENGTH,
true,
) {
Some(p) => p,
None => {
trace!("{} didn't find mark", materials.session_id);
if buf.len() > MAX_HANDSHAKE_LENGTH {
Err(Error::HandshakeErr(RelayHandshakeError::BadClientHandshake))?
}
Err(Error::HandshakeErr(RelayHandshakeError::EAgain))?
}
};
let mut mac_found = false;
let mut epoch_hr = String::new();
for offset in [0_i64, -1, 1] {
trace!("server trying offset: {offset}");
let eh = format!("{}", offset + get_epoch_hour() as i64);
h.reset();
h.update(&buf[..pos + MARK_LENGTH]);
h.update(eh.as_bytes());
let mac_calculated = &h.finalize_reset().into_bytes()[..MAC_LENGTH];
let mac_received = &buf[pos + MARK_LENGTH..pos + MARK_LENGTH + MAC_LENGTH];
trace!(
"server {}-{}",
hex::encode(mac_calculated),
hex::encode(mac_received)
);
if mac_calculated.ct_eq(mac_received).into() {
trace!("correct mac");
if self
.replay_filter
.test_and_set(Instant::now(), mac_received)
{
Err(Error::HandshakeErr(RelayHandshakeError::ReplayedHandshake))?
}
epoch_hr = eh;
mac_found = true;
}
}
if !mac_found {
Err(Error::HandshakeErr(RelayHandshakeError::BadClientHandshake))?
}
if buf.len() != pos + MARK_LENGTH + MAC_LENGTH {
Err(Error::HandshakeErr(RelayHandshakeError::BadClientHandshake))?
}
Ok(ClientHandshakeMessage::new(
repres, 0, epoch_hr,
))
}
}