#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use either::Either;
use futures::{future, prelude::*, ready};
use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
use libp2p_core::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeInfo};
use std::collections::VecDeque;
use std::io::{IoSlice, IoSliceMut};
use std::task::Waker;
use std::{
    io, iter,
    pin::Pin,
    task::{Context, Poll},
};
use thiserror::Error;
#[derive(Debug)]
pub struct Muxer<C> {
    connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>,
    inbound_stream_buffer: VecDeque<Stream>,
    inbound_stream_waker: Option<Waker>,
}
const MAX_BUFFERED_INBOUND_STREAMS: usize = 256;
impl<C> Muxer<C>
where
    C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
    fn new(connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>) -> Self {
        Muxer {
            connection,
            inbound_stream_buffer: VecDeque::default(),
            inbound_stream_waker: None,
        }
    }
}
impl<C> StreamMuxer for Muxer<C>
where
    C: AsyncRead + AsyncWrite + Unpin + 'static,
{
    type Substream = Stream;
    type Error = Error;
    #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_inbound", skip(self, cx))]
    fn poll_inbound(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Self::Substream, Self::Error>> {
        if let Some(stream) = self.inbound_stream_buffer.pop_front() {
            return Poll::Ready(Ok(stream));
        }
        if let Poll::Ready(res) = self.poll_inner(cx) {
            return Poll::Ready(res);
        }
        self.inbound_stream_waker = Some(cx.waker().clone());
        Poll::Pending
    }
    #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_outbound", skip(self, cx))]
    fn poll_outbound(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Self::Substream, Self::Error>> {
        let stream = match self.connection.as_mut() {
            Either::Left(c) => ready!(c.poll_new_outbound(cx))
                .map_err(|e| Error(Either::Left(e)))
                .map(|s| Stream(Either::Left(s))),
            Either::Right(c) => ready!(c.poll_new_outbound(cx))
                .map_err(|e| Error(Either::Right(e)))
                .map(|s| Stream(Either::Right(s))),
        }?;
        Poll::Ready(Ok(stream))
    }
    #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_close", skip(self, cx))]
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        match self.connection.as_mut() {
            Either::Left(c) => c.poll_close(cx).map_err(|e| Error(Either::Left(e))),
            Either::Right(c) => c.poll_close(cx).map_err(|e| Error(Either::Right(e))),
        }
    }
    #[tracing::instrument(level = "trace", name = "StreamMuxer::poll", skip(self, cx))]
    fn poll(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
        let this = self.get_mut();
        let inbound_stream = ready!(this.poll_inner(cx))?;
        if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS {
            tracing::warn!(
                stream=%inbound_stream.0,
                "dropping stream because buffer is full"
            );
            drop(inbound_stream);
        } else {
            this.inbound_stream_buffer.push_back(inbound_stream);
            if let Some(waker) = this.inbound_stream_waker.take() {
                waker.wake()
            }
        }
        cx.waker().wake_by_ref();
        Poll::Pending
    }
}
#[derive(Debug)]
pub struct Stream(Either<yamux012::Stream, yamux013::Stream>);
impl AsyncRead for Stream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read(cx, buf))
    }
    fn poll_read_vectored(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &mut [IoSliceMut<'_>],
    ) -> Poll<io::Result<usize>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read_vectored(cx, bufs))
    }
}
impl AsyncWrite for Stream {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write(cx, buf))
    }
    fn poll_write_vectored(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[IoSlice<'_>],
    ) -> Poll<io::Result<usize>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write_vectored(cx, bufs))
    }
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_flush(cx))
    }
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_close(cx))
    }
}
impl<C> Muxer<C>
where
    C: AsyncRead + AsyncWrite + Unpin + 'static,
{
    fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream, Error>> {
        let stream = match self.connection.as_mut() {
            Either::Left(c) => ready!(c.poll_next_inbound(cx))
                .ok_or(Error(Either::Left(yamux012::ConnectionError::Closed)))?
                .map_err(|e| Error(Either::Left(e)))
                .map(|s| Stream(Either::Left(s)))?,
            Either::Right(c) => ready!(c.poll_next_inbound(cx))
                .ok_or(Error(Either::Right(yamux013::ConnectionError::Closed)))?
                .map_err(|e| Error(Either::Right(e)))
                .map(|s| Stream(Either::Right(s)))?,
        };
        Poll::Ready(Ok(stream))
    }
}
#[derive(Debug, Clone)]
pub struct Config(Either<Config012, Config013>);
impl Default for Config {
    fn default() -> Self {
        Self(Either::Right(Config013::default()))
    }
}
#[derive(Debug, Clone)]
struct Config012 {
    inner: yamux012::Config,
    mode: Option<yamux012::Mode>,
}
impl Default for Config012 {
    fn default() -> Self {
        let mut inner = yamux012::Config::default();
        inner.set_read_after_close(false);
        Self { inner, mode: None }
    }
}
pub struct WindowUpdateMode(yamux012::WindowUpdateMode);
impl WindowUpdateMode {
    #[deprecated(note = "Use `WindowUpdateMode::on_read` instead.")]
    pub fn on_receive() -> Self {
        #[allow(deprecated)]
        WindowUpdateMode(yamux012::WindowUpdateMode::OnReceive)
    }
    pub fn on_read() -> Self {
        WindowUpdateMode(yamux012::WindowUpdateMode::OnRead)
    }
}
impl Config {
    #[deprecated(note = "Will be removed with the next breaking release.")]
    pub fn client() -> Self {
        Self(Either::Left(Config012 {
            mode: Some(yamux012::Mode::Client),
            ..Default::default()
        }))
    }
    #[deprecated(note = "Will be removed with the next breaking release.")]
    pub fn server() -> Self {
        Self(Either::Left(Config012 {
            mode: Some(yamux012::Mode::Server),
            ..Default::default()
        }))
    }
    #[deprecated(
        note = "Will be replaced in the next breaking release with a connection receive window size limit."
    )]
    pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
        self.set(|cfg| cfg.set_receive_window(num_bytes))
    }
    #[deprecated(note = "Will be removed with the next breaking release.")]
    pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
        self.set(|cfg| cfg.set_max_buffer_size(num_bytes))
    }
    pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
        self.set(|cfg| cfg.set_max_num_streams(num_streams))
    }
    #[deprecated(
        note = "`WindowUpdate::OnRead` is the default. `WindowUpdate::OnReceive` breaks backpressure, is thus not recommended, and will be removed in the next breaking release. Thus this method becomes obsolete and will be removed with the next breaking release."
    )]
    pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
        self.set(|cfg| cfg.set_window_update_mode(mode.0))
    }
    fn set(&mut self, f: impl FnOnce(&mut yamux012::Config) -> &mut yamux012::Config) -> &mut Self {
        let cfg012 = match self.0.as_mut() {
            Either::Left(c) => &mut c.inner,
            Either::Right(_) => {
                self.0 = Either::Left(Config012::default());
                &mut self.0.as_mut().unwrap_left().inner
            }
        };
        f(cfg012);
        self
    }
}
impl UpgradeInfo for Config {
    type Info = &'static str;
    type InfoIter = iter::Once<Self::Info>;
    fn protocol_info(&self) -> Self::InfoIter {
        iter::once("/yamux/1.0.0")
    }
}
impl<C> InboundConnectionUpgrade<C> for Config
where
    C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
    type Output = Muxer<C>;
    type Error = io::Error;
    type Future = future::Ready<Result<Self::Output, Self::Error>>;
    fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
        let connection = match self.0 {
            Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
                io,
                inner,
                mode.unwrap_or(yamux012::Mode::Server),
            )),
            Either::Right(Config013(cfg)) => {
                Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Server))
            }
        };
        future::ready(Ok(Muxer::new(connection)))
    }
}
impl<C> OutboundConnectionUpgrade<C> for Config
where
    C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
    type Output = Muxer<C>;
    type Error = io::Error;
    type Future = future::Ready<Result<Self::Output, Self::Error>>;
    fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
        let connection = match self.0 {
            Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
                io,
                inner,
                mode.unwrap_or(yamux012::Mode::Client),
            )),
            Either::Right(Config013(cfg)) => {
                Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Client))
            }
        };
        future::ready(Ok(Muxer::new(connection)))
    }
}
#[derive(Debug, Clone)]
struct Config013(yamux013::Config);
impl Default for Config013 {
    fn default() -> Self {
        let mut cfg = yamux013::Config::default();
        cfg.set_read_after_close(false);
        Self(cfg)
    }
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct Error(Either<yamux012::ConnectionError, yamux013::ConnectionError>);
impl From<Error> for io::Error {
    fn from(err: Error) -> Self {
        match err.0 {
            Either::Left(err) => match err {
                yamux012::ConnectionError::Io(e) => e,
                e => io::Error::new(io::ErrorKind::Other, e),
            },
            Either::Right(err) => match err {
                yamux013::ConnectionError::Io(e) => e,
                e => io::Error::new(io::ErrorKind::Other, e),
            },
        }
    }
}
#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn config_set_switches_to_v012() {
        let mut cfg = Config::default();
        assert!(matches!(
            cfg,
            Config(Either::Right(Config013(yamux013::Config { .. })))
        ));
        cfg.set_max_num_streams(42);
        assert!(matches!(cfg, Config(Either::Left(Config012 { .. }))));
    }
}