use futures::{Stream, StreamExt, stream::FusedStream};
use std::{collections::HashMap, task::Poll};
use crate::{
Behavior, BehaviorOutput, Message as MessageTrait, OutboundQueue, PeerId, protocol as proto,
};
use super::{AcceptedVersion, AnyMessage, BlockRange, ConnectionState};
mod blockfetch;
mod chainsync;
mod connection;
mod discovery;
mod handshake;
mod keepalive;
mod promotion;
pub use blockfetch::*;
pub use chainsync::*;
pub use connection::*;
pub use discovery::*;
pub use handshake::*;
pub use keepalive::*;
pub use promotion::*;
pub trait PeerVisitor {
#[allow(unused_variables)]
fn visit_connected(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_disconnected(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_errored(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_discovered(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_inbound_msg(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_outbound_msg(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_tagged(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
#[allow(unused_variables)]
fn visit_housekeeping(
&mut self,
pid: &PeerId,
state: &mut InitiatorState,
outbound: &mut OutboundQueue<InitiatorBehavior>,
) {
}
}
#[derive(PartialEq, Debug, Default, Copy, Clone)]
pub enum PromotionTag {
#[default]
Cold,
Warm,
Hot,
Banned,
}
#[derive(Default, Debug)]
pub struct InitiatorState {
pub(crate) connection: ConnectionState,
pub(crate) promotion: PromotionTag,
pub(crate) handshake: proto::handshake::State<proto::handshake::n2n::VersionData>,
pub(crate) keepalive: proto::keepalive::State,
pub(crate) peersharing: proto::peersharing::State,
pub(crate) blockfetch: proto::blockfetch::State,
pub(crate) chainsync: proto::chainsync::State<proto::chainsync::HeaderContent>,
pub(crate) tx_submission: proto::txsubmission::State,
pub(crate) violation: bool,
pub(crate) error_count: u32,
pub(crate) continue_sync: bool,
}
impl InitiatorState {
pub fn new() -> Self {
InitiatorState {
connection: ConnectionState::default(),
promotion: PromotionTag::default(),
handshake: proto::handshake::State::default(),
keepalive: proto::keepalive::State::default(),
peersharing: proto::peersharing::State::default(),
blockfetch: proto::blockfetch::State::default(),
chainsync: crate::protocol::chainsync::State::default(),
tx_submission: crate::protocol::txsubmission::State::default(),
violation: false,
error_count: 0,
continue_sync: false,
}
}
pub fn is_initialized(&self) -> bool {
matches!(self.connection, ConnectionState::Initialized)
}
pub fn version(&self) -> Option<proto::handshake::n2n::VersionData> {
match &self.handshake {
proto::handshake::State::Done(proto::handshake::DoneState::Accepted(_, data)) => {
Some(data.clone())
}
_ => None,
}
}
pub fn promotion(&self) -> PromotionTag {
self.promotion
}
pub fn supports_peer_sharing(&self) -> bool {
let val = self
.version()
.as_ref()
.and_then(|v| v.peer_sharing)
.unwrap_or(0);
val > 0
}
pub fn apply_msg(&mut self, msg: &AnyMessage) {
match msg {
AnyMessage::Handshake(msg) => {
let result = self.handshake.apply(msg);
let Ok(new) = result else {
tracing::warn!("handshake violation");
self.violation = true;
return;
};
self.handshake = new;
}
AnyMessage::KeepAlive(msg) => {
let result = self.keepalive.apply(msg);
let Ok(new) = result else {
tracing::warn!("keepalive violation");
self.violation = true;
return;
};
self.keepalive = new;
}
AnyMessage::PeerSharing(msg) => {
let result = self.peersharing.apply(msg);
let Ok(new) = result else {
tracing::warn!("peer sharing violation");
self.violation = true;
return;
};
self.peersharing = new;
}
AnyMessage::BlockFetch(msg) => {
let result = self.blockfetch.apply(msg);
let Ok(new) = result else {
tracing::warn!("block fetch violation");
self.violation = true;
return;
};
self.blockfetch = new;
}
AnyMessage::ChainSync(msg) => {
let result = self.chainsync.apply(msg);
let Ok(new) = result else {
tracing::warn!("chain sync violation");
self.violation = true;
return;
};
self.chainsync = new;
}
AnyMessage::TxSubmission(msg) => {
let result = self.tx_submission.apply(msg);
let Ok(new) = result else {
tracing::warn!("tx submission violation");
self.violation = true;
return;
};
self.tx_submission = new;
}
}
}
pub fn reset(&mut self) {
self.connection = ConnectionState::default();
self.promotion = PromotionTag::default();
self.handshake = proto::handshake::State::default();
self.keepalive = proto::keepalive::State::default();
self.peersharing = proto::peersharing::State::default();
self.blockfetch = proto::blockfetch::State::default();
self.chainsync = proto::chainsync::State::default();
self.tx_submission = proto::txsubmission::State::default();
self.continue_sync = false;
self.violation = false;
}
}
pub type TagFn = fn(&mut InitiatorState);
#[derive(Debug)]
pub enum InitiatorCommand {
IncludePeer(PeerId),
Housekeeping,
StartSync(Vec<proto::Point>),
ContinueSync(PeerId),
RequestBlocks(BlockRange),
SendTx(
PeerId,
proto::txsubmission::EraTxId,
proto::txsubmission::EraTxBody,
),
BanPeer(PeerId),
DemotePeer(PeerId),
}
#[derive(Debug)]
pub enum InitiatorEvent {
PeerInitialized(PeerId, AcceptedVersion),
IntersectionFound(PeerId, proto::Point, proto::chainsync::Tip),
BlockHeaderReceived(
PeerId,
proto::chainsync::HeaderContent,
proto::chainsync::Tip,
),
RollbackReceived(PeerId, proto::Point, proto::chainsync::Tip),
BlockBodyReceived(PeerId, proto::blockfetch::Body),
TxRequested(PeerId, proto::txsubmission::EraTxId),
}
#[derive(Default)]
pub struct InitiatorBehavior {
pub promotion: promotion::PromotionBehavior,
pub connection: connection::ConnectionBehavior,
pub handshake: handshake::HandshakeBehavior,
pub keepalive: keepalive::KeepaliveBehavior,
pub discovery: discovery::DiscoveryBehavior,
pub blockfetch: blockfetch::BlockFetchBehavior,
pub chainsync: chainsync::ChainSyncBehavior,
pub peers: HashMap<PeerId, InitiatorState>,
pub outbound: OutboundQueue<Self>,
}
macro_rules! all_visitors {
($self:ident, $pid:ident, $state:expr, $method:ident) => {
$self.promotion.$method($pid, $state, &mut $self.outbound);
$self.connection.$method($pid, $state, &mut $self.outbound);
$self.handshake.$method($pid, $state, &mut $self.outbound);
$self.keepalive.$method($pid, $state, &mut $self.outbound);
$self.discovery.$method($pid, $state, &mut $self.outbound);
$self.blockfetch.$method($pid, $state, &mut $self.outbound);
$self.chainsync.$method($pid, $state, &mut $self.outbound);
};
}
impl InitiatorBehavior {
#[tracing::instrument(skip_all, fields(pid = %pid, channel = %msg.channel()))]
pub fn on_inbound_msg(&mut self, pid: &PeerId, msg: &AnyMessage) {
tracing::debug!(channel = msg.channel(), "new inbound message");
self.peers.entry(pid.clone()).and_modify(|state| {
state.apply_msg(msg);
all_visitors!(self, pid, state, visit_inbound_msg);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid, channel = %msg.channel()))]
pub fn on_outbound_msg(&mut self, pid: &PeerId, msg: &AnyMessage) {
tracing::debug!(channel = msg.channel(), "new outbound message");
self.peers.entry(pid.clone()).and_modify(|state| {
state.apply_msg(msg);
all_visitors!(self, pid, state, visit_outbound_msg);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid))]
fn on_connected(&mut self, pid: &PeerId) {
tracing::info!("connected");
self.peers.entry(pid.clone()).and_modify(|state| {
state.connection = ConnectionState::Connected;
all_visitors!(self, pid, state, visit_connected);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid))]
fn on_disconnected(&mut self, pid: &PeerId) {
tracing::info!("disconnected");
self.peers.entry(pid.clone()).and_modify(|state| {
state.connection = ConnectionState::Disconnected;
state.reset();
all_visitors!(self, pid, state, visit_disconnected);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid))]
fn on_errored(&mut self, pid: &PeerId) {
tracing::error!("error");
self.peers.entry(pid.clone()).and_modify(|state| {
state.connection = ConnectionState::Errored;
state.error_count += 1;
all_visitors!(self, pid, state, visit_errored);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid))]
fn on_tagged(&mut self, pid: &PeerId, tagger: TagFn) {
tracing::debug!("tagged");
self.peers.entry(pid.clone()).and_modify(|state| {
tagger(state);
all_visitors!(self, pid, state, visit_tagged);
});
}
#[tracing::instrument(skip_all, fields(pid = %pid))]
fn on_discovered(&mut self, pid: &PeerId) {
let mut state = InitiatorState::new();
all_visitors!(self, pid, &mut state, visit_discovered);
self.peers.insert(pid.clone(), state);
}
fn move_discovered_into_promotion(&mut self) {
let deficit = self.promotion.peer_deficit();
if deficit == 0 {
return;
}
let new = self.discovery.drain_new_peers(deficit);
if new.is_empty() {
tracing::trace!("no new peers discovered");
return;
}
tracing::info!(deficit = deficit, new = new.len(), "discovered new peers",);
for pid in new {
if !self.peers.contains_key(&pid) {
self.on_discovered(&pid);
}
}
}
#[tracing::instrument(skip_all)]
fn housekeeping(&mut self) {
for (pid, state) in self.peers.iter_mut() {
all_visitors!(self, pid, state, visit_housekeeping);
}
self.move_discovered_into_promotion();
}
}
impl Stream for InitiatorBehavior {
type Item = BehaviorOutput<Self>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let poll = self.outbound.futures.poll_next_unpin(cx);
match poll {
Poll::Ready(Some(x)) => Poll::Ready(Some(x)),
Poll::Ready(None) => Poll::Pending,
Poll::Pending => Poll::Pending,
}
}
}
impl FusedStream for InitiatorBehavior {
fn is_terminated(&self) -> bool {
false
}
}
impl Behavior for InitiatorBehavior {
type Event = InitiatorEvent;
type Command = InitiatorCommand;
type PeerState = InitiatorState;
type Message = AnyMessage;
fn handle_io(&mut self, event: crate::InterfaceEvent<Self::Message>) {
match &event {
crate::InterfaceEvent::Connected(pid) => {
self.on_connected(pid);
}
crate::InterfaceEvent::Disconnected(pid) => {
self.on_disconnected(pid);
}
crate::InterfaceEvent::Recv(pid, msgs) => {
for msg in msgs {
self.on_inbound_msg(pid, msg);
}
}
crate::InterfaceEvent::Sent(pid, msg) => {
self.on_outbound_msg(pid, msg);
}
crate::InterfaceEvent::Error(pid, _) => {
self.on_errored(pid);
}
crate::InterfaceEvent::Idle => {
self.housekeeping();
}
}
}
fn execute(&mut self, cmd: Self::Command) {
match cmd {
InitiatorCommand::IncludePeer(pid) => {
tracing::debug!("include peer command");
self.on_discovered(&pid);
}
InitiatorCommand::StartSync(points) => {
tracing::debug!("start sync command");
self.chainsync.start(points);
}
InitiatorCommand::ContinueSync(pid) => {
tracing::debug!("continue sync command");
self.on_tagged(&pid, |state| state.continue_sync = true);
}
InitiatorCommand::RequestBlocks(range) => {
tracing::debug!("request blocks command");
self.blockfetch.enqueue(range);
}
InitiatorCommand::Housekeeping => {
tracing::debug!("housekeeping command");
self.housekeeping();
}
InitiatorCommand::BanPeer(pid) => {
tracing::debug!("ban peer command");
self.on_tagged(&pid, |state| state.promotion = PromotionTag::Banned);
}
InitiatorCommand::DemotePeer(pid) => {
tracing::debug!("demote peer command");
self.on_tagged(&pid, |state| state.promotion = PromotionTag::Cold);
}
InitiatorCommand::SendTx(..) => {
tracing::warn!("SendTx not yet implemented");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{
MAINNET_MAGIC, Point, blockfetch as bf, chainsync as cs, handshake, keepalive, peersharing,
};
use crate::testing::BehaviorOutputExt;
use crate::{InterfaceError, InterfaceEvent};
use futures::StreamExt;
use std::collections::HashMap;
use std::net::Ipv4Addr;
fn drain_outputs(behavior: &mut InitiatorBehavior) -> Vec<BehaviorOutput<InitiatorBehavior>> {
let mut outputs = Vec::new();
let waker = futures::task::noop_waker();
let mut cx = std::task::Context::from_waker(&waker);
while let std::task::Poll::Ready(Some(output)) = behavior.poll_next_unpin(&mut cx) {
outputs.push(output);
}
outputs
}
fn complete_handshake(behavior: &mut InitiatorBehavior, pid: &PeerId) {
let version_data =
handshake::n2n::VersionData::new(MAINNET_MAGIC, false, Some(1), Some(false));
let mut values = HashMap::new();
values.insert(13u64, version_data.clone());
let version_table = handshake::VersionTable { values };
let propose = AnyMessage::Handshake(handshake::Message::Propose(version_table));
behavior.handle_io(InterfaceEvent::Sent(pid.clone(), propose));
drain_outputs(behavior);
let accept = AnyMessage::Handshake(handshake::Message::Accept(13, version_data));
behavior.handle_io(InterfaceEvent::Recv(pid.clone(), vec![accept]));
drain_outputs(behavior);
}
#[tokio::test]
async fn banned_peer_not_reconnected() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(1);
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Connected(pid.clone()));
drain_outputs(&mut behavior);
let bad_msg = AnyMessage::KeepAlive(keepalive::Message::ResponseKeepAlive(42));
behavior.handle_io(InterfaceEvent::Recv(pid.clone(), vec![bad_msg]));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Disconnected(pid.clone()));
drain_outputs(&mut behavior);
for _ in 0..10 {
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(!outputs.has_connect_for(&pid));
}
}
#[tokio::test]
async fn demote_peer_returns_to_cold() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(2);
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Connected(pid.clone()));
drain_outputs(&mut behavior);
complete_handshake(&mut behavior, &pid);
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
assert!(behavior.promotion.hot_peers.contains(&pid));
behavior.execute(InitiatorCommand::DemotePeer(pid.clone()));
drain_outputs(&mut behavior);
let state = behavior.peers.get(&pid).unwrap();
assert_eq!(state.promotion, PromotionTag::Cold);
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(outputs.has_disconnect_for(&pid));
}
#[tokio::test]
async fn error_count_persists_across_disconnect() {
tokio::time::pause();
let mut behavior = InitiatorBehavior {
promotion: PromotionBehavior::new(PromotionConfig {
max_error_count: 2,
..PromotionConfig::default()
}),
..Default::default()
};
let pid = PeerId::test(3);
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
for _ in 0..2 {
behavior.handle_io(InterfaceEvent::Error(
pid.clone(),
InterfaceError::Other("err".into()),
));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Disconnected(pid.clone()));
drain_outputs(&mut behavior);
}
assert!(!behavior.promotion.banned_peers.contains(&pid));
behavior.handle_io(InterfaceEvent::Error(
pid.clone(),
InterfaceError::Other("err".into()),
));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
assert!(behavior.promotion.banned_peers.contains(&pid));
}
#[tokio::test]
async fn full_peer_lifecycle_include_to_chainsync() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(10);
behavior.execute(InitiatorCommand::StartSync(vec![Point::Origin]));
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(behavior.promotion.warm_peers.contains(&pid));
assert!(outputs.has_connect_for(&pid));
behavior.handle_io(InterfaceEvent::Connected(pid.clone()));
let outputs = drain_outputs(&mut behavior);
assert!(
outputs
.has_send(|m| matches!(m, AnyMessage::Handshake(handshake::Message::Propose(_))))
);
complete_handshake(&mut behavior, &pid);
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(behavior.promotion.hot_peers.contains(&pid));
assert!(
outputs.has_send(|m| matches!(m, AnyMessage::ChainSync(cs::Message::FindIntersect(_)))),
"chainsync should start for hot initialized peer"
);
}
#[tokio::test]
async fn housekeeping_promotes_and_connects_in_same_pass() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(11);
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(
behavior.promotion.warm_peers.contains(&pid),
"peer should be promoted to warm"
);
assert!(
outputs.has_connect_for(&pid),
"Connect should be issued in the same housekeeping pass"
);
}
#[tokio::test]
async fn discovery_feeds_into_promotion() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let seed_pid = PeerId::test(12);
behavior.execute(InitiatorCommand::IncludePeer(seed_pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Connected(seed_pid.clone()));
drain_outputs(&mut behavior);
complete_handshake(&mut behavior, &seed_pid);
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
let share_response = AnyMessage::PeerSharing(peersharing::Message::SharePeers(vec![
peersharing::PeerAddress::V4(Ipv4Addr::new(192, 168, 1, 1), 3000),
peersharing::PeerAddress::V4(Ipv4Addr::new(192, 168, 1, 2), 3001),
]));
let share_req = AnyMessage::PeerSharing(peersharing::Message::ShareRequest(10));
behavior.handle_io(InterfaceEvent::Sent(seed_pid.clone(), share_req));
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Recv(seed_pid.clone(), vec![share_response]));
drain_outputs(&mut behavior);
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
let discovered_1 = PeerId {
host: "192.168.1.1".to_string(),
port: 3000,
};
let discovered_2 = PeerId {
host: "192.168.1.2".to_string(),
port: 3001,
};
assert!(
behavior.peers.contains_key(&discovered_1),
"discovered peer 1 should be tracked after housekeeping"
);
assert!(
behavior.peers.contains_key(&discovered_2),
"discovered peer 2 should be tracked after housekeeping"
);
}
#[tokio::test]
async fn violation_bans_and_disconnects() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(13);
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Connected(pid.clone()));
drain_outputs(&mut behavior);
let bad_msg = AnyMessage::KeepAlive(keepalive::Message::ResponseKeepAlive(42));
behavior.handle_io(InterfaceEvent::Recv(pid.clone(), vec![bad_msg]));
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(
behavior.promotion.banned_peers.contains(&pid),
"promotion should ban the violating peer"
);
assert!(
outputs.has_disconnect_for(&pid),
"connection should disconnect the banned peer"
);
}
#[tokio::test]
async fn blockfetch_requires_initialized_and_idle() {
tokio::time::pause();
let mut behavior = InitiatorBehavior::default();
let pid = PeerId::test(14);
let range = (Point::Origin, Point::new(100, vec![0xAA; 32]));
behavior.blockfetch.enqueue(range.clone());
behavior.execute(InitiatorCommand::IncludePeer(pid.clone()));
behavior.execute(InitiatorCommand::Housekeeping);
drain_outputs(&mut behavior);
behavior.handle_io(InterfaceEvent::Connected(pid.clone()));
drain_outputs(&mut behavior);
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(
!outputs
.has_send(|m| matches!(m, AnyMessage::BlockFetch(bf::Message::RequestRange(_)))),
"should NOT send RequestRange before handshake"
);
complete_handshake(&mut behavior, &pid);
behavior.execute(InitiatorCommand::Housekeeping);
let outputs = drain_outputs(&mut behavior);
assert!(
outputs.has_send(|m| matches!(m, AnyMessage::BlockFetch(bf::Message::RequestRange(_)))),
"should send RequestRange after handshake completes"
);
}
}