use clap::{crate_version, Parser};
use donut::http::{http_route, HandlerContext};
use donut::request::{RequestParserJsonGet, RequestParserWireGet, RequestParserWirePost};
use donut::resolve::UdpResolver;
use donut::response::{ResponseEncoderJson, ResponseEncoderWire};
use donut::types::DonutResult;
use hyper::service::{make_service_fn, service_fn};
use hyper::Server;
use std::net::SocketAddr;
use std::process;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tracing::{event, span, Instrument, Level};
use trust_dns_client::client::AsyncClient;
use trust_dns_client::udp::UdpClientStream;
const DEFAULT_UPSTREAM_UDP: ([u8; 4], u16) = ([127, 0, 0, 1], 53);
const DEFAULT_UPSTREAM_TIMEOUT_MS: u64 = 1000;
const DEFAULT_LOG_LEVEL: Level = Level::INFO;
const DEFAULT_BIND_ADDR: ([u8; 4], u16) = ([127, 0, 0, 1], 3000);
#[derive(Debug, Parser)]
#[clap(name = "donut", version = crate_version!())]
struct DonutApplication {
#[clap(long, default_value_t = DEFAULT_UPSTREAM_UDP.into())]
upstream_udp: SocketAddr,
#[clap(long, default_value_t = DEFAULT_UPSTREAM_TIMEOUT_MS)]
upstream_timeout: u64,
#[clap(long, default_value_t = DEFAULT_LOG_LEVEL)]
log_level: Level,
#[clap(long, default_value_t = DEFAULT_BIND_ADDR.into())]
bind: SocketAddr,
}
async fn new_udp_dns_client(addr: SocketAddr, timeout: Duration) -> DonutResult<AsyncClient> {
let conn = UdpClientStream::<UdpSocket>::with_timeout(addr, timeout);
let (client, bg) = AsyncClient::connect(conn).await?;
tokio::spawn(bg);
Ok(client)
}
async fn new_handler_context(addr: SocketAddr, timeout: Duration) -> DonutResult<HandlerContext> {
let client = new_udp_dns_client(addr, timeout).await?;
let resolver = UdpResolver::new(client);
let json_parser = RequestParserJsonGet::default();
let get_parser = RequestParserWireGet::default();
let post_parser = RequestParserWirePost::default();
let json_encoder = ResponseEncoderJson::default();
let wire_encoder = ResponseEncoderWire::default();
Ok(HandlerContext::new(
json_parser,
get_parser,
post_parser,
resolver,
json_encoder,
wire_encoder,
))
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let opts = DonutApplication::parse();
tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_max_level(opts.log_level)
.finish(),
)
.expect("Failed to set tracing subscriber");
let startup = Instant::now();
let timeout = Duration::from_millis(opts.upstream_timeout);
let context = Arc::new(new_handler_context(opts.upstream_udp, timeout).await.unwrap());
let service = make_service_fn(move |_| {
let context = context.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {
http_route(req, context.clone()).instrument(span!(Level::DEBUG, "donut_request"))
}))
}
});
let server = Server::try_bind(&opts.bind).unwrap_or_else(|e| {
event!(
Level::ERROR,
message = "server failed to start",
error = %e,
upstream = %opts.upstream_udp,
address = %opts.bind,
timeout_ms = %timeout.as_millis(),
);
process::exit(1);
});
event!(
Level::INFO,
message = "server started",
upstream = %opts.upstream_udp,
address = %opts.bind,
timeout_ms = %timeout.as_millis(),
);
server
.serve(service)
.with_graceful_shutdown(async {
let _ = tokio::signal::ctrl_c().await;
})
.await?;
event!(
Level::INFO,
message = "server shutdown",
runtime_secs = %startup.elapsed().as_secs(),
);
Ok(())
}