use core::net::{IpAddr, Ipv4Addr, SocketAddr};
use core::str::FromStr;
use core::sync::atomic::{AtomicBool, Ordering};
use core::time::Duration;
use std::net::{ToSocketAddrs, UdpSocket};
use std::sync::Arc;
use std::{println, process, thread};
use futures_util::stream::StreamExt;
use tracing::debug;
use crate::NetError;
use crate::proto::op::{DnsRequest, DnsRequestOptions, DnsResponse, Message, Query, SerialMessage};
use crate::proto::rr::rdata::NULL;
use crate::proto::rr::{Name, RData, Record, RecordType};
use crate::runtime::RuntimeProvider;
use crate::udp::{UdpClientStream, UdpStream};
use crate::xfer::dns_handle::DnsStreamHandle;
use crate::xfer::{DnsRequestSender, FirstAnswer};
pub(super) async fn next_random_socket_test(provider: impl RuntimeProvider) {
let (stream, _) = UdpStream::new(
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 52),
None,
None,
false,
provider,
);
drop(stream.await.expect("failed to get next socket address"));
}
pub(super) async fn udp_stream_test<P: RuntimeProvider>(server_addr: IpAddr, provider: P) {
let stop_thread_killer = start_thread_killer();
let server = UdpSocket::bind(SocketAddr::new(server_addr, 0)).unwrap();
server
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap(); server
.set_write_timeout(Some(Duration::from_secs(5)))
.unwrap(); let server_addr = server.local_addr().unwrap();
println!("server listening on: {server_addr}");
let test_bytes: &'static [u8; 8] = b"DEADBEEF";
let send_recv_times = 4u32;
let server_handle = thread::Builder::new()
.name("test_udp_stream_ipv4:server".to_string())
.spawn(move || {
let mut buffer = [0_u8; 512];
for _ in 0..send_recv_times {
let (len, addr) = server.recv_from(&mut buffer).expect("receive failed");
assert_eq!(&buffer[0..len], test_bytes);
assert_eq!(
server.send_to(&buffer[0..len], addr).expect("send failed"),
len
);
}
})
.unwrap();
let client_addr = match server_addr {
SocketAddr::V4(_) => "127.0.0.1:0",
SocketAddr::V6(_) => "[::1]:0",
};
println!("binding client socket");
let socket = provider
.bind_udp(
client_addr.to_socket_addrs().unwrap().next().unwrap(),
server_addr,
)
.await
.expect("could not create socket"); println!("bound client socket");
let (mut stream, mut sender) = UdpStream::<P>::with_bound(socket, server_addr);
for _i in 0..send_recv_times {
sender
.send(SerialMessage::new(test_bytes.to_vec(), server_addr))
.unwrap();
let buffer_and_addr = stream.next().await;
let message = buffer_and_addr.expect("no message").expect("io error");
assert_eq!(message.bytes(), test_bytes);
assert_eq!(message.addr(), server_addr);
}
stop_thread_killer.store(true, Ordering::Relaxed);
server_handle.join().expect("server thread failed");
}
#[allow(clippy::print_stdout)]
pub(super) async fn udp_client_stream_test(server_addr: IpAddr, provider: impl RuntimeProvider) {
udp_client_stream_test_inner(
server_addr,
provider,
"udp_client_stream",
4,
1,
|_, _| {},
|response| match response {
Ok(response) => {
let response = Message::from(response);
if let RData::NULL(null) = &response.answers[0].data {
assert_eq!(null.anything, b"DEADBEEF");
true
} else {
panic!("not a NULL response");
}
}
Err(_) => false,
},
)
.await;
}
#[allow(clippy::print_stdout)]
pub(super) async fn udp_client_stream_bad_id_test(
server_addr: IpAddr,
provider: impl RuntimeProvider,
) {
udp_client_stream_test_inner(
server_addr,
provider,
"udp_client_stream_bad_id",
1,
1,
|idx, message| {
if idx == 0 {
message.metadata.id = message.id.wrapping_add(1);
}
},
|response| {
matches!(response, Err(NetError::Timeout))
},
)
.await;
}
#[allow(clippy::print_stdout)]
pub(super) async fn udp_client_stream_response_limit_test(
server_addr: IpAddr,
provider: impl RuntimeProvider,
) {
udp_client_stream_test_inner(
server_addr,
provider,
"udp_client_stream_response_limit",
1,
4,
|idx, message| {
if idx < 3 {
message.queries.clear();
message.add_query(Query::query(
Name::from_str("wrong.name.").unwrap(),
RecordType::A,
));
}
},
|response| {
matches!(
response,
Err(NetError::Message("udp receive attempts exceeded"))
)
},
)
.await;
}
async fn udp_client_stream_test_inner(
server_addr: IpAddr,
provider: impl RuntimeProvider,
test_name: &str,
request_count: usize,
response_count: usize,
response_mutator: impl Fn(usize, &mut Message) + Send + 'static,
accept_response: impl Fn(Result<DnsResponse, NetError>) -> bool,
) {
let stop_thread_killer = start_thread_killer();
let server = UdpSocket::bind(SocketAddr::new(server_addr, 0)).unwrap();
server
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap(); server
.set_write_timeout(Some(Duration::from_secs(5)))
.unwrap(); let server_addr = server.local_addr().unwrap();
let mut query = Message::query();
let query_name = Name::from_str("dead.beef.").unwrap();
query.add_query(Query::query(query_name.clone(), RecordType::NULL));
let test_bytes: &'static [u8; 8] = b"DEADBEEF";
let server_done = stop_thread_killer.clone();
let test_name_server = query_name;
let server_handle = thread::Builder::new()
.name(format!("{test_name}:server"))
.spawn(move || {
let mut buffer = [0_u8; 512];
for i in 0..request_count {
debug!("server receiving request {}", i);
let (len, addr) = server.recv_from(&mut buffer).expect("receive failed");
debug!("server received request {} from: {}", i, addr);
let request = Message::from_vec(&buffer[0..len]).expect("failed parse of request");
assert_eq!(*request.queries[0].name(), test_name_server.clone());
assert_eq!(request.queries[0].query_type(), RecordType::NULL);
for response_idx in 0..response_count {
let mut message = request.clone().into_response();
message.add_answer(Record::from_rdata(
test_name_server.clone(),
0,
RData::NULL(NULL::with(test_bytes.to_vec())),
));
response_mutator(response_idx, &mut message);
let bytes = message.to_vec().unwrap();
server.send_to(&bytes, addr).expect("send failed");
debug!("server sent response {response_idx} for request {i}");
}
thread::yield_now();
}
while !server_done.load(Ordering::Relaxed) {
thread::sleep(Duration::from_millis(10));
}
})
.unwrap();
let mut stream = UdpClientStream::builder(server_addr, provider)
.with_timeout(Some(Duration::from_millis(500)))
.build();
let mut worked_once = false;
for i in 0..request_count {
let response_stream =
stream.send_message(DnsRequest::new(query.clone(), DnsRequestOptions::default()));
println!("client sending request {i}");
let response = response_stream.first_answer().await;
println!("client got response {i}");
if accept_response(response) {
worked_once = true;
}
}
stop_thread_killer.store(true, Ordering::Relaxed);
server_handle.join().expect("server thread failed");
assert!(worked_once);
}
fn start_thread_killer() -> Arc<AtomicBool> {
let succeeded = Arc::new(AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone;
for _ in 0..15 {
thread::sleep(Duration::from_secs(1));
if succeeded.load(Ordering::Relaxed) {
return;
}
}
println!("Thread Killer has been awoken, killing process");
process::exit(-1);
})
.unwrap();
succeeded
}