1use crate::{
6 GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
7 X_VIA, get_host, get_http_version_str, response_for_panic,
8};
9use ahash::AHasher;
10use axum::Router;
11use axum::body::HttpBody;
12use axum::extract::Request;
13use axum::http::{HeaderValue, StatusCode, header};
14use axum::middleware::Next;
15use axum::response::{IntoResponse, Response};
16use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
17use hyper::HeaderMap;
18use hyper::header::LAST_MODIFIED;
19use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
20use std::hash::Hasher;
21use std::sync::Arc;
22use std::time::Duration;
23use time::UtcDateTime;
24use tower::ServiceBuilder;
25use tower_http::catch_panic::CatchPanicLayer;
26use tower_http::classify::ServerErrorsFailureClass;
27use tower_http::compression::CompressionLayer;
28use tower_http::decompression::RequestDecompressionLayer;
29use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
30use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
31use tower_http::trace::TraceLayer;
32use tracing::Span;
33use uuid::Uuid;
34
35#[derive(Clone)]
36pub enum ServiceKind {
37 App,
38 Api,
39 Redirect,
40 Proxy,
41}
42
43#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
44pub fn apply_common_middleware<T>(
45 router: Router<T>,
46 state: &T,
47 server_span: Option<Span>,
48 domain: String,
49 log_headers: bool,
50 log_ips: bool,
51 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
52 kind: ServiceKind,
53 via_domain: Option<String>,
54) -> Router
55where
56 T: Clone + Send + Sync + 'static,
57{
58 let redacted_hash_clone = redacted_hash.clone();
59
60 router
61 .with_state(state.clone())
62 .layer(
63 ServiceBuilder::new()
64 .layer(CatchPanicLayer::custom(response_for_panic))
65 .layer(RequestDecompressionLayer::new())
66 .layer(CompressionLayer::new()),
67 )
68 .layer(
69 ServiceBuilder::new()
70 .layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
71 .layer(
72 TraceLayer::new_for_http()
73 .make_span_with(move |req: &axum::http::Request<_>| {
74 let request_id = req
75 .headers()
76 .get(X_REQUEST_ID)
77 .and_then(|rid| {
78 rid.to_str()
79 .ok()
80 .and_then(|rid| Uuid::parse_str(rid).ok())
81 .map(tracing::field::display)
82 })
83 .unwrap_or(tracing::field::display(Uuid::new_v4()));
84
85 let host =
86 get_host(req.headers(), req.uri()).map(tracing::field::display);
87
88 let ip = crate::get_display_ip(log_ips, req);
89
90 let query = req.uri().query().map(tracing::field::display);
91
92 match kind {
93 ServiceKind::App => {
94 if let Some(server_span) = &server_span {
95 server_span.in_scope(|| {
96 tracing::info_span!(
97 "app",
98 %domain,
99 host,
100 rid = %request_id,
101 ip,
102 path = %req.uri().path(),
103 query,
104 )
105 })
106 } else {
107 tracing::info_span!(
108 "app",
109 %domain,
110 host,
111 rid = %request_id,
112 ip,
113 path = %req.uri().path(),
114 query,
115 )
116 }
117 }
118 ServiceKind::Proxy => {
119 if let Some(server_span) = &server_span {
120 server_span.in_scope(|| {
121 tracing::info_span!(
122 "proxy",
123 %domain,
124 host,
125 rid = %request_id,
126 ip,
127 path = %req.uri().path(),
128 query,
129 )
130 })
131 } else {
132 tracing::info_span!(
133 "proxy",
134 %domain,
135 host,
136 rid = %request_id,
137 ip,
138 path = %req.uri().path(),
139 query,
140 )
141 }
142 }
143 ServiceKind::Api => {
144 if let Some(server_span) = &server_span {
145 server_span.in_scope(|| {
146 tracing::info_span!(
147 "api",
148 %domain,
149 host,
150 rid = %request_id,
151 ip,
152 path = %req.uri().path(),
153 query,
154 )
155 })
156 } else {
157 tracing::info_span!(
158 "api",
159 %domain,
160 host,
161 rid = %request_id,
162 ip,
163 path = %req.uri().path(),
164 query,
165 )
166 }
167 }
168 ServiceKind::Redirect => {
169 if let Some(server_span) = &server_span {
170 server_span.in_scope(|| {
171 tracing::info_span!(
172 "redirect",
173 host,
174 rid = %request_id,
175 ip,
176 path = %req.uri().path(),
177 query,
178 )
179 })
180 } else {
181 tracing::info_span!(
182 "redirect",
183 host,
184 rid = %request_id,
185 ip,
186 path = %req.uri().path(),
187 query,
188 )
189 }
190 }
191 }
192 })
193 .on_request(move |req: &axum::http::Request<_>, _: &Span| {
194 let hd = log_headers
195 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
196
197 #[cfg(tracing_unstable)]
198 let headers = log_headers.then_some(tracing::field::valuable(&hd));
199 #[cfg(not(tracing_unstable))]
200 let headers = log_headers.then_some(tracing::field::debug(&hd));
201
202 tracing::info!(
203 version = ?req.version(),
204 method = %req.method(),
205 headers,
206 "req"
207 );
208 })
209 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
210 let hd = log_headers.then_some(HeadersDebug(
211 res.headers(),
212 redacted_hash_clone.clone(),
213 ));
214
215 #[cfg(tracing_unstable)]
216 let headers = log_headers.then_some(tracing::field::valuable(&hd));
217
218 #[cfg(not(tracing_unstable))]
219 let headers = log_headers.then_some(tracing::field::debug(&hd));
220
221 let status = res.status().as_u16();
222 let latency = LatencyDisplay(latency.as_nanos() as f64);
223
224 if status >= 500 {
225 tracing::error!(status, headers, %latency, "res");
226 } else if status >= 400 {
227 tracing::warn!(status, headers, %latency, "res");
228 } else {
229 tracing::info!(status, headers, %latency, "res");
230 }
231 })
232 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
233 tracing::error!(
234 err = %error,
235 "fail"
236 );
237 }),
238 )
239 .layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
240 .layer(SetResponseHeaderLayer::if_not_present(
241 header::SERVER,
242 HeaderValue::from_static(SERVER),
243 ))
244 .option_layer(via_domain.map(|domain| {
245 SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
246 let req_version = get_http_version_str(req.version());
247 HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
248 })
249 })),
250 )
251}
252
253#[allow(clippy::similar_names)]
254pub async fn http_cache_middleware(
255 last_modified: UtcDateTime,
256 last_modified_header: HeaderValue,
257 req_headers: HeaderMap,
258 request: Request,
259 next: Next,
260) -> Response {
261 let response = next.run(request).await;
262 let (mut parts, body) = response.into_parts();
263
264 let body_bytes = if let Some(limit) = body.size_hint().upper()
265 && let Ok(limit) = usize::try_from(limit)
266 && let Ok(body_bytes) = axum::body::to_bytes(body, limit).await
267 {
268 body_bytes
269 } else {
270 return StatusCode::INTERNAL_SERVER_ERROR.into_response();
271 };
272
273 let mut res_headers = HeaderMap::new();
274 res_headers.insert(LAST_MODIFIED, last_modified_header);
275
276 let etag_string = get_etag_hash(body_bytes.as_ref(), None);
277 let etag_str = etag_string.as_str();
278
279 if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
280 && let Ok(if_none_match_str) = if_none_match.to_str()
281 && if_none_match_str == etag_str
282 {
283 res_headers.insert(header::ETAG, if_none_match.to_owned());
284
285 return (StatusCode::NOT_MODIFIED, res_headers).into_response();
286 } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
287 if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
288 && let Ok(if_modified_since_str) = if_modified_since.to_str()
289 && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
290 && if_modified_since >= last_modified
291 {
292 res_headers.insert(header::ETAG, etag_header);
293 return (StatusCode::NOT_MODIFIED, res_headers).into_response();
294 }
295
296 parts.headers.insert(header::ETAG, etag_header);
297 }
298
299 (parts, body_bytes).into_response()
300}
301
302#[must_use]
303pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
304 if let Some(http_cache) = http_cache
305 && let Some(etag_config) = &http_cache.etag
306 && let Some(etag_alg) = &etag_config.alg
307 {
308 return match etag_alg {
309 HttpEtagAlgorithm::AHash => {
310 let mut hasher = AHasher::default();
311 hasher.write(content);
312 b64.encode(hasher.finish().to_be_bytes())
313 }
314 HttpEtagAlgorithm::XXH3(variation) => match variation {
315 XXH3Variation::Bit64 => {
316 b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
317 }
318 XXH3Variation::Bit128 => {
319 b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
320 }
321 },
322 HttpEtagAlgorithm::Rustc => {
323 let mut hasher = rustc_hash::FxHasher::default();
324 hasher.write(content);
325
326 b64.encode(hasher.finish().to_be_bytes())
327 }
328 HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
329 };
330 }
331
332 let mut hasher = AHasher::default();
333 hasher.write(content);
334 b64.encode(hasher.finish().to_be_bytes())
335}
336
337#[must_use]
339pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
340 let mut response = next.run(request).await;
341
342 if let Some(x_via) = headers.get(X_VIA) {
343 response.headers_mut().insert(header::VIA, x_via.to_owned());
344 }
345
346 response
347}
348
349pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
350 let headers = res.headers();
351
352 if let Some(curr_etag) = headers.get(header::ETAG)
353 && let Ok(curr_etag_str) = curr_etag.to_str()
354 {
355 let etag_len = curr_etag_str.len();
356
357 if (etag_len == 22 || etag_len == 11)
358 && let Some(compression) = headers.get(header::CONTENT_ENCODING)
359 && let Ok(compression_str) = compression.to_str()
360 {
361 let mut etag_string = curr_etag_str.to_owned();
362
363 match compression_str {
364 "gzip" => etag_string.push('1'),
365 "zstd" => etag_string.push('2'),
366 "br" => etag_string.push('3'),
367 "deflate" => etag_string.push('4'),
368 _ => (),
369 }
370
371 match HeaderValue::from_str(etag_string.as_str()) {
372 Ok(v) => return Some(v),
373 Err(err) => tracing::error!(%err),
374 }
375 } else {
376 return Some(curr_etag.clone());
377 }
378 }
379
380 None
381}
382
383pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
384 if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
385 && let Ok(if_none_match_str) = if_none_match.to_str()
386 {
387 if if_none_match_str.len() < 11 {
388 return None;
389 }
390
391 if (etag.len() == 23
392 || etag.len() == 12
393 || if_none_match_str.len() == 22
394 || if_none_match_str.len() == 11)
395 && if_none_match_str == etag
396 {
397 return Some(etag);
398 }
399
400 if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
401 Some(if_none_match_str)
402 } else {
403 None
404 }
405 } else {
406 None
407 }
408}