use crate::Version;
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use bytes::{Bytes, BytesMut, BufMut};
use futures::{prelude::*, io::IoSlice, ready};
use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
use unsigned_varint as uvi;
const MAX_PROTOCOLS: usize = 1000;
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
const MSG_LS: &[u8] = b"ls\n";
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum HeaderLine {
    
    V1,
}
impl From<Version> for HeaderLine {
    fn from(v: Version) -> HeaderLine {
        match v {
            Version::V1 | Version::V1Lazy => HeaderLine::V1,
        }
    }
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Protocol(Bytes);
impl AsRef<[u8]> for Protocol {
    fn as_ref(&self) -> &[u8] {
        self.0.as_ref()
    }
}
impl TryFrom<Bytes> for Protocol {
    type Error = ProtocolError;
    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
        if !value.as_ref().starts_with(b"/") {
            return Err(ProtocolError::InvalidProtocol)
        }
        Ok(Protocol(value))
    }
}
impl TryFrom<&[u8]> for Protocol {
    type Error = ProtocolError;
    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        Self::try_from(Bytes::copy_from_slice(value))
    }
}
impl fmt::Display for Protocol {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", String::from_utf8_lossy(&self.0))
    }
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
    
    
    Header(HeaderLine),
    
    Protocol(Protocol),
    
    
    ListProtocols,
    
    Protocols(Vec<Protocol>),
    
    NotAvailable,
}
impl Message {
    
    pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
        match self {
            Message::Header(HeaderLine::V1) => {
                dest.reserve(MSG_MULTISTREAM_1_0.len());
                dest.put(MSG_MULTISTREAM_1_0);
                Ok(())
            }
            Message::Protocol(p) => {
                let len = p.0.as_ref().len() + 1; 
                dest.reserve(len);
                dest.put(p.0.as_ref());
                dest.put_u8(b'\n');
                Ok(())
            }
            Message::ListProtocols => {
                dest.reserve(MSG_LS.len());
                dest.put(MSG_LS);
                Ok(())
            }
            Message::Protocols(ps) => {
                let mut buf = uvi::encode::usize_buffer();
                let mut encoded = Vec::with_capacity(ps.len());
                for p in ps {
                    encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); 
                    encoded.extend_from_slice(p.0.as_ref());
                    encoded.push(b'\n')
                }
                encoded.push(b'\n');
                dest.reserve(encoded.len());
                dest.put(encoded.as_ref());
                Ok(())
            }
            Message::NotAvailable => {
                dest.reserve(MSG_PROTOCOL_NA.len());
                dest.put(MSG_PROTOCOL_NA);
                Ok(())
            }
        }
    }
    
    pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
        if msg == MSG_MULTISTREAM_1_0 {
            return Ok(Message::Header(HeaderLine::V1))
        }
        if msg == MSG_PROTOCOL_NA {
            return Ok(Message::NotAvailable);
        }
        if msg == MSG_LS {
            return Ok(Message::ListProtocols)
        }
        
        
        if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') &&
            !msg[.. msg.len() - 1].contains(&b'\n')
        {
            let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
            return Ok(Message::Protocol(p));
        }
        
        
        let mut protocols = Vec::new();
        let mut remaining: &[u8] = &msg;
        loop {
            
            if remaining == [b'\n'] {
                break
            } else if protocols.len() == MAX_PROTOCOLS {
                return Err(ProtocolError::TooManyProtocols)
            }
            
            
            let (len, tail) = uvi::decode::usize(remaining)?;
            if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
                return Err(ProtocolError::InvalidMessage)
            }
            
            let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?;
            protocols.push(p);
            
            remaining = &tail[len ..];
        }
        Ok(Message::Protocols(protocols))
    }
}
#[pin_project::pin_project]
pub struct MessageIO<R> {
    #[pin]
    inner: LengthDelimited<R>,
}
impl<R> MessageIO<R> {
    
    pub fn new(inner: R) -> MessageIO<R>
    where
        R: AsyncRead + AsyncWrite
    {
        Self { inner: LengthDelimited::new(inner) }
    }
    
    
    
    
    
    
    
    pub fn into_reader(self) -> MessageReader<R> {
        MessageReader { inner: self.inner.into_reader() }
    }
    
    
    
    
    
    
    
    
    
