use anyhow::Result;
use async_channel::Receiver;
use async_stream::stream;
use bytes::{BufMut, BytesMut};
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response};
use hyper::{Method, StatusCode};
use log::*;
use prometheus::core::{AtomicU64, GenericCounter};
use prometheus::Encoder as PrometheusEncoder;
use prometheus::TextEncoder;
use prost::Message;
use std::convert::Infallible;
use std::net::SocketAddr;
use tokio_util::codec::Encoder as CodecEncoder;
use dnstap_utils::dnstap;
use dnstap_utils::framestreams_codec::{self, Frame, FrameStreamsCodec};
use crate::Channels;
pub struct HttpHandler {
http_address: SocketAddr,
channels: Channels,
}
struct HttpChannel {
receiver: Receiver<dnstap::Dnstap>,
success_metric: &'static GenericCounter<AtomicU64>,
}
impl HttpHandler {
pub fn new(http_address: SocketAddr, channels: &Channels) -> Self {
HttpHandler {
http_address,
channels: channels.clone(),
}
}
pub async fn run(&self) -> Result<()> {
let channel_error = self.channels.error_receiver.clone();
let channel_timeout = self.channels.timeout_receiver.clone();
let make_svc = make_service_fn(move |_| {
let channel_error = channel_error.clone();
let channel_timeout = channel_timeout.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let channel_error = HttpChannel {
receiver: channel_error.clone(),
success_metric: &crate::metrics::CHANNEL_ERROR_RX.success,
};
let channel_timeout = HttpChannel {
receiver: channel_timeout.clone(),
success_metric: &crate::metrics::CHANNEL_TIMEOUT_RX.success,
};
async move { http_service(req, channel_error, channel_timeout).await }
}))
}
});
let server = hyper::server::Server::try_bind(&self.http_address)?.serve(make_svc);
info!("HTTP server listening on http://{}", &self.http_address);
Ok(server.await?)
}
}
async fn http_service(
req: Request<Body>,
channel_error: HttpChannel,
channel_timeout: HttpChannel,
) -> Result<Response<Body>> {
match (req.method(), req.uri().path()) {
(&Method::GET, "/metrics") => get_metrics_response(),
(&Method::GET, "/errors") => get_channel_response(channel_error),
(&Method::GET, "/timeouts") => get_channel_response(channel_timeout),
_ => {
let mut not_found = Response::default();
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
fn get_metrics_response() -> Result<Response<Body>> {
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
let mut buffer = vec![];
encoder.encode(&metric_families, &mut buffer).unwrap();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap();
Ok(response)
}
fn get_channel_response(channel: HttpChannel) -> Result<Response<Body>> {
Ok(Response::new(Body::wrap_stream(dnstap_receiver_to_stream(
channel,
))))
}
fn dnstap_receiver_to_stream(
channel: HttpChannel,
) -> impl tokio_stream::Stream<Item = std::result::Result<BytesMut, std::io::Error>> {
let mut f = FrameStreamsCodec {};
stream! {
let mut buf = BytesMut::with_capacity(64);
f.encode(
Frame::ControlStart(framestreams_codec::encode_content_type_payload(
b"protobuf:dnstap.Dnstap",
)),
&mut buf,
)?;
yield Ok(buf);
loop {
match channel.receiver.try_recv() {
Ok(d) => {
channel.success_metric.inc();
let len = d.encoded_len();
let mut buf = BytesMut::with_capacity(4 + len);
buf.put_u32(len as u32);
d.encode(&mut buf).unwrap();
yield Ok(buf);
}
Err(_) => {
let mut buf = BytesMut::with_capacity(64);
f.encode(Frame::ControlStop, &mut buf)?;
yield Ok(buf);
break;
}
}
}
}
}