use crate::proto::Protocol;
use crate::proto::{ALPN_H1, ALPN_H2};
use crate::Error;
use crate::Stream;
use crate::{AsyncRead, AsyncWrite};
use futures_util::future::poll_fn;
use futures_util::ready;
use rustls::Session;
use rustls::{ClientConfig, ClientSession};
use std::io;
use std::io::Read;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use webpki::DNSNameRef;
use webpki_roots::TLS_SERVER_ROOTS;
pub(crate) async fn wrap_tls_client(
    stream: impl Stream,
    domain: &str,
    tls_disable_verify: bool,
) -> Result<(impl Stream, Protocol), Error> {
        let mut config = ClientConfig::new();
    config
        .root_store
        .add_server_trust_anchors(&TLS_SERVER_ROOTS);
    if tls_disable_verify {
        config
            .dangerous()
            .set_certificate_verifier(Arc::new(DisabledCertVerified));
    }
    config.alpn_protocols = vec![ALPN_H2.to_owned(), ALPN_H1.to_owned()];
    let config = Arc::new(config);
    let dnsname = DNSNameRef::try_from_ascii_str(domain)?;
    let client = ClientSession::new(&config, dnsname);
    let mut tls = TlsStream::new(stream, client);
    let ret = poll_fn(|cx| Pin::new(&mut tls).poll_handshake(cx)).await;
    trace!("tls handshake: {:?}", ret);
    ret?;
    let proto = Protocol::from_alpn(tls.tls.get_alpn_protocol());
    Ok((tls, proto))
}
struct DisabledCertVerified;
impl rustls::ServerCertVerifier for DisabledCertVerified {
    fn verify_server_cert(
        &self,
        _: &rustls::RootCertStore,
        _: &[rustls::Certificate],
        name: DNSNameRef,
        _: &[u8],
    ) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
        warn!("Ignoring TLS verification for {:?}", name);
        Ok(rustls::ServerCertVerified::assertion())
    }
}
#[cfg(feature = "server")]
use rustls::ServerConfig;
#[cfg(feature = "server")]
pub(crate) fn configure_tls_server(config: &mut ServerConfig) {
    config.alpn_protocols = vec![ALPN_H2.to_owned(), ALPN_H1.to_owned()];
}
#[cfg(feature = "server")]
pub(crate) async fn wrap_tls_server(
    stream: impl Stream,
    config: Arc<ServerConfig>,
) -> Result<(impl Stream, Protocol), Error> {
    use rustls::ServerSession;
    let server = ServerSession::new(&config);
    let mut tls = TlsStream::new(stream, server);
    let ret = poll_fn(|cx| Pin::new(&mut tls).poll_handshake(cx)).await;
    trace!("tls handshake: {:?}", ret);
    ret?;
    let proto = Protocol::from_alpn(tls.tls.get_alpn_protocol());
    Ok((tls, proto))
}
struct TlsStream<S, E> {
    stream: S,
    tls: E,
    read_buf: Vec<u8>,     write_buf: Vec<u8>,
    wants_flush: bool,
    plaintext: Vec<u8>,
    plaintext_idx: usize,
}
impl<S: Stream, E: Session + Unpin + 'static> TlsStream<S, E> {
    pub fn new(stream: S, tls: E) -> Self {
        TlsStream {
            stream,
            tls,
            read_buf: Vec::new(),
            write_buf: Vec::new(),
            wants_flush: false,
            plaintext: Vec::new(),
            plaintext_idx: 0,
        }
    }
    fn plaintext_left(&self) -> usize {
        self.plaintext.len() - self.plaintext_idx
    }
                                            #[allow(clippy::useless_let_if_seq)]
    fn poll_tls(&mut self, cx: &mut Context, poll_for_read: bool) -> Poll<io::Result<()>> {
        loop {
                                    ready!(self.try_write_buf(cx))?;
                        if self.wants_flush {
                ready!(Pin::new(&mut self.stream).poll_flush(cx))?;
                self.wants_flush = false;
            }
                                                            if self.read_buf.is_empty()
                && (poll_for_read && self.plaintext_left() == 0 || self.tls.is_handshaking())
            {
                                let _ = self.try_read_buf(cx);
            }
            let mut did_tls_read_or_write = false;
            if self.tls.wants_read() && !self.read_buf.is_empty() {
                let mut sync = SyncStream::new(
                    &mut self.read_buf,
                    &mut self.write_buf,
                    &mut self.wants_flush,
                );
                                let _ = ready!(blocking_to_poll(self.tls.read_tls(&mut sync), cx))?;
                                self.tls
                    .process_new_packets()
                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
                if !self.tls.is_handshaking() {
                    let _ = self.tls.read_to_end(&mut self.plaintext)?;
                }
                did_tls_read_or_write = true;
            }
            if self.tls.wants_write() {
                let mut sync = SyncStream::new(
                    &mut self.read_buf,
                    &mut self.write_buf,
                    &mut self.wants_flush,
                );
                let _ = ready!(blocking_to_poll(self.tls.write_tls(&mut sync), cx))?;
                did_tls_read_or_write = true;
            }
                                    if did_tls_read_or_write {
                continue;
            }
            if poll_for_read && self.plaintext_left() == 0 {
                                                                return Poll::Pending;
            } else {
                                                                return Poll::Ready(Ok(()));
            }
        }
    }
        fn try_write_buf(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
                if !self.write_buf.is_empty() {
            let to_write = &self.write_buf[..];
            let amount = ready!(Pin::new(&mut self.stream).poll_write(cx, to_write))?;
            let rest = self.write_buf.split_off(amount);
            self.write_buf = rest;
        }
        Ok(()).into()
    }
        fn try_read_buf(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
        let mut tmp = [0; 8_192];
        let amount = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut tmp[..]))?;
        self.read_buf.extend_from_slice(&tmp[0..amount]);
        Ok(()).into()
    }
        fn poll_handshake(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
        let this = self.get_mut();
        ready!(this.poll_tls(cx, false))?;
        if this.tls.is_handshaking() {
            Poll::Pending
        } else {
            Ok(()).into()
        }
    }
}
impl<S: Stream, E: Session + Unpin + 'static> AsyncRead for TlsStream<S, E> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();
        if this.plaintext_left() == 0 {
            ready!(this.poll_tls(cx, true))?;
        }
        let idx = this.plaintext_idx;
        let amt = (&this.plaintext[idx..]).read(buf)?;
        this.plaintext_idx += amt;
        if this.plaintext_idx == this.plaintext.len() {
            this.plaintext_idx = 0;
            this.plaintext.clear();
        }
        Ok(amt).into()
    }
}
impl<S: Stream, E: Session + Unpin + 'static> AsyncWrite for TlsStream<S, E> {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        let this = self.get_mut();
        ready!(this.poll_tls(cx, false))?;
        let amount = this.tls.write(buf)?;
        Ok(amount).into()
    }
    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context,
        bufs: &[io::IoSlice],
    ) -> Poll<Result<usize, io::Error>> {
        let this = self.get_mut();
        ready!(this.poll_tls(cx, false))?;
        let amount = this.tls.write_vectored(bufs)?;
        Ok(amount).into()
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
        let this = self.get_mut();
        ready!(this.poll_tls(cx, false))?;
        this.tls.flush()?;
        ready!(this.poll_tls(cx, false))?;
        Ok(()).into()
    }
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
        let this = self.get_mut();
        ready!(this.poll_tls(cx, false))?;
        this.tls.send_close_notify();
        ready!(this.poll_tls(cx, false))?;
        Pin::new(&mut this.stream).poll_close(cx)
    }
}
struct SyncStream<'a> {
    read_buf: &'a mut Vec<u8>,
    write_buf: &'a mut Vec<u8>,
    wants_flush: &'a mut bool,
}
impl<'a> SyncStream<'a> {
    fn new(
        read_buf: &'a mut Vec<u8>,
        write_buf: &'a mut Vec<u8>,
        wants_flush: &'a mut bool,
    ) -> Self {
        SyncStream {
            read_buf,
            write_buf,
            wants_flush,
        }
    }
}
impl<'a> io::Read for SyncStream<'a> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let from = &mut self.read_buf;
        if from.is_empty() {
            return would_block();
        }
        let amt = from.as_slice().read(buf)?;
        let rest = from.split_off(amt);
        *self.read_buf = rest;
        Ok(amt)
    }
}
impl<'a> io::Write for SyncStream<'a> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let to = &mut self.write_buf;
        to.extend_from_slice(buf);
        Ok(buf.len())
    }
    fn flush(&mut self) -> io::Result<()> {
        *self.wants_flush = true;
        Ok(())
    }
}
fn would_block() -> io::Result<usize> {
    Err(io::Error::new(io::ErrorKind::WouldBlock, "block"))
}
fn blocking_to_poll<T>(result: io::Result<T>, cx: &mut Context) -> Poll<io::Result<T>> {
    match result {
        Ok(v) => Poll::Ready(Ok(v)),
        Err(e) => {
            if e.kind() == io::ErrorKind::WouldBlock {
                cx.waker().wake_by_ref();
                Poll::Pending
            } else {
                Poll::Ready(Err(e))
            }
        }
    }
}