use std::{
cmp,
convert::TryInto,
io,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::ready;
use log::*;
use snow::{error::StateProblem, HandshakeState, TransportState};
use tari_utilities::ByteArray;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
time,
};
use crate::types::CommsPublicKey;
const LOG_TARGET: &str = "comms::noise::socket";
const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize;
const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16;
struct NoiseBuffers {
read_encrypted: [u8; MAX_PAYLOAD_LENGTH],
read_decrypted: [u8; MAX_PAYLOAD_LENGTH],
write_decrypted: [u8; MAX_WRITE_BUFFER_LENGTH],
write_encrypted: [u8; MAX_PAYLOAD_LENGTH],
}
impl NoiseBuffers {
fn new() -> Self {
Self {
read_encrypted: [0; MAX_PAYLOAD_LENGTH],
read_decrypted: [0; MAX_PAYLOAD_LENGTH],
write_decrypted: [0; MAX_WRITE_BUFFER_LENGTH],
write_encrypted: [0; MAX_PAYLOAD_LENGTH],
}
}
}
impl ::std::fmt::Debug for NoiseBuffers {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
f.debug_struct("NoiseBuffers").finish()
}
}
#[derive(Debug)]
enum ReadState {
Init,
ReadFrameLen { buf: [u8; 2], offset: usize },
ReadFrame { frame_len: u16, offset: usize },
CopyDecryptedFrame { decrypted_len: usize, offset: usize },
Eof(Result<(), ()>),
DecryptionError(snow::Error),
}
#[derive(Debug)]
enum WriteState {
Init,
BufferData { offset: usize },
WriteFrameLen {
frame_len: u16,
buf: [u8; 2],
offset: usize,
},
WriteEncryptedFrame { frame_len: u16, offset: usize },
Flush,
Eof,
EncryptionError(snow::Error),
}
#[derive(Debug)]
pub struct NoiseSocket<TSocket> {
socket: TSocket,
state: NoiseState,
buffers: Box<NoiseBuffers>,
read_state: ReadState,
write_state: WriteState,
}
impl<TSocket> NoiseSocket<TSocket> {
fn new(socket: TSocket, session: NoiseState) -> Self {
Self {
socket,
state: session,
buffers: Box::new(NoiseBuffers::new()),
read_state: ReadState::Init,
write_state: WriteState::Init,
}
}
pub fn get_remote_static(&self) -> Option<&[u8]> {
self.state.get_remote_static()
}
pub fn get_remote_public_key(&self) -> Option<CommsPublicKey> {
self.get_remote_static()
.and_then(|s| CommsPublicKey::from_canonical_bytes(s).ok())
}
}
fn poll_write_all<TSocket>(
context: &mut Context,
mut socket: Pin<&mut TSocket>,
buf: &[u8],
offset: &mut usize,
) -> Poll<io::Result<()>>
where
TSocket: AsyncWrite,
{
loop {
let bytes = match buf.get(*offset..) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Offset exceeds buffer length",
)));
},
};
let n = ready!(socket.as_mut().poll_write(context, bytes))?;
trace!(
target: LOG_TARGET,
"poll_write_all: wrote {}/{} bytes",
*offset + n,
buf.len()
);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
*offset += n;
assert!(*offset <= buf.len());
if *offset == buf.len() {
return Poll::Ready(Ok(()));
}
}
}
fn poll_read_u16frame_len<TSocket>(
context: &mut Context,
socket: Pin<&mut TSocket>,
buf: &mut [u8; 2],
offset: &mut usize,
) -> Poll<io::Result<Option<u16>>>
where
TSocket: AsyncRead,
{
match ready!(poll_read_exact(context, socket, buf, offset)) {
Ok(()) => Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))),
Err(e) => {
if *offset == 0 && e.kind() == io::ErrorKind::UnexpectedEof {
return Poll::Ready(Ok(None));
}
Poll::Ready(Err(e))
},
}
}
fn poll_read_exact<TSocket>(
context: &mut Context,
mut socket: Pin<&mut TSocket>,
buf: &mut [u8],
offset: &mut usize,
) -> Poll<io::Result<()>>
where
TSocket: AsyncRead,
{
loop {
let bytes = match buf.get_mut(*offset..) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Offset exceeds buffer length",
)));
},
};
let mut read_buf = ReadBuf::new(bytes);
let prev_rem = read_buf.remaining();
ready!(socket.as_mut().poll_read(context, &mut read_buf))?;
let n = prev_rem
.checked_sub(read_buf.remaining())
.ok_or_else(|| io::Error::other("buffer underflow: prev_rem < read_buf.remaining()"))?;
trace!(
target: LOG_TARGET,
"poll_read_exact: read {}/{} bytes",
*offset + n,
buf.len()
);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()));
}
*offset += n;
assert!(*offset <= buf.len());
if *offset == buf.len() {
return Poll::Ready(Ok(()));
}
}
}
impl<TSocket> NoiseSocket<TSocket>
where TSocket: AsyncRead + Unpin
{
#[allow(clippy::too_many_lines)]
fn poll_read(&mut self, context: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
loop {
trace!(target: LOG_TARGET, "NoiseSocket ReadState::{:?}", self.read_state);
match self.read_state {
ReadState::Init => {
self.read_state = ReadState::ReadFrameLen { buf: [0, 0], offset: 0 };
},
ReadState::ReadFrameLen {
ref mut buf,
ref mut offset,
} => {
match ready!(poll_read_u16frame_len(context, Pin::new(&mut self.socket), buf, offset)) {
Ok(Some(frame_len)) => {
if frame_len == 0 {
self.read_state = ReadState::Init;
} else {
self.read_state = ReadState::ReadFrame { frame_len, offset: 0 };
}
},
Ok(None) => {
self.read_state = ReadState::Eof(Ok(()));
},
Err(e) => {
if e.kind() == io::ErrorKind::UnexpectedEof {
self.read_state = ReadState::Eof(Err(()));
}
return Poll::Ready(Err(e));
},
}
},
ReadState::ReadFrame {
frame_len,
ref mut offset,
} => {
let bytes = match self.buffers.read_encrypted.get_mut(..(frame_len as usize)) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame length exceeds buffer length",
)));
},
};
match ready!(poll_read_exact(context, Pin::new(&mut self.socket), bytes, offset)) {
Ok(()) => match self.state.read_message(bytes, &mut self.buffers.read_decrypted) {
Ok(decrypted_len) => {
self.read_state = ReadState::CopyDecryptedFrame {
decrypted_len,
offset: 0,
};
},
Err(e) => {
warn!(target: LOG_TARGET, "Decryption Error: {e}");
self.read_state = ReadState::DecryptionError(e);
},
},
Err(e) => {
if e.kind() == io::ErrorKind::UnexpectedEof {
self.read_state = ReadState::Eof(Err(()));
}
return Poll::Ready(Err(e));
},
}
},
ReadState::CopyDecryptedFrame {
decrypted_len,
ref mut offset,
} => {
let num_bytes_to_copy = cmp::min(decrypted_len - *offset, buf.len());
let bytes_to_copy = match self.buffers.read_decrypted.get(*offset..(*offset + num_bytes_to_copy)) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Offset exceeds buffer length",
)));
},
};
buf.get_mut(..num_bytes_to_copy)
.expect("this is checked")
.copy_from_slice(bytes_to_copy);
trace!(
target: LOG_TARGET,
"CopyDecryptedFrame: copied {}/{} bytes",
*offset + num_bytes_to_copy,
decrypted_len
);
*offset += num_bytes_to_copy;
if *offset == decrypted_len {
self.read_state = ReadState::Init;
}
return Poll::Ready(Ok(num_bytes_to_copy));
},
ReadState::Eof(Ok(())) => return Poll::Ready(Ok(0)),
ReadState::Eof(Err(())) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
ReadState::DecryptionError(ref e) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("DecryptionError: {e}"),
)))
},
}
}
}
}
impl<TSocket> AsyncRead for NoiseSocket<TSocket>
where TSocket: AsyncRead + Unpin
{
fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let slice = buf.initialize_unfilled();
let n = futures::ready!(self.get_mut().poll_read(context, slice))?;
buf.advance(n);
Poll::Ready(Ok(()))
}
}
impl<TSocket> NoiseSocket<TSocket>
where TSocket: AsyncWrite + Unpin
{
#[allow(clippy::too_many_lines)]
fn poll_write_or_flush(&mut self, context: &mut Context, buf: Option<&[u8]>) -> Poll<io::Result<Option<usize>>> {
loop {
trace!(
target: LOG_TARGET,
"NoiseSocket {} WriteState::{:?}",
if buf.is_some() { "poll_write" } else { "poll_flush" },
self.write_state,
);
match self.write_state {
WriteState::Init => {
if buf.is_some() {
self.write_state = WriteState::BufferData { offset: 0 };
} else {
return Poll::Ready(Ok(None));
}
},
WriteState::BufferData { ref mut offset } => {
let bytes_buffered = if let Some(buf) = buf {
let num_bytes_to_copy = ::std::cmp::min(MAX_WRITE_BUFFER_LENGTH - *offset, buf.len());
let bytes = match buf.get(..num_bytes_to_copy) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame length exceeds buffer length",
)));
},
};
self.buffers
.write_decrypted
.get_mut(*offset..(*offset + num_bytes_to_copy))
.expect("this is checked")
.copy_from_slice(bytes);
trace!(
target: LOG_TARGET,
"BufferData: buffered {}/{} bytes",
num_bytes_to_copy,
buf.len()
);
*offset += num_bytes_to_copy;
Some(num_bytes_to_copy)
} else {
None
};
if buf.is_none() || *offset == MAX_WRITE_BUFFER_LENGTH {
let bytes = match self.buffers.write_decrypted.get(..*offset) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame length exceeds buffer length",
)));
},
};
match self.state.write_message(bytes, &mut self.buffers.write_encrypted) {
Ok(encrypted_len) => {
let frame_len = encrypted_len
.try_into()
.map_err(|_| io::Error::other("offset should be able to fit in u16"))?;
self.write_state = WriteState::WriteFrameLen {
frame_len,
buf: u16::to_be_bytes(frame_len),
offset: 0,
};
},
Err(e) => {
warn!(target: LOG_TARGET, "Encryption Error: {e}");
let err = io::Error::new(io::ErrorKind::InvalidData, format!("EncryptionError: {e}"));
self.write_state = WriteState::EncryptionError(e);
return Poll::Ready(Err(err));
},
}
}
if let Some(bytes_buffered) = bytes_buffered {
return Poll::Ready(Ok(Some(bytes_buffered)));
}
},
WriteState::WriteFrameLen {
frame_len,
ref buf,
ref mut offset,
} => match ready!(poll_write_all(context, Pin::new(&mut self.socket), buf, offset)) {
Ok(()) => {
self.write_state = WriteState::WriteEncryptedFrame { frame_len, offset: 0 };
},
Err(e) => {
if e.kind() == io::ErrorKind::WriteZero {
self.write_state = WriteState::Eof;
}
return Poll::Ready(Err(e));
},
},
WriteState::WriteEncryptedFrame {
frame_len,
ref mut offset,
} => {
let bytes = match self.buffers.write_encrypted.get(..(frame_len as usize)) {
Some(bytes) => bytes,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"frame length exceeds buffer length",
)));
},
};
match ready!(poll_write_all(context, Pin::new(&mut self.socket), bytes, offset)) {
Ok(()) => {
self.write_state = WriteState::Flush;
},
Err(e) => {
if e.kind() == io::ErrorKind::WriteZero {
self.write_state = WriteState::Eof;
}
return Poll::Ready(Err(e));
},
}
},
WriteState::Flush => {
ready!(Pin::new(&mut self.socket).poll_flush(context))?;
self.write_state = WriteState::Init;
},
WriteState::Eof => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
WriteState::EncryptionError(ref e) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("EncryptionError: {e}"),
)))
},
}
}
}
fn poll_write(&mut self, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
if let Some(bytes_written) = ready!(self.poll_write_or_flush(context, Some(buf)))? {
Poll::Ready(Ok(bytes_written))
} else {
unreachable!();
}
}
fn poll_flush(&mut self, context: &mut Context) -> Poll<io::Result<()>> {
if ready!(self.poll_write_or_flush(context, None))?.is_none() {
Poll::Ready(Ok(()))
} else {
unreachable!();
}
}
}
impl<TSocket> AsyncWrite for NoiseSocket<TSocket>
where TSocket: AsyncWrite + Unpin
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.get_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.get_mut().poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.socket).poll_shutdown(cx)
}
}
pub struct Handshake<TSocket> {
socket: NoiseSocket<TSocket>,
recv_timeout: Duration,
}
impl<TSocket> Handshake<TSocket> {
pub fn new(socket: TSocket, state: HandshakeState, recv_timeout: Duration) -> Self {
Self {
socket: NoiseSocket::new(socket, state.into()),
recv_timeout,
}
}
}
impl<TSocket> Handshake<TSocket>
where TSocket: AsyncRead + AsyncWrite + Unpin
{
pub async fn perform_handshake(mut self) -> io::Result<NoiseSocket<TSocket>> {
match self.handshake_1_5rtt().await {
Ok(_) => self.build(),
Err(err) => {
info!(
target: LOG_TARGET,
"Noise handshake failed because '{err:?}'. Closing socket."
);
self.socket.shutdown().await?;
Err(err)
},
}
}
async fn handshake_1_5rtt(&mut self) -> io::Result<()> {
if self.socket.state.is_initiator() {
self.send().await?;
self.flush().await?;
self.receive().await?;
self.send().await?;
self.flush().await?;
} else {
self.receive().await?;
self.send().await?;
self.flush().await?;
self.receive().await?;
}
Ok(())
}
async fn send(&mut self) -> io::Result<usize> {
self.socket.write(&[]).await
}
async fn flush(&mut self) -> io::Result<()> {
self.socket.flush().await
}
async fn receive(&mut self) -> io::Result<usize> {
time::timeout(self.recv_timeout, self.socket.read(&mut []))
.await
.map_err(|_| io::Error::from(io::ErrorKind::TimedOut))?
}
fn build(self) -> io::Result<NoiseSocket<TSocket>> {
let transport_state = self
.socket
.state
.into_transport_mode()
.map_err(|err| io::Error::other(format!("Invalid snow state: {err}")))?;
Ok(NoiseSocket {
state: transport_state,
..self.socket
})
}
}
#[derive(Debug)]
enum NoiseState {
HandshakeState(Box<HandshakeState>),
TransportState(Box<TransportState>),
}
macro_rules! proxy_state_method {
(pub fn $name:ident(&mut self$(,)? $($arg_name:ident : $arg_type:ty),*) -> $ret:ty) => {
pub fn $name(&mut self, $($arg_name:$arg_type),*) -> $ret {
match self {
NoiseState::HandshakeState(state) => state.$name($($arg_name),*),
NoiseState::TransportState(state) => state.$name($($arg_name),*),
}
}
};
(pub fn $name:ident(&self$(,)? $($arg_name:ident : $arg_type:ty),*) -> $ret:ty) => {
pub fn $name(&self, $($arg_name:$arg_type),*) -> $ret {
match self {
NoiseState::HandshakeState(state) => state.$name($($arg_name),*),
NoiseState::TransportState(state) => state.$name($($arg_name),*),
}
}
}
}
impl NoiseState {
proxy_state_method!(pub fn write_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, snow::Error>);
proxy_state_method!(pub fn is_initiator(&self) -> bool);
proxy_state_method!(pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, snow::Error>);
proxy_state_method!(pub fn get_remote_static(&self) -> Option<&[u8]>);
pub fn into_transport_mode(self) -> Result<Self, snow::Error> {
match self {
NoiseState::HandshakeState(state) => Ok(NoiseState::TransportState(Box::new(state.into_transport_mode()?))),
_ => Err(snow::Error::State(StateProblem::HandshakeAlreadyFinished)),
}
}
}
impl From<HandshakeState> for NoiseState {
fn from(state: HandshakeState) -> Self {
NoiseState::HandshakeState(Box::new(state))
}
}
impl From<TransportState> for NoiseState {
fn from(state: TransportState) -> Self {
NoiseState::TransportState(Box::new(state))
}
}
#[cfg(test)]
mod test {
use futures::future::join;
use snow::{params::NoiseParams, Builder, Error, Keypair};
use super::*;
use crate::{memsocket::MemorySocket, noise::config::NOISE_PARAMETERS};
async fn build_test_connection(
) -> Result<((Keypair, Handshake<MemorySocket>), (Keypair, Handshake<MemorySocket>)), Error> {
let parameters: NoiseParams = NOISE_PARAMETERS.parse().expect("Invalid protocol name");
let dialer_keypair = Builder::new(parameters.clone()).generate_keypair()?;
let listener_keypair = Builder::new(parameters.clone()).generate_keypair()?;
let dialer_session = Builder::new(parameters.clone())
.local_private_key(&dialer_keypair.private)
.build_initiator()?;
let listener_session = Builder::new(parameters)
.local_private_key(&listener_keypair.private)
.build_responder()?;
let (dialer_socket, listener_socket) = MemorySocket::new_pair();
let (dialer, listener) = (
NoiseSocket::new(dialer_socket, dialer_session.into()),
NoiseSocket::new(listener_socket, listener_session.into()),
);
Ok((
(dialer_keypair, Handshake {
socket: dialer,
recv_timeout: Duration::from_secs(1),
}),
(listener_keypair, Handshake {
socket: listener,
recv_timeout: Duration::from_secs(1),
}),
))
}
async fn perform_handshake(
dialer: Handshake<MemorySocket>,
listener: Handshake<MemorySocket>,
) -> io::Result<(NoiseSocket<MemorySocket>, NoiseSocket<MemorySocket>)> {
let (dialer_result, listener_result) = join(dialer.perform_handshake(), listener.perform_handshake()).await;
Ok((dialer_result?, listener_result?))
}
#[tokio::test]
async fn test_handshake() {
let ((dialer_keypair, dialer), (listener_keypair, listener)) = build_test_connection().await.unwrap();
let (dialer_socket, listener_socket) = perform_handshake(dialer, listener).await.unwrap();
assert_eq!(
dialer_socket.get_remote_static(),
Some(listener_keypair.public.as_ref())
);
assert_eq!(
listener_socket.get_remote_static(),
Some(dialer_keypair.public.as_ref())
);
}
#[tokio::test]
async fn simple_test() -> io::Result<()> {
let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
let (mut dialer_socket, mut listener_socket) = perform_handshake(dialer, listener).await?;
dialer_socket.write_all(b"stormlight").await?;
dialer_socket.write_all(b" ").await?;
dialer_socket.write_all(b"archive").await?;
dialer_socket.flush().await?;
dialer_socket.shutdown().await?;
let mut buf = Vec::new();
listener_socket.read_to_end(&mut buf).await?;
assert_eq!(buf, b"stormlight archive");
Ok(())
}
#[tokio::test]
async fn interleaved_writes() -> io::Result<()> {
let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
let (mut a, mut b) = perform_handshake(dialer, listener).await?;
a.write_all(b"The Name of the Wind").await?;
a.flush().await?;
a.write_all(b"The Wise Man's Fear").await?;
a.flush().await?;
b.write_all(b"The Doors of Stone").await?;
b.flush().await?;
let mut buf = [0; 20];
b.read_exact(&mut buf).await?;
assert_eq!(&buf, b"The Name of the Wind");
let mut buf = [0; 19];
b.read_exact(&mut buf).await?;
assert_eq!(&buf, b"The Wise Man's Fear");
let mut buf = [0; 18];
a.read_exact(&mut buf).await?;
assert_eq!(&buf, b"The Doors of Stone");
Ok(())
}
#[tokio::test]
async fn u16_max_writes() -> io::Result<()> {
let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
let (mut a, mut b) = perform_handshake(dialer, listener).await?;
let buf_send = &[1; MAX_PAYLOAD_LENGTH + 1];
a.write_all(buf_send).await?;
a.flush().await?;
let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH + 1];
b.read_exact(&mut buf_receive).await?;
assert_eq!(&buf_receive[..], &buf_send[..]);
Ok(())
}
#[tokio::test]
async fn larger_writes() -> io::Result<()> {
let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
let (mut a, mut b) = perform_handshake(dialer, listener).await?;
let buf_send = &[1; MAX_PAYLOAD_LENGTH * 2 + 1024];
a.write_all(buf_send).await?;
a.flush().await?;
let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH * 2 + 1024];
b.read_exact(&mut buf_receive).await?;
assert_eq!(&buf_receive[..], &buf_send[..]);
Ok(())
}
#[tokio::test]
async fn unexpected_eof() -> io::Result<()> {
let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
let (mut a, mut b) = perform_handshake(dialer, listener).await?;
let buf_send = &[1; MAX_PAYLOAD_LENGTH];
a.write_all(buf_send).await?;
a.flush().await?;
a.socket.shutdown().await.unwrap();
drop(a);
let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH];
b.read_exact(&mut buf_receive).await.unwrap();
assert_eq!(&buf_receive[..], &buf_send[..]);
let err = b.read_exact(&mut buf_receive).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
Ok(())
}
}