use std::{
collections::{HashMap, HashSet},
future::IntoFuture,
pin::Pin,
task::{Context, Poll, Waker},
time::Duration,
};
use bytes::Bytes;
use connexa::prelude::PeerId;
use futures::{future::BoxFuture, stream::FusedStream, FutureExt, Stream};
use futures_timer::Delay;
use indexmap::IndexMap;
use ipld_core::cid::Cid;
use std::fmt::Debug;
use crate::{
repo::{DefaultStorage, Repo},
Block,
};
const CAP_THRESHOLD: usize = 100;
macro_rules! state_ready {
($context:expr, $e:expr, $ee:expr $(,)?) => {
match $ee {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => {
$e.waker.replace($context.waker().clone());
return std::task::Poll::Pending;
}
}
};
}
#[derive(Debug)]
pub enum WantSessionEvent {
Dial { peer_id: PeerId },
SendWant { peer_id: PeerId },
SendCancel { peer_id: PeerId },
SendBlock { peer_id: PeerId },
BlockStored,
NeedBlock,
Cancelled,
}
pub enum WantSessionState {
Idle,
NextBlock {
previous_peer_id: Option<PeerId>,
},
NextBlockPending {
peer_id: PeerId,
timer: Delay,
},
PutBlock {
from_peer_id: PeerId,
fut: BoxFuture<'static, Result<Cid, anyhow::Error>>,
},
Cancel,
Complete,
}
impl Debug for WantSessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WantSessionState")
}
}
#[derive(Debug)]
enum WantDiscovery {
Disable,
Start,
SilentStart,
Running { timer: Delay },
}
#[derive(Debug, PartialEq, Eq)]
enum PeerWantState {
Pending,
Sent,
Have,
Failed,
Waiting,
Disconnect { backoff: bool },
Cancel,
}
#[derive(Debug)]
pub struct WantSession {
cid: Cid,
wants: IndexMap<PeerId, PeerWantState>,
discovery: WantDiscovery,
received: bool,
waker: Option<Waker>,
repo: Repo<DefaultStorage>,
state: WantSessionState,
timeout: Option<Duration>,
discovery_timeout: Duration,
timer: Option<Delay>,
terminated: Option<bool>,
}
impl WantSession {
pub fn new(repo: &Repo<DefaultStorage>, cid: Cid, timeout: Option<Duration>) -> Self {
Self {
cid,
wants: Default::default(),
discovery: WantDiscovery::Disable,
received: false,
repo: repo.clone(),
waker: None,
state: WantSessionState::Idle,
timeout,
timer: timeout.map(Delay::new),
discovery_timeout: timeout.map(|d| d / 2).unwrap_or(Duration::from_secs(30)),
terminated: None,
}
}
pub fn send_have_block(&mut self, peer_id: PeerId) {
match self.wants.entry(peer_id) {
indexmap::map::Entry::Occupied(mut entry) => {
let state = entry.get_mut();
if !matches!(state, PeerWantState::Pending | PeerWantState::Sent) {
*state = PeerWantState::Pending;
}
}
indexmap::map::Entry::Vacant(entry) => {
entry.insert(PeerWantState::Pending);
}
};
tracing::trace!(session = %self.cid, %peer_id, name = "want_session", "send have block");
self.discovery = WantDiscovery::Disable;
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn has_block(&mut self, peer_id: PeerId) {
tracing::debug!(session = %self.cid, %peer_id, name = "want_session", "have block");
self.wants
.entry(peer_id)
.and_modify(|state| *state = PeerWantState::Have)
.or_insert(PeerWantState::Have);
if !matches!(self.state, WantSessionState::NextBlock { .. }) {
tracing::debug!(session = %self.cid, %peer_id, name = "want_session", "change state to next_block");
self.state = WantSessionState::NextBlock {
previous_peer_id: None,
};
}
self.discovery = WantDiscovery::Disable;
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn dont_have_block(&mut self, peer_id: PeerId) {
tracing::trace!(session = %self.cid, %peer_id, name = "want_session", "dont have block");
self.wants.shift_remove(&peer_id);
if self.is_empty() {
self.state = WantSessionState::Idle;
if !matches!(
self.discovery,
WantDiscovery::Running { .. } | WantDiscovery::Start
) {
self.discovery = WantDiscovery::Start;
}
tracing::warn!(session = %self.cid, %peer_id, name = "want_session", "session is empty. setting state to idle.");
} else {
tracing::debug!(session = %self.cid, name = "want_session", "checking next peer for block");
self.state = WantSessionState::NextBlock {
previous_peer_id: None,
};
}
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn peer_disconnected(&mut self, peer_id: PeerId) -> bool {
if !self.contains(peer_id) {
return false;
}
if let indexmap::map::Entry::Occupied(mut entry) = self.wants.entry(peer_id) {
let state = entry.get_mut();
if let PeerWantState::Disconnect { backoff } = state {
if *backoff {
entry.shift_remove();
}
return false;
} else {
*state = PeerWantState::Disconnect { backoff: false };
}
}
true
}
pub fn put_block(&mut self, peer_id: PeerId, block: Block) {
if matches!(self.state, WantSessionState::PutBlock { .. }) {
tracing::warn!(session = %self.cid, %peer_id, cid = %block.cid(), name = "want_session", "state already putting block into store");
} else {
tracing::info!(%peer_id, cid = %block.cid(), name = "want_session", "storing block");
let fut = self.repo.put_block(&block).into_future();
self.state = WantSessionState::PutBlock {
from_peer_id: peer_id,
fut,
};
self.discovery = WantDiscovery::Disable;
self.timer.take();
}
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn remove_peer(&mut self, peer_id: PeerId) {
if !self.is_empty() {
self.wants.shift_remove(&peer_id);
if matches!(self.state, WantSessionState::NextBlockPending { peer_id: p, .. } if p == peer_id)
{
self.state = WantSessionState::NextBlock {
previous_peer_id: Some(peer_id),
};
}
} else if !matches!(
self.discovery,
WantDiscovery::Running { .. } | WantDiscovery::Start
) {
self.discovery = WantDiscovery::Start;
}
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn contains(&self, peer_id: PeerId) -> bool {
self.wants.contains_key(&peer_id)
}
pub fn is_empty(&self) -> bool {
self.wants
.iter()
.all(|(_, state)| matches!(state, PeerWantState::Failed))
}
}
impl Unpin for WantSession {}
impl Stream for WantSession {
type Item = WantSessionEvent;
#[tracing::instrument(level = "trace", name = "WantSession::poll_next", skip(self, cx))]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let cid = self.cid;
if let Some(peer_id) = self
.wants
.iter()
.find(|(_, state)| matches!(state, PeerWantState::Cancel))
.map(|(peer_id, _)| peer_id)
.copied()
{
self.wants.shift_remove(&peer_id);
return Poll::Ready(Some(WantSessionEvent::SendCancel { peer_id }));
}
if self.received {
return Poll::Ready(None);
}
if let Some(terminated) = self.terminated.as_mut() {
match terminated {
true => return Poll::Ready(None),
false => {
*terminated = true;
return Poll::Ready(Some(WantSessionEvent::Cancelled));
}
}
}
if let Some((peer_id, state)) = self
.wants
.iter_mut()
.find(|(_, state)| matches!(state, PeerWantState::Disconnect { backoff } if !backoff))
{
let PeerWantState::Disconnect { backoff } = state else {
unreachable!("peer state is set to disconnect");
};
*backoff = true;
tracing::info!(session = %cid, %peer_id, name = "want_session", "peer is disconnected. Attempting to dial peer");
return Poll::Ready(Some(WantSessionEvent::Dial { peer_id: *peer_id }));
}
if !matches!(
self.state,
WantSessionState::Complete | WantSessionState::Cancel
) {
if let Some((peer_id, state)) = self
.wants
.iter_mut()
.find(|(_, state)| **state == PeerWantState::Pending)
{
*state = PeerWantState::Sent;
tracing::debug!(session = %cid, %peer_id, name = "want_session", "sent want block");
return Poll::Ready(Some(WantSessionEvent::SendWant { peer_id: *peer_id }));
} else if self.wants.capacity() > CAP_THRESHOLD {
self.wants.shrink_to_fit()
}
}
let this = &mut *self;
loop {
if let Some(timer) = this.timer.as_mut() {
if timer.poll_unpin(cx).is_ready() {
this.state = WantSessionState::Cancel;
this.timer.take();
}
}
match this.state {
WantSessionState::Idle => {
this.waker = Some(cx.waker().clone());
if let Some(peer_id) = this
.wants
.iter()
.find(|(_, state)| matches!(state, PeerWantState::Cancel))
.map(|(peer_id, _)| peer_id)
.copied()
{
this.wants.shift_remove(&peer_id);
return Poll::Ready(Some(WantSessionEvent::SendCancel { peer_id }));
}
if (this.wants.is_empty()
|| this
.wants
.values()
.all(|state| matches!(state, PeerWantState::Failed)))
&& matches!(this.discovery, WantDiscovery::Disable)
{
this.discovery = WantDiscovery::SilentStart;
}
match &mut this.discovery {
WantDiscovery::Disable => {}
WantDiscovery::Start => {
this.discovery = WantDiscovery::Running {
timer: Delay::new(this.discovery_timeout),
};
return Poll::Ready(Some(WantSessionEvent::NeedBlock));
}
WantDiscovery::SilentStart => {
this.discovery = WantDiscovery::Running {
timer: Delay::new(this.discovery_timeout),
};
cx.waker().wake_by_ref();
}
WantDiscovery::Running { timer } => {
if timer.poll_unpin(cx).is_ready() {
timer.reset(this.discovery_timeout);
return Poll::Ready(Some(WantSessionEvent::NeedBlock));
}
}
}
return Poll::Pending;
}
WantSessionState::NextBlock {
ref mut previous_peer_id,
} => {
if let Some(peer_id) = previous_peer_id.take() {
tracing::debug!(session = %cid, %peer_id, name = "want_session", "failed block");
if let Some(state) = this.wants.get_mut(&peer_id) {
*state = PeerWantState::Failed;
}
}
if let Some((next_peer_id, state)) = this
.wants
.iter_mut()
.find(|(_, state)| matches!(state, PeerWantState::Have))
{
tracing::info!(session = %cid, %next_peer_id, name = "want_session", "sending block request to next peer");
this.discovery = WantDiscovery::Disable;
let timeout = match this.timeout {
Some(timeout) if !timeout.is_zero() => timeout,
_ => Duration::from_secs(15),
};
let timer = Delay::new(timeout);
this.state = WantSessionState::NextBlockPending {
peer_id: *next_peer_id,
timer,
};
*state = PeerWantState::Waiting;
return Poll::Ready(Some(WantSessionEvent::SendBlock {
peer_id: *next_peer_id,
}));
}
tracing::debug!(session = %cid, name = "want_session", "session is idle");
this.state = WantSessionState::Idle;
if this.wants.is_empty()
|| this
.wants
.values()
.all(|state| matches!(state, PeerWantState::Failed))
&& matches!(this.discovery, WantDiscovery::Disable)
{
this.discovery = WantDiscovery::Start;
}
}
WantSessionState::NextBlockPending {
peer_id,
ref mut timer,
} => {
state_ready!(cx, this, timer.poll_unpin(cx));
tracing::warn!(session = %cid, name = "want_session", %peer_id, "request timeout attempting to get next block");
this.state = WantSessionState::NextBlock {
previous_peer_id: Some(peer_id),
};
}
WantSessionState::PutBlock {
from_peer_id,
ref mut fut,
} => {
let peer_id = from_peer_id;
match state_ready!(cx, this, fut.poll_unpin(cx)) {
Ok(cid) => {
tracing::info!(session = %self.cid, %peer_id, block = %cid, name = "want_session", "block stored in block store");
self.state = WantSessionState::Complete;
return Poll::Ready(Some(WantSessionEvent::BlockStored));
}
Err(e) => {
tracing::error!(session = %cid, %peer_id, error = %e, name = "want_session", "error storing block in store");
this.state = WantSessionState::NextBlock {
previous_peer_id: Some(peer_id),
};
}
}
}
WantSessionState::Cancel => {
tracing::info!(session = %cid, "cancelled");
for (peer_id, state) in this.wants.iter_mut() {
tracing::info!(%peer_id, session = %cid, "setting peer state to cancel");
*state = PeerWantState::Cancel;
}
this.state = WantSessionState::Idle;
this.terminated = Some(false);
cx.waker().wake_by_ref();
}
WantSessionState::Complete => {
tracing::info!(session = %cid, "obtaining block completed.");
this.received = true;
this.state = WantSessionState::Cancel;
}
}
}
}
}
impl FusedStream for WantSession {
fn is_terminated(&self) -> bool {
self.received
}
}
#[derive(Debug)]
pub enum HaveSessionEvent {
Have { peer_id: PeerId },
DontHave { peer_id: PeerId },
Block { peer_id: PeerId, bytes: Bytes },
Cancelled,
}
enum HaveSessionState {
Idle,
ContainBlock {
fut: BoxFuture<'static, Result<bool, anyhow::Error>>,
},
GetBlock {
fut: BoxFuture<'static, Result<Option<Block>, anyhow::Error>>,
},
Block {
bytes: Bytes,
},
Complete,
}
#[derive(Debug)]
enum HaveWantState {
#[allow(dead_code)]
Pending {
send_dont_have: bool,
},
Sent,
Block,
BlockSent,
}
pub struct HaveSession {
cid: Cid,
want: HashMap<PeerId, HaveWantState>,
send_dont_have: HashSet<PeerId>,
have: Option<bool>,
repo: Repo<DefaultStorage>,
waker: Option<Waker>,
state: HaveSessionState,
}
impl HaveSession {
pub fn new(repo: &Repo<DefaultStorage>, cid: Cid) -> Self {
let mut session = Self {
cid,
want: HashMap::new(),
have: None,
repo: repo.clone(),
waker: None,
send_dont_have: Default::default(),
state: HaveSessionState::Idle,
};
let repo = session.repo.clone();
let fut = async move { repo.contains(&cid).await }.boxed();
session.state = HaveSessionState::ContainBlock { fut };
session
}
pub fn has_peer(&self, peer_id: PeerId) -> bool {
self.want.contains_key(&peer_id)
}
pub fn want_block(&mut self, peer_id: PeerId, send_dont_have: bool) {
if self.want.contains_key(&peer_id) {
tracing::warn!(session = %self.cid, %peer_id, "peer already requested block. Ignoring additional request");
return;
}
tracing::info!(session = %self.cid, %peer_id, name = "have_session", "peer want block");
self.want
.insert(peer_id, HaveWantState::Pending { send_dont_have });
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn need_block(&mut self, peer_id: PeerId) {
if self
.want
.get(&peer_id)
.map(|state| matches!(state, HaveWantState::Block | HaveWantState::BlockSent))
.unwrap_or_default()
{
tracing::warn!(session = %self.cid, %peer_id, name = "have_session", "already sending block to peer");
return;
}
tracing::info!(session = %self.cid, %peer_id, name = "have_session", "peer requested block");
self.want
.entry(peer_id)
.and_modify(|state| *state = HaveWantState::Block)
.or_insert(HaveWantState::Block);
if !matches!(
self.state,
HaveSessionState::GetBlock { .. } | HaveSessionState::Block { .. }
) {
let repo = self.repo.clone();
let cid = self.cid;
let fut = async move { repo.get_block_now(&cid).await }.boxed();
tracing::info!(session = %self.cid, %peer_id, name = "have_session", "change state to get_block");
self.state = HaveSessionState::GetBlock { fut };
if let Some(w) = self.waker.take() {
w.wake();
}
}
}
pub fn remove_peer(&mut self, peer_id: PeerId) {
tracing::info!(session = %self.cid, %peer_id, name = "have_session", "removing peer from have_session");
self.want.remove(&peer_id);
self.send_dont_have.remove(&peer_id);
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn reset(&mut self) {
if self.have.is_none() || self.have.unwrap_or_default() {
return;
}
tracing::info!(session = %self.cid, name = "have_session", "resetting session");
for (peer_id, state) in self.want.iter_mut() {
*state = HaveWantState::Pending {
send_dont_have: self.send_dont_have.contains(peer_id),
};
tracing::debug!(session = %self.cid, name = "have_session", %peer_id, "resetting peer state");
}
let repo = self.repo.clone();
let cid = self.cid;
let fut = async move { repo.contains(&cid).await }.boxed();
self.state = HaveSessionState::ContainBlock { fut };
if let Some(w) = self.waker.take() {
w.wake();
}
}
pub fn cancel(&mut self, peer_id: PeerId) {
self.want.remove(&peer_id);
self.send_dont_have.remove(&peer_id);
tracing::info!(session = %self.cid, %peer_id, name = "have_session", "cancelling request");
}
}
impl Unpin for HaveSession {}
impl Stream for HaveSession {
type Item = HaveSessionEvent;
#[tracing::instrument(level = "trace", name = "HaveSession::poll_next", skip(self, cx))]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if matches!(self.state, HaveSessionState::Complete) {
return Poll::Ready(None);
}
let this = &mut *self;
if let HaveSessionState::Block { bytes } = &this.state {
if let Some((next_peer_id, state)) = this
.want
.iter_mut()
.find(|(_, state)| matches!(state, HaveWantState::Block))
{
*state = HaveWantState::BlockSent;
return Poll::Ready(Some(HaveSessionEvent::Block {
peer_id: *next_peer_id,
bytes: bytes.clone(),
}));
}
if this
.want
.iter()
.all(|(_, state)| matches!(state, HaveWantState::BlockSent))
|| this.want.is_empty()
{
this.state = HaveSessionState::Complete;
this.want.clear();
this.send_dont_have.clear();
return Poll::Ready(Some(HaveSessionEvent::Cancelled));
}
this.state = HaveSessionState::Idle;
return Poll::Pending;
}
loop {
match this.state {
HaveSessionState::Idle => {
if let Some(have) = this.have {
if let Some((peer_id, state)) = this
.want
.iter_mut()
.find(|(_, state)| matches!(state, HaveWantState::Pending { .. }))
{
let peer_id = *peer_id;
*state = HaveWantState::Sent;
tracing::debug!(%peer_id, peer_state = ?state, have_block=have, session = %this.cid, "notifying peer of block status");
return match have {
true => Poll::Ready(Some(HaveSessionEvent::Have { peer_id })),
false => Poll::Ready(Some(HaveSessionEvent::DontHave { peer_id })),
};
}
}
this.waker.replace(cx.waker().clone());
return Poll::Pending;
}
HaveSessionState::ContainBlock { ref mut fut } => {
let have = state_ready!(cx, this, fut.poll_unpin(cx)).unwrap_or_default();
this.have = Some(have);
this.state = HaveSessionState::Idle;
}
HaveSessionState::GetBlock { ref mut fut } => {
let result = state_ready!(cx, this, fut.poll_unpin(cx));
let (cid, bytes) = match result {
Ok(Some(block)) => block.into_inner(),
Ok(None) => {
tracing::warn!(session = %this.cid, "block does not exist");
this.state = HaveSessionState::Idle;
this.have = Some(false);
continue;
}
Err(e) => {
tracing::error!(session = %this.cid, error = %e, "error obtaining block");
this.state = HaveSessionState::Idle;
this.have = Some(false);
continue;
}
};
debug_assert_eq!(cid, this.cid);
this.have = Some(true);
this.state = HaveSessionState::Block { bytes };
}
HaveSessionState::Block { ref bytes } => {
match this
.want
.iter_mut()
.find(|(_, state)| matches!(state, HaveWantState::Block))
{
Some((peer_id, state)) => {
*state = HaveWantState::BlockSent;
return Poll::Ready(Some(HaveSessionEvent::Block {
peer_id: *peer_id,
bytes: bytes.clone(),
}));
}
None => return Poll::Pending,
}
}
HaveSessionState::Complete => {
return Poll::Ready(None);
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.want.len(), None)
}
}
impl FusedStream for HaveSession {
fn is_terminated(&self) -> bool {
matches!(self.state, HaveSessionState::Complete)
}
}