use std::{
io,
mem,
pin::Pin,
sync::Arc,
marker::Unpin,
task::{Context, Poll},
future::Future,
net::SocketAddr,
};
use futures::{
ready,
Stream,
future::FusedFuture,
};
use tokio::{
net::{TcpStream, TcpListener},
io::{AsyncRead, AsyncWrite, ReadBuf},
};
use cipher::{NewStreamCipher, SyncStreamCipher};
use salsa20::{Salsa20, Key, Nonce};
use sha3::{Digest, Sha3_256};
use chrono::{DateTime, Utc};
use super::{
Psk,
SalsaStream,
Randomness,
erase_bytes,
connector::HANDSHAKE_TIP,
};
fn salsa_battery(
key: &Key, timestamp: i64, secs_before: u16, secs_after: u16
) -> Vec<Salsa20> {
let timestamp_s = timestamp - secs_before as i64;
let timestamp_e = timestamp + secs_after as i64;
(timestamp_s..timestamp_e)
.into_iter()
.map(|ts| {
let nonce = Nonce::clone_from_slice(&ts.to_le_bytes());
Salsa20::new(&key, &nonce)
})
.collect()
}
fn check_handshake(
battery: Vec<Salsa20>, handshake: &[u8], check: &[u8], split: usize,
) -> Option<Vec<u8>> {
let mut buffer = handshake.to_vec();
let mut sliceb = buffer.as_mut_slice();
for mut trial_cipher in battery {
sliceb.copy_from_slice(&handshake);
trial_cipher.apply_keystream(&mut sliceb);
let (_, tail) = sliceb.split_at(split);
if tail == check {
return Some(buffer)
}
}
None
}
enum State<S> {
Start(S),
RecvHandshake {
stream: S,
battery: Vec<Salsa20>,
encrypted_h: Vec<u8>,
bytes_r: usize,
},
Check {
stream: S,
battery: Vec<Salsa20>,
encrypted_h: Vec<u8>,
},
GenKey {
stream: S,
decrypted_h: Vec<u8>,
},
Respond {
stream: S,
read_c: Salsa20,
write_c: Salsa20,
response: Vec<u8>,
bytes_w: usize,
},
Finalize {
stream: S,
read_c: Salsa20,
write_c: Salsa20,
},
Finished,
}
pub struct Accept<S> {
psk: Arc<Psk>,
randomness: Randomness,
state: State<S>,
}
impl<S> Accept<S>
where S: AsyncRead + AsyncWrite + Unpin,
{
fn new(psk: Arc<Psk>, randomness: Randomness, stream: S) -> Self {
Self {
psk,
randomness,
state: State::Start(stream),
}
}
}
impl<S> Future for Accept<S>
where S: AsyncRead + AsyncWrite + Unpin
{
type Output = io::Result<SalsaStream<S>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.get_mut();
loop {
let state = mem::replace(&mut pin.state, State::Finished);
match state {
State::Start(stream) => {
let now: DateTime<Utc> = Utc::now();
let timestamp = now.timestamp();
let battery = salsa_battery(&pin.psk.wrap_k().key(), timestamp, 30, 30);
let mut encrypted_h = Vec::new();
encrypted_h.resize(HANDSHAKE_TIP + &pin.psk.check().len(), 0);
let next_state = State::RecvHandshake {
stream,
battery,
encrypted_h,
bytes_r: 0,
};
pin.state = next_state;
},
State::RecvHandshake { mut stream, battery, mut encrypted_h, bytes_r } => {
let (_, buf) = encrypted_h.split_at_mut(bytes_r);
let mut read_buf = ReadBuf::new(buf);
match Pin::new(&mut stream).poll_read(cx, &mut read_buf) {
Poll::Pending => {
pin.state = State::RecvHandshake {
stream, battery, encrypted_h, bytes_r
};
return Poll::Pending;
},
Poll::Ready(Ok(())) => {
let read = read_buf
.filled()
.len();
if read == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof, "Client closed connection"
)));
} else if read_buf.remaining() == 0 {
pin.state = State::Check { stream, battery, encrypted_h };
} else {
pin.state = State::RecvHandshake {
stream, battery, encrypted_h, bytes_r: bytes_r + read
};
}
},
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
},
State::Check { stream, battery, encrypted_h } => {
if let Some(decrypted_h) = check_handshake(
battery, &encrypted_h[..], &pin.psk.check()[..], HANDSHAKE_TIP,
) {
pin.state = State::GenKey { stream, decrypted_h };
} else {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, "Client supplied bad key/nonce"
)));
}
},
State::GenKey { stream, decrypted_h } => {
let (write_key, rest) = decrypted_h.split_at(32);
let (write_nonce, _) = rest.split_at(8);
let (hash_portion, _) = decrypted_h.split_at(HANDSHAKE_TIP);
let mut hasher = Sha3_256::new();
hasher.update(hash_portion);
let hash = hasher.finalize();
let mut read_keynonce: [u8; 40] = [0; 40];
let rand_result = pin.randomness
.try_fill(&mut read_keynonce)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()));
match rand_result {
Ok(()) => (),
Err(e) => return Poll::Ready(Err(e)),
}
let (read_key, read_nonce) = read_keynonce.split_at(32);
let read_key = Key::from_slice(read_key);
let read_nonce = Nonce::from_slice(read_nonce);
let read_cipher = Salsa20::new(&read_key, &read_nonce);
let mut response = read_keynonce.to_vec();
response.extend_from_slice(hash.as_slice());
let write_key = Key::from_slice(write_key);
let write_nonce = Nonce::from_slice(write_nonce);
let mut write_cipher = Salsa20::new(&write_key, &write_nonce);
write_cipher.apply_keystream(response.as_mut());
let mut decrypted_h = decrypted_h;
erase_bytes(decrypted_h.as_mut());
pin.state = State::Respond {
stream,
read_c: read_cipher,
write_c: write_cipher,
response,
bytes_w: 0,
};
},
State::Respond { mut stream, read_c, write_c, response, bytes_w } => {
let (_, rest) = response
.as_slice()
.split_at(bytes_w);
match Pin::new(&mut stream).poll_write(cx, rest) {
Poll::Pending => {
pin.state = State::Respond {
stream, read_c, write_c, response, bytes_w
};
return Poll::Pending;
},
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(written)) => {
if written == rest.len() {
pin.state = State::Finalize { stream, read_c, write_c };
} else {
pin.state = State::Respond {
stream,
read_c,
write_c,
response,
bytes_w: bytes_w + written
};
}
}
}
},
State::Finalize { mut stream, read_c, write_c } => {
match Pin::new(&mut stream).poll_flush(cx) {
Poll::Pending => {
pin.state = State::Finalize { stream, read_c, write_c };
return Poll::Pending;
},
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {
pin.state = State::Finished;
return Poll::Ready(Ok(
SalsaStream::new(stream, read_c, write_c)
));
},
}
},
State::Finished => {
panic!("Polled finished future.");
},
}
}
}
}
impl<S> FusedFuture for Accept<S>
where S: AsyncRead + AsyncWrite + Unpin
{
fn is_terminated(&self) -> bool {
match self.state {
State::Finished => true,
_ => false,
}
}
}
#[derive(Clone)]
pub struct Acceptor {
psk: Arc<Psk>,
randomness: Randomness,
}
impl Acceptor {
pub fn new(psk: Psk, randomness: Randomness) -> Self {
Self {
psk: Arc::new(psk),
randomness,
}
}
pub fn accept<S>(&self, stream: S) -> Accept<S>
where S: AsyncRead + AsyncWrite + Unpin,
{
Accept::new(Arc::clone(&self.psk), self.randomness.clone(), stream)
}
}
enum Inner<S> {
Handshaking(Accept<S>),
Streaming(SalsaStream<S>),
}
pub struct ServerStream<S> {
inner: Inner<S>,
remote_addr: Option<SocketAddr>,
}
impl<S> ServerStream<S>
where S: AsyncRead + AsyncWrite + Unpin
{
pub fn new<A>(handshake: Accept<S>, remote_addr: A) -> Self
where A: Into<Option<SocketAddr>>,
{
Self {
inner: Inner::Handshaking(handshake),
remote_addr: remote_addr.into(),
}
}
pub fn remote_addr(&self) -> Option<&SocketAddr> {
self.remote_addr.as_ref()
}
}
impl<S> AsyncRead for ServerStream<S>
where S: AsyncRead + AsyncWrite + Unpin
{
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.inner {
Inner::Handshaking(ref mut handshake) => {
match ready!(Pin::new(handshake).poll(ctx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(ctx, buf);
pin.inner = Inner::Streaming(stream);
result
},
Err(e) => Poll::Ready(Err(e)),
}
},
Inner::Streaming(ref mut stream) => {
Pin::new(stream).poll_read(ctx, buf)
},
}
}
}
impl<S> AsyncWrite for ServerStream<S>
where S: AsyncRead + AsyncWrite + Unpin
{
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8]
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.inner {
Inner::Handshaking(ref mut handshake) => {
match ready!(Pin::new(handshake).poll(ctx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(ctx, buf);
pin.inner = Inner::Streaming(stream);
result
},
Err(e) => Poll::Ready(Err(e)),
}
},
Inner::Streaming(ref mut stream) => Pin::new(stream).poll_write(ctx, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.inner {
Inner::Handshaking(_) => Poll::Ready(Ok(())),
Inner::Streaming(ref mut stream) => Pin::new(stream).poll_flush(ctx),
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.inner {
Inner::Handshaking(_) => Poll::Ready(Ok(())),
Inner::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(ctx),
}
}
}
pub struct StreamAcceptor<S> {
acceptor: Acceptor,
incoming: S,
}
impl<S> StreamAcceptor<S> {
pub fn new(acceptor: Acceptor, incoming: S) -> Self {
Self { acceptor, incoming }
}
}
impl<S, T> Stream for StreamAcceptor<S>
where S: Stream<Item = io::Result<T>> + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = io::Result<ServerStream<T>>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
let stream = match ready!(Pin::new(&mut pin.incoming).poll_next(ctx)) {
Some(maybie_stream) => maybie_stream?,
None => return Poll::Ready(None),
};
let acceptor = pin.acceptor.accept(stream);
Poll::Ready(Some(Ok(ServerStream::new(acceptor, None))))
}
}
pub struct TcpListenAcceptor {
acceptor: Acceptor,
listener: TcpListener,
}
impl TcpListenAcceptor {
pub fn new(acceptor: Acceptor, listener: TcpListener) -> Self {
Self { acceptor, listener }
}
}
impl Stream for TcpListenAcceptor {
type Item = io::Result<ServerStream<TcpStream>>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
let (socket, addr) = ready!(Pin::new(&pin.listener).poll_accept(ctx))?;
let acceptor = pin.acceptor.accept(socket);
Poll::Ready(Some(Ok(ServerStream::new(acceptor, addr))))
}
}
#[cfg(test)]
mod tests {
use base64::decode_config_slice;
use super::*;
const B64KEY: &str = "9gU7mziAmlCzlyhmUHq9LFbKXXlqpjvD4LuJtOv89Ik=";
fn to_key(b64: &str) -> Key {
let mut key: [u8; 32] = [0; 32];
decode_config_slice(b64, base64::STANDARD, &mut key).unwrap();
Key::clone_from_slice(&key)
}
#[test]
fn salsa_battery_produces_right_number_of_candidates() {
let key = to_key(B64KEY);
let battery = salsa_battery(&key, 10000, 10, 10);
assert!(battery.len() == 20);
}
#[test]
fn check_check_handshake() {
let key = to_key(B64KEY);
let battery = salsa_battery(&key, 10000, 10, 10);
let nval: [u8; 8] = (10000 as u64).to_le_bytes();
let nonce = Nonce::from_slice(&nval);
let check = "check_val";
let mut salsa = Salsa20::new(&key, &nonce);
let mut data = vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ];
data.extend_from_slice(check.as_bytes());
let mut handshake = data.clone();
salsa.apply_keystream(handshake.as_mut_slice());
let maybie_u = check_handshake(battery, &handshake, check.as_bytes(), 10);
assert!(maybie_u.is_some());
let unencrypted = maybie_u.unwrap();
assert!(unencrypted == data);
}
}