use std::{
    borrow::Borrow,
    collections::{hash_map::Entry, HashMap},
    fmt::{self, Display},
    marker::Unpin,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
    time::{Duration, SystemTime, UNIX_EPOCH},
};
use futures_channel::mpsc;
use futures_util::{
    future::Future,
    ready,
    stream::{Stream, StreamExt},
    FutureExt,
};
use rand::{
    self,
    distributions::{Distribution, Standard},
};
use tracing::debug;
use crate::{
    error::{ProtoError, ProtoErrorKind},
    op::{MessageFinalizer, MessageVerifier},
    xfer::{
        ignore_send, BufDnsStreamHandle, DnsClientStream, DnsRequest, DnsRequestSender,
        DnsResponse, DnsResponseStream, SerialMessage, CHANNEL_BUFFER_SIZE,
    },
    DnsStreamHandle, Time,
};
const QOS_MAX_RECEIVE_MSGS: usize = 100; struct ActiveRequest {
    completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
    request_id: u16,
    timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
    verifier: Option<MessageVerifier>,
}
impl ActiveRequest {
    fn new(
        completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
        request_id: u16,
        timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
        verifier: Option<MessageVerifier>,
    ) -> Self {
        Self {
            completion,
            request_id,
            timeout,
            verifier,
        }
    }
    fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
        self.timeout.poll_unpin(cx)
    }
    fn is_canceled(&self) -> bool {
        self.completion.is_closed()
    }
    fn request_id(&self) -> u16 {
        self.request_id
    }
    fn complete_with_error(mut self, error: ProtoError) {
        ignore_send(self.completion.try_send(Err(error)));
    }
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexer<S, MF>
where
    S: DnsClientStream + 'static,
    MF: MessageFinalizer,
{
    stream: S,
    timeout_duration: Duration,
    stream_handle: BufDnsStreamHandle,
    active_requests: HashMap<u16, ActiveRequest>,
    signer: Option<Arc<MF>>,
    is_shutdown: bool,
}
impl<S, MF> DnsMultiplexer<S, MF>
where
    S: DnsClientStream + Unpin + 'static,
    MF: MessageFinalizer,
{
    #[allow(clippy::new_ret_no_self)]
    pub fn new<F>(
        stream: F,
        stream_handle: BufDnsStreamHandle,
        signer: Option<Arc<MF>>,
    ) -> DnsMultiplexerConnect<F, S, MF>
    where
        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
    {
        Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
    }
    pub fn with_timeout<F>(
        stream: F,
        stream_handle: BufDnsStreamHandle,
        timeout_duration: Duration,
        signer: Option<Arc<MF>>,
    ) -> DnsMultiplexerConnect<F, S, MF>
    where
        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
    {
        DnsMultiplexerConnect {
            stream,
            stream_handle: Some(stream_handle),
            timeout_duration,
            signer,
        }
    }
    fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
        let mut canceled = HashMap::<u16, ProtoError>::new();
        for (&id, ref mut active_req) in &mut self.active_requests {
            if active_req.is_canceled() {
                canceled.insert(id, ProtoError::from("requestor canceled"));
            }
            match active_req.poll_timeout(cx) {
                Poll::Ready(()) => {
                    debug!("request timed out: {}", id);
                    canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
                }
                Poll::Pending => (),
            }
        }
        for (id, error) in canceled {
            if let Some(active_request) = self.active_requests.remove(&id) {
                active_request.complete_with_error(error);
            }
        }
    }
    fn next_random_query_id(&self) -> Result<u16, ProtoError> {
        let mut rand = rand::thread_rng();
        for _ in 0..100 {
            let id: u16 = Standard.sample(&mut rand); if !self.active_requests.contains_key(&id) {
                return Ok(id);
            }
        }
        Err(ProtoError::from(
            "id space exhausted, consider filing an issue",
        ))
    }
    fn stream_closed_close_all(&mut self, error: ProtoError) {
        debug!(error = error.as_dyn(), stream = %self.stream);
        for (_, active_request) in self.active_requests.drain() {
            active_request.complete_with_error(error.clone());
        }
    }
}
#[must_use = "futures do nothing unless polled"]
pub struct DnsMultiplexerConnect<F, S, MF>
where
    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
    S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
    MF: MessageFinalizer + Send + Sync + 'static,
{
    stream: F,
    stream_handle: Option<BufDnsStreamHandle>,
    timeout_duration: Duration,
    signer: Option<Arc<MF>>,
}
impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
where
    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
    S: DnsClientStream + Unpin + 'static,
    MF: MessageFinalizer + Send + Sync + 'static,
{
    type Output = Result<DnsMultiplexer<S, MF>, ProtoError>;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let stream: S = ready!(self.stream.poll_unpin(cx))?;
        Poll::Ready(Ok(DnsMultiplexer {
            stream,
            timeout_duration: self.timeout_duration,
            stream_handle: self
                .stream_handle
                .take()
                .expect("must not poll after complete"),
            active_requests: HashMap::new(),
            signer: self.signer.clone(),
            is_shutdown: false,
        }))
    }
}
impl<S, MF> Display for DnsMultiplexer<S, MF>
where
    S: DnsClientStream + 'static,
    MF: MessageFinalizer + Send + Sync + 'static,
{
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        write!(formatter, "{}", self.stream)
    }
}
impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
where
    S: DnsClientStream + Unpin + 'static,
    MF: MessageFinalizer + Send + Sync + 'static,
{
    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
        if self.is_shutdown {
            panic!("can not send messages after stream is shutdown")
        }
        if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
            return ProtoError::from(ProtoErrorKind::Busy).into();
        }
        let query_id = match self.next_random_query_id() {
            Ok(id) => id,
            Err(e) => return e.into(),
        };
        let (mut request, _) = request.into_parts();
        request.set_id(query_id);
        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
            Ok(now) => now.as_secs(),
            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
        };
        let now = now as u32;
        let mut verifier = None;
        if let Some(ref signer) = self.signer {
            if signer.should_finalize_message(&request) {
                match request.finalize::<MF>(signer.borrow(), now) {
                    Ok(answer_verifier) => verifier = answer_verifier,
                    Err(e) => {
                        debug!("could not sign message: {}", e);
                        return e.into();
                    }
                }
            }
        }
        let timeout = S::Time::delay_for(self.timeout_duration);
        let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
        let active_request =
            ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
        match request.to_vec() {
            Ok(buffer) => {
                debug!(id = %active_request.request_id(), "sending message");
                let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
                debug!(
                    "final message: {}",
                    serial_message
                        .to_message()
                        .expect("bizarre we just made this message")
                );
                match self.stream_handle.send(serial_message) {
                    Ok(()) => self
                        .active_requests
                        .insert(active_request.request_id(), active_request),
                    Err(err) => return err.into(),
                };
            }
            Err(e) => {
                debug!(
                    id = %active_request.request_id(),
                    error = e.as_dyn(),
                    "error message"
                );
                return e.into();
            }
        }
        receiver.into()
    }
    fn shutdown(&mut self) {
        self.is_shutdown = true;
    }
    fn is_shutdown(&self) -> bool {
        self.is_shutdown
    }
}
impl<S, MF> Stream for DnsMultiplexer<S, MF>
where
    S: DnsClientStream + Unpin + 'static,
    MF: MessageFinalizer + Send + Sync + 'static,
{
    type Item = Result<(), ProtoError>;
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.drop_cancelled(cx);
        if self.is_shutdown && self.active_requests.is_empty() {
            debug!("stream is done: {}", self);
            return Poll::Ready(None);
        }
        let mut messages_received = 0;
        for i in 0..QOS_MAX_RECEIVE_MSGS {
            match self.stream.poll_next_unpin(cx) {
                Poll::Ready(Some(Ok(buffer))) => {
                    messages_received = i;
                    match buffer.to_message() {
                        Ok(message) => match self.active_requests.entry(message.id()) {
                            Entry::Occupied(mut request_entry) => {
                                let active_request = request_entry.get_mut();
                                if let Some(ref mut verifier) = active_request.verifier {
                                    ignore_send(
                                        active_request
                                            .completion
                                            .try_send(verifier(buffer.bytes())),
                                    );
                                } else {
                                    ignore_send(active_request.completion.try_send(Ok(
                                        DnsResponse::new(message, buffer.into_parts().0),
                                    )));
                                }
                            }
                            Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
                        },
                        Err(error) => debug!(error = error.as_dyn(), "error decoding message"),
                    }
                }
                Poll::Ready(err) => {
                    let err = match err {
                        Some(Err(e)) => e,
                        None => ProtoError::from("stream closed"),
                        _ => unreachable!(),
                    };
                    self.stream_closed_close_all(err);
                    self.is_shutdown = true;
                    return Poll::Ready(None);
                }
                Poll::Pending => break,
            }
        }
        if messages_received == QOS_MAX_RECEIVE_MSGS {
            cx.waker().wake_by_ref();
        }
        Poll::Pending
    }
}
#[cfg(test)]
mod test {
    use super::*;
    use crate::op::message::NoopMessageFinalizer;
    use crate::op::op_code::OpCode;
    use crate::op::{Message, MessageType, Query};
    use crate::rr::record_type::RecordType;
    use crate::rr::{DNSClass, Name, RData, Record};
    use crate::serialize::binary::BinEncodable;
    use crate::xfer::StreamReceiver;
    use crate::xfer::{DnsClientStream, DnsRequestOptions};
    use futures_util::future;
    use futures_util::stream::TryStreamExt;
    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
    struct MockClientStream {
        messages: Vec<Message>,
        addr: SocketAddr,
        id: Option<u16>,
        receiver: Option<StreamReceiver>,
    }
    impl MockClientStream {
        fn new(
            mut messages: Vec<Message>,
            addr: SocketAddr,
        ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
            messages.reverse(); Box::pin(future::ok(Self {
                messages,
                addr,
                id: None,
                receiver: None,
            }))
        }
    }
    impl fmt::Display for MockClientStream {
        fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
            write!(formatter, "TestClientStream")
        }
    }
    impl Stream for MockClientStream {
        type Item = Result<SerialMessage, ProtoError>;
        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
            let id = if let Some(id) = self.id {
                id
            } else {
                let serial = ready!(self
                    .receiver
                    .as_mut()
                    .expect("should only be polled after receiver has been set")
                    .poll_next_unpin(cx));
                let message = serial.unwrap().to_message().unwrap();
                self.id = Some(message.id());
                message.id()
            };
            if let Some(mut message) = self.messages.pop() {
                message.set_id(id);
                Poll::Ready(Some(Ok(SerialMessage::new(
                    message.to_bytes().unwrap(),
                    self.addr,
                ))))
            } else {
                Poll::Pending
            }
        }
    }
    impl DnsClientStream for MockClientStream {
        type Time = crate::TokioTime;
        fn name_server_addr(&self) -> SocketAddr {
            self.addr
        }
    }
    async fn get_mocked_multiplexer(
        mock_response: Vec<Message>,
    ) -> DnsMultiplexer<MockClientStream, NoopMessageFinalizer> {
        let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
        let mock_response = MockClientStream::new(mock_response, addr);
        let (handler, receiver) = BufDnsStreamHandle::new(addr);
        let mut multiplexer =
            DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
                .await
                .unwrap();
        multiplexer.stream.receiver = Some(receiver); multiplexer
    }
    fn a_query_answer() -> (DnsRequest, Vec<Message>) {
        let name = Name::from_ascii("www.example.com").unwrap();
        let mut msg = Message::new();
        msg.add_query({
            let mut query = Query::query(name.clone(), RecordType::A);
            query.set_query_class(DNSClass::IN);
            query
        })
        .set_message_type(MessageType::Query)
        .set_op_code(OpCode::Query)
        .set_recursion_desired(true);
        let query = msg.clone();
        msg.set_message_type(MessageType::Response).add_answer(
            Record::new()
                .set_name(name)
                .set_ttl(86400)
                .set_rr_type(RecordType::A)
                .set_dns_class(DNSClass::IN)
                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34).into())))
                .clone(),
        );
        (
            DnsRequest::new(query, DnsRequestOptions::default()),
            vec![msg],
        )
    }
    fn axfr_query() -> Message {
        let name = Name::from_ascii("example.com").unwrap();
        let mut msg = Message::new();
        msg.add_query({
            let mut query = Query::query(name, RecordType::AXFR);
            query.set_query_class(DNSClass::IN);
            query
        })
        .set_message_type(MessageType::Query)
        .set_op_code(OpCode::Query)
        .set_recursion_desired(true);
        msg
    }
    fn axfr_response() -> Vec<Record> {
        use crate::rr::rdata::*;
        let origin = Name::from_ascii("example.com").unwrap();
        let soa = Record::new()
            .set_name(origin.clone())
            .set_ttl(3600)
            .set_rr_type(RecordType::SOA)
            .set_dns_class(DNSClass::IN)
            .set_data(Some(RData::SOA(SOA::new(
                Name::parse("sns.dns.icann.org.", None).unwrap(),
                Name::parse("noc.dns.icann.org.", None).unwrap(),
                2015082403,
                7200,
                3600,
                1209600,
                3600,
            ))))
            .clone();
        vec![
            soa.clone(),
            Record::new()
                .set_name(origin.clone())
                .set_ttl(86400)
                .set_rr_type(RecordType::NS)
                .set_dns_class(DNSClass::IN)
                .set_data(Some(RData::NS(NS(Name::parse(
                    "a.iana-servers.net.",
                    None,
                )
                .unwrap()))))
                .clone(),
            Record::new()
                .set_name(origin.clone())
                .set_ttl(86400)
                .set_rr_type(RecordType::NS)
                .set_dns_class(DNSClass::IN)
                .set_data(Some(RData::NS(NS(Name::parse(
                    "b.iana-servers.net.",
                    None,
                )
                .unwrap()))))
                .clone(),
            Record::new()
                .set_name(origin.clone())
                .set_ttl(86400)
                .set_rr_type(RecordType::A)
                .set_dns_class(DNSClass::IN)
                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34).into())))
                .clone(),
            Record::new()
                .set_name(origin)
                .set_ttl(86400)
                .set_rr_type(RecordType::AAAA)
                .set_dns_class(DNSClass::IN)
                .set_data(Some(RData::AAAA(
                    Ipv6Addr::new(0x2606, 0x2800, 0x220, 0x1, 0x248, 0x1893, 0x25c8, 0x1946).into(),
                )))
                .clone(),
            soa,
        ]
    }
    fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
        let mut msg = axfr_query();
        let query = msg.clone();
        msg.set_message_type(MessageType::Response)
            .insert_answers(axfr_response());
        (
            DnsRequest::new(query, DnsRequestOptions::default()),
            vec![msg],
        )
    }
    fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
        let base = axfr_query();
        let query = base.clone();
        let mut rr = axfr_response();
        let rr2 = rr.split_off(3);
        let mut msg1 = base.clone();
        msg1.set_message_type(MessageType::Response)
            .insert_answers(rr);
        let mut msg2 = base;
        msg2.set_message_type(MessageType::Response)
            .insert_answers(rr2);
        (
            DnsRequest::new(query, DnsRequestOptions::default()),
            vec![msg1, msg2],
        )
    }
    #[tokio::test]
    async fn test_multiplexer_a() {
        let (query, answer) = a_query_answer();
        let mut multiplexer = get_mocked_multiplexer(answer).await;
        let response = multiplexer.send_message(query);
        let response = tokio::select! {
            _ = multiplexer.next() => {
                panic!("should never end")
            },
            r = response.try_collect::<Vec<_>>() => r.unwrap(),
        };
        assert_eq!(response.len(), 1);
    }
    #[tokio::test]
    async fn test_multiplexer_axfr() {
        let (query, answer) = axfr_query_answer();
        let mut multiplexer = get_mocked_multiplexer(answer).await;
        let response = multiplexer.send_message(query);
        let response = tokio::select! {
            _ = multiplexer.next() => {
                panic!("should never end")
            },
            r = response.try_collect::<Vec<_>>() => r.unwrap(),
        };
        assert_eq!(response.len(), 1);
        assert_eq!(response[0].answers().len(), axfr_response().len());
    }
    #[tokio::test]
    async fn test_multiplexer_axfr_multi() {
        let (query, answer) = axfr_query_answer_multi();
        let mut multiplexer = get_mocked_multiplexer(answer).await;
        let response = multiplexer.send_message(query);
        let response = tokio::select! {
            _ = multiplexer.next() => {
                panic!("should never end")
            },
            r = response.try_collect::<Vec<_>>() => r.unwrap(),
        };
        assert_eq!(response.len(), 2);
        assert_eq!(
            response.iter().map(|m| m.answers().len()).sum::<usize>(),
            axfr_response().len()
        );
    }
}