use std::collections::{HashMap, HashSet};
use std::net;
use bitcoin::network::constants::ServiceFlags;
use nakamoto_common::block::time::{LocalDuration, LocalTime};
use nakamoto_common::p2p::peer::{self, AddressSource, Source};
use super::channel::{Disconnect, SetTimeout};
use crate::protocol::{DisconnectReason, Link, PeerId, Timeout};
pub const CONNECTION_TIMEOUT: LocalDuration = LocalDuration::from_secs(3);
pub const IDLE_TIMEOUT: LocalDuration = LocalDuration::from_mins(1);
pub const TARGET_OUTBOUND_PEERS: usize = 8;
pub const MAX_INBOUND_PEERS: usize = 16;
pub trait Connect {
fn connect(&self, addr: net::SocketAddr, timeout: Timeout);
}
pub trait Events {
fn event(&self, event: Event);
}
#[derive(Debug, Clone)]
pub enum Event {
Connecting(PeerId, Source),
Connected(PeerId, Link),
Disconnected(PeerId),
AddressBookExhausted,
}
impl std::fmt::Display for Event {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Event::Connecting(addr, source) => {
write!(fmt, "Connecting to peer {} from source `{}`", addr, source)
}
Event::Connected(addr, link) => write!(fmt, "{}: Peer connected ({:?})", &addr, link),
Event::Disconnected(addr) => write!(fmt, "Disconnected from {}", &addr),
Event::AddressBookExhausted => {
write!(fmt, "Address book exhausted when attempting to connect..")
}
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub target_outbound_peers: usize,
pub max_inbound_peers: usize,
pub retry: Vec<net::SocketAddr>,
pub required_services: ServiceFlags,
pub preferred_services: ServiceFlags,
}
#[derive(Debug)]
struct Peer {
address: net::SocketAddr,
local_address: net::SocketAddr,
link: Link,
services: ServiceFlags,
time: LocalTime,
}
#[derive(Debug)]
pub struct ConnectionManager<U> {
pub config: Config,
connecting: HashSet<PeerId>,
connected: HashMap<PeerId, Peer>,
disconnected: HashSet<PeerId>,
last_idle: Option<LocalTime>,
upstream: U,
}
impl<U: Connect + Disconnect + Events + SetTimeout> ConnectionManager<U> {
pub fn new(upstream: U, config: Config) -> Self {
Self {
connecting: HashSet::new(),
connected: HashMap::new(),
disconnected: HashSet::new(),
last_idle: None,
config,
upstream,
}
}
pub fn initialize<S: peer::Store, A: AddressSource>(
&mut self,
_time: LocalTime,
addrs: &mut A,
) {
let retry = self
.config
.retry
.iter()
.take(self.config.target_outbound_peers)
.cloned()
.collect::<Vec<_>>();
for addr in retry {
self.connect::<S, A>(&addr);
}
self.upstream.set_timeout(IDLE_TIMEOUT);
self.maintain_connections::<S, A>(addrs);
}
pub fn connect<S: peer::Store, A: AddressSource>(&mut self, addr: &PeerId) -> bool {
if self.connected.contains_key(&addr) || self.connecting.contains(addr) {
return false;
}
self.connecting.insert(*addr);
self.upstream.connect(*addr, CONNECTION_TIMEOUT);
true
}
pub fn disconnect(&mut self, addr: PeerId, reason: DisconnectReason) {
if self.connected.contains_key(&addr) {
debug_assert!(!self.disconnected.contains(&addr));
self.upstream.disconnect(addr, reason);
}
}
pub fn peer_connected(
&mut self,
address: net::SocketAddr,
local_address: net::SocketAddr,
link: Link,
time: LocalTime,
) {
debug_assert!(!self.connected.contains_key(&address));
Events::event(&self.upstream, Event::Connected(address, link));
match link {
Link::Inbound if self.inbound_peers().count() >= self.config.max_inbound_peers => {
self.upstream
.disconnect(address, DisconnectReason::ConnectionLimit);
}
_ => {
self.disconnected.remove(&address);
self.connecting.remove(&address);
self.connected.insert(
address,
Peer {
address,
local_address,
services: ServiceFlags::NONE,
link,
time,
},
);
}
}
}
pub fn peer_negotiated(&mut self, address: net::SocketAddr, services: ServiceFlags) {
let peer = self.connected.get_mut(&address).expect(
"ConnectionManager::peer_negotiated: negotiated peers should be connected first",
);
peer.services = services;
}
pub fn peer_disconnected<S: peer::Store, A: AddressSource>(
&mut self,
addr: &net::SocketAddr,
addrs: &A,
) {
debug_assert!(self.connected.contains_key(&addr));
debug_assert!(!self.disconnected.contains(&addr));
Events::event(&self.upstream, Event::Disconnected(*addr));
self.disconnected.insert(*addr);
if let Some(peer) = self.connected.remove(&addr) {
if peer.link.is_outbound() {
self.maintain_connections::<S, A>(addrs);
}
} else {
self.connecting.remove(&addr);
}
}
pub fn received_timeout<S: peer::Store, A: AddressSource>(
&mut self,
local_time: LocalTime,
addrs: &A,
) {
if local_time - self.last_idle.unwrap_or_default() >= IDLE_TIMEOUT {
self.maintain_connections::<S, A>(addrs);
self.upstream.set_timeout(IDLE_TIMEOUT);
self.last_idle = Some(local_time);
}
}
pub fn outbound_peers(&self) -> impl Iterator<Item = &PeerId> {
self.connected
.iter()
.filter(|(_, p)| p.link.is_outbound())
.map(|(addr, _)| addr)
}
pub fn inbound_peers(&self) -> impl Iterator<Item = &PeerId> {
self.connected
.iter()
.filter(|(_, p)| p.link.is_inbound())
.map(|(addr, _)| addr)
}
fn maintain_connections<S: peer::Store, A: AddressSource>(&mut self, addrs: &A) {
while self.outbound().count() + self.connecting.len() < self.config.target_outbound_peers {
let result = addrs
.sample(self.config.preferred_services)
.or_else(|| addrs.sample(self.config.required_services));
if let Some((addr, source)) = result {
if let Ok(sockaddr) = addr.socket_addr() {
debug_assert!(!self.connected.contains_key(&sockaddr));
if self.connect::<S, A>(&sockaddr) {
self.upstream.event(Event::Connecting(sockaddr, source));
break;
}
}
} else {
Events::event(&self.upstream, Event::AddressBookExhausted);
break;
}
}
}
fn outbound(&self) -> impl Iterator<Item = &Peer> + Clone {
self.connected.values().filter(|p| p.link.is_outbound())
}
}