use core::future::{ready, Future};
use core::ops::Deref;
use core::pin::Pin;
use core::str::FromStr;
use std::boxed::Box;
use std::collections::{HashMap, VecDeque};
use std::fs::File;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::result::Result;
use std::string::{String, ToString};
use std::sync::Arc;
use std::time::Duration;
use std::vec::Vec;
use rstest::rstest;
use tracing::instrument;
use tracing::{trace, warn};
use crate::base::iana::{Class, Rcode};
use crate::base::name::ToName;
use crate::base::net::IpAddr;
use crate::base::Name;
use crate::base::Rtype;
use crate::base::Serial;
use crate::logging::init_logging;
use crate::net::client::request::{RequestMessage, RequestMessageMulti};
use crate::net::client::{dgram, stream, tsig};
use crate::net::server;
use crate::net::server::buf::VecBufSource;
use crate::net::server::dgram::DgramServer;
use crate::net::server::message::Request;
use crate::net::server::middleware::cookies::CookiesMiddlewareSvc;
use crate::net::server::middleware::edns::EdnsMiddlewareSvc;
use crate::net::server::middleware::mandatory::MandatoryMiddlewareSvc;
use crate::net::server::middleware::notify::{
Notifiable, NotifyError, NotifyMiddlewareSvc,
};
use crate::net::server::middleware::tsig::TsigMiddlewareSvc;
use crate::net::server::middleware::xfr::XfrMiddlewareSvc;
use crate::net::server::service::{CallResult, Service, ServiceResult};
use crate::net::server::stream::StreamServer;
use crate::net::server::util::{mk_builder_for_target, service_fn};
use crate::stelline::channel::ClientServerChannel;
use crate::stelline::client::{
do_client, Client, ClientFactory, CurrStepValue,
PerClientAddressClientFactory, QueryTailoredClientFactory,
};
use crate::stelline::parse_stelline::{self, parse_file, Config, Matches};
use crate::stelline::simple_dgram_client;
use crate::tsig::{Algorithm, Key, KeyName, KeyStore};
use crate::utils::base16;
use crate::zonefile::inplace::Zonefile;
use crate::zonetree::{Answer, Zone};
use crate::zonetree::{StoredName, ZoneBuilder, ZoneTree};
#[instrument(skip_all, fields(rpl = rpl_file.file_name().unwrap().to_str()))]
#[rstest]
#[tokio::test(start_paused = true)]
async fn server_tests(#[files("test-data/server/*.rpl")] rpl_file: PathBuf) {
use ring::{hkdf::KeyType, hmac};
init_logging();
let file = File::open(&rpl_file).unwrap();
let stelline = parse_file(&file, rpl_file.to_str().unwrap());
let server_config = parse_server_config(&stelline.config);
let mut key_store = TestKeyStore::new();
let key_name = KeyName::from_str("TESTKEY").unwrap();
let key_bytes = vec![0u8; hmac::HMAC_SHA256.len()];
let key =
Key::new(Algorithm::Sha256, &key_bytes, key_name.clone(), None, None)
.unwrap();
key_store.insert((key_name, Algorithm::Sha256), key.into());
let key_store = Arc::new(key_store);
let dgram_server_conn = ClientServerChannel::new_dgram();
let stream_server_conn = ClientServerChannel::new_stream();
let mut zones = ZoneTree::new();
match &server_config.zone {
ServerZone {
zone_files: zfs, ..
} if !zfs.is_empty() => {
for zone_file in zfs {
zones
.insert_zone(Zone::try_from(zone_file.clone()).unwrap())
.unwrap();
}
}
ServerZone {
zone_name: Some(zone_name),
..
} => {
let builder = ZoneBuilder::new(
Name::from_str(zone_name).unwrap(),
Class::IN,
);
zones.insert_zone(builder.build()).unwrap();
}
ServerZone {
zone_files: zfs,
zone_name: None,
} if zfs.is_empty() => {
}
_ => {
unimplemented!()
}
}
let zones = Arc::new(zones);
let with_cookies = server_config.cookies.enabled
&& server_config.cookies.secret.is_some();
let secret = if with_cookies {
let secret = server_config.cookies.secret.unwrap();
let secret = base16::decode_vec(secret).unwrap();
<[u8; 16]>::try_from(secret).unwrap()
} else {
Default::default()
};
let svc = service_fn(test_service, zones.clone());
let svc = CookiesMiddlewareSvc::new(svc, secret)
.with_denied_ips(server_config.cookies.ip_deny_list.clone())
.enable(with_cookies);
let svc =
EdnsMiddlewareSvc::new(svc).enable(server_config.edns_tcp_keepalive);
let svc = XfrMiddlewareSvc::<Vec<u8>, _, Option<Arc<Key>>, _>::new(
svc, zones, 1,
);
let svc = NotifyMiddlewareSvc::new(svc, TestNotifyTarget);
let svc = MandatoryMiddlewareSvc::new(svc);
let svc = TsigMiddlewareSvc::new(svc, key_store.clone());
let (dgram_srv, stream_srv) = mk_servers(
svc,
&server_config,
dgram_server_conn.clone(),
stream_server_conn.clone(),
);
let client_factory =
mk_client_factory(dgram_server_conn, stream_server_conn, key_store);
let step_value = Arc::new(CurrStepValue::new());
do_client(&stelline, &step_value, client_factory).await;
if !dgram_srv.await_shutdown(Duration::from_secs(5)).await {
warn!("Datagram server did not shutdown on time.");
}
if !stream_srv.await_shutdown(Duration::from_secs(5)).await {
warn!("Stream server did not shutdown on time.");
}
}
#[allow(clippy::type_complexity)]
fn mk_servers<Svc>(
service: Svc,
server_config: &ServerConfig<'_>,
dgram_server_conn: ClientServerChannel,
stream_server_conn: ClientServerChannel,
) -> (
Arc<DgramServer<ClientServerChannel, VecBufSource, Svc>>,
Arc<StreamServer<ClientServerChannel, VecBufSource, Svc>>,
)
where
Svc: Service<Vec<u8>, ()> + Clone,
{
let (dgram_config, stream_config) = mk_server_configs(server_config);
let dgram_server = DgramServer::with_config(
dgram_server_conn.clone(),
VecBufSource,
service.clone(),
dgram_config,
);
let dgram_server = Arc::new(dgram_server);
let cloned_dgram_server = dgram_server.clone();
tokio::spawn(async move { cloned_dgram_server.run().await });
let stream_server = StreamServer::with_config(
stream_server_conn.clone(),
VecBufSource,
service,
stream_config,
);
let stream_server = Arc::new(stream_server);
let cloned_stream_server = stream_server.clone();
tokio::spawn(async move { cloned_stream_server.run().await });
(dgram_server, stream_server)
}
fn mk_client_factory(
dgram_server_conn: ClientServerChannel,
stream_server_conn: ClientServerChannel,
key_store: Arc<TestKeyStore>,
) -> impl ClientFactory {
let only_for_tcp_queries = |entry: &parse_stelline::Entry| {
matches!(entry.matches, Matches { tcp: true, .. })
};
let tcp_key_store = key_store.clone();
let tcp_client_factory = PerClientAddressClientFactory::new(
move |source_addr, entry| {
let stream = stream_server_conn
.connect(Some(SocketAddr::new(*source_addr, 0)));
let key = entry.key_name.as_ref().and_then(|key_name| {
tcp_key_store.get_key(&key_name, Algorithm::Sha256)
});
if let Some(key) = key {
let (conn, transport) = stream::Connection::<
tsig::RequestMessage<RequestMessage<Vec<u8>>, Arc<Key>>,
tsig::RequestMessage<
RequestMessageMulti<Vec<u8>>,
Arc<Key>,
>,
>::new(stream);
tokio::spawn(transport.run());
let conn = Box::new(tsig::Connection::new(key, conn));
if let Some(q) = entry.sections.question.first() {
if matches!(q.qtype(), Rtype::AXFR | Rtype::IXFR) {
return Client::Multi(conn);
}
}
Client::Single(conn)
} else {
let (conn, transport) = stream::Connection::<
RequestMessage<Vec<u8>>,
RequestMessageMulti<Vec<u8>>,
>::new(stream);
tokio::spawn(transport.run());
let conn = Box::new(conn);
if let Some(q) = entry.sections.question.first() {
if matches!(q.qtype(), Rtype::AXFR | Rtype::IXFR) {
return Client::Multi(conn);
}
}
Client::Single(conn)
}
},
only_for_tcp_queries,
);
let for_all_other_queries = |_: &_| true;
let udp_client_factory = PerClientAddressClientFactory::new(
move |source_addr, entry| {
let connect = dgram_server_conn
.new_client(Some(SocketAddr::new(*source_addr, 0)));
let key = entry.key_name.as_ref().and_then(|key_name| {
key_store.get_key(&key_name, Algorithm::Sha256)
});
if let Some(key) = key {
if entry.matches.mock_client {
Client::Single(Box::new(tsig::Connection::new(
key,
simple_dgram_client::Connection::new(connect),
)))
} else {
Client::Single(Box::new(tsig::Connection::new(
key,
dgram::Connection::new(connect),
)))
}
} else if entry.matches.mock_client {
Client::Single(Box::new(
simple_dgram_client::Connection::new(connect),
))
} else {
let mut config = dgram::Config::new();
config.set_max_retries(0);
Client::Single(Box::new(dgram::Connection::with_config(
connect, config,
)))
}
},
for_all_other_queries,
);
QueryTailoredClientFactory::new(vec![
Box::new(tcp_client_factory),
Box::new(udp_client_factory),
])
}
fn mk_server_configs(
config: &ServerConfig<'_>,
) -> (server::dgram::Config, server::stream::Config) {
let dgram_config = server::dgram::Config::default();
let mut stream_config = server::stream::Config::default();
if let Some(idle_timeout) = config.idle_timeout {
let mut connection_config = server::ConnectionConfig::default();
connection_config.set_idle_timeout(idle_timeout);
stream_config.set_connection_config(connection_config);
}
(dgram_config, stream_config)
}
#[allow(clippy::type_complexity)]
fn test_service<RequestMeta>(
request: Request<Vec<u8>, RequestMeta>,
zones: Arc<ZoneTree>,
) -> ServiceResult<Vec<u8>> {
let question = request.message().sole_question().unwrap();
let answer = match zones.find_zone(question.qname(), question.qclass()) {
Some(zone) => {
let readable_zone = zone.read();
let qname = question.qname().to_bytes();
let qtype = question.qtype();
readable_zone.query(qname, qtype).unwrap()
}
None => Answer::new(Rcode::NXDOMAIN),
};
let builder = mk_builder_for_target();
let additional = answer.to_message(request.message(), builder);
Ok(CallResult::new(additional))
}
#[derive(Default)]
struct ServerZone {
zone_name: Option<String>,
zone_files: Vec<Zonefile>,
}
#[derive(Default)]
struct ServerConfig<'a> {
cookies: CookieConfig<'a>,
edns_tcp_keepalive: bool,
idle_timeout: Option<Duration>,
zone: ServerZone,
}
#[derive(Default)]
struct CookieConfig<'a> {
enabled: bool,
secret: Option<&'a str>,
ip_deny_list: Vec<IpAddr>,
}
fn parse_server_config(config: &Config) -> ServerConfig<'_> {
let mut parsed_config = ServerConfig::default();
let mut in_server_block = false;
let mut zone_name = None;
let mut zone_file_bytes = VecDeque::<u8>::new();
let mut zone_files = vec![];
for line in config.lines() {
if line.starts_with("server:") {
in_server_block = true;
} else if in_server_block {
if !line.starts_with(|c: char| c.is_whitespace()) {
in_server_block = false;
} else if let Some((setting, value)) = line.trim().split_once(':')
{
let setting = setting.trim();
let value = value
.split_once('#')
.map_or(value, |(value, _rest)| value)
.trim();
match (setting, value) {
("answer-cookie", "yes") => {
parsed_config.cookies.enabled = true
}
("cookie-secret", v) => {
parsed_config.cookies.secret =
Some(v.trim_matches('"'));
}
("access-control", v) => {
if let Some((ip, action)) =
v.split_once(|c: char| c.is_whitespace())
{
match action {
"allow_cookie" => {
if let Ok(ip) = ip.parse() {
parsed_config
.cookies
.ip_deny_list
.push(ip);
} else {
eprintln!("Ignoring malformed IP address '{ip}' in 'access-control' setting");
}
}
_ => {
eprintln!("Ignoring unknown action '{action}' for 'access-control' setting");
}
}
}
}
("local-data", v) => {
zone_file_bytes
.extend(v.trim_matches('"').as_bytes().iter());
zone_file_bytes.push_back(b'\n');
}
("local-file", v) => {
let zone_path =
Path::new("test-data").join(v.trim_matches('"'));
let zone_file = Zonefile::load(
&mut File::open(zone_path).unwrap(),
)
.unwrap();
zone_files.push(zone_file);
}
("edns-tcp-keepalive", "yes") => {
parsed_config.edns_tcp_keepalive = true;
}
("edns-tcp-keepalive-timeout", v) => {
if parsed_config.edns_tcp_keepalive {
parsed_config.idle_timeout = Some(
Duration::from_millis(v.parse().unwrap()),
);
}
}
("zone", v) => {
zone_file_bytes = Default::default();
zone_name = Some(v.to_string());
}
_ => {
eprintln!("Ignoring unknown server setting '{setting}' with value: {value:?}");
}
}
}
}
}
if let Some(zone_file) = (!zone_file_bytes.is_empty())
.then(|| Zonefile::load(&mut zone_file_bytes).unwrap())
{
zone_files.push(zone_file);
}
parsed_config.zone = ServerZone {
zone_name,
zone_files,
};
parsed_config
}
#[derive(Copy, Clone, Default, Debug)]
struct TestNotifyTarget;
impl Notifiable for TestNotifyTarget {
fn notify_zone_changed(
&self,
class: Class,
apex_name: &StoredName,
serial: Option<Serial>,
source: IpAddr,
) -> Pin<
Box<dyn Future<Output = Result<(), NotifyError>> + Sync + Send + '_>,
> {
trace!("Notify received from {source} of change to zone {apex_name} in class {class} with serial {serial:?}");
let res = match apex_name.to_string().to_lowercase().as_str() {
"example.com" => Ok(()),
"othererror.com" => Err(NotifyError::Other),
_ => Err(NotifyError::NotAuthForZone),
};
Box::pin(ready(res))
}
}
type TestKeyStore = HashMap<(KeyName, Algorithm), Arc<Key>>;
impl KeyStore for Arc<TestKeyStore> {
type Key = Arc<Key>;
fn get_key<N: ToName>(
&self,
name: &N,
algorithm: Algorithm,
) -> Option<Self::Key> {
Arc::deref(self).get_key(name, algorithm)
}
}