use std::{net::SocketAddr, time::Duration};
use http::{header::InvalidHeaderValue, status::InvalidStatusCode, HeaderMap};
use hyper::{
body, header,
server::conn::{AddrIncoming, AddrStream},
service::{make_service_fn, service_fn},
Body, Request, Response, Server, StatusCode,
};
use metrics::{register_counter, Counter};
use once_cell::unsync::OnceCell;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tracing::{debug, error, info};
use crate::signals::Shutdown;
use super::General;
#[allow(clippy::declare_interior_mutable_const)]
const RESPONSE: OnceCell<Vec<u8>> = OnceCell::new();
fn default_concurrent_requests_max() -> usize {
100
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("HTTP server error: {0}")]
Hyper(hyper::Error),
#[error("The configured content type value was not valid: {0}")]
InvalidContentType(InvalidHeaderValue),
#[error("The configured status code was not valid: {0}")]
InvalidStatusCode(InvalidStatusCode),
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum BodyVariant {
Nothing,
AwsKinesis,
Static(String),
}
fn default_body_variant() -> BodyVariant {
BodyVariant::Nothing
}
fn default_status_code() -> u16 {
StatusCode::OK.as_u16()
}
fn default_headers() -> HeaderMap {
let mut map = HeaderMap::new();
map.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
map
}
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
pub struct Config {
#[serde(default = "default_concurrent_requests_max")]
pub concurrent_requests_max: usize,
pub binding_addr: SocketAddr,
#[serde(default = "default_body_variant")]
pub body_variant: BodyVariant,
#[serde(with = "http_serde::header_map", default = "default_headers")]
pub headers: HeaderMap,
#[serde(default = "default_status_code")]
pub status: u16,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
struct KinesisPutRecordBatchResponseEntry {
error_code: Option<String>,
error_message: Option<String>,
record_id: String,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
struct KinesisPutRecordBatchResponse {
encrypted: Option<bool>,
failed_put_count: u32,
request_responses: Vec<KinesisPutRecordBatchResponseEntry>,
}
#[allow(clippy::borrow_interior_mutable_const)]
async fn srv(
status: StatusCode,
bytes_received: Counter,
requests_received: Counter,
body_variant: BodyVariant,
req: Request<Body>,
headers: HeaderMap,
) -> Result<Response<Body>, hyper::Error> {
bytes_received.increment(1);
let (parts, body) = req.into_parts();
let bytes = body::to_bytes(body).await?;
match crate::codec::decode(parts.headers.get(hyper::header::CONTENT_ENCODING), bytes) {
Err(response) => Ok(response),
Ok(body) => {
requests_received.increment(body.len() as u64);
let mut okay = Response::default();
*okay.status_mut() = status;
*okay.headers_mut() = headers;
let body_bytes = RESPONSE
.get_or_init(|| match body_variant {
BodyVariant::AwsKinesis => {
let response = KinesisPutRecordBatchResponse {
encrypted: None,
failed_put_count: 0,
request_responses: vec![KinesisPutRecordBatchResponseEntry {
error_code: None,
error_message: None,
record_id: "foobar".to_string(),
}],
};
serde_json::to_vec(&response).unwrap()
}
BodyVariant::Nothing => vec![],
BodyVariant::Static(val) => val.as_bytes().to_vec(),
})
.clone();
*okay.body_mut() = Body::from(body_bytes);
Ok(okay)
}
}
}
#[derive(Debug)]
pub struct Http {
httpd_addr: SocketAddr,
body_variant: BodyVariant,
concurrency_limit: usize,
shutdown: Shutdown,
headers: HeaderMap,
status: StatusCode,
metric_labels: Vec<(String, String)>,
}
impl Http {
pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result<Self, Error> {
let status = StatusCode::from_u16(config.status).map_err(Error::InvalidStatusCode)?;
let mut metric_labels = vec![
("component".to_string(), "blackhole".to_string()),
("component_name".to_string(), "http".to_string()),
];
if let Some(id) = general.id {
metric_labels.push(("id".to_string(), id));
}
Ok(Self {
httpd_addr: config.binding_addr,
body_variant: config.body_variant.clone(),
concurrency_limit: config.concurrent_requests_max,
headers: config.headers.clone(),
status,
shutdown,
metric_labels,
})
}
pub async fn run(mut self) -> Result<(), Error> {
let bytes_received = register_counter!("bytes_received", &self.metric_labels);
let requests_received = register_counter!("requests_received", &self.metric_labels);
let service = make_service_fn(|_: &AddrStream| {
let bytes_received = bytes_received.clone();
let requests_received = requests_received.clone();
let body_variant = self.body_variant.clone();
let headers = self.headers.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |request| {
debug!("REQUEST: {:?}", request);
srv(
self.status,
bytes_received.clone(),
requests_received.clone(),
body_variant.clone(),
request,
headers.clone(),
)
}))
}
});
let svc = ServiceBuilder::new()
.load_shed()
.concurrency_limit(self.concurrency_limit)
.timeout(Duration::from_secs(1))
.service(service);
let addr = AddrIncoming::bind(&self.httpd_addr)
.map(|mut addr| {
addr.set_keepalive(Some(Duration::from_secs(60)));
addr
})
.map_err(Error::Hyper)?;
let server = Server::builder(addr).serve(svc);
loop {
tokio::select! {
res = server => {
error!("server shutdown unexpectedly");
return res.map_err(Error::Hyper);
}
_ = self.shutdown.recv() => {
info!("shutdown signal received");
return Ok(())
}
}
}
}
}