    pub fn into_inner(self) -> R {
        self.inner.into_inner()
    }
}
impl<R> Sink<Message> for MessageIO<R>
where
    R: AsyncWrite,
{
    type Error = ProtocolError;
    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_ready(cx).map_err(From::from)
    }
    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
        let mut buf = BytesMut::new();
        item.encode(&mut buf)?;
        self.project().inner.start_send(buf.freeze()).map_err(From::from)
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_flush(cx).map_err(From::from)
    }
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().inner.poll_close(cx).map_err(From::from)
    }
}
impl<R> Stream for MessageIO<R>
where
    R: AsyncRead
{
    type Item = Result<Message, ProtocolError>;
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match poll_stream(self.project().inner, cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
        }
    }
}
#[pin_project::pin_project]
#[derive(Debug)]
pub struct MessageReader<R> {
    #[pin]
    inner: LengthDelimitedReader<R>
}
impl<R> MessageReader<R> {
    
    
    
    
    
    
    
    
    
    
    
    pub fn into_inner(self) -> R {
        self.inner.into_inner()
    }
}
impl<R> Stream for MessageReader<R>
where
    R: AsyncRead
{
    type Item = Result<Message, ProtocolError>;
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        poll_stream(self.project().inner, cx)
    }
}
impl<TInner> AsyncWrite for MessageReader<TInner>
where
    TInner: AsyncWrite
{
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
        self.project().inner.poll_write(cx, buf)
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        self.project().inner.poll_flush(cx)
    }
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        self.project().inner.poll_close(cx)
    }
    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, io::Error>> {
        self.project().inner.poll_write_vectored(cx, bufs)
    }
}
fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll<Option<Result<Message, ProtocolError>>>
where
    S: Stream<Item = Result<Bytes, io::Error>>,
{
    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
        match Message::decode(msg) {
            Ok(m) => m,
            Err(err) => return Poll::Ready(Some(Err(err))),
        }
    } else {
        return Poll::Ready(None)
    };
    log::trace!("Received message: {:?}", msg);
    Poll::Ready(Some(Ok(msg)))
}
#[derive(Debug)]
pub enum ProtocolError {
    
    IoError(io::Error),
    
    InvalidMessage,
    
    InvalidProtocol,
    
    TooManyProtocols,
}
impl From<io::Error> for ProtocolError {
    fn from(err: io::Error) -> ProtocolError {
        ProtocolError::IoError(err)
    }
}
impl Into<io::Error> for ProtocolError {
    fn into(self) -> io::Error {
        if let ProtocolError::IoError(e) = self {
            return e
        }
        io::ErrorKind::InvalidData.into()
    }
}
impl From<uvi::decode::Error> for ProtocolError {
    fn from(err: uvi::decode::Error) -> ProtocolError {
        Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
    }
}
impl Error for ProtocolError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match *self {
            ProtocolError::IoError(ref err) => Some(err),
            _ => None,
        }
    }
}
impl fmt::Display for ProtocolError {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        match self {
            ProtocolError::IoError(e) =>
                write!(fmt, "I/O error: {}", e),
            ProtocolError::InvalidMessage =>
                write!(fmt, "Received an invalid message."),
            ProtocolError::InvalidProtocol =>
                write!(fmt, "A protocol (name) is invalid."),
            ProtocolError::TooManyProtocols =>
                write!(fmt, "Too many protocols received.")
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use quickcheck::*;
    use rand::Rng;
    use rand::distributions::Alphanumeric;
    use std::iter;
    impl Arbitrary for Protocol {
        fn arbitrary<G: Gen>(g: &mut G) -> Protocol {
            let n = g.gen_range(1, g.size());
            let p: String = iter::repeat(())
                .map(|()| g.sample(Alphanumeric))
                .take(n)
                .collect();
            Protocol(Bytes::from(format!("/{}", p)))
        }
    }
    impl Arbitrary for Message {
        fn arbitrary<G: Gen>(g: &mut G) -> Message {
            match g.gen_range(0, 5) {
                0 => Message::Header(HeaderLine::V1),
                1 => Message::NotAvailable,
                2 => Message::ListProtocols,
                3 => Message::Protocol(Protocol::arbitrary(g)),
                4 => Message::Protocols(Vec::arbitrary(g)),
                _ => panic!()
            }
        }
    }
    #[test]
    fn encode_decode_message() {
        fn prop(msg: Message) {
            let mut buf = BytesMut::new();
            msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
            match Message::decode(buf.freeze()) {
                Ok(m) => assert_eq!(m, msg),
                Err(e) => panic!("Decoding failed: {:?}", e)
            }
        }
        quickcheck(prop as fn(_))
    }
}