use std::{fmt::Debug, sync::Arc};
use ahash::{AHashMap, AHashSet};
use anyhow::Result;
use cid::Cid;
use derivative::Derivative;
use futures::{future::BoxFuture, FutureExt};
use iroh_metrics::{bitswap::BitswapMetrics, core::MRecorder, inc};
use libp2p::PeerId;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, trace, warn};
use crate::network::Network;
use super::{message_queue::MessageQueue, peer_want_manager::PeerWantManager, session::Signaler};
#[derive(Debug, Clone)]
pub struct PeerManager {
sender: mpsc::Sender<Message>,
}
#[derive(Derivative)]
#[derivative(Debug)]
enum Message {
GetConnectedPeers(oneshot::Sender<Vec<PeerId>>),
GetCurrentWants(oneshot::Sender<AHashSet<Cid>>),
GetCurrentWantBlocks(oneshot::Sender<AHashSet<Cid>>),
GetCurrentWantHaves(oneshot::Sender<AHashSet<Cid>>),
Connected(PeerId),
Disconnected(PeerId),
ResponseReceived(PeerId, Vec<Cid>),
BroadcastWantHaves(AHashSet<Cid>),
SendWants {
peer: PeerId,
want_blocks: Vec<Cid>,
want_haves: Vec<Cid>,
},
SendCancels(Vec<Cid>),
RegisterSession {
peer: PeerId,
signaler: Signaler,
response: oneshot::Sender<bool>,
},
UnregisterSession(u64, oneshot::Sender<()>),
AddPeerToSession {
session: u64,
peer: PeerId,
response: oneshot::Sender<bool>,
},
RemovePeerFromSession {
session: u64,
peer: PeerId,
response: oneshot::Sender<bool>,
},
ProtectConnection {
session: u64,
peer: PeerId,
response: oneshot::Sender<()>,
},
PeersDiscoveredForSession {
session: u64,
response: oneshot::Sender<bool>,
},
PeersForSession {
session: u64,
response: oneshot::Sender<Vec<PeerId>>,
},
SessionHasPeers {
session: u64,
response: oneshot::Sender<bool>,
},
SessionHasPeer {
session: u64,
peer: PeerId,
response: oneshot::Sender<bool>,
},
SetCb(#[derivative(Debug = "ignore")] Arc<dyn DontHaveTimeout>),
}
pub trait DontHaveTimeout:
Fn(PeerId, Vec<Cid>) -> BoxFuture<'static, ()> + 'static + Sync + Send
{
}
impl<F: Fn(PeerId, Vec<Cid>) -> BoxFuture<'static, ()> + 'static + Sync + Send> DontHaveTimeout
for F
{
}
impl PeerManager {
pub async fn new(self_id: PeerId, network: Network) -> Self {
let (sender, receiver) = mpsc::channel(2048);
let actor = PeerManagerActor::new(self_id, network, receiver).await;
let _worker = tokio::task::spawn(async move {
run(actor).await;
});
Self { sender }
}
pub async fn set_cb<F>(&self, on_dont_have_timeout: F)
where
F: DontHaveTimeout,
{
self.send(Message::SetCb(Arc::new(on_dont_have_timeout)))
.await;
}
async fn send(&self, message: Message) {
if let Err(err) = self.sender.send(message).await {
warn!("failed to send message: {:?}", err);
}
}
pub async fn available_peers(&self) -> Vec<PeerId> {
self.connected_peers().await
}
pub async fn connected_peers(&self) -> Vec<PeerId> {
let (s, r) = oneshot::channel();
self.send(Message::GetConnectedPeers(s)).await;
r.await.unwrap_or_default()
}
pub async fn connected(&self, peer: &PeerId) {
self.send(Message::Connected(*peer)).await;
}
pub async fn disconnected(&self, peer: &PeerId) {
self.send(Message::Disconnected(*peer)).await;
}
pub async fn response_received(&self, peer: &PeerId, cids: &[Cid]) {
self.send(Message::ResponseReceived(*peer, cids.to_vec()))
.await;
}
pub async fn broadcast_want_haves(&self, want_haves: &AHashSet<Cid>) {
self.send(Message::BroadcastWantHaves(want_haves.to_owned()))
.await
}
pub async fn send_wants(&self, peer: &PeerId, want_blocks: &[Cid], want_haves: &[Cid]) {
self.send(Message::SendWants {
peer: *peer,
want_blocks: want_blocks.to_vec(),
want_haves: want_haves.to_vec(),
})
.await;
}
pub async fn send_cancels(&self, cancels: &[Cid]) {
self.send(Message::SendCancels(cancels.to_vec())).await;
}
pub async fn current_wants(&self) -> AHashSet<Cid> {
let (s, r) = oneshot::channel();
self.send(Message::GetCurrentWants(s)).await;
r.await.unwrap_or_default()
}
pub async fn current_want_blocks(&self) -> AHashSet<Cid> {
let (s, r) = oneshot::channel();
self.send(Message::GetCurrentWantBlocks(s)).await;
r.await.unwrap_or_default()
}
pub async fn current_want_haves(&self) -> AHashSet<Cid> {
let (s, r) = oneshot::channel();
self.send(Message::GetCurrentWantHaves(s)).await;
r.await.unwrap_or_default()
}
pub async fn register_session(&self, peer: &PeerId, signaler: Signaler) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::RegisterSession {
peer: *peer,
signaler,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn unregister_session(&self, session_id: u64) {
let (s, r) = oneshot::channel();
self.send(Message::UnregisterSession(session_id, s)).await;
let _ = r.await;
}
pub async fn stop(self) -> Result<()> {
debug!("stopping peer manager");
Ok(())
}
pub async fn add_peer_to_session(&self, session: u64, peer: PeerId) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::AddPeerToSession {
session,
peer,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn remove_peer_from_session(&self, session: u64, peer: PeerId) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::RemovePeerFromSession {
session,
peer,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn protect_connection(&self, session: u64, peer: PeerId) {
let (s, r) = oneshot::channel();
self.send(Message::ProtectConnection {
session,
peer,
response: s,
})
.await;
let _ = r.await;
}
pub async fn peers_discovered_for_session(&self, session: u64) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::PeersDiscoveredForSession {
session,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn peers_for_session(&self, session: u64) -> Vec<PeerId> {
let (s, r) = oneshot::channel();
self.send(Message::PeersForSession {
session,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn session_has_peers(&self, session: u64) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::SessionHasPeers {
session,
response: s,
})
.await;
r.await.unwrap_or_default()
}
pub async fn session_has_peer(&self, session: u64, peer: PeerId) -> bool {
let (s, r) = oneshot::channel();
self.send(Message::SessionHasPeer {
session,
peer,
response: s,
})
.await;
r.await.unwrap_or_default()
}
}
async fn run(mut actor: PeerManagerActor) {
loop {
inc!(BitswapMetrics::PeerManagerLoopTick);
tokio::select! {
message = actor.receiver.recv() => {
match message {
Some(Message::GetConnectedPeers(r)) => {
let _= r.send(actor.connected_peers().await);
},
Some(Message::GetCurrentWants(r)) => {
let _ = r.send(actor.current_wants());
},
Some(Message::GetCurrentWantBlocks(r)) => {
let _ = r.send(actor.current_want_blocks());
},
Some(Message::GetCurrentWantHaves(r)) => {
let _ = r.send(actor.current_want_haves());
},
Some(Message::Connected(peer)) => {
actor.connected(peer).await;
},
Some(Message::Disconnected(peer)) => {
actor.disconnected(peer).await;
},
Some(Message::ResponseReceived(peer, responses)) => {
actor.response_received(peer, responses).await;
},
Some(Message::BroadcastWantHaves(list)) => {
actor.broadcast_want_haves(list).await;
},
Some(Message::SendWants {
peer,
want_blocks,
want_haves,
}) => {
actor.send_wants(peer, want_blocks, want_haves).await;
},
Some(Message::SendCancels(cancels)) => {
actor.send_cancels(cancels).await;
},
Some(Message::RegisterSession { peer, signaler, response }) => {
let _ = response.send(actor.register_session(peer, signaler).await);
},
Some(Message::UnregisterSession(session, response)) => {
actor.unregister_session(session, response).await;
},
Some(Message::SetCb(cb)) => {
actor.on_dont_have_timeout = cb;
}
Some(Message::AddPeerToSession{
session,
peer,
response,
}) => {
actor.add_peer_to_session(session, peer, response).await;
},
Some(Message::RemovePeerFromSession{
session,
peer,
response,
}) => {
actor.remove_peer_from_session(session, peer, response).await;
},
Some(Message::ProtectConnection{
session,
peer,
response,
}) => {
actor.protect_connection(session, peer, response).await;
},
Some(Message::PeersDiscoveredForSession{
session,
response,
}) => {
actor.peers_discovered_for_session(session, response).await;
},
Some(Message::PeersForSession{
session,
response,
}) => {
actor.peers_for_session(session, response).await;
},
Some(Message::SessionHasPeers{
session,
response,
}) => {
actor.session_has_peers(session, response).await;
},
Some(Message::SessionHasPeer{
session,
peer,
response,
}) => {
actor.session_has_peer(session, peer, response).await;
},
None => {
break;
}
}
}
}
}
if let Err(err) = actor.stop().await {
warn!("failed to shutdown peer manager: {:?}", err);
}
}
#[derive(Derivative)]
#[derivative(Debug)]
struct PeerManagerActor {
receiver: mpsc::Receiver<Message>,
peers: AHashMap<PeerId, PeerState>,
peer_want_manager: PeerWantManager,
sessions: AHashMap<u64, SessionState>,
self_id: PeerId,
network: Network,
#[derivative(Debug = "ignore")]
on_dont_have_timeout: Arc<dyn DontHaveTimeout>,
}
#[derive(Debug)]
pub(super) struct PeerState {
pub(super) message_queue: MessageQueue,
pub(super) sessions: AHashSet<u64>,
}
#[derive(Debug)]
struct SessionState {
signaler: Signaler,
peers: AHashSet<PeerId>,
peers_discovered: bool,
}
impl PeerManagerActor {
async fn new(self_id: PeerId, network: Network, receiver: mpsc::Receiver<Message>) -> Self {
Self {
self_id,
receiver,
network,
peers: Default::default(),
peer_want_manager: Default::default(),
sessions: Default::default(),
on_dont_have_timeout: Arc::new(|_, _| async move {}.boxed()),
}
}
async fn stop(self) -> Result<()> {
let results = futures::future::join_all(
self.peers
.into_iter()
.map(|(_, state)| async move { state.message_queue.stop().await }),
)
.await;
for r in results {
r?;
}
Ok(())
}
async fn connected_peers(&self) -> Vec<PeerId> {
self.peers.keys().copied().collect()
}
async fn connected(&mut self, peer: PeerId) {
self.insert_peer(peer, None).await;
let peer_state = self.peers.get_mut(&peer).unwrap();
if !peer_state.message_queue.is_running() {
trace!("found stopped peer_queue, restarting: {}", peer);
inc!(BitswapMetrics::MessageQueuesCreated);
peer_state.message_queue = MessageQueue::new(
peer,
self.network.clone(),
self.on_dont_have_timeout.clone(),
)
.await;
}
self.peer_want_manager
.add_peer(&peer_state.message_queue, &peer)
.await;
self.signal_availability(peer, true).await;
}
async fn disconnected(&mut self, peer: PeerId) {
if let Some(peer_state) = self.peers.remove(&peer) {
inc!(BitswapMetrics::MessageQueuesDestroyed);
self.peer_want_manager.remove_peer(&peer);
if let Err(err) = peer_state.message_queue.stop().await {
error!("failed to shutdown message queue for {}: {:?}", peer, err);
}
}
}
async fn response_received(&self, peer: PeerId, cids: Vec<Cid>) {
if let Some(peer_state) = self.peers.get(&peer) {
peer_state.message_queue.response_received(cids).await;
}
}
async fn broadcast_want_haves(&mut self, want_haves: AHashSet<Cid>) {
self.peer_want_manager
.broadcast_want_haves(&want_haves, &self.peers)
.await;
}
async fn send_wants(&mut self, peer: PeerId, want_blocks: Vec<Cid>, want_haves: Vec<Cid>) {
debug!(
"send_wants to {}: {}, {} {:?}, {:?}",
peer,
want_blocks.len(),
want_haves.len(),
want_blocks,
want_haves
);
if let Some(peer_state) = self.peers.get(&peer) {
self.peer_want_manager
.send_wants(&peer, &want_blocks, &want_haves, &peer_state.message_queue)
.await;
}
}
async fn send_cancels(&mut self, cancels: Vec<Cid>) {
self.peer_want_manager
.send_cancels(&cancels, &self.peers)
.await;
}
fn current_wants(&self) -> AHashSet<Cid> {
self.peer_want_manager.get_wants()
}
fn current_want_blocks(&self) -> AHashSet<Cid> {
self.peer_want_manager.get_want_blocks()
}
fn current_want_haves(&self) -> AHashSet<Cid> {
self.peer_want_manager.get_want_haves()
}
async fn register_session(&mut self, peer: PeerId, signaler: Signaler) -> bool {
debug!("register session {}: {}", peer, signaler.id());
let id = signaler.id();
match self.sessions.entry(id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().peers_discovered = true;
entry.get_mut().peers.insert(peer)
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(SessionState {
signaler,
peers: [peer].into_iter().collect(),
peers_discovered: true,
});
true
}
}
}
async fn insert_peer(&mut self, peer: PeerId, session: Option<u64>) {
match self.peers.entry(peer) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
if let Some(id) = session {
entry.get_mut().sessions.insert(id);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
inc!(BitswapMetrics::MessageQueuesCreated);
let message_queue = MessageQueue::new(
peer,
self.network.clone(),
self.on_dont_have_timeout.clone(),
)
.await;
let sessions = session
.map(|id| [id].into_iter().collect())
.unwrap_or_default();
entry.insert(PeerState {
message_queue,
sessions,
});
}
}
}
async fn unregister_session(&mut self, session_id: u64, response: oneshot::Sender<()>) {
for peer_state in self.peers.values_mut() {
peer_state.sessions.remove(&session_id);
}
if let Some(session) = self.sessions.remove(&session_id) {
for peer in session.peers {
self.network.unprotect_peer(peer).await;
}
}
let _ = response.send(());
}
async fn signal_availability(&self, peer: PeerId, is_connected: bool) {
if let Some(peer_state) = self.peers.get(&peer) {
for session_id in &peer_state.sessions {
if let Some(session) = self.sessions.get(session_id) {
session.signaler.signal_availability(peer, is_connected);
}
}
}
}
async fn add_peer_to_session(
&mut self,
session_id: u64,
peer: PeerId,
response: oneshot::Sender<bool>,
) {
debug!("add peer to session {}: {}", peer, session_id);
if let Some(session) = self.sessions.get_mut(&session_id) {
debug!("found session: {}: {}", peer, session_id);
if session.peers.contains(&peer) {
let _ = response.send(false);
return;
}
session.peers.insert(peer);
session.peers_discovered = true;
let _ = response.send(true);
} else {
debug!("found no session: {}: {}", peer, session_id);
let _ = response.send(true);
}
}
async fn protect_connection(
&mut self,
session: u64,
peer: PeerId,
response: oneshot::Sender<()>,
) {
if let Some(session) = self.sessions.get(&session) {
if session.peers.contains(&peer) {
self.network.protect_peer(peer).await;
}
}
let _ = response.send(());
}
async fn remove_peer_from_session(
&mut self,
session: u64,
peer: PeerId,
response: oneshot::Sender<bool>,
) {
if let Some(session) = self.sessions.get_mut(&session) {
let existed = session.peers.remove(&peer);
let _ = response.send(existed);
if existed {
self.network.unprotect_peer(peer).await;
}
} else {
let _ = response.send(false);
}
}
async fn peers_discovered_for_session(&self, session: u64, response: oneshot::Sender<bool>) {
if let Some(session) = self.sessions.get(&session) {
let _ = response.send(session.peers_discovered);
} else {
let _ = response.send(false);
}
}
async fn peers_for_session(&self, session: u64, response: oneshot::Sender<Vec<PeerId>>) {
if let Some(session) = self.sessions.get(&session) {
let _ = response.send(session.peers.iter().copied().collect());
} else {
let _ = response.send(Vec::new());
}
}
async fn session_has_peers(&self, session: u64, response: oneshot::Sender<bool>) {
if let Some(session) = self.sessions.get(&session) {
let _ = response.send(!session.peers.is_empty());
} else {
let _ = response.send(false);
}
}
async fn session_has_peer(&self, session: u64, peer: PeerId, response: oneshot::Sender<bool>) {
if let Some(session) = self.sessions.get(&session) {
let _ = response.send(session.peers.contains(&peer));
} else {
let _ = response.send(false);
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::block::tests::create_random_block_v1;
use super::*;
#[tokio::test]
async fn test_adding_removing_peers() {
let this = PeerId::random();
let peer1 = PeerId::random();
let peer2 = PeerId::random();
let peer3 = PeerId::random();
let peer4 = PeerId::random();
let peer5 = PeerId::random();
let network = Network::new(this);
let peer_manager = PeerManager::new(this, network).await;
peer_manager.connected(&peer1).await;
peer_manager.connected(&peer2).await;
peer_manager.connected(&peer3).await;
let connected_peers = peer_manager.connected_peers().await;
assert!(connected_peers.contains(&peer1));
assert!(connected_peers.contains(&peer2));
assert!(connected_peers.contains(&peer3));
assert!(!connected_peers.contains(&peer4));
assert!(!connected_peers.contains(&peer5));
peer_manager.disconnected(&peer1).await;
let connected_peers = peer_manager.connected_peers().await;
assert!(!connected_peers.contains(&peer1));
peer_manager.connected(&peer1).await;
let connected_peers = peer_manager.connected_peers().await;
assert!(connected_peers.contains(&peer1));
peer_manager.stop().await.unwrap();
}
#[tokio::test]
async fn test_broadcast_on_connect() {
let this = PeerId::random();
let peer1 = PeerId::random();
let network = Network::new(this);
let peer_manager = PeerManager::new(this, network).await;
let cids: AHashSet<_> = gen_cids(2).into_iter().collect();
peer_manager.broadcast_want_haves(&cids).await;
peer_manager.connected(&peer1).await;
{
}
peer_manager.stop().await.unwrap();
}
#[tokio::test]
async fn test_broadcast_want_haves() {
let this = PeerId::random();
let peer1 = PeerId::random();
let peer2 = PeerId::random();
let network = Network::new(this);
let peer_manager = PeerManager::new(this, network).await;
let cids = gen_cids(3);
peer_manager
.broadcast_want_haves(&cids[..2].iter().copied().collect::<AHashSet<_>>())
.await;
peer_manager.connected(&peer1).await;
{
}
peer_manager.connected(&peer2).await;
peer_manager
.broadcast_want_haves(&[cids[0], cids[2]].into_iter().collect::<AHashSet<_>>())
.await;
{
}
peer_manager.stop().await.unwrap();
}
#[tokio::test]
async fn test_send_wants() {
let this = PeerId::random();
let peer1 = PeerId::random();
let network = Network::new(this);
let peer_manager = PeerManager::new(this, network).await;
let cids = gen_cids(4);
peer_manager.connected(&peer1).await;
peer_manager
.send_wants(&peer1, &[cids[0]][..], &[cids[2]][..])
.await;
{
}
peer_manager
.send_wants(&peer1, &[cids[0], cids[1]][..], &[cids[2], cids[3]][..])
.await;
{
}
peer_manager.stop().await.unwrap();
}
#[tokio::test]
async fn test_send_cancels() {
let this = PeerId::random();
let peer1 = PeerId::random();
let peer2 = PeerId::random();
let network = Network::new(this);
let peer_manager = PeerManager::new(this, network).await;
let cids = gen_cids(4);
peer_manager.connected(&peer1).await;
peer_manager.connected(&peer2).await;
peer_manager
.send_wants(&peer1, &[cids[0], cids[1]][..], &[cids[2]][..])
.await;
std::thread::sleep(Duration::from_millis(100));
{
}
peer_manager.send_cancels(&[cids[0], cids[2]][..]).await;
std::thread::sleep(Duration::from_millis(100));
{
}
peer_manager.stop().await.unwrap();
}
fn gen_cids(n: usize) -> Vec<Cid> {
(0..n).map(|_| *create_random_block_v1().cid()).collect()
}
}