#![allow(clippy::print_stdout)]
use core::{net::SocketAddr, str::FromStr};
use std::sync::Arc;
use std::{env, path::Path, println};
use futures_util::StreamExt;
use rustls::{
ClientConfig, KeyLogFile,
pki_types::{
CertificateDer, PrivateKeyDer,
pem::{self, PemObject},
},
sign::{CertifiedKey, SingleCertAndKey},
};
use test_support::subscribe;
use crate::{
proto::{
op::{Message, Query},
rr::{Name, RecordType},
},
quic::QuicClientStreamBuilder,
tls::default_provider,
xfer::DnsRequestSender,
};
use super::quic_server::QuicServer;
async fn server_responder(mut server: QuicServer) {
while let Some((mut conn, addr)) = server
.next()
.await
.expect("failed to get next quic session")
{
println!("received client request {addr}");
while let Some(stream) = conn.next().await {
let mut stream = stream.expect("new client stream failed");
let bytes = stream.receive_bytes().await.expect("failed to receive");
let client_message = Message::from_vec(&bytes).expect("failed to parse message");
stream
.send(client_message.into_response())
.await
.expect("failed to send response")
}
}
}
#[tokio::test]
async fn test_quic_stream() {
subscribe();
let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
println!("using server src path: {server_path}");
let ca = read_certs(format!("{server_path}/tests/test-data/ca.pem")).unwrap();
let cert_chain = read_certs(format!("{server_path}/tests/test-data/cert.pem")).unwrap();
let key =
PrivateKeyDer::from_pem_file(format!("{server_path}/tests/test-data/cert.key")).unwrap();
let certificate_and_key = SingleCertAndKey::from(
CertifiedKey::from_der(cert_chain, key, &default_provider()).unwrap(),
);
let quic_ns = QuicServer::new(
SocketAddr::from(([127, 0, 0, 1], 0)),
Arc::new(certificate_and_key),
)
.await
.expect("failed to initialize QuicServer");
let server_addr = quic_ns.local_addr().expect("no address");
println!("testing quic on: {server_addr}");
let server_join = tokio::spawn(server_responder(quic_ns));
let mut roots = rustls::RootCertStore::empty();
let (_, ignored) = roots.add_parsable_certificates(ca.into_iter());
assert_eq!(ignored, 0);
let mut client_config = ClientConfig::builder_with_provider(Arc::new(default_provider()))
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(roots)
.with_no_client_auth();
client_config.key_log = Arc::new(KeyLogFile::new());
println!("starting quic connect");
let builder = QuicClientStreamBuilder::default().crypto_config(client_config);
let mut client_stream = builder
.build(server_addr, Arc::from("ns.example.com"))
.await
.expect("failed to connect");
println!("connected client to server");
let mut message = Message::query();
message.add_query(Query::query(
Name::from_str("www.example.test.").unwrap(),
RecordType::AAAA,
));
message.metadata.id = 0;
let bytes = message.to_vec().unwrap();
let message = Message::from_vec(&bytes).unwrap();
let response = client_stream
.send_message(message.clone().into())
.next()
.await
.expect("no response received")
.expect("failed to read response");
assert_eq!(*response, message.into_response());
server_join.abort();
}
fn read_certs(cert_path: impl AsRef<Path>) -> Result<Vec<CertificateDer<'static>>, pem::Error> {
CertificateDer::pem_file_iter(cert_path)?.collect::<Result<Vec<_>, _>>()
}