use crate::error;
use crate::error::Error;
use blake2::digest::Mac;
use blake2::digest::consts::U16;
use blake2::digest::generic_array::GenericArray;
use blake2::{Blake2s256, Blake2sMac, Digest};
use std::collections::{HashSet, VecDeque};
use std::hash::{BuildHasher, Hasher};
use std::ops::Range;
#[derive(Debug, Clone, Copy)]
struct MacHasher(u64);
impl Hasher for MacHasher {
fn write(&mut self, bytes: &[u8]) {
for byte in bytes {
self.0 = (self.0 << 8) | (*byte as u64);
}
}
fn write_u64(&mut self, value: u64) {
self.0 = value;
}
fn finish(&self) -> u64 {
self.0
}
}
impl BuildHasher for MacHasher {
type Hasher = Self;
fn build_hasher(&self) -> Self::Hasher {
*self
}
}
#[derive(Debug)]
pub struct Handshake {
public_key: [u8; 32],
mac_index: HashSet<u64, MacHasher>,
mac_history: VecDeque<u64>,
}
impl Handshake {
const HISTORY_SIZE: usize = 1024 * 256;
pub fn new(public_key: [u8; 32]) -> Self {
let mac_index = HashSet::with_capacity_and_hasher(Self::HISTORY_SIZE, MacHasher(0));
let mac_history = VecDeque::with_capacity(Self::HISTORY_SIZE);
Self { public_key, mac_index, mac_history }
}
pub fn is_valid_handshake(&mut self, packet: &[u8]) -> Result<(), Error> {
const PACKET_LENGTH: usize = 148;
const MTYPE_RANGE: Range<usize> = 0..4;
const MTYPE_VALUE: &[u8] = b"\x01\x00\x00\x00";
const PAYLOAD_RANGE: Range<usize> = 0..116;
const MAC1_RANGE: Range<usize> = 116..132;
const MAC1_LABEL: &[u8] = b"mac1----";
let PACKET_LENGTH = packet.len() else {
return Err(error!("Packet is not a handshake initiation packet"));
};
let MTYPE_VALUE = &packet[MTYPE_RANGE] else {
return Err(error!("Packet is not a handshake initiation packet"));
};
let label_pubkey_hash = Blake2s256::new().chain_update(MAC1_LABEL).chain_update(self.public_key).finalize();
let mac1 = Blake2sMac::<U16>::new(&label_pubkey_hash).chain_update(&packet[PAYLOAD_RANGE]);
let packet_mac1 = GenericArray::from_slice(&packet[MAC1_RANGE]);
let Ok(_) = mac1.verify(packet_mac1) else {
return Err(error!("MAC1 does not match the server public key"));
};
let packet_mac1 = <[u8; 16]>::from(*packet_mac1);
self.register_mac1(&packet_mac1)
}
#[must_use]
fn register_mac1(&mut self, mac: &[u8; 16]) -> Result<(), Error> {
let mac64 = u64::from_ne_bytes([mac[4], mac[5], mac[6], mac[7], mac[8], mac[9], mac[10], mac[11]]);
let false = self.mac_index.contains(&mac64) else {
let mac = u128::from_be_bytes(*mac);
return Err(error!("MAC1 {mac:032x} has already been seen before"));
};
if self.mac_history.len() == Self::HISTORY_SIZE
&& let Some(to_evict) = self.mac_history.pop_front()
{
self.mac_index.remove(&to_evict);
}
self.mac_index.insert(mac64);
self.mac_history.push_back(mac64);
Ok(())
}
}