use std::convert::TryInto;
use std::sync::Arc;
use aes::cipher::{StreamCipher, StreamCipherSeek};
use everscale_crypto::ed25519;
use super::encryption::*;
use super::keystore::Key;
use super::node_id::{NodeIdFull, NodeIdShort};
use super::packet_view::*;
use crate::util::FastHashMap;
#[inline(always)]
pub fn compute_handshake_prefix_len(version: Option<u16>) -> usize {
96 + if version.is_some() { 4 } else { 0 }
}
pub fn build_handshake_packet(
peer_id: &NodeIdShort,
peer_id_full: &NodeIdFull,
buffer: &mut Vec<u8>,
version: Option<u16>,
) {
let temp_private_key = ed25519::SecretKey::generate(&mut rand::thread_rng());
let temp_private_key = ed25519::ExpandedSecretKey::from(&temp_private_key);
let temp_public_key = ed25519::PublicKey::from(&temp_private_key);
let shared_secret = temp_private_key.compute_shared_secret(peer_id_full.public_key());
let checksum: [u8; 32] = compute_packet_data_hash(version, buffer.as_slice());
let header_len = compute_handshake_prefix_len(version);
let buffer_len = buffer.len();
buffer.resize(header_len + buffer_len, 0);
buffer.copy_within(..buffer_len, header_len);
buffer[..32].copy_from_slice(peer_id.as_slice());
buffer[32..64].copy_from_slice(temp_public_key.as_bytes());
match version {
Some(version) => {
let mut xor = [
(version >> 8) as u8,
version as u8,
(version >> 8) as u8,
version as u8,
];
for (i, byte) in buffer[..64].iter().enumerate() {
xor[i % 4] ^= *byte;
}
for (i, byte) in checksum.iter().enumerate() {
xor[i % 4] ^= *byte;
}
buffer[64..68].copy_from_slice(&xor);
buffer[68..100].copy_from_slice(&checksum);
build_packet_cipher(&shared_secret, &checksum).apply_keystream(&mut buffer[100..]);
}
None => {
buffer[64..96].copy_from_slice(&checksum);
build_packet_cipher(&shared_secret, &checksum).apply_keystream(&mut buffer[96..]);
}
}
}
pub fn parse_handshake_packet(
keys: &FastHashMap<NodeIdShort, Arc<Key>>,
buffer: &mut PacketView<'_>,
) -> Result<Option<(NodeIdShort, Option<u16>)>, HandshakeError> {
const PUBLIC_KEY_RANGE: std::ops::Range<usize> = 32..64;
const DATA_START: usize = 96;
const CHECKSUM_RANGE: std::ops::Range<usize> = 64..DATA_START;
const DATA_RANGE: std::ops::RangeFrom<usize> = DATA_START..;
const EXT_DATA_START: usize = 100;
const EXT_CHECKSUM_RANGE: std::ops::Range<usize> = 68..EXT_DATA_START;
const EXT_DATA_RANGE: std::ops::RangeFrom<usize> = EXT_DATA_START..;
if buffer.len() < DATA_START {
return Err(HandshakeError::BadHandshakePacketLength);
}
let local_id = unsafe { &*(buffer.as_ptr() as *const NodeIdShort) };
let local_key = match keys.get(local_id) {
Some(key) => key,
None => return Ok(None),
};
let shared_secret =
match ed25519::PublicKey::from_bytes(buffer[PUBLIC_KEY_RANGE].try_into().unwrap()) {
Some(other_public_key) => local_key
.secret_key()
.compute_shared_secret(&other_public_key),
None => return Err(HandshakeError::InvalidPublicKey),
};
if buffer.len() > EXT_DATA_START {
if let Some(version) =
decode_version::<EXT_DATA_START>((&buffer[..EXT_DATA_START]).try_into().unwrap())
{
let mut cipher = build_packet_cipher(
&shared_secret,
&buffer[EXT_CHECKSUM_RANGE].try_into().unwrap(),
);
cipher.apply_keystream(&mut buffer[EXT_DATA_RANGE]);
if compute_packet_data_hash(Some(version), &buffer[EXT_DATA_RANGE]).as_slice()
== &buffer[EXT_CHECKSUM_RANGE]
{
buffer.remove_prefix(EXT_DATA_START);
return Ok(Some((*local_id, Some(version))));
}
cipher.seek(0);
cipher.apply_keystream(&mut buffer[EXT_DATA_RANGE]);
}
}
build_packet_cipher(&shared_secret, &buffer[CHECKSUM_RANGE].try_into().unwrap())
.apply_keystream(&mut buffer[DATA_RANGE]);
if compute_packet_data_hash(None, &buffer[DATA_RANGE]).as_slice() != &buffer[CHECKSUM_RANGE] {
return Err(HandshakeError::BadHandshakePacketChecksum);
}
buffer.remove_prefix(DATA_START);
Ok(Some((*local_id, None)))
}
#[derive(thiserror::Error, Debug)]
pub enum HandshakeError {
#[error("Bad handshake packet length")]
BadHandshakePacketLength,
#[error("Bad handshake packet checksum")]
BadHandshakePacketChecksum,
#[error("Invalid public key")]
InvalidPublicKey,
}