use core::fmt::{self, Display};
use core::future::{Future, poll_fn};
use core::net::SocketAddr;
use core::pin::Pin;
use core::str::FromStr;
use core::task::{Context, Poll};
use std::sync::Arc;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::stream::Stream;
use h3::client::SendRequest;
use h3_quinn::OpenStreams;
use http::header::{self, CONTENT_LENGTH};
use quinn::{Endpoint, EndpointConfig, TransportConfig};
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::error::NetError;
use crate::http::{RequestContext, SetHeaders, Version};
use crate::proto::ProtoError;
use crate::proto::op::{DnsRequest, DnsResponse};
use crate::quic::connect_quic;
use crate::runtime::{RuntimeProvider, Spawn};
use crate::tls::client_config;
use crate::udp::UdpSocket;
use crate::xfer::{DnsExchange, DnsRequestSender, DnsResponseStream};
use super::ALPN_H3;
#[derive(Clone)]
#[must_use = "futures do nothing unless polled"]
pub struct H3ClientStream {
name_server: SocketAddr,
send_request: SendRequest<OpenStreams, Bytes>,
context: Arc<RequestContext>,
shutdown_tx: mpsc::Sender<()>,
is_shutdown: bool,
}
impl H3ClientStream {
pub fn builder() -> H3ClientStreamBuilder {
H3ClientStreamBuilder {
crypto_config: None,
transport_config: Arc::new(super::transport()),
bind_addr: None,
set_headers: None,
disable_grease: false,
}
}
async fn inner_send(
mut h3: SendRequest<OpenStreams, Bytes>,
message: Bytes,
cx: Arc<RequestContext>,
) -> Result<DnsResponse, NetError> {
let request = cx
.build(message.remaining())
.map_err(|err| NetError::from(format!("bad http request: {err}")))?;
debug!("request: {:#?}", request);
let mut stream = h3
.send_request(request)
.await
.map_err(|err| NetError::from(format!("h3 send_request error: {err}")))?;
stream
.send_data(message)
.await
.map_err(|e| NetError::from(format!("h3 send_data error: {e}")))?;
stream
.finish()
.await
.map_err(|err| NetError::from(format!("received a stream error: {err}")))?;
let response = stream
.recv_response()
.await
.map_err(|err| NetError::from(format!("h3 recv_response error: {err}")))?;
debug!("got response: {:#?}", response);
let content_length = response
.headers()
.get(CONTENT_LENGTH)
.map(|v| v.to_str())
.transpose()
.map_err(|e| NetError::from(format!("bad headers received: {e}")))?
.map(usize::from_str)
.transpose()
.map_err(|e| NetError::from(format!("bad headers received: {e}")))?;
let mut response_bytes =
BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4_096));
while let Some(partial_bytes) = stream
.recv_data()
.await
.map_err(|e| NetError::from(format!("h3 recv_data error: {e}")))?
{
debug!("got bytes: {}", partial_bytes.remaining());
response_bytes.put(partial_bytes);
if let Some(content_length) = content_length {
if response_bytes.len() >= content_length {
break;
}
}
}
if let Some(content_length) = content_length {
if response_bytes.len() != content_length {
return Err(NetError::from(format!(
"expected byte length: {}, got: {}",
content_length,
response_bytes.len()
)));
}
}
if !response.status().is_success() {
let error_string = String::from_utf8_lossy(response_bytes.as_ref());
return Err(NetError::from(format!(
"http unsuccessful code: {}, message: {}",
response.status(),
error_string
)));
} else {
{
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.map(|h| {
h.to_str().map_err(|err| {
NetError::from(format!("ContentType header not a string: {err}"))
})
})
.unwrap_or(Ok(crate::http::MIME_APPLICATION_DNS))?;
if content_type != crate::http::MIME_APPLICATION_DNS {
return Err(NetError::from(format!(
"ContentType unsupported (must be '{}'): '{}'",
crate::http::MIME_APPLICATION_DNS,
content_type
)));
}
}
};
DnsResponse::from_buffer(response_bytes.to_vec()).map_err(NetError::from)
}
}
impl DnsRequestSender for H3ClientStream {
fn send_message(&mut self, mut request: DnsRequest) -> DnsResponseStream {
if self.is_shutdown {
panic!("can not send messages after stream is shutdown")
}
request.metadata.id = 0;
let bytes = match request.to_vec() {
Ok(bytes) => bytes,
Err(err) => return NetError::from(err).into(),
};
Box::pin(Self::inner_send(
self.send_request.clone(),
Bytes::from(bytes),
self.context.clone(),
))
.into()
}
fn shutdown(&mut self) {
self.is_shutdown = true;
}
fn is_shutdown(&self) -> bool {
self.is_shutdown
}
}
impl Stream for H3ClientStream {
type Item = Result<(), NetError>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_shutdown {
return Poll::Ready(None);
}
if self.shutdown_tx.is_closed() {
return Poll::Ready(Some(Err(NetError::from(
"h3 connection is already shutdown",
))));
}
Poll::Ready(Some(Ok(())))
}
}
impl Display for H3ClientStream {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
formatter,
"H3({},{})",
self.name_server, self.context.server_name
)
}
}
#[derive(Clone)]
pub struct H3ClientStreamBuilder {
crypto_config: Option<rustls::ClientConfig>,
transport_config: Arc<TransportConfig>,
bind_addr: Option<SocketAddr>,
set_headers: Option<Arc<dyn SetHeaders>>,
disable_grease: bool,
}
impl H3ClientStreamBuilder {
pub fn crypto_config(mut self, crypto_config: rustls::ClientConfig) -> Self {
self.crypto_config = Some(crypto_config);
self
}
pub fn bind_addr(mut self, bind_addr: SocketAddr) -> Self {
self.bind_addr = Some(bind_addr);
self
}
pub fn set_headers(&mut self, headers: Arc<dyn SetHeaders>) {
self.set_headers.replace(headers);
}
pub fn disable_grease(mut self, disable_grease: bool) -> Self {
self.disable_grease = disable_grease;
self
}
pub fn build(
self,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
) -> impl Future<Output = Result<H3ClientStream, NetError>> + Send + 'static {
self.connect(name_server, server_name, path)
}
pub async fn exchange<P: RuntimeProvider>(
self,
socket: Arc<dyn quinn::AsyncUdpSocket>,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
provider: P,
) -> Result<DnsExchange<P>, NetError> {
let stream = self
.connect_with_future(socket, name_server, server_name, path)
.await?;
let (exchange, bg) = DnsExchange::from_stream(stream);
provider.create_handle().spawn_bg(bg);
Ok(exchange)
}
pub fn build_with_future(
self,
socket: Arc<dyn quinn::AsyncUdpSocket>,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
) -> impl Future<Output = Result<H3ClientStream, NetError>> + Send + 'static {
self.connect_with_future(socket, name_server, server_name, path)
}
async fn connect_with_future(
self,
socket: Arc<dyn quinn::AsyncUdpSocket>,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
) -> Result<H3ClientStream, NetError> {
let endpoint = Endpoint::new_with_abstract_socket(
EndpointConfig::default(),
None,
socket,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, server_name, path)
.await
}
async fn connect(
self,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
) -> Result<H3ClientStream, NetError> {
let connect = if let Some(bind_addr) = self.bind_addr {
<tokio::net::UdpSocket as UdpSocket>::connect_with_bind(name_server, bind_addr)
} else {
<tokio::net::UdpSocket as UdpSocket>::connect(name_server)
};
let socket = connect.await?;
let socket = socket.into_std()?;
let endpoint = Endpoint::new(
EndpointConfig::default(),
None,
socket,
Arc::new(quinn::TokioRuntime),
)?;
self.connect_inner(endpoint, name_server, server_name, path)
.await
}
async fn connect_inner(
self,
endpoint: Endpoint,
name_server: SocketAddr,
server_name: Arc<str>,
path: Arc<str>,
) -> Result<H3ClientStream, NetError> {
let quic_connection = connect_quic(
name_server,
server_name.clone(),
ALPN_H3,
match self.crypto_config {
Some(crypto_config) => crypto_config,
None => client_config()?,
},
self.transport_config,
endpoint,
)
.await?;
let h3_connection = h3_quinn::Connection::new(quic_connection);
let (mut driver, send_request) = h3::client::builder()
.send_grease(!self.disable_grease)
.build(h3_connection)
.await
.map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?;
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
debug!("h3 connection is ready: {}", name_server);
tokio::spawn(async move {
tokio::select! {
error = poll_fn(|cx| driver.poll_close(cx)) => {
if !error.is_h3_no_error() {
warn!(%error, "h3 connection failed to close")
}
}
_ = shutdown_rx.recv() => {
debug!("h3 connection is shutting down: {}", name_server);
}
}
});
Ok(H3ClientStream {
name_server,
send_request,
context: Arc::new(RequestContext {
version: Version::Http3,
server_name,
query_path: path,
set_headers: self.set_headers,
}),
shutdown_tx,
is_shutdown: false,
})
}
}
#[cfg(all(
test,
any(feature = "rustls-platform-verifier", feature = "webpki-roots")
))]
mod tests {
use core::net::SocketAddr;
use core::str::FromStr;
use std::println;
use rustls::KeyLogFile;
use test_support::subscribe;
use tokio::task::JoinSet;
use super::*;
use crate::proto::op::{DnsRequestOptions, Edns, Message, Query};
use crate::proto::rr::{Name, RData, RecordType};
use crate::xfer::FirstAnswer;
#[tokio::test]
async fn test_h3_google() {
subscribe();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::query();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = client_config().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3 = H3ClientStream::builder()
.crypto_config(client_config)
.build(google, Arc::from("dns.google"), Arc::from("/dns-query"))
.await
.expect("h3 connect failed");
let response = h3
.send_message(request)
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::A(_)))
);
let mut request = Message::query();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = h3
.send_message(request.clone())
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::AAAA(_)))
);
}
#[tokio::test]
async fn test_h3_google_with_pure_ip_address_server() {
subscribe();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut request = Message::query();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = client_config().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3 = H3ClientStream::builder()
.crypto_config(client_config)
.build(
google,
Arc::from(google.ip().to_string()),
Arc::from("/dns-query"),
)
.await
.expect("h3 connect failed");
let response = h3
.send_message(request)
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::A(_)))
);
let mut request = Message::query();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = h3
.send_message(request.clone())
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::AAAA(_)))
);
}
#[tokio::test]
async fn test_h3_cloudflare() {
subscribe();
let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
let mut request = Message::query();
let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut client_config = client_config().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let mut h3 = H3ClientStream::builder()
.crypto_config(client_config)
.disable_grease(true)
.build(
cloudflare,
Arc::from("cloudflare-dns.com"),
Arc::from("/dns-query"),
)
.await
.expect("h3 connect failed");
let response = h3
.send_message(request)
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::A(_)))
);
let mut request = Message::query();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
request.metadata.recursion_desired = true;
let mut edns = Edns::new();
edns.set_version(0);
edns.set_max_payload(1232);
request.edns = Some(edns);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let response = h3
.send_message(request)
.first_answer()
.await
.expect("send_message failed");
assert!(
response
.answers
.iter()
.any(|record| matches!(record.data, RData::AAAA(_)))
);
}
#[tokio::test]
#[allow(clippy::print_stdout)]
async fn test_h3_client_stream_clonable() {
subscribe();
let google = SocketAddr::from(([8, 8, 8, 8], 443));
let mut client_config = client_config().unwrap();
client_config.key_log = Arc::new(KeyLogFile::new());
let h3 = H3ClientStream::builder()
.crypto_config(client_config)
.build(google, Arc::from("dns.google"), Arc::from("/dns-query"))
.await
.expect("h3 connect failed");
let mut request = Message::query();
let query = Query::query(
Name::from_str("www.example.com.").unwrap(),
RecordType::AAAA,
);
request.add_query(query);
let request = DnsRequest::new(request, DnsRequestOptions::default());
let mut join_set = JoinSet::new();
for i in 0..50 {
let mut h3 = h3.clone();
let request = request.clone();
join_set.spawn(async move {
let start = std::time::Instant::now();
h3.send_message(request)
.first_answer()
.await
.expect("send_message failed");
println!("request[{i}] completed: {:?}", start.elapsed());
});
}
let total = join_set.len();
let mut idx = 0usize;
while join_set.join_next().await.is_some() {
println!("join_set completed {idx}/{total}");
idx += 1;
}
}
}