use core::net::SocketAddr;
use std::str::FromStr;
use futures_util::stream::iter;
use test_support::subscribe;
use super::ClientStreamXfrState::*;
use super::*;
use crate::{
proto::rr::{
RData,
rdata::{A, SOA},
},
runtime::TokioRuntimeProvider,
};
#[tokio::test]
async fn readme_example() {
subscribe();
use core::net::SocketAddr;
use crate::client::{Client, ClientHandle};
use crate::proto::rr::{DNSClass, Name, RecordType};
use crate::runtime::TokioRuntimeProvider;
use crate::udp::UdpClientStream;
let address = SocketAddr::from(([8, 8, 8, 8], 53));
let conn = UdpClientStream::builder(address, TokioRuntimeProvider::default()).build();
let (mut client, bg) = Client::<TokioRuntimeProvider>::from_sender(conn);
tokio::spawn(bg);
let name = Name::from_str("www.example.com.").unwrap();
let response = client
.query(name, DNSClass::IN, RecordType::A)
.await
.unwrap();
let answers = &response.answers;
let a_data = answers
.iter()
.flat_map(|record| match record.data {
RData::A(addr) => Some(addr),
_ => None,
})
.collect::<Vec<_>>();
assert!(!a_data.is_empty());
}
fn soa_record(serial: u32) -> Record {
let soa = RData::SOA(SOA::new(
Name::from_ascii("example.com.").unwrap(),
Name::from_ascii("admin.example.com.").unwrap(),
serial,
60,
60,
60,
60,
));
Record::from_rdata(Name::from_ascii("example.com.").unwrap(), 600, soa)
}
fn a_record(ip: u8) -> Record {
let a = RData::A(A::new(0, 0, 0, ip));
Record::from_rdata(Name::from_ascii("www.example.com.").unwrap(), 600, a)
}
fn get_stream_testcase(
records: Vec<Vec<Record>>,
) -> impl Stream<Item = Result<DnsResponse, NetError>> + Send + Unpin + 'static {
let stream = records.into_iter().map(|r| {
Ok({
let mut m = Message::query();
m.insert_answers(r);
DnsResponse::from_message(m.into_response()).unwrap()
})
});
iter(stream)
}
#[tokio::test]
async fn test_stream_xfr_valid_axfr() {
subscribe();
let stream = get_stream_testcase(vec![vec![
soa_record(3),
a_record(1),
a_record(2),
soa_record(3),
]]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 4);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_valid_axfr_multipart() {
subscribe();
let stream = get_stream_testcase(vec![
vec![soa_record(3)],
vec![a_record(1)],
vec![soa_record(3)],
vec![a_record(2)], ]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Second { .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Axfr { .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 1);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_empty_axfr() {
subscribe();
let stream = get_stream_testcase(vec![vec![soa_record(3)], vec![soa_record(3)]]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Second { .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 1);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_axfr_with_ixfr_reply() {
subscribe();
let stream = get_stream_testcase(vec![vec![
soa_record(3),
soa_record(2),
a_record(1),
soa_record(3),
a_record(2),
soa_record(3),
]]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
stream.next().await.unwrap().unwrap_err();
assert!(matches!(stream.state, Ended));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_axfr_with_non_xfr_reply() {
subscribe();
let stream = get_stream_testcase(vec![
vec![a_record(1)], vec![a_record(2)],
]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 1);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_invalid_axfr_multipart() {
subscribe();
let stream = get_stream_testcase(vec![
vec![soa_record(3)],
vec![a_record(1)],
vec![soa_record(3), a_record(2)],
vec![soa_record(3)],
]);
let mut stream = ClientStreamXfr::new(stream, false);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Second { .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Axfr { .. }));
assert_eq!(response.answers.len(), 1);
stream.next().await.unwrap().unwrap_err();
assert!(matches!(stream.state, Ended));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_valid_ixfr() {
subscribe();
let stream = get_stream_testcase(vec![vec![
soa_record(3),
soa_record(2),
a_record(1),
soa_record(3),
a_record(2),
soa_record(3),
]]);
let mut stream = ClientStreamXfr::new(stream, true);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 6);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_xfr_valid_ixfr_multipart() {
subscribe();
let stream = get_stream_testcase(vec![
vec![soa_record(3)],
vec![soa_record(2)],
vec![a_record(1)],
vec![soa_record(3)],
vec![a_record(2)],
vec![soa_record(3)],
vec![a_record(3)], ]);
let mut stream = ClientStreamXfr::new(stream, true);
assert!(matches!(stream.state, Start { .. }));
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Second { .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ixfr { even: true, .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ixfr { even: true, .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ixfr { even: false, .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ixfr { even: false, .. }));
assert_eq!(response.answers.len(), 1);
let response = stream.next().await.unwrap().unwrap();
assert!(matches!(stream.state, Ended));
assert_eq!(response.answers.len(), 1);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn async_client() {
subscribe();
use crate::client::{Client, ClientHandle};
use crate::{
proto::rr::{DNSClass, Name, RData, RecordType},
tcp::TcpClientStream,
};
use core::str::FromStr;
let addr = SocketAddr::from(([8, 8, 8, 8], 53));
let (future, sender) = TcpClientStream::new(addr, None, None, TokioRuntimeProvider::new());
let (mut client, bg) = Client::<TokioRuntimeProvider>::new(future.await.unwrap(), sender);
tokio::spawn(bg);
let query = client.query(
Name::from_str("dns.google.").unwrap(),
DNSClass::IN,
RecordType::A,
);
let (message_returned, buffer) = query.await.unwrap().into_parts();
let assert_a_records_match = |answers: &[Record], expected: &[A]| {
let mut a_records = answers
.iter()
.filter_map(|record| match &record.data {
RData::A(addr) => Some(*addr),
_ => None,
})
.collect::<Vec<_>>();
a_records.sort_by_key(|a| u32::from(**a));
assert_eq!(a_records, expected);
};
let expected_answers = vec![A::new(8, 8, 4, 4), A::new(8, 8, 8, 8)];
assert_a_records_match(&message_returned.answers, &expected_answers);
let message_parsed = Message::from_vec(&buffer)
.expect("buffer was parsed already by Client so we should be able to do it again");
assert_a_records_match(&message_parsed.answers, &expected_answers);
}