use std::fmt::{self, Display};
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::future::FutureExt;
use futures_util::stream::Stream;
use h3::client::{Connection, SendRequest};
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig};
use rustls::ClientConfig as TlsClientConfig;
use tracing::debug;
use crate::error::ProtoError;
use crate::http::Version;
use crate::op::Message;
use crate::quic::quic_socket::QuinnAsyncUdpSocketAdapter;
use crate::quic::QuicLocalAddr;
use crate::udp::{DnsUdpSocket, UdpSocket};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
use super::ALPN_H3;
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
    name_server_name: Arc<str>,
    name_server: SocketAddr,
    driver: Connection<h3_quinn::Connection, Bytes>,
    send_request: SendRequest<OpenStreams, Bytes>,
    is_shutdown: bool,
}
impl Display for H3ClientStream {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        write!(
            formatter,
            "H3({},{})",
            self.name_server, self.name_server_name
        )
    }
}
impl H3ClientStream {
    pub fn builder() -> H3ClientStreamBuilder {
        H3ClientStreamBuilder::default()
    }
    async fn inner_send(
        mut h3: SendRequest<OpenStreams, Bytes>,
        message: Bytes,
        name_server_name: Arc<str>,
    ) -> Result<DnsResponse, ProtoError> {
        let request =
            crate::http::request::new(Version::Http3, &name_server_name, message.remaining());
        let request =
            request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
        debug!("request: {:#?}", request);
        let mut stream = h3
            .send_request(request)
            .await
            .map_err(|err| ProtoError::from(format!("h3 send_request error: {err}")))?;
        stream
            .send_data(message)
            .await
            .map_err(|e| ProtoError::from(format!("h3 send_data error: {e}")))?;
        stream
            .finish()
            .await
            .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
        let response = stream
            .recv_response()
            .await
            .map_err(|err| ProtoError::from(format!("h3 recv_response error: {err}")))?;
        debug!("got response: {:#?}", response);
        let content_length = response
            .headers()
            .get(CONTENT_LENGTH)
            .map(|v| v.to_str())
            .transpose()
            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
            .map(usize::from_str)
            .transpose()
            .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
        let mut response_bytes =
            BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
        while let Some(partial_bytes) = stream
            .recv_data()
            .await
            .map_err(|e| ProtoError::from(format!("h3 recv_data error: {e}")))?
        {
            debug!("got bytes: {}", partial_bytes.remaining());
            response_bytes.put(partial_bytes);
            if let Some(content_length) = content_length {
                if response_bytes.len() >= content_length {
                    break;
                }
            }
        }
        if let Some(content_length) = content_length {
            if response_bytes.len() != content_length {
                return Err(ProtoError::from(format!(
                    "expected byte length: {}, got: {}",
                    content_length,
                    response_bytes.len()
                )));
            }
        }
        if !response.status().is_success() {
            let error_string = String::from_utf8_lossy(response_bytes.as_ref());
            return Err(ProtoError::from(format!(
                "http unsuccessful code: {}, message: {}",
                response.status(),
                error_string
            )));
        } else {
            {
                let content_type = response
                    .headers()
                    .get(header::CONTENT_TYPE)
                    .map(|h| {
                        h.to_str().map_err(|err| {
                            ProtoError::from(format!("ContentType header not a string: {err}"))
                        })
                    })
                    .unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
                if content_type != crate::http::MIME_APPLICATION_DNS {
                    return Err(ProtoError::from(format!(
                        "ContentType unsupported (must be '{}'): '{}'",
                        crate::http::MIME_APPLICATION_DNS,
                        content_type
                    )));
                }
            }
        };
        let message = Message::from_vec(&response_bytes)?;
        Ok(DnsResponse::new(message, response_bytes.to_vec()))
    }
}
impl DnsRequestSender for H3ClientStream {
    fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
        if self.is_shutdown {
            panic!("can not send messages after stream is shutdown")
        }
        message.set_id(0);
        let bytes = match message.to_vec() {
            Ok(bytes) => bytes,
            Err(err) => return err.into(),
        };
        Box::pin(Self::inner_send(
            self.send_request.clone(),
            Bytes::from(bytes),
            Arc::clone(&self.name_server_name),
        ))
        .into()
    }
    fn shutdown(&mut self) {
        self.is_shutdown = true;
    }
    fn is_shutdown(&self) -> bool {
        self.is_shutdown
    }
}
impl Stream for H3ClientStream {
    type Item = Result<(), ProtoError>;
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        if self.is_shutdown {
            return Poll::Ready(None);
        }
        match self.driver.poll_close(cx) {
            Poll::Ready(Ok(())) => Poll::Ready(None),
            Poll::Pending => Poll::Pending,
            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
                "h3 stream errored: {e}",
            ))))),
        }
    }
}
#[derive(Clone)]
pub struct H3ClientStreamBuilder {
    crypto_config: TlsClientConfig,
    transport_config: Arc<TransportConfig>,
    bind_addr: Option<SocketAddr>,
}
impl H3ClientStreamBuilder {
    pub fn crypto_config(&mut self, crypto_config: TlsClientConfig) -> &mut Self {
        self.crypto_config = crypto_config;
        self
    }
    pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
        self.bind_addr = Some(bind_addr);
    }
    pub fn build(self, name_server: SocketAddr, dns_name: String) -> H3ClientConnect {
        H3ClientConnect(Box::pin(self.connect(name_server, dns_name)) as _)
    }
    pub fn build_with_future<S, F>(
        self,
        future: F,
        name_server: SocketAddr,
        dns_name: String,
    ) -> H3ClientConnect
    where
        S: DnsUdpSocket + QuicLocalAddr + 'static,
        F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
    {
        H3ClientConnect(Box::pin(self.connect_with_future(future, name_server, dns_name)) as _)
    }
    async fn connect_with_future<S, F>(
        self,
        future: F,
        name_server: SocketAddr,
        dns_name: String,
    ) -> Result<H3ClientStream, ProtoError>
    where
        S: DnsUdpSocket + QuicLocalAddr + 'static,
        F: Future<Output = std::io::Result<S>> + Send,
    {
        let socket = future.await?;
        let wrapper = QuinnAsyncUdpSocketAdapter { io: socket };
        let endpoint = Endpoint::new_with_abstract_socket(
            EndpointConfig::default(),
            None,
            wrapper,
            Arc::new(quinn::TokioRuntime),
        )?;
        self.connect_inner(endpoint, name_server, dns_name).await
    }
    async fn connect(
        self,
        name_server: SocketAddr,
        dns_name: String,
    ) -> Result<H3ClientStream, ProtoError> {
        let connect = if let Some(bind_addr) = self.bind_addr {
            <tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
        } else {
            <tokio::net::UdpSocket as UdpSocket>::connect(name_server)
        };
        let socket = connect.await?;
        let socket = socket.into_std()?;
        let endpoint = Endpoint::new(
            EndpointConfig::default(),
            None,
            socket,
            Arc::new(quinn::TokioRuntime),
        )?;
        self.connect_inner(endpoint, name_server, dns_name).await
    }
    async fn connect_inner(
        self,
        mut endpoint: Endpoint,
        name_server: SocketAddr,
        dns_name: String,
    ) -> Result<H3ClientStream, ProtoError> {
        let mut crypto_config = self.crypto_config;
        if crypto_config.alpn_protocols.is_empty() {
            crypto_config.alpn_protocols = vec![ALPN_H3.to_vec()];
        }
        let early_data_enabled = crypto_config.enable_early_data;
        let mut client_config = ClientConfig::new(Arc::new(crypto_config));
        client_config.transport_config(self.transport_config.clone());
        endpoint.set_default_client_config(client_config);
        let connecting = endpoint.connect(name_server, &dns_name)?;
        let quic_connection = if early_data_enabled {
            match connecting.into_0rtt() {
                Ok((new_connection, _)) => new_connection,
                Err(connecting) => connecting.await?,
            }
        } else {
            connecting.await?
        };
        let h3_connection = h3_quinn::Connection::new(quic_connection);
        let (driver, send_request) = h3::client::new(h3_connection)
            .await
            .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
        Ok(H3ClientStream {
            name_server_name: Arc::from(dns_name),
            name_server,
            driver,
            send_request,
            is_shutdown: false,
        })
    }
}
impl Default for H3ClientStreamBuilder {
    fn default() -> Self {
        Self {
            crypto_config: super::client_config_tls13().unwrap(),
            transport_config: Arc::new(super::transport()),
            bind_addr: None,
        }
    }
}
pub struct H3ClientConnect(
    Pin<Box<dyn Future<Output = Result<H3ClientStream, ProtoError>> + Send>>,
);
impl Future for H3ClientConnect {
    type Output = Result<H3ClientStream, ProtoError>;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.0.poll_unpin(cx)
    }
}
pub struct H3ClientResponse(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>);
impl Future for H3ClientResponse {
    type Output = Result<DnsResponse, ProtoError>;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.0.as_mut().poll(cx).map_err(ProtoError::from)
    }
}
#[cfg(all(test, any(feature = "native-certs", feature = "webpki-roots")))]
mod tests {
    use std::net::SocketAddr;
    use std::str::FromStr;
    use rustls::KeyLogFile;
    use tokio::runtime::Runtime;
    use crate::op::{Message, Query, ResponseCode};
    use crate::rr::rdata::{A, AAAA};
    use crate::rr::{Name, RData, RecordType};
    use crate::xfer::{DnsRequestOptions, FirstAnswer};
    use super::*;
    #[test]
    fn test_h3_google() {
        let google = SocketAddr::from(([8, 8, 8, 8], 443));
        let mut request = Message::new();
        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        let mut client_config = super::super::client_config_tls13().unwrap();
        client_config.key_log = Arc::new(KeyLogFile::new());
        let mut h3_builder = H3ClientStream::builder();
        h3_builder.crypto_config(client_config);
        let connect = h3_builder.build(google, "dns.google".to_string());
        let runtime = Runtime::new().expect("could not start runtime");
        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
        let response = runtime
            .block_on(h3.send_message(request).first_answer())
            .expect("send_message failed");
        let record = &response.answers()[0];
        let addr = record
            .data()
            .and_then(RData::as_a)
            .expect("Expected A record");
        assert_eq!(addr, &A::new(93, 184, 216, 34));
        let mut request = Message::new();
        let query = Query::query(
            Name::from_str("www.example.com.").unwrap(),
            RecordType::AAAA,
        );
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        for _ in 0..3 {
            let response = runtime
                .block_on(h3.send_message(request.clone()).first_answer())
                .expect("send_message failed");
            if response.response_code() == ResponseCode::ServFail {
                continue;
            }
            let record = &response.answers()[0];
            let addr = record
                .data()
                .and_then(RData::as_aaaa)
                .expect("invalid response, expected A record");
            assert_eq!(
                addr,
                &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
            );
        }
    }
    #[test]
    fn test_h3_google_with_pure_ip_address_server() {
        let google = SocketAddr::from(([8, 8, 8, 8], 443));
        let mut request = Message::new();
        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        let mut client_config = super::super::client_config_tls13().unwrap();
        client_config.key_log = Arc::new(KeyLogFile::new());
        let mut h3_builder = H3ClientStream::builder();
        h3_builder.crypto_config(client_config);
        let connect = h3_builder.build(google, google.ip().to_string());
        let runtime = Runtime::new().expect("could not start runtime");
        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
        let response = runtime
            .block_on(h3.send_message(request).first_answer())
            .expect("send_message failed");
        let record = &response.answers()[0];
        let addr = record
            .data()
            .and_then(RData::as_a)
            .expect("Expected A record");
        assert_eq!(addr, &A::new(93, 184, 216, 34));
        let mut request = Message::new();
        let query = Query::query(
            Name::from_str("www.example.com.").unwrap(),
            RecordType::AAAA,
        );
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        for _ in 0..3 {
            let response = runtime
                .block_on(h3.send_message(request.clone()).first_answer())
                .expect("send_message failed");
            if response.response_code() == ResponseCode::ServFail {
                continue;
            }
            let record = &response.answers()[0];
            let addr = record
                .data()
                .and_then(RData::as_aaaa)
                .expect("invalid response, expected A record");
            assert_eq!(
                addr,
                &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
            );
        }
    }
    #[test]
    #[ignore] fn test_h3_cloudflare() {
        let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
        let mut request = Message::new();
        let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        let mut client_config = super::super::client_config_tls13().unwrap();
        client_config.key_log = Arc::new(KeyLogFile::new());
        let mut h3_builder = H3ClientStream::builder();
        h3_builder.crypto_config(client_config);
        let connect = h3_builder.build(cloudflare, "cloudflare-dns.com".to_string());
        let runtime = Runtime::new().expect("could not start runtime");
        let mut h3 = runtime.block_on(connect).expect("h3 connect failed");
        let response = runtime
            .block_on(h3.send_message(request).first_answer())
            .expect("send_message failed");
        let record = &response.answers()[0];
        let addr = record
            .data()
            .and_then(RData::as_a)
            .expect("invalid response, expected A record");
        assert_eq!(addr, &A::new(93, 184, 216, 34));
        let mut request = Message::new();
        let query = Query::query(
            Name::from_str("www.example.com.").unwrap(),
            RecordType::AAAA,
        );
        request.add_query(query);
        let request = DnsRequest::new(request, DnsRequestOptions::default());
        let response = runtime
            .block_on(h3.send_message(request).first_answer())
            .expect("send_message failed");
        let record = &response.answers()[0];
        let addr = record
            .data()
            .and_then(RData::as_aaaa)
            .expect("invalid response, expected A record");
        assert_eq!(
            addr,
            &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
        );
    }
}