mod config;
mod error;
mod event;
mod gossiped_address;
mod message;
#[cfg(test)]
mod tests;
use std::{
collections::{HashMap, HashSet},
convert::Infallible,
fmt::{self, Debug, Display, Formatter},
io,
net::{SocketAddr, TcpListener},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use anyhow::Context;
use datasize::DataSize;
use futures::{
future::{select, BoxFuture, Either},
stream::{SplitSink, SplitStream},
FutureExt, SinkExt, StreamExt,
};
use openssl::pkey;
use pkey::{PKey, Private};
use rand::seq::IteratorRandom;
use serde::{de::DeserializeOwned, Serialize};
use tokio::{
net::TcpStream,
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
watch,
},
task::JoinHandle,
};
use tokio_openssl::SslStream;
use tokio_serde::{formats::SymmetricalMessagePack, SymmetricallyFramed};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{debug, error, info, trace, warn};
use self::error::Result;
pub(crate) use self::{event::Event, gossiped_address::GossipedAddress, message::Message};
use crate::{
components::Component,
effect::{
announcements::NetworkAnnouncement,
requests::{NetworkInfoRequest, NetworkRequest},
EffectBuilder, EffectExt, EffectResultExt, Effects,
},
fatal,
reactor::{EventQueueHandle, Finalize, QueueKind},
tls::{self, KeyFingerprint, TlsCert},
types::CryptoRngCore,
utils,
};
pub use config::Config;
pub use error::Error;
pub(crate) type NodeId = KeyFingerprint;
const MAX_ASYMMETRIC_CONNECTION_SEEN: u16 = 3;
#[derive(DataSize, Debug)]
pub(crate) struct OutgoingConnection<P> {
#[data_size(skip)] sender: UnboundedSender<Message<P>>,
peer_address: SocketAddr,
times_seen_asymmetric: u16,
}
#[derive(DataSize, Debug)]
pub(crate) struct IncomingConnection {
peer_address: SocketAddr,
times_seen_asymmetric: u16,
}
#[derive(DataSize)]
pub(crate) struct SmallNetwork<REv, P>
where
REv: 'static,
{
certificate: Arc<TlsCert>,
secret_key: Arc<PKey<Private>>,
public_address: SocketAddr,
our_id: NodeId,
event_queue: EventQueueHandle<REv>,
incoming: HashMap<NodeId, IncomingConnection>,
outgoing: HashMap<NodeId, OutgoingConnection<P>>,
blocklist: HashSet<SocketAddr>,
pending: HashSet<SocketAddr>,
gossip_interval: Duration,
next_gossip_address_index: u32,
#[data_size(skip)]
shutdown_sender: Option<watch::Sender<()>>,
#[data_size(skip)]
shutdown_receiver: watch::Receiver<()>,
is_stopped: Arc<AtomicBool>,
server_join_handle: Option<JoinHandle<()>>,
}
impl<REv, P> SmallNetwork<REv, P>
where
P: Serialize + DeserializeOwned + Clone + Debug + Display + Send + 'static,
REv: Send + From<Event<P>> + From<NetworkAnnouncement<NodeId, P>>,
{
#[allow(clippy::type_complexity)]
pub(crate) fn new(
event_queue: EventQueueHandle<REv>,
cfg: Config,
notify: bool,
) -> Result<(SmallNetwork<REv, P>, Effects<Event<P>>)> {
let (cert, secret_key) = tls::generate_node_cert().map_err(Error::CertificateGeneration)?;
let certificate = Arc::new(tls::validate_cert(cert).map_err(Error::OwnCertificateInvalid)?);
let bind_address = utils::resolve_address(&cfg.bind_address).map_err(Error::ResolveAddr)?;
let listener = TcpListener::bind(bind_address)
.map_err(|error| Error::ListenerCreation(error, bind_address))?;
if notify {
if cfg.systemd_support {
if sd_notify::booted().map_err(Error::SystemD)? {
info!("notifying systemd that the network is ready to receive connections");
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.map_err(Error::SystemD)?;
} else {
warn!("systemd_support enabled but not booted with systemd, ignoring");
}
} else {
debug!("systemd_support disabled, not notifying");
}
}
let local_address = listener.local_addr().map_err(Error::ListenerAddr)?;
let mut public_address =
utils::resolve_address(&cfg.public_address).map_err(Error::ResolveAddr)?;
if public_address.port() == 0 {
public_address.set_port(local_address.port());
}
let our_id = certificate.public_key_fingerprint();
info!(%local_address, %public_address, "{}: starting server background task", our_id);
let (server_shutdown_sender, server_shutdown_receiver) = watch::channel(());
let shutdown_receiver = server_shutdown_receiver.clone();
let server_join_handle = tokio::spawn(server_task(
event_queue,
tokio::net::TcpListener::from_std(listener).map_err(Error::ListenerConversion)?,
server_shutdown_receiver,
our_id,
));
let mut model = SmallNetwork {
certificate,
secret_key: Arc::new(secret_key),
public_address,
our_id,
event_queue,
incoming: HashMap::new(),
outgoing: HashMap::new(),
pending: HashSet::new(),
blocklist: HashSet::new(),
gossip_interval: cfg.gossip_interval,
next_gossip_address_index: 0,
shutdown_sender: Some(server_shutdown_sender),
shutdown_receiver,
server_join_handle: Some(server_join_handle),
is_stopped: Arc::new(AtomicBool::new(false)),
};
let mut effects = Effects::new();
for address in &cfg.known_addresses {
match utils::resolve_address(address) {
Ok(known_address) => {
model.pending.insert(known_address);
effects.extend(
connect_outgoing(
known_address,
Arc::clone(&model.certificate),
Arc::clone(&model.secret_key),
Arc::clone(&model.is_stopped),
)
.result(
move |(peer_id, transport)| Event::OutgoingEstablished {
peer_id,
transport,
},
move |error| Event::BootstrappingFailed {
peer_address: known_address,
error,
},
),
);
}
Err(err) => {
warn!("failed to resolve known address {}: {}", address, err);
}
}
}
let effect_builder = EffectBuilder::new(event_queue);
if model.pending.is_empty() && !cfg.known_addresses.is_empty() {
effects.extend(fatal!(
effect_builder,
"was given known addresses, but failed to resolve any of them"
));
} else {
effects.extend(model.gossip_our_address(effect_builder));
}
Ok((model, effects))
}
fn broadcast_message(&self, msg: Message<P>) {
for peer_id in self.outgoing.keys() {
self.send_message(*peer_id, msg.clone());
}
}
fn gossip_message(
&self,
rng: &mut dyn CryptoRngCore,
msg: Message<P>,
count: usize,
exclude: HashSet<NodeId>,
) -> HashSet<NodeId> {
let peer_ids = self
.outgoing
.keys()
.filter(|&peer_id| !exclude.contains(peer_id))
.choose_multiple(rng, count);
if peer_ids.len() != count {
trace!(
wanted = count,
selected = peer_ids.len(),
"{}: could not select enough random nodes for gossiping, not enough non-excluded \
outgoing connections",
self.our_id
);
}
for &peer_id in &peer_ids {
self.send_message(*peer_id, msg.clone());
}
peer_ids.into_iter().copied().collect()
}
fn send_message(&self, dest: NodeId, msg: Message<P>) {
if let Some(connection) = self.outgoing.get(&dest) {
if let Err(msg) = connection.sender.send(msg) {
warn!(%dest, ?msg, "{}: dropped outgoing message, lost connection", self.our_id);
}
} else {
debug!(%dest, ?msg, "{}: dropped outgoing message, no connection", self.our_id);
}
}
fn handle_incoming_handshake_completed(
&mut self,
effect_builder: EffectBuilder<REv>,
result: Result<(NodeId, Transport)>,
peer_address: SocketAddr,
) -> Effects<Event<P>> {
match result {
Ok((peer_id, transport)) => {
if peer_id == self.our_id {
debug!(
%peer_address,
local_address=?transport.get_ref().local_addr(),
"{}: connected incoming to ourself - closing connection",
self.our_id
);
return Effects::new();
}
if let Err(error) = transport.get_ref().peer_addr() {
debug!(
%peer_address,
local_address=?transport.get_ref().local_addr(),
%error,
"{}: incoming connection dropped",
self.our_id
);
return Effects::new();
}
debug!(%peer_id, %peer_address, "{}: established incoming connection", self.our_id);
let (_sink, stream) = framed::<P>(transport).split();
let _ = self.incoming.insert(
peer_id,
IncomingConnection {
peer_address,
times_seen_asymmetric: 0,
},
);
let mut effects = self.check_connection_complete(effect_builder, peer_id);
effects.extend(
message_reader(
self.event_queue,
stream,
self.shutdown_receiver.clone(),
self.our_id,
peer_id,
)
.event(move |result| Event::IncomingClosed {
result,
peer_id,
peer_address,
}),
);
effects
}
Err(err) => {
warn!(%peer_address, %err, "{}: TLS handshake failed", self.our_id);
Effects::new()
}
}
}
fn setup_outgoing(
&mut self,
effect_builder: EffectBuilder<REv>,
peer_id: NodeId,
transport: Transport,
) -> Effects<Event<P>> {
let peer_address = transport
.get_ref()
.peer_addr()
.expect("should have peer address");
if !self.pending.remove(&peer_address) {
info!(
%peer_address,
"{}: this peer's incoming connection has dropped, so don't establish an outgoing",
self.our_id
);
return Effects::new();
}
if peer_id == self.our_id {
debug!(
peer_address=?transport.get_ref().peer_addr(),
local_address=?transport.get_ref().local_addr(),
"{}: connected outgoing to ourself - closing connection",
self.our_id,
);
return Effects::new();
}
let (sink, _stream) = framed::<P>(transport).split();
debug!(%peer_id, %peer_address, "{}: established outgoing connection", self.our_id);
let (sender, receiver) = mpsc::unbounded_channel();
let connection = OutgoingConnection {
peer_address,
sender,
times_seen_asymmetric: 0,
};
if self.outgoing.insert(peer_id, connection).is_some() {
error!(%peer_id, "{}: did not expect leftover channel in outgoing map", self.our_id);
}
let mut effects = self.check_connection_complete(effect_builder, peer_id);
effects.extend(
message_sender(receiver, sink).event(move |result| Event::OutgoingFailed {
peer_id: Some(peer_id),
peer_address,
error: result.err().map(Into::into),
}),
);
effects
}
fn handle_outgoing_lost(
&mut self,
peer_id: Option<NodeId>,
peer_address: SocketAddr,
error: Option<Error>,
) -> Effects<Event<P>> {
let _ = self.pending.remove(&peer_address);
if let Some(peer_id) = peer_id {
if let Some(err) = error {
warn!(%peer_id, %peer_address, %err, "{}: outgoing connection failed", self.our_id);
} else {
warn!(%peer_id, %peer_address, "{}: outgoing connection closed", self.our_id);
}
self.remove(&peer_id);
} else {
if let Some(err) = error {
warn!(%peer_address, %err, "{}: outgoing connection failed", self.our_id);
} else {
warn!(%peer_address, "{}: outgoing connection closed", self.our_id);
}
}
Effects::new()
}
fn remove(&mut self, peer_id: &NodeId) {
if let Some(incoming) = self.incoming.remove(&peer_id) {
let _ = self.pending.remove(&incoming.peer_address);
}
let _ = self.outgoing.remove(&peer_id);
}
fn gossip_our_address(&mut self, effect_builder: EffectBuilder<REv>) -> Effects<Event<P>> {
self.next_gossip_address_index = self.next_gossip_address_index.wrapping_add(1);
let our_address = GossipedAddress::new(self.public_address, self.next_gossip_address_index);
let mut effects = effect_builder
.announce_gossip_our_address(our_address)
.ignore();
effects.extend(
effect_builder
.set_timeout(self.gossip_interval)
.event(|_| Event::GossipOurAddress),
);
effects
}
fn enforce_symmetric_connections(&mut self) {
let mut remove = Vec::new();
enum Node {
Incoming(NodeId),
Outgoing(NodeId, SocketAddr),
}
for (node_id, conn) in self.incoming.iter_mut() {
if !self.outgoing.contains_key(node_id) {
if conn.times_seen_asymmetric >= MAX_ASYMMETRIC_CONNECTION_SEEN {
remove.push(Node::Outgoing(*node_id, conn.peer_address));
} else {
conn.times_seen_asymmetric += 1;
}
} else {
conn.times_seen_asymmetric = 0;
}
}
for (node_id, conn) in self.outgoing.iter_mut() {
if !self.incoming.contains_key(node_id) {
if conn.times_seen_asymmetric >= MAX_ASYMMETRIC_CONNECTION_SEEN {
remove.push(Node::Incoming(*node_id));
} else {
conn.times_seen_asymmetric += 1;
}
} else {
conn.times_seen_asymmetric = 0;
}
}
for connection in remove {
match connection {
Node::Incoming(node_id) => self.remove(&node_id),
Node::Outgoing(node_id, peer_address) => {
self.blocklist.insert(peer_address);
self.remove(&node_id);
}
}
}
}
fn handle_message(
&mut self,
effect_builder: EffectBuilder<REv>,
peer_id: NodeId,
msg: Message<P>,
) -> Effects<Event<P>>
where
REv: From<NetworkAnnouncement<NodeId, P>>,
{
effect_builder
.announce_message_received(peer_id, msg.0)
.ignore()
}
fn connect_to_peer_if_required(&mut self, peer_address: SocketAddr) -> Effects<Event<P>> {
if self.pending.contains(&peer_address)
|| self.blocklist.contains(&peer_address)
|| self
.outgoing
.iter()
.any(|(_peer_id, connection)| connection.peer_address == peer_address)
{
Effects::new()
} else {
assert!(self.pending.insert(peer_address));
connect_outgoing(
peer_address,
Arc::clone(&self.certificate),
Arc::clone(&self.secret_key),
Arc::clone(&self.is_stopped),
)
.result(
move |(peer_id, transport)| Event::OutgoingEstablished { peer_id, transport },
move |error| Event::OutgoingFailed {
peer_id: None,
peer_address,
error: Some(error),
},
)
}
}
fn check_connection_complete(
&self,
effect_builder: EffectBuilder<REv>,
peer_id: NodeId,
) -> Effects<Event<P>> {
if self.outgoing.contains_key(&peer_id) && self.incoming.contains_key(&peer_id) {
debug!(%peer_id, "connection to peer is now complete");
effect_builder.announce_new_peer(peer_id).ignore()
} else {
Effects::new()
}
}
pub(crate) fn peers(&self) -> HashMap<NodeId, SocketAddr> {
let mut ret: HashMap<NodeId, SocketAddr> = HashMap::new();
for x in &self.outgoing {
ret.insert(*x.0, x.1.peer_address);
}
for x in &self.incoming {
ret.entry(*x.0).or_insert(x.1.peer_address);
}
ret
}
fn is_isolated(&self) -> bool {
self.pending.is_empty() && self.outgoing.is_empty() && self.incoming.is_empty()
}
#[cfg(test)]
pub(crate) fn node_id(&self) -> NodeId {
self.our_id
}
}
impl<REv, P> Finalize for SmallNetwork<REv, P>
where
REv: Send + 'static,
P: Send + 'static,
{
fn finalize(mut self) -> BoxFuture<'static, ()> {
async move {
drop(self.shutdown_sender.take());
self.is_stopped.store(true, Ordering::SeqCst);
if let Some(join_handle) = self.server_join_handle.take() {
match join_handle.await {
Ok(_) => debug!("{}: server exited cleanly", self.our_id),
Err(err) => error!(%self.our_id,%err, "could not join server task cleanly"),
}
} else {
warn!("{}: server shutdown while already shut down", self.our_id)
}
}
.boxed()
}
}
impl<REv, P> Component<REv> for SmallNetwork<REv, P>
where
REv: Send + From<Event<P>> + From<NetworkAnnouncement<NodeId, P>>,
P: Serialize + DeserializeOwned + Clone + Debug + Display + Send + 'static,
{
type Event = Event<P>;
type ConstructionError = Infallible;
#[allow(clippy::cognitive_complexity)]
fn handle_event(
&mut self,
effect_builder: EffectBuilder<REv>,
rng: &mut dyn CryptoRngCore,
event: Self::Event,
) -> Effects<Self::Event> {
match event {
Event::BootstrappingFailed {
peer_address,
error,
} => {
warn!(%error, "{}: connection to known node at {} failed", self.our_id, peer_address);
let was_removed = self.pending.remove(&peer_address);
assert!(
was_removed,
"Bootstrap failed for node, but it was not in the set of pending connections"
);
if self.is_isolated() {
fatal!(
effect_builder,
"failed to connect to any known node, now isolated"
)
} else {
Effects::new()
}
}
Event::IncomingNew {
stream,
peer_address,
} => {
debug!(%peer_address, "{}: incoming connection, starting TLS handshake", self.our_id);
setup_tls(stream, self.certificate.clone(), self.secret_key.clone())
.boxed()
.event(move |result| Event::IncomingHandshakeCompleted {
result,
peer_address,
})
}
Event::IncomingHandshakeCompleted {
result,
peer_address,
} => self.handle_incoming_handshake_completed(effect_builder, result, peer_address),
Event::IncomingMessage { peer_id, msg } => {
self.handle_message(effect_builder, peer_id, msg)
}
Event::IncomingClosed {
result,
peer_id,
peer_address,
} => {
match result {
Ok(()) => info!(%peer_id, %peer_address, "{}: connection closed", self.our_id),
Err(err) => {
warn!(%peer_id, %peer_address, %err, "{}: connection dropped", self.our_id)
}
}
self.remove(&peer_id);
Effects::new()
}
Event::OutgoingEstablished { peer_id, transport } => {
self.setup_outgoing(effect_builder, peer_id, transport)
}
Event::OutgoingFailed {
peer_id,
peer_address,
error,
} => self.handle_outgoing_lost(peer_id, peer_address, error),
Event::NetworkRequest {
req:
NetworkRequest::SendMessage {
dest,
payload,
responder,
},
} => {
self.send_message(dest, Message(payload));
responder.respond(()).ignore()
}
Event::NetworkRequest {
req: NetworkRequest::Broadcast { payload, responder },
} => {
self.broadcast_message(Message(payload));
responder.respond(()).ignore()
}
Event::NetworkRequest {
req:
NetworkRequest::Gossip {
payload,
count,
exclude,
responder,
},
} => {
let sent_to = self.gossip_message(rng, Message(payload), count, exclude);
responder.respond(sent_to).ignore()
}
Event::NetworkInfoRequest {
req: NetworkInfoRequest::GetPeers { responder },
} => responder.respond(self.peers()).ignore(),
Event::GossipOurAddress => {
let effects = self.gossip_our_address(effect_builder);
self.enforce_symmetric_connections();
effects
}
Event::PeerAddressReceived(gossiped_address) => {
self.connect_to_peer_if_required(gossiped_address.into())
}
}
}
}
async fn server_task<P, REv>(
event_queue: EventQueueHandle<REv>,
mut listener: tokio::net::TcpListener,
mut shutdown_receiver: watch::Receiver<()>,
our_id: NodeId,
) where
REv: From<Event<P>>,
{
let accept_connections = async move {
loop {
match listener.accept().await {
Ok((stream, peer_address)) => {
let event = Event::IncomingNew {
stream,
peer_address,
};
event_queue
.schedule(event, QueueKind::NetworkIncoming)
.await;
}
Err(err) => warn!(%err, "{}: dropping incoming connection during accept", our_id),
}
}
};
let shutdown_messages = async move { while shutdown_receiver.recv().await.is_some() {} };
match select(Box::pin(shutdown_messages), Box::pin(accept_connections)).await {
Either::Left(_) => info!(
"{}: shutting down socket, no longer accepting incoming connections",
our_id
),
Either::Right(_) => unreachable!(),
}
}
async fn setup_tls(
stream: TcpStream,
cert: Arc<TlsCert>,
secret_key: Arc<PKey<Private>>,
) -> Result<(NodeId, Transport)> {
let tls_stream = tokio_openssl::accept(
&tls::create_tls_acceptor(&cert.as_x509().as_ref(), &secret_key.as_ref())
.map_err(Error::AcceptorCreation)?,
stream,
)
.await?;
let peer_cert = tls_stream
.ssl()
.peer_certificate()
.ok_or_else(|| Error::NoClientCertificate)?;
Ok((
tls::validate_cert(peer_cert)?.public_key_fingerprint(),
tls_stream,
))
}
async fn message_reader<REv, P>(
event_queue: EventQueueHandle<REv>,
mut stream: SplitStream<FramedTransport<P>>,
mut shutdown_receiver: watch::Receiver<()>,
our_id: NodeId,
peer_id: NodeId,
) -> io::Result<()>
where
P: DeserializeOwned + Send + Display,
REv: From<Event<P>>,
{
let read_messages = async move {
while let Some(msg_result) = stream.next().await {
match msg_result {
Ok(msg) => {
debug!(%msg, %peer_id, "{}: message received", our_id);
event_queue
.schedule(
Event::IncomingMessage { peer_id, msg },
QueueKind::NetworkIncoming,
)
.await;
}
Err(err) => {
warn!(%err, %peer_id, "{}: receiving message failed, closing connection", our_id);
return Err(err);
}
}
}
Ok(())
};
let shutdown_messages = async move { while shutdown_receiver.recv().await.is_some() {} };
match select(Box::pin(shutdown_messages), Box::pin(read_messages)).await {
Either::Left(_) => info!(
%peer_id,
"{}: shutting down incoming connection message reader",
our_id
),
Either::Right(_) => (),
}
Ok(())
}
async fn message_sender<P>(
mut queue: UnboundedReceiver<Message<P>>,
mut sink: SplitSink<FramedTransport<P>, Message<P>>,
) -> Result<()>
where
P: Serialize + Send,
{
while let Some(payload) = queue.recv().await {
sink.send(payload).await.map_err(Error::MessageNotSent)?;
}
Ok(())
}
type Transport = SslStream<TcpStream>;
type FramedTransport<P> = SymmetricallyFramed<
Framed<Transport, LengthDelimitedCodec>,
Message<P>,
SymmetricalMessagePack<Message<P>>,
>;
fn framed<P>(stream: Transport) -> FramedTransport<P> {
let length_delimited = Framed::new(stream, LengthDelimitedCodec::new());
SymmetricallyFramed::new(
length_delimited,
SymmetricalMessagePack::<Message<P>>::default(),
)
}
async fn connect_outgoing(
peer_address: SocketAddr,
our_certificate: Arc<TlsCert>,
secret_key: Arc<PKey<Private>>,
server_is_stopped: Arc<AtomicBool>,
) -> Result<(NodeId, Transport)> {
let mut config = tls::create_tls_connector(&our_certificate.as_x509(), &secret_key)
.context("could not create TLS connector")?
.configure()
.map_err(Error::ConnectorConfiguration)?;
config.set_verify_hostname(false);
let stream = TcpStream::connect(peer_address)
.await
.context("TCP connection failed")?;
let tls_stream = tokio_openssl::connect(config, "this-will-not-be-checked.example.com", stream)
.await
.context("tls handshake failed")?;
let peer_cert = tls_stream
.ssl()
.peer_certificate()
.ok_or_else(|| Error::NoServerCertificate)?;
let peer_id = tls::validate_cert(peer_cert)?.public_key_fingerprint();
if server_is_stopped.load(Ordering::SeqCst) {
debug!(
our_id=%our_certificate.public_key_fingerprint(),
%peer_address,
"server stopped - aborting outgoing TLS connection"
);
Err(Error::ServerStopped)
} else {
Ok((peer_id, tls_stream))
}
}
impl<R, P> Debug for SmallNetwork<R, P>
where
P: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("SmallNetwork")
.field("our_id", &self.our_id)
.field("certificate", &"<SSL cert>")
.field("secret_key", &"<hidden>")
.field("public_address", &self.public_address)
.field("event_queue", &"<event_queue>")
.field("incoming", &self.incoming)
.field("outgoing", &self.outgoing)
.field("pending", &self.pending)
.finish()
}
}