use core::fmt;
use core::future::{ready, Future, Ready};
use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use core::task::{Context, Poll};
use core::time::Duration;
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLock;
use std::vec::Vec;
use futures_util::stream::{once, Empty, Once, Stream};
use octseq::{FreezeBuilder, Octets};
use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket};
use tokio::sync::mpsc::unbounded_channel;
use tokio::time::Instant;
use tokio_rustls::rustls;
use tokio_rustls::TlsAcceptor;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_tfo::{TfoListener, TfoStream};
use tracing_subscriber::EnvFilter;
use domain::base::iana::{Class, Rcode};
use domain::base::message_builder::{AdditionalBuilder, PushError};
use domain::base::name::ToLabelIter;
use domain::base::wire::Composer;
use domain::base::{MessageBuilder, Name, Rtype, Serial, StreamTarget, Ttl};
use domain::net::server::buf::VecBufSource;
use domain::net::server::dgram::DgramServer;
use domain::net::server::message::Request;
use domain::net::server::middleware::cookies::CookiesMiddlewareSvc;
use domain::net::server::middleware::edns::EdnsMiddlewareSvc;
use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc;
use domain::net::server::middleware::stream::{
MiddlewareStream, PostprocessingStream,
};
use domain::net::server::service::{
CallResult, Service, ServiceFeedback, ServiceResult,
};
use domain::net::server::sock::AsyncAccept;
use domain::net::server::stream::StreamServer;
use domain::net::server::util::{mk_builder_for_target, service_fn};
use domain::rdata::{Soa, A};
fn mk_answer<Target>(
msg: &Request<Vec<u8>, ()>,
builder: MessageBuilder<StreamTarget<Target>>,
) -> Result<AdditionalBuilder<StreamTarget<Target>>, PushError>
where
Target: Octets + Composer + FreezeBuilder<Octets = Target>,
<Target as octseq::OctetsBuilder>::AppendError: fmt::Debug,
{
let mut answer =
builder.start_answer(msg.message(), Rcode::NOERROR).unwrap();
answer.push((
Name::root_ref(),
Class::IN,
86400,
A::from_octets(192, 0, 2, 1),
))?;
Ok(answer.additional())
}
fn mk_soa_answer<Target>(
msg: &Request<Vec<u8>, ()>,
builder: MessageBuilder<StreamTarget<Target>>,
) -> Result<AdditionalBuilder<StreamTarget<Target>>, PushError>
where
Target: Octets + Composer + FreezeBuilder<Octets = Target>,
<Target as octseq::OctetsBuilder>::AppendError: fmt::Debug,
{
let mname: Name<Vec<u8>> = "a.root-servers.net".parse().unwrap();
let rname = "nstld.verisign-grs.com".parse().unwrap();
let mut answer =
builder.start_answer(msg.message(), Rcode::NOERROR).unwrap();
answer.push((
Name::root_slice(),
86390,
Soa::new(
mname,
rname,
Serial(2020081701),
Ttl::from_secs(1800),
Ttl::from_secs(900),
Ttl::from_secs(604800),
Ttl::from_secs(86400),
),
))?;
Ok(answer.additional())
}
#[derive(Clone)]
struct MySingleResultService;
impl Service<Vec<u8>, ()> for MySingleResultService {
type Target = Vec<u8>;
type Stream = Once<Ready<ServiceResult<Self::Target>>>;
type Future = Ready<Self::Stream>;
fn call(&self, request: Request<Vec<u8>, ()>) -> Self::Future {
let builder = mk_builder_for_target();
let additional = mk_answer(&request, builder).unwrap();
let item = Ok(CallResult::new(additional));
ready(once(ready(item)))
}
}
#[derive(Clone)]
struct MyAsyncStreamingService;
impl Service<Vec<u8>, ()> for MyAsyncStreamingService {
type Target = Vec<u8>;
type Stream =
Pin<Box<dyn Stream<Item = ServiceResult<Self::Target>> + Send>>;
type Future = Pin<Box<dyn Future<Output = Self::Stream> + Send>>;
fn call(&self, request: Request<Vec<u8>, ()>) -> Self::Future {
Box::pin(async move {
if !matches!(
request
.message()
.sole_question()
.map(|q| q.qtype() == Rtype::AXFR),
Ok(true)
) {
let builder = mk_builder_for_target();
let additional = builder
.start_answer(request.message(), Rcode::NOTIMP)
.unwrap()
.additional();
let item = Ok(CallResult::new(additional));
let immediate_result = once(ready(item));
return Box::pin(immediate_result) as Self::Stream;
}
let (sender, receiver) = unbounded_channel();
let cloned_sender = sender.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
let builder = mk_builder_for_target();
let additional = mk_soa_answer(&request, builder).unwrap();
let item = Ok(CallResult::new(additional));
cloned_sender.send(item).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let builder = mk_builder_for_target();
let additional = mk_answer(&request, builder).unwrap();
let item = Ok(CallResult::new(additional));
cloned_sender.send(item).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let builder = mk_builder_for_target();
let additional = mk_soa_answer(&request, builder).unwrap();
let item = Ok(CallResult::new(additional));
cloned_sender.send(item).unwrap();
});
Box::pin(UnboundedReceiverStream::new(receiver)) as Self::Stream
})
}
}
#[allow(clippy::type_complexity)]
fn name_to_ip(
request: Request<Vec<u8>, ()>,
_: (),
) -> ServiceResult<Vec<u8>> {
let mut out_answer = None;
if let Ok(question) = request.message().sole_question() {
let qname = question.qname();
let num_labels = qname.label_count();
if num_labels >= 5 {
let mut iter = qname.iter_labels();
let a = iter.nth(num_labels - 5).unwrap();
let b = iter.next().unwrap();
let c = iter.next().unwrap();
let d = iter.next().unwrap();
let a_rec: Result<A, _> = format!("{a}.{b}.{c}.{d}").parse();
if let Ok(a_rec) = a_rec {
let builder = mk_builder_for_target();
let mut answer = builder
.start_answer(request.message(), Rcode::NOERROR)
.unwrap();
answer
.push((Name::root_ref(), Class::IN, 86400, a_rec))
.unwrap();
out_answer = Some(answer);
}
}
}
if out_answer.is_none() {
let builder = mk_builder_for_target();
eprintln!("Refusing request, only requests for A records in IPv4 dotted quad format are accepted by this service.");
out_answer = Some(
builder
.start_answer(request.message(), Rcode::REFUSED)
.unwrap(),
);
}
let additional = out_answer.unwrap().additional();
Ok(CallResult::new(additional))
}
fn query(
request: Request<Vec<u8>, ()>,
count: Arc<AtomicU8>,
) -> ServiceResult<Vec<u8>> {
let cnt = count
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
Some(if x > 0 { x - 1 } else { 0 })
})
.unwrap();
let idle_timeout = Duration::from_millis((50 * cnt).into());
let cmd = ServiceFeedback::Reconfigure {
idle_timeout: Some(idle_timeout),
};
eprintln!("Setting idle timeout to {idle_timeout:?}");
let builder = mk_builder_for_target();
let answer = mk_answer(&request, builder)?;
Ok(CallResult::new(answer).with_feedback(cmd))
}
struct DoubleListener {
a: TcpListener,
b: TcpListener,
alt: AtomicBool,
}
impl DoubleListener {
fn new(a: TcpListener, b: TcpListener) -> Self {
let alt = AtomicBool::new(false);
Self { a, b, alt }
}
}
impl AsyncAccept for DoubleListener {
type Error = io::Error;
type StreamType = TcpStream;
type Future = Ready<Result<Self::StreamType, io::Error>>;
fn poll_accept(
&self,
cx: &mut Context,
) -> Poll<Result<(Self::Future, SocketAddr), io::Error>> {
let (x, y) = match self.alt.fetch_xor(true, Ordering::SeqCst) {
false => (&self.a, &self.b),
true => (&self.b, &self.a),
};
match TcpListener::poll_accept(x, cx)
.map(|res| res.map(|(stream, addr)| (ready(Ok(stream)), addr)))
{
Poll::Ready(res) => Poll::Ready(res),
Poll::Pending => TcpListener::poll_accept(y, cx).map(|res| {
res.map(|(stream, addr)| (ready(Ok(stream)), addr))
}),
}
}
}
struct LocalTfoListener(TfoListener);
impl std::ops::DerefMut for LocalTfoListener {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl std::ops::Deref for LocalTfoListener {
type Target = TfoListener;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl AsyncAccept for LocalTfoListener {
type Error = io::Error;
type StreamType = TfoStream;
type Future = Ready<Result<Self::StreamType, io::Error>>;
fn poll_accept(
&self,
cx: &mut Context,
) -> Poll<Result<(Self::Future, SocketAddr), io::Error>> {
TfoListener::poll_accept(self, cx)
.map(|res| res.map(|(stream, addr)| (ready(Ok(stream)), addr)))
}
}
struct BufferedTcpListener(TcpListener);
impl std::ops::DerefMut for BufferedTcpListener {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl std::ops::Deref for BufferedTcpListener {
type Target = TcpListener;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl AsyncAccept for BufferedTcpListener {
type Error = io::Error;
type StreamType = tokio::io::BufReader<TcpStream>;
type Future = Ready<Result<Self::StreamType, io::Error>>;
fn poll_accept(
&self,
cx: &mut Context,
) -> Poll<Result<(Self::Future, SocketAddr), io::Error>> {
match TcpListener::poll_accept(self, cx) {
Poll::Ready(Ok((stream, addr))) => {
let stream = tokio::io::BufReader::new(stream);
Poll::Ready(Ok((ready(Ok(stream)), addr)))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct RustlsTcpListener {
listener: TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
}
impl RustlsTcpListener {
pub fn new(
listener: TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
) -> Self {
Self { listener, acceptor }
}
}
impl AsyncAccept for RustlsTcpListener {
type Error = io::Error;
type StreamType = tokio_rustls::server::TlsStream<TcpStream>;
type Future = tokio_rustls::Accept<TcpStream>;
#[allow(clippy::type_complexity)]
fn poll_accept(
&self,
cx: &mut Context,
) -> Poll<Result<(Self::Future, SocketAddr), io::Error>> {
TcpListener::poll_accept(&self.listener, cx).map(|res| {
res.map(|(stream, addr)| (self.acceptor.accept(stream), addr))
})
}
}
#[derive(Default)]
pub struct Stats {
slowest_req: Option<Duration>,
fastest_req: Option<Duration>,
num_req_bytes: u32,
num_resp_bytes: u32,
num_reqs: u32,
num_ipv4: u32,
num_ipv6: u32,
num_udp: u32,
}
impl std::fmt::Display for Stats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "# Reqs={} [UDP={}, IPv4={}, IPv6={}] Bytes [rx={}, tx={}] Speed [fastest={}, slowest={}]",
self.num_reqs,
self.num_udp,
self.num_ipv4,
self.num_ipv6,
self.num_req_bytes,
self.num_resp_bytes,
self.fastest_req.map(|v| format!("{}μs", v.as_micros())).unwrap_or_else(|| "-".to_string()),
self.slowest_req.map(|v| format!("{}ms", v.as_millis())).unwrap_or_else(|| "-".to_string()),
)
}
}
#[derive(Clone)]
pub struct StatsMiddlewareSvc<Svc> {
svc: Svc,
stats: Arc<RwLock<Stats>>,
}
impl<Svc> StatsMiddlewareSvc<Svc> {
#[must_use]
pub fn new(svc: Svc, stats: Arc<RwLock<Stats>>) -> Self {
Self { svc, stats }
}
fn preprocess<RequestOctets>(&self, request: &Request<RequestOctets, ()>)
where
RequestOctets: Octets + Send + Sync + Unpin,
{
let mut stats = self.stats.write().unwrap();
stats.num_reqs += 1;
stats.num_req_bytes += request.message().as_slice().len() as u32;
if request.transport_ctx().is_udp() {
stats.num_udp += 1;
}
if request.client_addr().is_ipv4() {
stats.num_ipv4 += 1;
} else {
stats.num_ipv6 += 1;
}
}
fn postprocess<RequestOctets>(
request: &Request<RequestOctets, ()>,
response: &AdditionalBuilder<StreamTarget<Svc::Target>>,
stats: &RwLock<Stats>,
) where
RequestOctets: Octets + Send + Sync + Unpin,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
{
let duration = Instant::now().duration_since(request.received_at());
let mut stats = stats.write().unwrap();
stats.num_resp_bytes += response.as_slice().len() as u32;
if duration < stats.fastest_req.unwrap_or(Duration::MAX) {
stats.fastest_req = Some(duration);
}
if duration > stats.slowest_req.unwrap_or(Duration::ZERO) {
stats.slowest_req = Some(duration);
}
}
fn map_stream_item<RequestOctets>(
request: Request<RequestOctets, ()>,
stream_item: ServiceResult<Svc::Target>,
stats: &mut Arc<RwLock<Stats>>,
) -> ServiceResult<Svc::Target>
where
RequestOctets: Octets + Send + Sync + Unpin,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
{
if let Ok(cr) = &stream_item {
if let Some(response) = cr.response() {
Self::postprocess(&request, response, stats);
}
}
stream_item
}
}
impl<RequestOctets, Svc> Service<RequestOctets, ()>
for StatsMiddlewareSvc<Svc>
where
RequestOctets: Octets + Send + Sync + 'static + Unpin,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
Svc::Future: Unpin,
{
type Target = Svc::Target;
type Stream = MiddlewareStream<
Svc::Future,
Svc::Stream,
PostprocessingStream<
RequestOctets,
Svc::Future,
Svc::Stream,
(),
Arc<RwLock<Stats>>,
>,
Empty<ServiceResult<Self::Target>>,
ServiceResult<Self::Target>,
>;
type Future = Ready<Self::Stream>;
fn call(&self, request: Request<RequestOctets, ()>) -> Self::Future {
self.preprocess(&request);
let svc_call_fut = self.svc.call(request.clone());
let map = PostprocessingStream::new(
svc_call_fut,
request,
self.stats.clone(),
Self::map_stream_item,
);
ready(MiddlewareStream::Map(map))
}
}
#[allow(clippy::type_complexity)]
fn build_middleware_chain<Svc, Octs>(
svc: Svc,
stats: Arc<RwLock<Stats>>,
) -> impl Service<Octs, ()>
where
Octs: Octets + Send + Sync + Clone + Unpin + 'static,
Svc: Service<Octs, ()>,
<Svc as Service<Octs, ()>>::Future: Unpin,
{
#[cfg(feature = "siphasher")]
let svc = CookiesMiddlewareSvc::<Octs, _, ()>::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::new(svc);
let svc = MandatoryMiddlewareSvc::new(svc);
StatsMiddlewareSvc::new(svc, stats.clone())
}
#[tokio::main(flavor = "multi_thread")]
async fn main() {
eprintln!("Test with commands such as:");
eprintln!(" dig +short -4 @127.0.0.1 -p 8053 A 1.2.3.4");
eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8053 A google.com");
eprintln!(" dig +short -4 @127.0.0.1 -p 8054 A google.com");
eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8080 AXFR google.com");
eprintln!(" dig +short -6 @::1 +tcp -p 8080 AXFR google.com");
eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8081 A google.com");
eprintln!(" dig +short -4 @127.0.0.1 +tls -p 8443 A google.com");
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_thread_ids(true)
.without_time()
.try_init()
.ok();
let stats = Arc::new(RwLock::new(Stats::default()));
let my_svc = Arc::new(build_middleware_chain(
MySingleResultService,
stats.clone(),
));
let my_async_svc = Arc::new(build_middleware_chain(
MyAsyncStreamingService,
stats.clone(),
));
let name_into_ip_svc = Arc::new(build_middleware_chain(
service_fn(name_to_ip, ()),
stats.clone(),
));
let count = Arc::new(AtomicU8::new(5));
let svc = service_fn(query, count);
let svc = MandatoryMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
#[cfg(feature = "siphasher")]
let svc = {
let server_secret = "server12secret34".as_bytes().try_into().unwrap();
CookiesMiddlewareSvc::<Vec<u8>, _, _>::new(svc, server_secret)
};
let svc = StatsMiddlewareSvc::new(svc, stats.clone());
let query_svc = Arc::new(svc);
let udpsocket = UdpSocket::bind("127.0.0.1:8053").await.unwrap();
let buf = Arc::new(VecBufSource);
let srv = DgramServer::new(udpsocket, buf.clone(), name_into_ip_svc);
let udp_join_handle = tokio::spawn(async move { srv.run().await });
let v4socket = TcpSocket::new_v4().unwrap();
v4socket.set_reuseaddr(true).unwrap();
v4socket.bind("127.0.0.1:8053".parse().unwrap()).unwrap();
let v4listener = v4socket.listen(1024).unwrap();
let buf = Arc::new(VecBufSource);
let srv = StreamServer::new(v4listener, buf.clone(), query_svc.clone());
let srv = srv.with_pre_connect_hook(|stream| {
eprintln!("TCP connection detected: enabling socket TCP keepalive.");
let keep_alive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(20))
.with_interval(Duration::from_secs(20));
let socket = socket2::SockRef::from(&stream);
socket.set_tcp_keepalive(&keep_alive).unwrap();
eprintln!("Waiting for 5 seconds so you can run a command like:");
eprintln!(" ss -nte | grep 8053 | grep keepalive");
eprintln!("and see `timer:(keepalive,20sec,0) or similar.");
std::thread::sleep(Duration::from_secs(5));
});
let tcp_join_handle = tokio::spawn(async move { srv.run().await });
#[cfg(target_os = "linux")]
let udp_mtu_join_handle = {
fn setsockopt(socket: libc::c_int, flag: libc::c_int) -> libc::c_int {
unsafe {
libc::setsockopt(
socket,
libc::IPPROTO_UDP,
libc::IP_MTU_DISCOVER,
&flag as *const libc::c_int as *const libc::c_void,
std::mem::size_of_val(&flag) as libc::socklen_t,
)
}
}
let udpsocket = UdpSocket::bind("127.0.0.1:8054").await.unwrap();
let fd = <UdpSocket as std::os::fd::AsRawFd>::as_raw_fd(&udpsocket);
if setsockopt(fd, libc::IP_PMTUDISC_OMIT) == -1 {
eprintln!(
"setsockopt error when setting IP_MTU_DISCOVER to IP_PMTUDISC_OMIT, will retry with IP_PMTUDISC_DONT: {}",
std::io::Error::last_os_error()
);
if setsockopt(fd, libc::IP_PMTUDISC_DONT) == -1 {
eprintln!(
"setsockopt error when setting IP_MTU_DISCOVER to IP_PMTUDISC_DONT: {}",
std::io::Error::last_os_error()
);
}
}
let srv = DgramServer::new(udpsocket, buf.clone(), my_svc.clone());
tokio::spawn(async move { srv.run().await })
};
let v4socket = TcpSocket::new_v4().unwrap();
v4socket.set_reuseaddr(true).unwrap();
v4socket.bind("127.0.0.1:8080".parse().unwrap()).unwrap();
let v4listener = v4socket.listen(1024).unwrap();
let v6socket = TcpSocket::new_v6().unwrap();
v6socket.set_reuseaddr(true).unwrap();
v6socket.bind("[::1]:8080".parse().unwrap()).unwrap();
let v6listener = v6socket.listen(1024).unwrap();
let listener = DoubleListener::new(v4listener, v6listener);
let srv = StreamServer::new(listener, buf.clone(), my_async_svc);
let double_tcp_join_handle = tokio::spawn(async move { srv.run().await });
let listener = TfoListener::bind("127.0.0.1:8081".parse().unwrap())
.await
.unwrap();
let listener = LocalTfoListener(listener);
let srv = StreamServer::new(listener, buf.clone(), my_svc.clone());
let tfo_join_handle = tokio::spawn(async move { srv.run().await });
let listener = TcpListener::bind("127.0.0.1:8082").await.unwrap();
let listener = BufferedTcpListener(listener);
let srv = StreamServer::new(listener, buf.clone(), query_svc);
let fn_join_handle = tokio::spawn(async move { srv.run().await });
let certs = rustls_pemfile::certs(&mut BufReader::new(
File::open("examples/sample.pem").unwrap(),
))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let key = rustls_pemfile::private_key(&mut BufReader::new(
File::open("examples/sample.rsa").unwrap(),
))
.unwrap()
.unwrap();
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap();
let listener = RustlsTcpListener::new(listener, acceptor);
let srv = StreamServer::new(listener, buf.clone(), my_svc.clone());
let tls_join_handle = tokio::spawn(async move { srv.run().await });
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
println!("Statistics report: {}", stats.read().unwrap());
}
});
udp_join_handle.await.unwrap();
tcp_join_handle.await.unwrap();
#[cfg(target_os = "linux")]
udp_mtu_join_handle.await.unwrap();
double_tcp_join_handle.await.unwrap();
tfo_join_handle.await.unwrap();
fn_join_handle.await.unwrap();
tls_join_handle.await.unwrap();
}