use std::pin::Pin;
use std::task::{Context, Poll};
use std::vec;
use std::vec::Vec;
use bitcoin::Network;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{
handshake::{self, GarbageResult, VersionResult},
io::{Payload, ProtocolError},
Error, Handshake, InboundCipher, OutboundCipher, Role, MAX_PACKET_SIZE_FOR_ALLOCATION,
NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES,
};
pub struct ProtocolSessionReader<R> {
leftover: Vec<u8>,
leftover_pos: usize,
reader: R,
}
impl<R> ProtocolSessionReader<R> {
fn new(leftover: Vec<u8>, reader: R) -> Self {
Self {
leftover,
leftover_pos: 0,
reader,
}
}
}
impl<R: AsyncRead + Unpin> AsyncRead for ProtocolSessionReader<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if this.leftover_pos < this.leftover.len() {
let remaining = &this.leftover[this.leftover_pos..];
let to_copy = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..to_copy]);
this.leftover_pos += to_copy;
return Poll::Ready(Ok(()));
}
Pin::new(&mut this.reader).poll_read(cx, buf)
}
}
pub async fn handshake<R, W>(
network: Network,
role: Role,
garbage: Option<Vec<u8>>,
decoys: Option<Vec<Vec<u8>>>,
mut reader: R,
writer: &mut W,
) -> Result<(InboundCipher, OutboundCipher, ProtocolSessionReader<R>), ProtocolError>
where
R: AsyncRead + Send + Unpin,
W: AsyncWrite + Unpin,
{
let garbage_ref = garbage.as_deref();
let decoy_refs: Option<Vec<&[u8]>> = decoys
.as_ref()
.map(|vecs| vecs.iter().map(Vec::as_slice).collect());
let decoys_ref = decoy_refs.as_deref();
let handshake = Handshake::<handshake::Initialized>::new(network, role)?;
let key_buffer_len = Handshake::<handshake::Initialized>::send_key_len(garbage_ref);
let mut key_buffer = vec![0u8; key_buffer_len];
let handshake = handshake.send_key(garbage_ref, &mut key_buffer)?;
writer.write_all(&key_buffer).await?;
writer.flush().await?;
let mut remote_ellswift_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
reader.read_exact(&mut remote_ellswift_buffer).await?;
let handshake = handshake.receive_key(remote_ellswift_buffer)?;
let version_buffer_len = Handshake::<handshake::ReceivedKey>::send_version_len(decoys_ref);
let mut version_buffer = vec![0u8; version_buffer_len];
let handshake = handshake.send_version(&mut version_buffer, decoys_ref)?;
writer.write_all(&version_buffer).await?;
writer.flush().await?;
let mut garbage_buffer = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES];
reader.read_exact(&mut garbage_buffer).await?;
let mut handshake = handshake;
let (mut handshake, garbage_bytes) = loop {
match handshake.receive_garbage(&garbage_buffer) {
Ok(GarbageResult::FoundGarbage {
handshake,
consumed_bytes,
}) => {
break (handshake, consumed_bytes);
}
Ok(GarbageResult::NeedMoreData(h)) => {
handshake = h;
let mut temp = vec![0u8; 256];
match reader.read(&mut temp).await {
Ok(0) => return Err(ProtocolError::eof()),
Ok(n) => {
garbage_buffer.extend_from_slice(&temp[..n]);
}
Err(e) => return Err(ProtocolError::from(e)),
}
}
Err(e) => return Err(ProtocolError::Internal(e)),
}
};
let leftover_bytes = garbage_buffer[garbage_bytes..].to_vec();
let mut session_reader = ProtocolSessionReader::new(leftover_bytes, reader);
let mut length_bytes = [0u8; NUM_LENGTH_BYTES];
loop {
session_reader.read_exact(&mut length_bytes).await?;
let packet_len = handshake.decrypt_packet_len(length_bytes)?;
if packet_len > MAX_PACKET_SIZE_FOR_ALLOCATION {
return Err(ProtocolError::Internal(Error::PacketTooBig));
}
let mut packet_bytes = vec![0u8; packet_len];
session_reader.read_exact(&mut packet_bytes).await?;
match handshake.receive_version(&mut packet_bytes) {
Ok(VersionResult::Complete { cipher }) => {
let (inbound_cipher, outbound_cipher) = cipher.into_split();
return Ok((inbound_cipher, outbound_cipher, session_reader));
}
Ok(VersionResult::Decoy(h)) => {
handshake = h;
}
Err(e) => return Err(ProtocolError::Internal(e)),
}
}
}
pub struct Protocol<R, W> {
reader: ProtocolReader<R>,
writer: ProtocolWriter<W>,
}
impl<R, W> Protocol<R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
pub async fn new(
network: Network,
role: Role,
garbage: Option<Vec<u8>>,
decoys: Option<Vec<Vec<u8>>>,
reader: R,
mut writer: W,
) -> Result<Protocol<R, W>, ProtocolError> {
let (inbound_cipher, outbound_cipher, session_reader) =
handshake(network, role, garbage, decoys, reader, &mut writer).await?;
Ok(Protocol {
reader: ProtocolReader {
inbound_cipher,
reader: session_reader,
state: DecryptState::init_reading_length(),
},
writer: ProtocolWriter {
outbound_cipher,
writer,
},
})
}
pub fn into_split(self) -> (ProtocolReader<R>, ProtocolWriter<W>) {
(self.reader, self.writer)
}
pub async fn read(&mut self) -> Result<Payload, ProtocolError> {
self.reader.read().await
}
pub async fn write(&mut self, payload: &Payload) -> Result<(), ProtocolError> {
self.writer.write(payload).await
}
}
#[derive(Debug)]
enum DecryptState {
ReadingLength {
length_bytes: [u8; NUM_LENGTH_BYTES],
bytes_read: usize,
},
ReadingPayload {
packet_bytes: Vec<u8>,
bytes_read: usize,
},
}
impl DecryptState {
fn init_reading_length() -> Self {
DecryptState::ReadingLength {
length_bytes: [0u8; NUM_LENGTH_BYTES],
bytes_read: 0,
}
}
fn init_reading_payload(packet_bytes_len: usize) -> Self {
DecryptState::ReadingPayload {
packet_bytes: vec![0u8; packet_bytes_len],
bytes_read: 0,
}
}
}
pub struct ProtocolReader<R> {
inbound_cipher: InboundCipher,
reader: ProtocolSessionReader<R>,
state: DecryptState,
}
impl<R> ProtocolReader<R>
where
R: AsyncRead + Unpin + Send,
{
pub async fn read(&mut self) -> Result<Payload, ProtocolError> {
loop {
match &mut self.state {
DecryptState::ReadingLength {
length_bytes,
bytes_read,
} => {
while *bytes_read < NUM_LENGTH_BYTES {
*bytes_read += self.reader.read(&mut length_bytes[*bytes_read..]).await?;
}
let packet_bytes_len = self.inbound_cipher.decrypt_packet_len(*length_bytes);
self.state = DecryptState::init_reading_payload(packet_bytes_len);
}
DecryptState::ReadingPayload {
packet_bytes,
bytes_read,
} => {
while *bytes_read < packet_bytes.len() {
*bytes_read += self.reader.read(&mut packet_bytes[*bytes_read..]).await?;
}
let plaintext_len = InboundCipher::decryption_buffer_len(packet_bytes.len());
let mut plaintext_buffer = vec![0u8; plaintext_len];
self.inbound_cipher
.decrypt(packet_bytes, &mut plaintext_buffer, None)?;
self.state = DecryptState::init_reading_length();
return Ok(Payload::decrypted(plaintext_buffer));
}
}
}
}
pub fn into_inner(self) -> (InboundCipher, ProtocolSessionReader<R>) {
(self.inbound_cipher, self.reader)
}
}
pub struct ProtocolWriter<W> {
outbound_cipher: OutboundCipher,
writer: W,
}
impl<W> ProtocolWriter<W>
where
W: AsyncWrite + Unpin + Send,
{
pub async fn write(&mut self, payload: &Payload) -> Result<(), ProtocolError> {
let packet_len = OutboundCipher::encryption_buffer_len(payload.contents().len());
let mut packet_buffer = vec![0u8; packet_len];
self.outbound_cipher.encrypt(
payload.contents(),
&mut packet_buffer,
payload.packet_type(),
None,
)?;
self.writer.write_all(&packet_buffer).await?;
self.writer.flush().await?;
Ok(())
}
pub fn into_inner(self) -> (OutboundCipher, W) {
(self.outbound_cipher, self.writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bitcoin::Network;
#[tokio::test]
async fn test_async_handshake_functions() {
use tokio::io::duplex;
let (local_stream, remote_stream) = duplex(1024);
let (local_read, mut local_write) = tokio::io::split(local_stream);
let (remote_read, mut remote_write) = tokio::io::split(remote_stream);
let local_handshake = tokio::spawn(async move {
handshake(
Network::Bitcoin,
Role::Initiator,
Some(b"local garbage".to_vec()),
Some(vec![b"local decoy".to_vec()]),
local_read,
&mut local_write,
)
.await
});
let remote_handshake = tokio::spawn(async move {
handshake(
Network::Bitcoin,
Role::Responder,
Some(b"remote garbage".to_vec()),
Some(vec![b"remote decoy 1".to_vec(), b"remote decoy 2".to_vec()]),
remote_read,
&mut remote_write,
)
.await
});
let (local_result, remote_result) = tokio::join!(local_handshake, remote_handshake);
local_result.unwrap().unwrap();
remote_result.unwrap().unwrap();
}
#[tokio::test]
async fn test_async_handshake_packet_too_big_protection() {
use tokio::io::duplex;
let (local_stream, remote_stream) = duplex(MAX_PACKET_SIZE_FOR_ALLOCATION * 2);
let (local_read, mut local_write) = tokio::io::split(local_stream);
let (remote_read, mut remote_write) = tokio::io::split(remote_stream);
let local_handshake = tokio::spawn(async move {
handshake(
Network::Bitcoin,
Role::Initiator,
None,
None,
local_read,
&mut local_write,
)
.await
});
let remote_handshake = tokio::spawn(async move {
let large_decoy = vec![0u8; MAX_PACKET_SIZE_FOR_ALLOCATION + 1];
handshake(
Network::Bitcoin,
Role::Responder,
None,
Some(vec![large_decoy]),
remote_read,
&mut remote_write,
)
.await
});
let (local_result, _remote_result) = tokio::join!(local_handshake, remote_handshake);
assert!(matches!(
local_result.unwrap(),
Err(ProtocolError::Internal(Error::PacketTooBig))
));
}
}