use std::{convert::TryFrom, io, time::Duration};
use bytes::Bytes;
use log::*;
use prost::Message;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
time,
};
use crate::{
bans::{BAN_DURATION_LONG, BAN_DURATION_SHORT},
message::MessageExt,
peer_manager::NodeIdentity,
proto::identity::PeerIdentityMsg,
protocol::{NodeNetworkInfo, ProtocolId},
};
const LOG_TARGET: &str = "comms::protocol::identity";
const MAX_IDENTITY_PROTOCOL_MSG_SIZE: u16 = 1024;
pub async fn identity_exchange<'p, TSocket, P>(
node_identity: &NodeIdentity,
our_supported_protocols: P,
network_info: NodeNetworkInfo,
socket: &mut TSocket,
) -> Result<PeerIdentityMsg, IdentityProtocolError>
where
TSocket: AsyncRead + AsyncWrite + Unpin,
P: IntoIterator<Item = &'p ProtocolId>,
{
let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect();
let msg_bytes = PeerIdentityMsg {
addresses: node_identity.public_addresses().iter().map(|a| a.to_vec()).collect(),
features: node_identity.features().bits(),
supported_protocols,
user_agent: network_info.user_agent,
identity_signature: node_identity.identity_signature_read().as_ref().map(Into::into),
}
.to_encoded_bytes();
write_protocol_frame(socket, network_info.major_version, &msg_bytes).await?;
socket.flush().await?;
let (_, msg_bytes) = time::timeout(
Duration::from_secs(10),
read_protocol_frame(socket, network_info.major_version),
)
.await??;
debug!(
target: LOG_TARGET,
"Identity message received {} bytes",
msg_bytes.len()
);
let identity_msg = PeerIdentityMsg::decode(Bytes::from(msg_bytes))?;
Ok(identity_msg)
}
async fn read_protocol_frame<S: AsyncRead + Unpin>(
socket: &mut S,
max_supported_version: u8,
) -> Result<(u8, Vec<u8>), IdentityProtocolError> {
let mut buf = [0u8; 3];
socket.read_exact(&mut buf).await?;
let version = buf[0];
if version > max_supported_version {
return Err(IdentityProtocolError::UnsupportedProtocolVersion {
max_supported_version,
provided_version: version,
});
}
let buf = [buf[1], buf[2]];
let len = u16::from_le_bytes(buf);
if len > MAX_IDENTITY_PROTOCOL_MSG_SIZE {
return Err(IdentityProtocolError::MaxMsgSizeExceeded {
expected: MAX_IDENTITY_PROTOCOL_MSG_SIZE,
got: len,
});
}
let len = len as usize;
let mut msg = vec![0u8; len];
socket.read_exact(&mut msg).await?;
Ok((version, msg))
}
async fn write_protocol_frame<S: AsyncWrite + Unpin>(
socket: &mut S,
version: u8,
msg_bytes: &[u8],
) -> Result<(), IdentityProtocolError> {
if msg_bytes.len() > MAX_IDENTITY_PROTOCOL_MSG_SIZE as usize {
return Err(IdentityProtocolError::InvariantError(format!(
"Sending identity protocol message of size {}, greater than {} bytes. This is a protocol violation",
msg_bytes.len(),
MAX_IDENTITY_PROTOCOL_MSG_SIZE
)));
}
let len = u16::try_from(msg_bytes.len()).map_err(|_| {
IdentityProtocolError::InvariantError(
"This node attempted to send a message of size greater than u16::MAX".to_string(),
)
})?;
let version_bytes = [version];
let len_bytes = len.to_le_bytes();
trace!(
target: LOG_TARGET,
"Writing {} bytes",
len_bytes.len() + msg_bytes.len() + 1
);
socket.write_all(&version_bytes[..]).await?;
socket.write_all(&len_bytes[..]).await?;
socket.write_all(msg_bytes).await?;
Ok(())
}
#[derive(Debug, Error, Clone)]
pub enum IdentityProtocolError {
#[error("IoError: {0}")]
IoError(String),
#[error("Possible bug: InvariantError {0}")]
InvariantError(String),
#[error("ProtobufDecodeError: {0}")]
ProtobufDecodeError(String),
#[error("Peer unexpectedly closed the connection")]
PeerUnexpectedCloseConnection,
#[error("Timeout waiting for peer to send identity information")]
Timeout,
#[error(
"Unsupported protocol version. Max supported version: {max_supported_version}, provided version: \
{provided_version}"
)]
UnsupportedProtocolVersion {
max_supported_version: u8,
provided_version: u8,
},
#[error("Max identity protocol message size exceeded. Expected <= {expected} got {got}")]
MaxMsgSizeExceeded { expected: u16, got: u16 },
}
impl IdentityProtocolError {
pub fn as_ban_duration(&self) -> Option<Duration> {
match self {
IdentityProtocolError::InvariantError(_) | IdentityProtocolError::IoError(_) => None,
IdentityProtocolError::ProtobufDecodeError(_) | IdentityProtocolError::MaxMsgSizeExceeded { .. } => {
Some(BAN_DURATION_LONG)
},
IdentityProtocolError::PeerUnexpectedCloseConnection |
IdentityProtocolError::UnsupportedProtocolVersion { .. } |
IdentityProtocolError::Timeout => Some(BAN_DURATION_SHORT),
}
}
}
impl From<time::error::Elapsed> for IdentityProtocolError {
fn from(_: time::error::Elapsed) -> Self {
IdentityProtocolError::Timeout
}
}
impl From<io::Error> for IdentityProtocolError {
fn from(err: io::Error) -> Self {
IdentityProtocolError::IoError(err.to_string())
}
}
impl From<prost::DecodeError> for IdentityProtocolError {
fn from(err: prost::DecodeError) -> Self {
IdentityProtocolError::ProtobufDecodeError(err.to_string())
}
}
#[cfg(test)]
mod test {
use futures::{StreamExt, future};
use crate::{
peer_manager::PeerFeatures,
protocol::{IdentityProtocolError, NodeNetworkInfo},
test_utils::node_identity::build_node_identity,
transports::{MemoryTransport, Transport},
};
#[tokio::test]
async fn identity_exchange() {
let transport = MemoryTransport;
let addr = "/memory/0".parse().unwrap();
let (mut listener, addr) = transport.listen(&addr).await.unwrap();
let (out_sock, in_sock) = future::join(transport.dial(&addr), listener.next()).await;
let mut out_sock = out_sock.unwrap();
let (mut in_sock, _) = in_sock.unwrap().unwrap();
let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT);
let (result1, result2) = future::join(
super::identity_exchange(
&node_identity1,
&[],
NodeNetworkInfo {
minor_version: 1,
..Default::default()
},
&mut in_sock,
),
super::identity_exchange(
&node_identity2,
&[],
NodeNetworkInfo {
minor_version: 2,
..Default::default()
},
&mut out_sock,
),
)
.await;
let identity2 = result1.unwrap();
let identity1 = result2.unwrap();
assert_eq!(identity1.features, node_identity1.features().bits());
assert_eq!(
identity1.addresses,
node_identity1
.public_addresses()
.iter()
.map(|a| a.to_vec())
.collect::<Vec<_>>()
);
assert_eq!(identity2.features, node_identity2.features().bits());
assert_eq!(
identity2.addresses,
node_identity2
.public_addresses()
.iter()
.map(|a| a.to_vec())
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn fail_cases() {
let transport = MemoryTransport;
let addr = "/memory/0".parse().unwrap();
let (mut listener, addr) = transport.listen(&addr).await.unwrap();
let (out_sock, in_sock) = future::join(transport.dial(&addr), listener.next()).await;
let mut out_sock = out_sock.unwrap();
let (mut in_sock, _) = in_sock.unwrap().unwrap();
let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE);
let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT);
let (result1, result2) = future::join(
super::identity_exchange(
&node_identity1,
&[],
NodeNetworkInfo {
major_version: 0,
..Default::default()
},
&mut in_sock,
),
super::identity_exchange(
&node_identity2,
&[],
NodeNetworkInfo {
major_version: 1,
..Default::default()
},
&mut out_sock,
),
)
.await;
let err = result1.unwrap_err();
assert!(matches!(err, IdentityProtocolError::UnsupportedProtocolVersion { .. }));
result2.unwrap();
}
}