use core::fmt;
use std::{
io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use crate::AlreadyRegistered;
use crate::{handler::NewStream, shared::Shared};
use futures::{
channel::{mpsc, oneshot},
SinkExt as _, StreamExt as _,
};
use libp2p_identity::PeerId;
use libp2p_swarm::{Stream, StreamProtocol};
#[derive(Clone)]
pub struct Control {
shared: Arc<Mutex<Shared>>,
}
impl Control {
pub(crate) fn new(shared: Arc<Mutex<Shared>>) -> Self {
Self { shared }
}
pub async fn open_stream(
&mut self,
peer: PeerId,
protocol: StreamProtocol,
) -> Result<Stream, OpenStreamError> {
tracing::debug!(%peer, "Requesting new stream");
let mut new_stream_sender = Shared::lock(&self.shared).sender(peer);
let (sender, receiver) = oneshot::channel();
new_stream_sender
.send(NewStream { protocol, sender })
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;
let stream = receiver
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))??;
Ok(stream)
}
pub fn accept(
&mut self,
protocol: StreamProtocol,
) -> Result<IncomingStreams, AlreadyRegistered> {
Shared::lock(&self.shared).accept(protocol)
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum OpenStreamError {
UnsupportedProtocol(StreamProtocol),
Io(std::io::Error),
}
impl From<std::io::Error> for OpenStreamError {
fn from(v: std::io::Error) -> Self {
Self::Io(v)
}
}
impl fmt::Display for OpenStreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OpenStreamError::UnsupportedProtocol(p) => {
write!(f, "failed to open stream: remote peer does not support {p}")
}
OpenStreamError::Io(e) => {
write!(f, "failed to open stream: io error: {e}")
}
}
}
}
impl std::error::Error for OpenStreamError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(error) => Some(error),
_ => None,
}
}
}
#[must_use = "Streams do nothing unless polled."]
pub struct IncomingStreams {
receiver: mpsc::Receiver<(PeerId, Stream)>,
}
impl IncomingStreams {
pub(crate) fn new(receiver: mpsc::Receiver<(PeerId, Stream)>) -> Self {
Self { receiver }
}
}
impl futures::Stream for IncomingStreams {
type Item = (PeerId, Stream);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_next_unpin(cx)
}
}