use std::{fmt, io, num::NonZeroU16};
use ntex_util::future::Either;
use crate::v5::codec::DisconnectReasonCode;
#[derive(Debug, thiserror::Error)]
pub enum MqttError<E> {
    #[error("Service error")]
    Service(E),
    #[error("Mqtt handshake error: {}", _0)]
    Handshake(#[from] HandshakeError<E>),
}
#[derive(Debug, thiserror::Error)]
pub enum HandshakeError<E> {
    #[error("Handshake service error")]
    Service(E),
    #[error("Mqtt protocol error: {}", _0)]
    Protocol(#[from] ProtocolError),
    #[error("Handshake timeout")]
    Timeout,
    #[error("Peer is disconnected, error: {:?}", _0)]
    Disconnected(Option<io::Error>),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ProtocolError {
    #[error("Decoding error: {0:?}")]
    Decode(#[from] DecodeError),
    #[error("Encoding error: {0:?}")]
    Encode(#[from] EncodeError),
    #[error("Protocol violation: {0}")]
    ProtocolViolation(#[from] ProtocolViolationError),
    #[error("Keep Alive timeout")]
    KeepAliveTimeout,
    #[error("Read frame timeout")]
    ReadTimeout,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error(transparent)]
pub struct ProtocolViolationError {
    inner: ViolationInner,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
enum ViolationInner {
    #[error("{message}")]
    Common { reason: DisconnectReasonCode, message: &'static str },
    #[error("{message}; received packet with type `{packet_type:b}`")]
    UnexpectedPacket { packet_type: u8, message: &'static str },
}
impl ProtocolViolationError {
    pub(crate) fn reason(&self) -> DisconnectReasonCode {
        match self.inner {
            ViolationInner::Common { reason, .. } => reason,
            ViolationInner::UnexpectedPacket { .. } => DisconnectReasonCode::ProtocolError,
        }
    }
}
impl ProtocolError {
    pub(crate) fn violation(reason: DisconnectReasonCode, message: &'static str) -> Self {
        Self::ProtocolViolation(ProtocolViolationError {
            inner: ViolationInner::Common { reason, message },
        })
    }
    pub fn generic_violation(message: &'static str) -> Self {
        Self::violation(DisconnectReasonCode::ProtocolError, message)
    }
    pub(crate) fn unexpected_packet(packet_type: u8, message: &'static str) -> ProtocolError {
        Self::ProtocolViolation(ProtocolViolationError {
            inner: ViolationInner::UnexpectedPacket { packet_type, message },
        })
    }
    pub(crate) fn packet_id_mismatch() -> Self {
        Self::generic_violation(
            "Packet id of PUBACK packet does not match expected next value according to sending order of PUBLISH packets [MQTT-4.6.0-2]"
        )
    }
}
impl<E> From<io::Error> for MqttError<E> {
    fn from(err: io::Error) -> Self {
        MqttError::Handshake(HandshakeError::Disconnected(Some(err)))
    }
}
impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
    fn from(err: Either<io::Error, io::Error>) -> Self {
        MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))
    }
}
impl<E> From<EncodeError> for MqttError<E> {
    fn from(err: EncodeError) -> Self {
        MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))
    }
}
impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
    fn from(err: Either<DecodeError, io::Error>) -> Self {
        match err {
            Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
            Either::Right(err) => HandshakeError::Disconnected(Some(err)),
        }
    }
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)]
pub enum DecodeError {
    #[error("Invalid protocol")]
    InvalidProtocol,
    #[error("Invalid length")]
    InvalidLength,
    #[error("Malformed packet")]
    MalformedPacket,
    #[error("Unsupported protocol level")]
    UnsupportedProtocolLevel,
    #[error("Connect frame's reserved flag is set")]
    ConnectReservedFlagSet,
    #[error("ConnectAck frame's reserved flag is set")]
    ConnAckReservedFlagSet,
    #[error("Invalid client id")]
    InvalidClientId,
    #[error("Unsupported packet type")]
    UnsupportedPacketType,
    #[error("Packet id is required")]
    PacketIdRequired,
    #[error("Max size exceeded")]
    MaxSizeExceeded,
    #[error("utf8 error")]
    Utf8Error,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, thiserror::Error)]
pub enum EncodeError {
    #[error("Packet is bigger than peer's Maximum Packet Size")]
    OverMaxPacketSize,
    #[error("Invalid length")]
    InvalidLength,
    #[error("Malformed packet")]
    MalformedPacket,
    #[error("Packet id is required")]
    PacketIdRequired,
    #[error("Unsupported version")]
    UnsupportedVersion,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, thiserror::Error)]
pub enum SendPacketError {
    #[error("Encoding error {:?}", _0)]
    Encode(#[from] EncodeError),
    #[error("Provided packet id is in use")]
    PacketIdInUse(NonZeroU16),
    #[error("Peer is disconnected")]
    Disconnected,
}
#[derive(Debug, thiserror::Error)]
pub enum ClientError<T: fmt::Debug> {
    #[error("Connect ack failed: {:?}", _0)]
    Ack(T),
    #[error("Protocol error: {:?}", _0)]
    Protocol(#[from] ProtocolError),
    #[error("Handshake timeout")]
    HandshakeTimeout,
    #[error("Peer disconnected")]
    Disconnected(Option<std::io::Error>),
    #[error("Connect error: {}", _0)]
    Connect(#[from] ntex_net::connect::ConnectError),
}
impl<T: fmt::Debug> From<EncodeError> for ClientError<T> {
    fn from(err: EncodeError) -> Self {
        ClientError::Protocol(ProtocolError::Encode(err))
    }
}
impl<T: fmt::Debug> From<Either<DecodeError, std::io::Error>> for ClientError<T> {
    fn from(err: Either<DecodeError, std::io::Error>) -> Self {
        match err {
            Either::Left(err) => ClientError::Protocol(ProtocolError::Decode(err)),
            Either::Right(err) => ClientError::Disconnected(Some(err)),
        }
    }
}