use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use futures_util::future::BoxFuture;
use rustls::pki_types::ServerName;
use crate::proto::BufDnsStreamHandle;
use crate::proto::ProtoError;
use crate::proto::rustls::TlsClientStream;
use crate::proto::rustls::tls_client_stream::tls_client_connect_with_future;
use crate::proto::tcp::DnsTcpStream;
pub(crate) fn new_tls_stream_with_future<S, F>(
future: F,
socket_addr: SocketAddr,
server_name: ServerName<'static>,
mut tls_config: rustls::ClientConfig,
) -> (
BoxFuture<'static, Result<TlsClientStream<S>, ProtoError>>,
BufDnsStreamHandle,
)
where
S: DnsTcpStream,
F: Future<Output = io::Result<S>> + Send + Unpin + 'static,
{
tls_config.enable_sni = false;
let (stream, handle) =
tls_client_connect_with_future(future, socket_addr, server_name, Arc::new(tls_config));
(Box::pin(stream), handle)
}
#[cfg(feature = "__tls")]
#[cfg(any(feature = "webpki-roots", feature = "rustls-platform-verifier"))]
#[cfg(test)]
mod tests {
use test_support::subscribe;
use crate::TokioResolver;
use crate::config::{CLOUDFLARE, GOOGLE, ResolverConfig};
use crate::name_server::TokioConnectionProvider;
async fn tls_test(config: ResolverConfig) {
let mut resolver_builder =
TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
resolver_builder.options_mut().try_tcp_on_error = true;
let resolver = resolver_builder.build();
let response = resolver
.lookup_ip("www.example.com.")
.await
.expect("failed to run lookup");
assert_ne!(response.iter().count(), 0);
}
#[tokio::test]
async fn test_google_tls() {
subscribe();
tls_test(ResolverConfig::tls(&GOOGLE)).await
}
#[tokio::test]
async fn test_cloudflare_tls() {
subscribe();
tls_test(ResolverConfig::tls(&CLOUDFLARE)).await
}
}