1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use core::fmt;
use std::{
    io,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};

use crate::AlreadyRegistered;
use crate::{handler::NewStream, shared::Shared};

use futures::{
    channel::{mpsc, oneshot},
    SinkExt as _, StreamExt as _,
};
use libp2p_identity::PeerId;
use libp2p_swarm::{Stream, StreamProtocol};

/// A (remote) control for opening new streams and registration of inbound protocols.
///
/// A [`Control`] can be cloned and thus allows for concurrent access.
#[derive(Clone)]
pub struct Control {
    shared: Arc<Mutex<Shared>>,
}

impl Control {
    pub(crate) fn new(shared: Arc<Mutex<Shared>>) -> Self {
        Self { shared }
    }

    /// Attempt to open a new stream for the given protocol and peer.
    ///
    /// In case we are currently not connected to the peer, we will attempt to make a new connection.
    ///
    /// ## Backpressure
    ///
    /// [`Control`]s support backpressure similarly to bounded channels:
    /// Each [`Control`] has a guaranteed slot for internal messages.
    /// A single control will always open one stream at a time which is enforced by requiring `&mut self`.
    ///
    /// This backpressure mechanism breaks if you clone [`Control`]s excessively.
    pub async fn open_stream(
        &mut self,
        peer: PeerId,
        protocol: StreamProtocol,
    ) -> Result<Stream, OpenStreamError> {
        tracing::debug!(%peer, "Requesting new stream");

        let mut new_stream_sender = Shared::lock(&self.shared).sender(peer);

        let (sender, receiver) = oneshot::channel();

        new_stream_sender
            .send(NewStream { protocol, sender })
            .await
            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;

        let stream = receiver
            .await
            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))??;

        Ok(stream)
    }

    /// Accept inbound streams for the provided protocol.
    ///
    /// To stop accepting streams, simply drop the returned [`IncomingStreams`] handle.
    pub fn accept(
        &mut self,
        protocol: StreamProtocol,
    ) -> Result<IncomingStreams, AlreadyRegistered> {
        Shared::lock(&self.shared).accept(protocol)
    }
}

/// Errors while opening a new stream.
#[derive(Debug)]
#[non_exhaustive]
pub enum OpenStreamError {
    /// The remote does not support the requested protocol.
    UnsupportedProtocol(StreamProtocol),
    /// IO Error that occurred during the protocol handshake.
    Io(std::io::Error),
}

impl From<std::io::Error> for OpenStreamError {
    fn from(v: std::io::Error) -> Self {
        Self::Io(v)
    }
}

impl fmt::Display for OpenStreamError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            OpenStreamError::UnsupportedProtocol(p) => {
                write!(f, "failed to open stream: remote peer does not support {p}")
            }
            OpenStreamError::Io(e) => {
                write!(f, "failed to open stream: io error: {e}")
            }
        }
    }
}

impl std::error::Error for OpenStreamError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            Self::Io(error) => Some(error),
            _ => None,
        }
    }
}

/// A handle to inbound streams for a particular protocol.
#[must_use = "Streams do nothing unless polled."]
pub struct IncomingStreams {
    receiver: mpsc::Receiver<(PeerId, Stream)>,
}

impl IncomingStreams {
    pub(crate) fn new(receiver: mpsc::Receiver<(PeerId, Stream)>) -> Self {
        Self { receiver }
    }
}

impl futures::Stream for IncomingStreams {
    type Item = (PeerId, Stream);

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.receiver.poll_next_unpin(cx)
    }
}