lading/blackhole/
http.rs

1//! The HTTP protocol speaking blackhole.
2//!
3//! ## Metrics
4//!
5//! `bytes_received`: Total bytes received
6//! `requests_received`: Total requests received
7//!
8
9use std::{net::SocketAddr, time::Duration};
10
11use http::{header::InvalidHeaderValue, status::InvalidStatusCode, HeaderMap};
12use hyper::{
13    body, header,
14    server::conn::{AddrIncoming, AddrStream},
15    service::{make_service_fn, service_fn},
16    Body, Request, Response, Server, StatusCode,
17};
18use metrics::{register_counter, Counter};
19use once_cell::unsync::OnceCell;
20use serde::{Deserialize, Serialize};
21use tower::ServiceBuilder;
22use tracing::{debug, error, info};
23
24use crate::signals::Shutdown;
25
26use super::General;
27
28#[allow(clippy::declare_interior_mutable_const)]
29const RESPONSE: OnceCell<Vec<u8>> = OnceCell::new();
30
31fn default_concurrent_requests_max() -> usize {
32    100
33}
34
35/// Errors produced by [`Http`].
36#[derive(thiserror::Error, Debug)]
37pub enum Error {
38    /// Wrapper for [`hyper::Error`].
39    #[error("HTTP server error: {0}")]
40    Hyper(hyper::Error),
41    /// The configured content type value was not valid.
42    #[error("The configured content type value was not valid: {0}")]
43    InvalidContentType(InvalidHeaderValue),
44    /// The configured status code was not valid.
45    #[error("The configured status code was not valid: {0}")]
46    InvalidStatusCode(InvalidStatusCode),
47}
48
49/// Body variant supported by this blackhole.
50#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
51#[serde(rename_all = "snake_case")]
52pub enum BodyVariant {
53    /// All response bodies will be empty.
54    Nothing,
55    /// All response bodies will mimic AWS Kinesis.
56    AwsKinesis,
57    /// Respond with a hardcoded string value
58    Static(String),
59}
60
61fn default_body_variant() -> BodyVariant {
62    BodyVariant::Nothing
63}
64
65fn default_status_code() -> u16 {
66    StatusCode::OK.as_u16()
67}
68
69fn default_headers() -> HeaderMap {
70    let mut map = HeaderMap::new();
71    map.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
72    map
73}
74
75#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
76/// Configuration for [`Http`]
77pub struct Config {
78    /// number of concurrent HTTP connections to allow
79    #[serde(default = "default_concurrent_requests_max")]
80    pub concurrent_requests_max: usize,
81    /// address -- IP plus port -- to bind to
82    pub binding_addr: SocketAddr,
83    /// the body variant to respond with, default nothing
84    #[serde(default = "default_body_variant")]
85    pub body_variant: BodyVariant,
86    /// Headers to include in the response; default is `Content-Type: application/json`
87    #[serde(with = "http_serde::header_map", default = "default_headers")]
88    pub headers: HeaderMap,
89    /// the content-type header to respond with, defaults to 200
90    #[serde(default = "default_status_code")]
91    pub status: u16,
92}
93
94#[derive(Serialize)]
95#[serde(rename_all = "snake_case")]
96struct KinesisPutRecordBatchResponseEntry {
97    error_code: Option<String>,
98    error_message: Option<String>,
99    record_id: String,
100}
101
102#[derive(Serialize)]
103#[serde(rename_all = "snake_case")]
104struct KinesisPutRecordBatchResponse {
105    encrypted: Option<bool>,
106    failed_put_count: u32,
107    request_responses: Vec<KinesisPutRecordBatchResponseEntry>,
108}
109
110#[allow(clippy::borrow_interior_mutable_const)]
111async fn srv(
112    status: StatusCode,
113    bytes_received: Counter,
114    requests_received: Counter,
115    body_variant: BodyVariant,
116    req: Request<Body>,
117    headers: HeaderMap,
118) -> Result<Response<Body>, hyper::Error> {
119    bytes_received.increment(1);
120
121    let (parts, body) = req.into_parts();
122
123    let bytes = body::to_bytes(body).await?;
124
125    match crate::codec::decode(parts.headers.get(hyper::header::CONTENT_ENCODING), bytes) {
126        Err(response) => Ok(response),
127        Ok(body) => {
128            requests_received.increment(body.len() as u64);
129
130            let mut okay = Response::default();
131            *okay.status_mut() = status;
132
133            *okay.headers_mut() = headers;
134
135            let body_bytes = RESPONSE
136                .get_or_init(|| match body_variant {
137                    BodyVariant::AwsKinesis => {
138                        let response = KinesisPutRecordBatchResponse {
139                            encrypted: None,
140                            failed_put_count: 0,
141                            request_responses: vec![KinesisPutRecordBatchResponseEntry {
142                                error_code: None,
143                                error_message: None,
144                                record_id: "foobar".to_string(),
145                            }],
146                        };
147                        serde_json::to_vec(&response).unwrap()
148                    }
149                    BodyVariant::Nothing => vec![],
150                    BodyVariant::Static(val) => val.as_bytes().to_vec(),
151                })
152                .clone();
153            *okay.body_mut() = Body::from(body_bytes);
154            Ok(okay)
155        }
156    }
157}
158
159#[derive(Debug)]
160/// The HTTP blackhole.
161pub struct Http {
162    httpd_addr: SocketAddr,
163    body_variant: BodyVariant,
164    concurrency_limit: usize,
165    shutdown: Shutdown,
166    headers: HeaderMap,
167    status: StatusCode,
168    metric_labels: Vec<(String, String)>,
169}
170
171impl Http {
172    /// Create a new [`Http`] server instance
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if the configuration is invalid.
177    pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result<Self, Error> {
178        let status = StatusCode::from_u16(config.status).map_err(Error::InvalidStatusCode)?;
179
180        let mut metric_labels = vec![
181            ("component".to_string(), "blackhole".to_string()),
182            ("component_name".to_string(), "http".to_string()),
183        ];
184        if let Some(id) = general.id {
185            metric_labels.push(("id".to_string(), id));
186        }
187
188        Ok(Self {
189            httpd_addr: config.binding_addr,
190            body_variant: config.body_variant.clone(),
191            concurrency_limit: config.concurrent_requests_max,
192            headers: config.headers.clone(),
193            status,
194            shutdown,
195            metric_labels,
196        })
197    }
198
199    /// Run [`Http`] to completion
200    ///
201    /// This function runs the HTTP server forever, unless a shutdown signal is
202    /// received or an unrecoverable error is encountered.
203    ///
204    /// # Errors
205    ///
206    /// Function will return an error if the configuration is invalid or if
207    /// receiving a packet fails.
208    ///
209    /// # Panics
210    ///
211    /// None known.
212    pub async fn run(mut self) -> Result<(), Error> {
213        let bytes_received = register_counter!("bytes_received", &self.metric_labels);
214        let requests_received = register_counter!("requests_received", &self.metric_labels);
215        let service = make_service_fn(|_: &AddrStream| {
216            let bytes_received = bytes_received.clone();
217            let requests_received = requests_received.clone();
218            let body_variant = self.body_variant.clone();
219            let headers = self.headers.clone();
220            async move {
221                Ok::<_, hyper::Error>(service_fn(move |request| {
222                    debug!("REQUEST: {:?}", request);
223                    srv(
224                        self.status,
225                        bytes_received.clone(),
226                        requests_received.clone(),
227                        body_variant.clone(),
228                        request,
229                        headers.clone(),
230                    )
231                }))
232            }
233        });
234        let svc = ServiceBuilder::new()
235            .load_shed()
236            .concurrency_limit(self.concurrency_limit)
237            .timeout(Duration::from_secs(1))
238            .service(service);
239
240        let addr = AddrIncoming::bind(&self.httpd_addr)
241            .map(|mut addr| {
242                addr.set_keepalive(Some(Duration::from_secs(60)));
243                addr
244            })
245            .map_err(Error::Hyper)?;
246
247        let server = Server::builder(addr).serve(svc);
248        loop {
249            tokio::select! {
250                res = server => {
251                    error!("server shutdown unexpectedly");
252                    return res.map_err(Error::Hyper);
253                }
254                _ = self.shutdown.recv() => {
255                    info!("shutdown signal received");
256                    return Ok(())
257                }
258            }
259        }
260    }
261}