use std::borrow::Borrow;
use std::fmt::{self, Display};
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures_util::{future::Future, stream::Stream};
use tracing::{debug, trace, warn};
use crate::error::ProtoError;
use crate::op::message::NoopMessageFinalizer;
use crate::op::{Message, MessageFinalizer, MessageVerifier};
use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket};
use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
use crate::Time;
#[must_use = "futures do nothing unless polled"]
pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
where
    S: Send,
    MF: MessageFinalizer,
{
    name_server: SocketAddr,
    timeout: Duration,
    is_shutdown: bool,
    signer: Option<Arc<MF>>,
    creator: UdpCreator<S>,
    marker: PhantomData<S>,
}
impl<S: UdpSocket + Send + 'static> UdpClientStream<S, NoopMessageFinalizer> {
    #[allow(clippy::new_ret_no_self)]
    pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
        Self::with_timeout(name_server, Duration::from_secs(5))
    }
    pub fn with_timeout(
        name_server: SocketAddr,
        timeout: Duration,
    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
        Self::with_bind_addr_and_timeout(name_server, None, timeout)
    }
    pub fn with_bind_addr_and_timeout(
        name_server: SocketAddr,
        bind_addr: Option<SocketAddr>,
        timeout: Duration,
    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
        Self::with_timeout_and_signer_and_bind_addr(name_server, timeout, None, bind_addr)
    }
}
impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> UdpClientStream<S, MF> {
    pub fn with_timeout_and_signer(
        name_server: SocketAddr,
        timeout: Duration,
        signer: Option<Arc<MF>>,
    ) -> UdpClientConnect<S, MF> {
        UdpClientConnect {
            name_server,
            timeout,
            signer,
            creator: Arc::new(|local_addr: _, server_addr: _| {
                Box::pin(NextRandomUdpSocket::<S>::new(
                    &server_addr,
                    &Some(local_addr),
                ))
            }),
            marker: PhantomData::<S>,
        }
    }
    pub fn with_timeout_and_signer_and_bind_addr(
        name_server: SocketAddr,
        timeout: Duration,
        signer: Option<Arc<MF>>,
        bind_addr: Option<SocketAddr>,
    ) -> UdpClientConnect<S, MF> {
        UdpClientConnect {
            name_server,
            timeout,
            signer,
            creator: Arc::new(move |local_addr: _, server_addr: _| {
                Box::pin(NextRandomUdpSocket::<S>::new(
                    &server_addr,
                    &Some(bind_addr.unwrap_or(local_addr)),
                ))
            }),
            marker: PhantomData::<S>,
        }
    }
}
impl<S: DnsUdpSocket + Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
    pub fn with_creator(
        name_server: SocketAddr,
        signer: Option<Arc<MF>>,
        timeout: Duration,
        creator: UdpCreator<S>,
    ) -> UdpClientConnect<S, MF> {
        UdpClientConnect {
            name_server,
            timeout,
            signer,
            creator,
            marker: PhantomData::<S>,
        }
    }
}
impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        write!(formatter, "UDP({})", self.name_server)
    }
}
fn random_query_id() -> u16 {
    use rand::distributions::{Distribution, Standard};
    let mut rand = rand::thread_rng();
    Standard.sample(&mut rand)
}
impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
    for UdpClientStream<S, MF>
{
    fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
        if self.is_shutdown {
            panic!("can not send messages after stream is shutdown")
        }
        message.set_id(random_query_id());
        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
            Ok(now) => now.as_secs(),
            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
        };
        let now = now as u32;
        let mut verifier = None;
        if let Some(ref signer) = self.signer {
            if signer.should_finalize_message(&message) {
                match message.finalize::<MF>(signer.borrow(), now) {
                    Ok(answer_verifier) => verifier = answer_verifier,
                    Err(e) => {
                        debug!("could not sign message: {}", e);
                        return e.into();
                    }
                }
            }
        }
        let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize);
        let bytes = match message.to_vec() {
            Ok(bytes) => bytes,
            Err(err) => {
                return err.into();
            }
        };
        let message_id = message.id();
        let message = SerialMessage::new(bytes, self.name_server);
        debug!(
            "final message: {}",
            message
                .to_message()
                .expect("bizarre we just made this message")
        );
        let creator = self.creator.clone();
        let addr = message.addr();
        S::Time::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
            self.timeout,
            Box::pin(async move {
                let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?;
                send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
                    .await
            }),
        )
        .into()
    }
    fn shutdown(&mut self) {
        self.is_shutdown = true;
    }
    fn is_shutdown(&self) -> bool {
        self.is_shutdown
    }
}
impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
    type Item = Result<(), ProtoError>;
    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        if self.is_shutdown {
            Poll::Ready(None)
        } else {
            Poll::Ready(Some(Ok(())))
        }
    }
}
pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
where
    S: Send,
    MF: MessageFinalizer,
{
    name_server: SocketAddr,
    timeout: Duration,
    signer: Option<Arc<MF>>,
    creator: UdpCreator<S>,
    marker: PhantomData<S>,
}
impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
    type Output = Result<UdpClientStream<S, MF>, ProtoError>;
    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
        Poll::Ready(Ok(UdpClientStream::<S, MF> {
            name_server: self.name_server,
            is_shutdown: false,
            timeout: self.timeout,
            signer: self.signer.take(),
            creator: self.creator.clone(),
            marker: PhantomData,
        }))
    }
}
async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
    msg: SerialMessage,
    msg_id: u16,
    verifier: Option<MessageVerifier>,
    socket: S,
    recv_buf_size: usize,
) -> Result<DnsResponse, ProtoError> {
    let bytes = msg.bytes();
    let addr = msg.addr();
    let len_sent: usize = socket.send_to(bytes, addr).await?;
    if bytes.len() != len_sent {
        return Err(ProtoError::from(format!(
            "Not all bytes of message sent, {} of {}",
            len_sent,
            bytes.len()
        )));
    }
    trace!("creating UDP receive buffer with size {recv_buf_size}");
    let mut recv_buf = vec![0; recv_buf_size];
    loop {
        let (len, src) = socket.recv_from(&mut recv_buf).await?;
        let buffer: Vec<_> = Vec::from(&recv_buf[0..len]);
        let request_target = msg.addr();
        if src != request_target {
            warn!(
                "ignoring response from {} because it does not match name_server: {}.",
                src, request_target,
            );
            continue;
        }
        match Message::from_vec(&buffer) {
            Ok(message) => {
                if msg_id == message.id() {
                    debug!("received message id: {}", message.id());
                    if let Some(mut verifier) = verifier {
                        return verifier(&buffer);
                    } else {
                        return Ok(DnsResponse::new(message, buffer));
                    }
                } else {
                    warn!(
                        "expected message id: {} got: {}, dropped",
                        msg_id,
                        message.id()
                    );
                    continue;
                }
            }
            Err(e) => {
                warn!(
                    "dropped malformed message waiting for id: {} err: {}",
                    msg_id, e
                );
                continue;
            }
        }
    }
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
    #![allow(clippy::dbg_macro, clippy::print_stdout)]
    use crate::tests::udp_client_stream_test;
    #[cfg(not(target_os = "linux"))]
    use std::net::Ipv6Addr;
    use std::net::{IpAddr, Ipv4Addr};
    use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
    #[test]
    fn test_udp_client_stream_ipv4() {
        let io_loop = Runtime::new().expect("failed to create tokio runtime");
        udp_client_stream_test::<TokioUdpSocket, Runtime>(
            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
            io_loop,
        )
    }
    #[test]
    #[cfg(not(target_os = "linux"))] fn test_udp_client_stream_ipv6() {
        let io_loop = Runtime::new().expect("failed to create tokio runtime");
        udp_client_stream_test::<TokioUdpSocket, Runtime>(
            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
            io_loop,
        )
    }
}