use std::{
error::Error as StdError,
fmt::Display,
io,
net::SocketAddr,
pin::Pin,
sync::{Arc, Weak},
time::Duration,
};
use casper_types::PublicKey;
use futures::{
future::{self, Either},
stream::{SplitSink, SplitStream},
Future, SinkExt, StreamExt,
};
use openssl::{
pkey::{PKey, Private},
ssl::Ssl,
};
use prometheus::IntGauge;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{
net::TcpStream,
sync::{mpsc::UnboundedReceiver, watch},
};
use tokio_openssl::SslStream;
use tracing::{
debug, error_span,
field::{self, Empty},
info, trace, warn, Instrument, Span,
};
use super::{
chain_info::ChainInfo,
counting_format::{ConnectionId, Role},
error::{ConnectionError, IoError},
event::{IncomingConnection, OutgoingConnection},
framed,
limiter::LimiterHandle,
message::{ConsensusKeyPair, PayloadWeights},
Event, FramedTransport, Message, Metrics, Payload, Transport,
};
use crate::{
reactor::{EventQueueHandle, QueueKind},
tls::{self, TlsCert},
types::NodeId,
utils::display_error,
};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(20);
async fn tls_connect<REv>(
context: &NetworkContext<REv>,
peer_addr: SocketAddr,
) -> Result<(NodeId, Transport), ConnectionError>
where
REv: 'static,
{
let stream = TcpStream::connect(peer_addr)
.await
.map_err(ConnectionError::TcpConnection)?;
let mut transport = tls::create_tls_connector(context.our_cert.as_x509(), &context.secret_key)
.and_then(|connector| connector.configure())
.and_then(|mut config| {
config.set_verify_hostname(false);
config.into_ssl("this-will-not-be-checked.example.com")
})
.and_then(|ssl| SslStream::new(ssl, stream))
.map_err(ConnectionError::TlsInitialization)?;
SslStream::connect(Pin::new(&mut transport))
.await
.map_err(ConnectionError::TlsHandshake)?;
let peer_cert = transport
.ssl()
.peer_certificate()
.ok_or(ConnectionError::NoPeerCertificate)?;
let peer_id = NodeId::from(
tls::validate_cert(peer_cert)
.map_err(ConnectionError::PeerCertificateInvalid)?
.public_key_fingerprint(),
);
Ok((peer_id, transport))
}
pub(super) async fn connect_outgoing<P, REv>(
context: Arc<NetworkContext<REv>>,
peer_addr: SocketAddr,
) -> OutgoingConnection<P>
where
REv: 'static,
P: Payload,
{
let (peer_id, transport) = match tls_connect(&context, peer_addr).await {
Ok(value) => value,
Err(error) => return OutgoingConnection::FailedEarly { peer_addr, error },
};
Span::current().record("peer_id", &field::display(peer_id));
if peer_id == context.our_id {
info!("incoming loopback connection");
return OutgoingConnection::Loopback { peer_addr };
}
debug!("Outgoing TLS connection established");
let connection_id = ConnectionId::from_connection(transport.ssl(), context.our_id, peer_id);
let mut transport = framed::<P>(
context.net_metrics.clone(),
connection_id,
transport,
Role::Dialer,
context.chain_info.maximum_net_message_size,
);
match negotiate_handshake(&context, &mut transport, connection_id).await {
Ok((public_addr, peer_consensus_public_key)) => {
if let Some(ref public_key) = peer_consensus_public_key {
Span::current().record("validator_id", &field::display(public_key));
}
if public_addr != peer_addr {
warn!(%public_addr, %peer_addr, "peer advertises a different public address than what we connected to");
}
let (sink, _stream) = transport.split();
OutgoingConnection::Established {
peer_addr,
peer_id,
peer_consensus_public_key,
sink,
}
}
Err(error) => OutgoingConnection::Failed {
peer_addr,
peer_id,
error,
},
}
}
pub(crate) struct NetworkContext<REv>
where
REv: 'static,
{
pub(super) event_queue: EventQueueHandle<REv>,
pub(super) our_id: NodeId,
pub(super) our_cert: Arc<TlsCert>,
pub(super) secret_key: Arc<PKey<Private>>,
pub(super) net_metrics: Weak<Metrics>,
pub(super) chain_info: ChainInfo,
pub(super) public_addr: SocketAddr,
pub(super) consensus_keys: Option<ConsensusKeyPair>,
pub(super) payload_weights: PayloadWeights,
}
async fn handle_incoming<P, REv>(
context: Arc<NetworkContext<REv>>,
stream: TcpStream,
peer_addr: SocketAddr,
) -> IncomingConnection<P>
where
REv: From<Event<P>> + 'static,
P: Payload,
for<'de> P: Serialize + Deserialize<'de>,
for<'de> Message<P>: Serialize + Deserialize<'de>,
{
let (peer_id, transport) =
match server_setup_tls(stream, &context.our_cert, &context.secret_key).await {
Ok(value) => value,
Err(error) => {
return IncomingConnection::FailedEarly { peer_addr, error };
}
};
Span::current().record("peer_id", &field::display(peer_id));
if peer_id == context.our_id {
info!("incoming loopback connection");
return IncomingConnection::Loopback;
}
debug!("Incoming TLS connection established");
let connection_id = ConnectionId::from_connection(transport.ssl(), context.our_id, peer_id);
let mut transport = framed::<P>(
context.net_metrics.clone(),
connection_id,
transport,
Role::Listener,
context.chain_info.maximum_net_message_size,
);
match negotiate_handshake(&context, &mut transport, connection_id).await {
Ok((public_addr, peer_consensus_public_key)) => {
if let Some(ref public_key) = peer_consensus_public_key {
Span::current().record("validator_id", &field::display(public_key));
}
let (_sink, stream) = transport.split();
IncomingConnection::Established {
peer_addr,
public_addr,
peer_id,
peer_consensus_public_key,
stream,
}
}
Err(error) => IncomingConnection::Failed {
peer_addr,
peer_id,
error,
},
}
}
pub(super) async fn server_setup_tls(
stream: TcpStream,
cert: &TlsCert,
secret_key: &PKey<Private>,
) -> Result<(NodeId, Transport), ConnectionError> {
let mut tls_stream = tls::create_tls_acceptor(cert.as_x509().as_ref(), secret_key.as_ref())
.and_then(|ssl_acceptor| Ssl::new(ssl_acceptor.context()))
.and_then(|ssl| SslStream::new(ssl, stream))
.map_err(ConnectionError::TlsInitialization)?;
SslStream::accept(Pin::new(&mut tls_stream))
.await
.map_err(ConnectionError::TlsHandshake)?;
let peer_cert = tls_stream
.ssl()
.peer_certificate()
.ok_or(ConnectionError::NoPeerCertificate)?;
Ok((
NodeId::from(
tls::validate_cert(peer_cert)
.map_err(ConnectionError::PeerCertificateInvalid)?
.public_key_fingerprint(),
),
tls_stream,
))
}
async fn io_timeout<F, T, E>(duration: Duration, future: F) -> Result<T, IoError<E>>
where
F: Future<Output = Result<T, E>>,
E: StdError + 'static,
{
tokio::time::timeout(duration, future)
.await
.map_err(|_elapsed| IoError::Timeout)?
.map_err(IoError::Error)
}
async fn io_opt_timeout<F, T, E>(duration: Duration, future: F) -> Result<T, IoError<E>>
where
F: Future<Output = Option<Result<T, E>>>,
E: StdError + 'static,
{
let item = tokio::time::timeout(duration, future)
.await
.map_err(|_elapsed| IoError::Timeout)?;
match item {
Some(Ok(value)) => Ok(value),
Some(Err(err)) => Err(IoError::Error(err)),
None => Err(IoError::UnexpectedEof),
}
}
async fn negotiate_handshake<P, REv>(
context: &NetworkContext<REv>,
transport: &mut FramedTransport<P>,
connection_id: ConnectionId,
) -> Result<(SocketAddr, Option<PublicKey>), ConnectionError>
where
P: Payload,
{
let handshake = context.chain_info.create_handshake(
context.public_addr,
context.consensus_keys.as_ref(),
connection_id,
);
io_timeout(HANDSHAKE_TIMEOUT, transport.send(Arc::new(handshake)))
.await
.map_err(ConnectionError::HandshakeSend)?;
let remote_handshake = io_opt_timeout(HANDSHAKE_TIMEOUT, transport.next())
.await
.map_err(ConnectionError::HandshakeRecv)?;
if let Message::Handshake {
network_name,
public_addr,
protocol_version,
consensus_certificate,
} = remote_handshake
{
debug!(%protocol_version, "handshake received");
if network_name != context.chain_info.network_name {
return Err(ConnectionError::WrongNetwork(network_name));
}
let peer_consensus_public_key = consensus_certificate
.map(|cert| {
cert.validate(connection_id)
.map_err(ConnectionError::InvalidConsensusCertificate)
})
.transpose()?;
Ok((public_addr, peer_consensus_public_key))
} else {
Err(ConnectionError::DidNotSendHandshake)
}
}
pub(super) async fn server<P, REv>(
context: Arc<NetworkContext<REv>>,
listener: tokio::net::TcpListener,
mut shutdown_receiver: watch::Receiver<()>,
) where
REv: From<Event<P>> + Send,
P: Payload,
{
let accept_connections = async {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
let span =
error_span!("incoming", %peer_addr, peer_id=Empty, validator_id=Empty);
let context = context.clone();
let handler_span = span.clone();
tokio::spawn(
async move {
let incoming =
handle_incoming(context.clone(), stream, peer_addr).await;
context
.event_queue
.schedule(
Event::IncomingConnection {
incoming: Box::new(incoming),
span,
},
QueueKind::NetworkIncoming,
)
.await;
}
.instrument(handler_span),
);
}
Err(ref err) => {
warn!(%context.our_id, err=display_error(err), "dropping incoming connection during accept")
}
}
}
};
let shutdown_messages = async move { while shutdown_receiver.changed().await.is_ok() {} };
match future::select(Box::pin(shutdown_messages), Box::pin(accept_connections)).await {
Either::Left(_) => info!(
%context.our_id,
"shutting down socket, no longer accepting incoming connections"
),
Either::Right(_) => unreachable!(),
}
}
pub(super) async fn message_reader<REv, P>(
context: Arc<NetworkContext<REv>>,
mut stream: SplitStream<FramedTransport<P>>,
limiter: Box<dyn LimiterHandle>,
mut shutdown_receiver: watch::Receiver<()>,
peer_id: NodeId,
span: Span,
) -> io::Result<()>
where
P: DeserializeOwned + Send + Display + Payload,
REv: From<Event<P>>,
{
let read_messages = async move {
while let Some(msg_result) = stream.next().await {
match msg_result {
Ok(msg) => {
trace!(%msg, "message received");
limiter
.request_allowance(
msg.payload_incoming_resource_estimate(&context.payload_weights),
)
.await;
context
.event_queue
.schedule(
Event::IncomingMessage {
peer_id: Box::new(peer_id),
msg: Box::new(msg),
span: span.clone(),
},
QueueKind::NetworkIncoming,
)
.await;
}
Err(err) => {
warn!(
err = display_error(&err),
"receiving message failed, closing connection"
);
return Err(err);
}
}
}
Ok(())
};
let shutdown_messages = async move { while shutdown_receiver.changed().await.is_ok() {} };
match future::select(Box::pin(shutdown_messages), Box::pin(read_messages)).await {
Either::Left(_) => info!("shutting down incoming connection message reader"),
Either::Right(_) => (),
}
Ok(())
}
pub(super) async fn message_sender<P>(
mut queue: UnboundedReceiver<Arc<Message<P>>>,
mut sink: SplitSink<FramedTransport<P>, Arc<Message<P>>>,
limiter: Box<dyn LimiterHandle>,
counter: IntGauge,
) where
P: Payload,
{
while let Some(message) = queue.recv().await {
counter.dec();
let estimated_wire_size = rmp_serde::to_vec(&message)
.as_ref()
.map(Vec::len)
.unwrap_or(0) as u32;
limiter.request_allowance(estimated_wire_size).await;
if let Err(ref err) = sink.send(message).await {
info!(
err = display_error(err),
"message send failed, closing outgoing connection"
);
break;
};
}
}