use core::fmt;
use std::io::{Chain, Cursor, Read, Write};
use std::vec;
use std::vec::Vec;
use bitcoin::Network;
use crate::{
handshake::{self, GarbageResult, VersionResult},
Error, Handshake, InboundCipher, OutboundCipher, PacketType, Role,
MAX_PACKET_SIZE_FOR_ALLOCATION, NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES,
NUM_LENGTH_BYTES,
};
pub struct ProtocolSessionReader<R> {
inner: Chain<Cursor<Vec<u8>>, R>,
}
impl<R> ProtocolSessionReader<R> {
fn new(leftover: Vec<u8>, reader: R) -> Self
where
R: Read,
{
Self {
inner: Cursor::new(leftover).chain(reader),
}
}
}
impl<R: Read> Read for ProtocolSessionReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inner.read(buf)
}
}
#[derive(Clone)]
pub struct Payload {
data: PayloadData,
}
#[derive(Clone)]
enum PayloadData {
Decrypted(Vec<u8>),
Parts { header: u8, contents: Vec<u8> },
}
impl Payload {
pub fn genuine(contents: impl Into<Vec<u8>>) -> Self {
Self {
data: PayloadData::Parts {
header: PacketType::Genuine.to_byte(),
contents: contents.into(),
},
}
}
pub fn decoy(contents: impl Into<Vec<u8>>) -> Self {
Self {
data: PayloadData::Parts {
header: PacketType::Decoy.to_byte(),
contents: contents.into(),
},
}
}
pub(crate) fn decrypted(data: Vec<u8>) -> Self {
debug_assert!(
!data.is_empty(),
"Payload data must contain at least the header byte"
);
Self {
data: PayloadData::Decrypted(data),
}
}
pub fn contents(&self) -> &[u8] {
match &self.data {
PayloadData::Decrypted(data) => &data[1..],
PayloadData::Parts { contents, .. } => contents,
}
}
pub fn packet_type(&self) -> PacketType {
match &self.data {
PayloadData::Decrypted(data) => PacketType::from_byte(&data[0]),
PayloadData::Parts { header, .. } => PacketType::from_byte(header),
}
}
}
#[derive(Debug)]
pub enum ProtocolError {
Io(std::io::Error, ProtocolFailureSuggestion),
Internal(Error),
}
#[derive(Debug)]
pub enum ProtocolFailureSuggestion {
RetryV1,
Abort,
}
impl From<std::io::Error> for ProtocolError {
fn from(error: std::io::Error) -> Self {
let suggestion = match error.kind() {
std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::UnexpectedEof => ProtocolFailureSuggestion::RetryV1,
_ => ProtocolFailureSuggestion::Abort,
};
ProtocolError::Io(error, suggestion)
}
}
impl From<Error> for ProtocolError {
fn from(error: Error) -> Self {
ProtocolError::Internal(error)
}
}
impl ProtocolError {
pub fn eof() -> Self {
ProtocolError::Io(
std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Remote peer closed connection during handshake",
),
ProtocolFailureSuggestion::RetryV1,
)
}
}
impl std::error::Error for ProtocolError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ProtocolError::Io(e, _) => Some(e),
ProtocolError::Internal(e) => Some(e),
}
}
}
impl fmt::Display for ProtocolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ProtocolError::Io(e, suggestion) => {
write!(
f,
"IO error: {}. Suggestion: {}.",
e,
match suggestion {
ProtocolFailureSuggestion::RetryV1 => "Retry with V1 protocol",
ProtocolFailureSuggestion::Abort => "Abort, do not retry",
}
)
}
ProtocolError::Internal(e) => write!(f, "Internal error: {e}."),
}
}
}
pub fn handshake<R, W>(
network: Network,
role: Role,
garbage: Option<Vec<u8>>,
decoys: Option<Vec<Vec<u8>>>,
reader: R,
writer: &mut W,
) -> Result<(InboundCipher, OutboundCipher, ProtocolSessionReader<R>), ProtocolError>
where
R: Read,
W: Write,
{
let handshake = Handshake::<handshake::Initialized>::new(network, role)?;
handshake_with_initialized(handshake, garbage, decoys, reader, writer)
}
fn handshake_with_initialized<R, W>(
handshake: Handshake<handshake::Initialized>,
garbage: Option<Vec<u8>>,
decoys: Option<Vec<Vec<u8>>>,
mut reader: R,
writer: &mut W,
) -> Result<(InboundCipher, OutboundCipher, ProtocolSessionReader<R>), ProtocolError>
where
R: Read,
W: Write,
{
let garbage_ref = garbage.as_deref();
let decoy_refs: Option<Vec<&[u8]>> = decoys
.as_ref()
.map(|vecs| vecs.iter().map(Vec::as_slice).collect());
let decoys_ref = decoy_refs.as_deref();
let key_buffer_len = Handshake::<handshake::Initialized>::send_key_len(garbage_ref);
let mut key_buffer = vec![0u8; key_buffer_len];
let handshake = handshake.send_key(garbage_ref, &mut key_buffer)?;
writer.write_all(&key_buffer)?;
writer.flush()?;
let mut remote_ellswift_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES];
reader.read_exact(&mut remote_ellswift_buffer)?;
let handshake = handshake.receive_key(remote_ellswift_buffer)?;
let version_buffer_len = Handshake::<handshake::ReceivedKey>::send_version_len(decoys_ref);
let mut version_buffer = vec![0u8; version_buffer_len];
let handshake = handshake.send_version(&mut version_buffer, decoys_ref)?;
writer.write_all(&version_buffer)?;
writer.flush()?;
let mut garbage_buffer = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES];
reader.read_exact(&mut garbage_buffer)?;
let mut handshake = handshake;
let (mut handshake, garbage_bytes) = loop {
match handshake.receive_garbage(&garbage_buffer) {
Ok(GarbageResult::FoundGarbage {
handshake,
consumed_bytes,
}) => {
break (handshake, consumed_bytes);
}
Ok(GarbageResult::NeedMoreData(h)) => {
handshake = h;
let mut temp = vec![0u8; 256];
match reader.read(&mut temp) {
Ok(0) => return Err(ProtocolError::eof()),
Ok(n) => {
garbage_buffer.extend_from_slice(&temp[..n]);
}
Err(e) => return Err(ProtocolError::from(e)),
}
}
Err(e) => return Err(ProtocolError::Internal(e)),
}
};
let leftover_bytes = garbage_buffer[garbage_bytes..].to_vec();
let mut session_reader = ProtocolSessionReader::new(leftover_bytes, reader);
let mut length_bytes = [0u8; NUM_LENGTH_BYTES];
loop {
session_reader.read_exact(&mut length_bytes)?;
let packet_len = handshake.decrypt_packet_len(length_bytes)?;
if packet_len > MAX_PACKET_SIZE_FOR_ALLOCATION {
return Err(ProtocolError::Internal(Error::PacketTooBig));
}
let mut packet_bytes = vec![0u8; packet_len];
session_reader.read_exact(&mut packet_bytes)?;
match handshake.receive_version(&mut packet_bytes) {
Ok(VersionResult::Complete { cipher }) => {
let (inbound_cipher, outbound_cipher) = cipher.into_split();
return Ok((inbound_cipher, outbound_cipher, session_reader));
}
Ok(VersionResult::Decoy(h)) => {
handshake = h;
}
Err(e) => return Err(ProtocolError::Internal(e)),
}
}
}
pub struct Protocol<R, W> {
reader: ProtocolReader<R>,
writer: ProtocolWriter<W>,
}
impl<R, W> Protocol<R, W>
where
R: Read,
W: Write,
{
pub fn new(
network: Network,
role: Role,
garbage: Option<Vec<u8>>,
decoys: Option<Vec<Vec<u8>>>,
reader: R,
mut writer: W,
) -> Result<Protocol<R, W>, ProtocolError> {
let (inbound_cipher, outbound_cipher, session_reader) =
handshake(network, role, garbage, decoys, reader, &mut writer)?;
Ok(Protocol {
reader: ProtocolReader {
inbound_cipher,
reader: session_reader,
},
writer: ProtocolWriter {
outbound_cipher,
writer,
},
})
}
pub fn into_split(self) -> (ProtocolReader<R>, ProtocolWriter<W>) {
(self.reader, self.writer)
}
pub fn read(&mut self) -> Result<Payload, ProtocolError> {
self.reader.read()
}
pub fn write(&mut self, payload: &Payload) -> Result<(), ProtocolError> {
self.writer.write(payload)
}
}
pub struct ProtocolReader<R> {
inbound_cipher: InboundCipher,
reader: ProtocolSessionReader<R>,
}
impl<R> ProtocolReader<R>
where
R: Read,
{
pub fn read(&mut self) -> Result<Payload, ProtocolError> {
let mut length_bytes = [0u8; NUM_LENGTH_BYTES];
self.reader.read_exact(&mut length_bytes)?;
let packet_bytes_len = self.inbound_cipher.decrypt_packet_len(length_bytes);
let mut packet_bytes = vec![0u8; packet_bytes_len];
self.reader.read_exact(&mut packet_bytes)?;
let (_, plaintext_buffer) = self.inbound_cipher.decrypt_to_vec(&packet_bytes, None)?;
Ok(Payload::decrypted(plaintext_buffer))
}
pub fn into_inner(self) -> (InboundCipher, ProtocolSessionReader<R>) {
(self.inbound_cipher, self.reader)
}
}
pub struct ProtocolWriter<W> {
outbound_cipher: OutboundCipher,
writer: W,
}
impl<W> ProtocolWriter<W>
where
W: Write,
{
pub fn write(&mut self, payload: &Payload) -> Result<(), ProtocolError> {
let packet_buffer =
self.outbound_cipher
.encrypt_to_vec(payload.contents(), payload.packet_type(), None);
self.writer.write_all(&packet_buffer)?;
self.writer.flush()?;
Ok(())
}
pub fn into_inner(self) -> (OutboundCipher, W) {
(self.outbound_cipher, self.writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{rngs::StdRng, SeedableRng};
use std::io::Cursor;
fn generate_handshake_messages(
local_seed: u64,
remote_seed: u64,
local_role: Role,
garbage: Option<&[u8]>,
decoys: Option<&[&[u8]]>,
) -> Vec<u8> {
let secp = bitcoin::secp256k1::Secp256k1::new();
let mut local_rng = StdRng::seed_from_u64(local_seed);
let local_handshake = Handshake::<handshake::Initialized>::new_with_rng(
Network::Bitcoin,
local_role,
&mut local_rng,
&secp,
)
.unwrap();
let mut remote_rng = StdRng::seed_from_u64(remote_seed);
let remote_role = match local_role {
Role::Initiator => Role::Responder,
Role::Responder => Role::Initiator,
};
let remote_handshake = Handshake::<handshake::Initialized>::new_with_rng(
Network::Bitcoin,
remote_role,
&mut remote_rng,
&secp,
)
.unwrap();
let mut local_key_buffer =
vec![0u8; Handshake::<handshake::Initialized>::send_key_len(garbage)];
let local_handshake = local_handshake
.send_key(garbage, &mut local_key_buffer)
.unwrap();
let mut remote_key_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES];
remote_handshake
.send_key(None, &mut remote_key_buffer)
.unwrap();
let local_handshake = local_handshake
.receive_key(
remote_key_buffer[..NUM_ELLIGATOR_SWIFT_BYTES]
.try_into()
.unwrap(),
)
.unwrap();
let mut local_version_buffer =
vec![0u8; Handshake::<handshake::ReceivedKey>::send_version_len(decoys)];
local_handshake
.send_version(&mut local_version_buffer, decoys)
.unwrap();
let garbage_bytes = garbage.map(|g| g.to_vec()).unwrap_or_default();
[
&local_key_buffer[..NUM_ELLIGATOR_SWIFT_BYTES],
&garbage_bytes[..],
&local_version_buffer[..],
]
.concat()
}
#[test]
fn test_handshake_session_reader() {
let mut init_rng = StdRng::seed_from_u64(42);
let secp = bitcoin::secp256k1::Secp256k1::new();
let init_handshake = Handshake::<handshake::Initialized>::new_with_rng(
Network::Bitcoin,
Role::Initiator,
&mut init_rng,
&secp,
)
.unwrap();
let resp_garbage = b"responder garbage";
let resp_decoys: &[&[u8]] = &[b"decoy1", b"another decoy packet"];
let mut messages = generate_handshake_messages(
1042,
42,
Role::Responder,
Some(resp_garbage),
Some(resp_decoys),
);
let session_byte = 0x42u8;
messages.push(session_byte);
let reader = Cursor::new(messages);
let mut writer = Vec::new();
let result = handshake_with_initialized(init_handshake, None, None, reader, &mut writer);
let (_, _, mut session_reader) = result.unwrap();
let mut buffer = [0u8; 1];
match session_reader.read(&mut buffer) {
Ok(1) => {
assert_eq!(
buffer[0], session_byte,
"Session reader should contain the extra byte"
);
}
Ok(n) => panic!("Expected to read 1 byte but read {}", n),
Err(e) => panic!("Unexpected error reading from session reader: {}", e),
}
}
#[test]
fn test_handshake_packet_too_big_protection() {
let mut init_rng = StdRng::seed_from_u64(42);
let secp = bitcoin::secp256k1::Secp256k1::new();
let init_handshake = Handshake::<handshake::Initialized>::new_with_rng(
Network::Bitcoin,
Role::Initiator,
&mut init_rng,
&secp,
)
.unwrap();
let large_decoy = vec![0; MAX_PACKET_SIZE_FOR_ALLOCATION + 1];
let resp_decoys: &[&[u8]] = &[large_decoy.as_slice()];
let messages =
generate_handshake_messages(1042, 42, Role::Responder, None, Some(resp_decoys));
let reader = Cursor::new(messages);
let mut writer = Vec::new();
let result = handshake_with_initialized(init_handshake, None, None, reader, &mut writer);
assert!(matches!(
result,
Err(ProtocolError::Internal(Error::PacketTooBig))
));
}
}