use anyhow::Result;
use async_channel::{bounded, Receiver, Sender};
use clap::{ArgAction, Parser, ValueHint};
use ip_network::IpNetwork;
use log::*;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::net::UnixListener;
pub mod metrics;
mod dnstap_handler;
use dnstap_handler::*;
mod frame_handler;
use frame_handler::*;
mod http_handler;
use http_handler::*;
mod monitor_handler;
use monitor_handler::*;
use dnstap_utils::dnstap;
struct Server {
opts: Opts,
channels: Channels,
}
#[derive(Clone)]
pub struct Channels {
sender: Sender<dnstap::Dnstap>,
receiver: Receiver<dnstap::Dnstap>,
error_sender: Sender<dnstap::Dnstap>,
error_receiver: Receiver<dnstap::Dnstap>,
timeout_sender: Sender<dnstap::Dnstap>,
timeout_receiver: Receiver<dnstap::Dnstap>,
}
#[derive(Parser, Clone)]
pub struct Opts {
#[clap(long, default_value = "10000")]
channel_capacity: usize,
#[clap(long, default_value = "100000")]
channel_error_capacity: usize,
#[clap(long, default_value = "100000")]
channel_timeout_capacity: usize,
#[clap(long, name = "DNS IP:PORT")]
dns: SocketAddr,
#[clap(long,
name = "DSCP code point",
value_parser = clap::value_parser!(u8).range(0..63))]
dscp: Option<u8>,
#[clap(long, name = "HTTP IP:PORT")]
http: SocketAddr,
#[clap(long)]
ignore_tc: bool,
#[clap(long, value_parser = clap::value_parser!(IpNetwork))]
ignore_query_net: Vec<IpNetwork>,
#[clap(long, default_value = "10")]
num_sockets: usize,
#[clap(long)]
proxy: bool,
#[clap(long)]
proxy_timespec: bool,
#[clap(long, name = "MILLISECONDS", default_value = "5000", required = false)]
match_status_delay: u64,
#[clap(long = "match-status-files",
name = "STATUS-FILE",
required = false,
num_args(2..),
value_parser,
value_hint = ValueHint::FilePath)
]
status_files: Vec<PathBuf>,
#[clap(long, name = "PATH")]
unix: String,
#[clap(short, long, action = ArgAction::Count)]
verbose: u8,
}
impl Server {
pub fn new(opts: &Opts) -> Self {
let (sender, receiver) = bounded(opts.channel_capacity);
let (error_sender, error_receiver) = bounded(opts.channel_error_capacity);
let (timeout_sender, timeout_receiver) = bounded(opts.channel_timeout_capacity);
Server {
opts: opts.clone(),
channels: Channels {
sender,
receiver,
error_sender,
error_receiver,
timeout_sender,
timeout_receiver,
},
}
}
async fn run(&mut self) -> Result<()> {
let match_status = Arc::new(AtomicBool::new(false));
if !self.opts.status_files.is_empty() {
let match_status_mh = match_status.clone();
let mut monitor_handler =
MonitorHandler::new(&self.opts.status_files, self.opts.match_status_delay)?;
tokio::spawn(async move {
if let Err(err) = monitor_handler.run(match_status_mh).await {
error!("Monitor handler error: {}", err);
}
});
} else {
match_status.store(true, Ordering::Relaxed);
crate::metrics::MATCH_STATUS.set(1);
}
let http_handler = HttpHandler::new(self.opts.http, &self.channels);
tokio::spawn(async move {
if let Err(err) = http_handler.run().await {
error!("Hyper HTTP server error: {}", err);
}
});
for _ in 0..self.opts.num_sockets {
let match_status_dh = match_status.clone();
let mut dnstap_handler =
DnstapHandler::new(&self.opts, &self.channels, match_status_dh).await?;
tokio::spawn(async move {
if let Err(err) = dnstap_handler.run().await {
error!("DnstapHandler error: {}", err);
}
});
}
info!(
"Sending DNS queries to server {} using {} UDP query sockets",
&self.opts.dns, self.opts.num_sockets,
);
if self.opts.proxy {
info!("Sending DNS queries with PROXY v2 header");
}
if let Some(dscp) = self.opts.dscp {
info!("Sending DNS queries with DSCP value {}", dscp);
}
let _ = std::fs::remove_file(&self.opts.unix);
let listener = UnixListener::bind(&self.opts.unix)?;
info!("Listening on Unix socket path {}", &self.opts.unix);
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
let mut frame_handler = FrameHandler::new(stream, self.channels.sender.clone());
tokio::spawn(async move {
if let Err(err) = frame_handler.run().await {
warn!("FrameHandler error: {}", err);
}
});
}
Err(err) => {
warn!("Accept error: {}", err);
}
}
}
}
}
fn main() -> Result<()> {
let opts = Opts::parse();
stderrlog::new()
.verbosity(opts.verbose as usize)
.module(module_path!())
.init()
.unwrap();
metrics::initialize_metrics();
let mut server = Server::new(&opts);
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
.block_on(async { server.run().await })
}