use std::{
io,
marker::Unpin,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use cipher::{SyncStreamCipher, NewStreamCipher};
use salsa20::{Key, Nonce, Salsa20};
use sha3::{Digest, Sha3_256};
use chrono::{DateTime, Utc};
use super::{Psk, SalsaStream, Randomness, erase_bytes};
pub (in crate) const HANDSHAKE_TIP: usize = 72;
type A72 = [u8; HANDSHAKE_TIP];
static Z72: A72 = [0; HANDSHAKE_TIP];
struct K32N8S32<'a>(&'a A72);
impl<'a> K32N8S32<'a> {
fn sha3_256(&self) -> [u8; 32] {
let mut hash: [u8; 32] = [0; 32];
let mut hasher = Sha3_256::new();
hasher.update(self.0);
let result = hasher.finalize();
hash.copy_from_slice(result.as_slice());
hash
}
fn to_salsa20(&self) -> Salsa20 {
let (key, rest) = self.0.split_at(32);
let key = Key::from_slice(key);
let (nonce, _) = rest.split_at(8);
let nonce = Nonce::from_slice(nonce);
Salsa20::new(&key, &nonce)
}
}
#[derive(Clone, Debug)]
pub struct Connector {
psk: Psk,
randomness: Randomness,
}
impl Connector {
pub fn new(psk: Psk, randomness: Randomness) -> Self {
Self { psk , randomness}
}
pub async fn connect<S>(&self, mut stream: S) -> io::Result<SalsaStream<S>>
where S: AsyncRead + AsyncWrite + Unpin,
{
let mut initiate_start: A72 = Z72;
let mut response_packet: A72 = Z72;
self.randomness
.try_fill(&mut initiate_start)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let now: DateTime<Utc> = Utc::now();
let mut tsb: [u8; 8] = now
.timestamp()
.to_le_bytes();
let nonce = Nonce::from_slice(&tsb);
let mut initiate_salsa = Salsa20::new(self.psk.wrap_k().key(), &nonce);
let mut initiate_packet = initiate_start.to_vec();
initiate_packet.extend_from_slice(self.psk.check().as_slice());
initiate_salsa.apply_keystream(initiate_packet.as_mut());
stream
.write_all(initiate_packet.as_slice())
.await?;
stream
.read_exact(&mut response_packet) .await?;
let initiate_funcs = K32N8S32(&initiate_start);
let initiate_hash = initiate_funcs.sha3_256();
let mut read_cipher = initiate_funcs.to_salsa20();
read_cipher.apply_keystream(&mut response_packet);
let (_, response_hash) = response_packet.split_at(40);
if &initiate_hash != response_hash {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Hash check failed. Server returned invalid data.",
));
}
let response_funcs = K32N8S32(&response_packet);
let write_cipher = response_funcs.to_salsa20();
drop(initiate_funcs);
drop(response_hash);
erase_bytes(&mut tsb);
erase_bytes(&mut initiate_start);
erase_bytes(&mut response_packet);
erase_bytes(initiate_packet.as_mut());
Ok(SalsaStream::new(stream, read_cipher, write_cipher))
}
}