1use crate::{
6 GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
7 X_VIA, get_host, get_http_version_str, response_for_panic,
8};
9use ahash::AHasher;
10use axum::Router;
11use axum::extract::Request;
12use axum::http::{HeaderValue, StatusCode, header};
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
16use http_body_util::BodyExt;
17use hyper::HeaderMap;
18use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
19use std::hash::Hasher;
20use std::sync::Arc;
21use std::time::Duration;
22use time::UtcDateTime;
23use tower::ServiceBuilder;
24use tower_http::catch_panic::CatchPanicLayer;
25use tower_http::classify::ServerErrorsFailureClass;
26use tower_http::compression::CompressionLayer;
27use tower_http::decompression::RequestDecompressionLayer;
28use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
29use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
30use tower_http::trace::TraceLayer;
31use tracing::Span;
32use uuid::Uuid;
33
34#[derive(Clone)]
35pub enum ServiceKind {
36 App,
37 Api,
38 Redirect,
39 Proxy,
40}
41
42#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
43pub fn apply_common_middleware<T>(
44 router: Router<T>,
45 state: &T,
46 server_span: Option<Span>,
47 domain: String,
48 log_headers: bool,
49 log_ips: bool,
50 redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
51 kind: ServiceKind,
52 via_domain: Option<String>,
53) -> Router
54where
55 T: Clone + Send + Sync + 'static,
56{
57 let redacted_hash_clone = redacted_hash.clone();
58
59 router
60 .with_state(state.clone())
61 .layer(
62 ServiceBuilder::new()
63 .layer(CatchPanicLayer::custom(response_for_panic))
64 .layer(RequestDecompressionLayer::new())
65 .layer(CompressionLayer::new()),
66 )
67 .layer(
68 ServiceBuilder::new()
69 .layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
70 .layer(
71 TraceLayer::new_for_http()
72 .make_span_with(move |req: &axum::http::Request<_>| {
73 let request_id = req
74 .headers()
75 .get(X_REQUEST_ID)
76 .and_then(|rid| {
77 rid.to_str()
78 .ok()
79 .and_then(|rid| Uuid::parse_str(rid).ok())
80 .map(tracing::field::display)
81 })
82 .unwrap_or(tracing::field::display(Uuid::new_v4()));
83
84 let host =
85 get_host(req.headers(), req.uri()).map(tracing::field::display);
86
87 let ip = crate::get_display_ip(log_ips, req);
88
89 let query = req.uri().query().map(tracing::field::display);
90
91 match kind {
92 ServiceKind::App => {
93 if let Some(server_span) = &server_span {
94 server_span.in_scope(|| {
95 tracing::info_span!(
96 "app",
97 %domain,
98 host,
99 rid = %request_id,
100 ip,
101 path = %req.uri().path(),
102 query,
103 )
104 })
105 } else {
106 tracing::info_span!(
107 "app",
108 %domain,
109 host,
110 rid = %request_id,
111 ip,
112 path = %req.uri().path(),
113 query,
114 )
115 }
116 }
117 ServiceKind::Proxy => {
118 if let Some(server_span) = &server_span {
119 server_span.in_scope(|| {
120 tracing::info_span!(
121 "proxy",
122 %domain,
123 host,
124 rid = %request_id,
125 ip,
126 path = %req.uri().path(),
127 query,
128 )
129 })
130 } else {
131 tracing::info_span!(
132 "proxy",
133 %domain,
134 host,
135 rid = %request_id,
136 ip,
137 path = %req.uri().path(),
138 query,
139 )
140 }
141 }
142 ServiceKind::Api => {
143 if let Some(server_span) = &server_span {
144 server_span.in_scope(|| {
145 tracing::info_span!(
146 "api",
147 %domain,
148 host,
149 rid = %request_id,
150 ip,
151 path = %req.uri().path(),
152 query,
153 )
154 })
155 } else {
156 tracing::info_span!(
157 "api",
158 %domain,
159 host,
160 rid = %request_id,
161 ip,
162 path = %req.uri().path(),
163 query,
164 )
165 }
166 }
167 ServiceKind::Redirect => {
168 if let Some(server_span) = &server_span {
169 server_span.in_scope(|| {
170 tracing::info_span!(
171 "redirect",
172 host,
173 rid = %request_id,
174 ip,
175 path = %req.uri().path(),
176 query,
177 )
178 })
179 } else {
180 tracing::info_span!(
181 "redirect",
182 host,
183 rid = %request_id,
184 ip,
185 path = %req.uri().path(),
186 query,
187 )
188 }
189 }
190 }
191 })
192 .on_request(move |req: &axum::http::Request<_>, _: &Span| {
193 let hd = log_headers
194 .then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
195
196 #[cfg(tracing_unstable)]
197 let headers = log_headers.then_some(tracing::field::valuable(&hd));
198 #[cfg(not(tracing_unstable))]
199 let headers = log_headers.then_some(tracing::field::debug(&hd));
200
201 tracing::info!(
202 version = ?req.version(),
203 method = %req.method(),
204 headers,
205 "req"
206 );
207 })
208 .on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
209 let hd = log_headers.then_some(HeadersDebug(
210 res.headers(),
211 redacted_hash_clone.clone(),
212 ));
213
214 #[cfg(tracing_unstable)]
215 let headers = log_headers.then_some(tracing::field::valuable(&hd));
216
217 #[cfg(not(tracing_unstable))]
218 let headers = log_headers.then_some(tracing::field::debug(&hd));
219
220 let status = res.status().as_u16();
221 let latency = LatencyDisplay(latency.as_nanos() as f64);
222
223 if status >= 500 {
224 tracing::error!(status, headers, %latency, "res");
225 } else if status >= 400 {
226 tracing::warn!(status, headers, %latency, "res");
227 } else {
228 tracing::info!(status, headers, %latency, "res");
229 }
230 })
231 .on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
232 tracing::error!(
233 err = %error,
234 "fail"
235 );
236 }),
237 )
238 .layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
239 .layer(SetResponseHeaderLayer::if_not_present(
240 header::SERVER,
241 HeaderValue::from_static(SERVER),
242 ))
243 .option_layer(via_domain.map(|domain| {
244 SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
245 let req_version = get_http_version_str(req.version());
246 HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
247 })
248 })),
249 )
250}
251
252#[allow(clippy::similar_names)]
253pub async fn http_cache_middleware(
254 last_modified: UtcDateTime,
255 req_headers: HeaderMap,
256 request: Request,
257 next: Next,
258) -> Response {
259 let response = next.run(request).await;
260 let (mut parts, body) = response.into_parts();
261
262 let body_bytes = if let Ok(collected) = body.collect().await {
263 collected.to_bytes()
264 } else {
265 return StatusCode::INTERNAL_SERVER_ERROR.into_response();
266 };
267
268 let mut res_headers = HeaderMap::new();
269
270 let etag_string = get_etag_hash(body_bytes.as_ref(), None);
271 let etag_str = etag_string.as_str();
272
273 if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
274 && let Ok(if_none_match_str) = if_none_match.to_str()
275 && if_none_match_str == etag_str
276 {
277 res_headers.insert(header::ETAG, if_none_match.to_owned());
278
279 return (StatusCode::NOT_MODIFIED, res_headers).into_response();
280 } else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
281 if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
282 && let Ok(if_modified_since_str) = if_modified_since.to_str()
283 && let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
284 && if_modified_since >= last_modified
285 {
286 res_headers.insert(header::ETAG, etag_header);
287 return (StatusCode::NOT_MODIFIED, res_headers).into_response();
288 }
289
290 parts.headers.insert(header::ETAG, etag_header);
291 }
292
293 (parts, body_bytes).into_response()
294}
295
296#[must_use]
297pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
298 if let Some(http_cache) = http_cache
299 && let Some(etag_config) = &http_cache.etag
300 && let Some(etag_alg) = &etag_config.alg
301 {
302 return match etag_alg {
303 HttpEtagAlgorithm::AHash => {
304 let mut hasher = AHasher::default();
305 hasher.write(content);
306 b64.encode(hasher.finish().to_be_bytes())
307 }
308 HttpEtagAlgorithm::XXH3(variation) => match variation {
309 XXH3Variation::Bit64 => {
310 b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
311 }
312 XXH3Variation::Bit128 => {
313 b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
314 }
315 },
316 HttpEtagAlgorithm::Rustc => {
317 let mut hasher = rustc_hash::FxHasher::default();
318 hasher.write(content);
319
320 b64.encode(hasher.finish().to_be_bytes())
321 }
322 HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
323 };
324 }
325
326 let mut hasher = AHasher::default();
327 hasher.write(content);
328 b64.encode(hasher.finish().to_be_bytes())
329}
330
331#[must_use]
332pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
333 let mut response = next.run(request).await;
334
335 if let Some(x_via) = headers.get(X_VIA) {
336 response.headers_mut().insert(header::VIA, x_via.to_owned());
337 }
338
339 response
340}
341
342pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
343 let headers = res.headers();
344
345 if let Some(curr_etag) = headers.get(header::ETAG)
346 && let Ok(curr_etag_str) = curr_etag.to_str()
347 {
348 let etag_len = curr_etag_str.len();
349
350 if (etag_len == 22 || etag_len == 11)
351 && let Some(compression) = headers.get(header::CONTENT_ENCODING)
352 && let Ok(compression_str) = compression.to_str()
353 {
354 let mut etag_string = curr_etag_str.to_owned();
355
356 match compression_str {
357 "gzip" => etag_string.push('1'),
358 "zstd" => etag_string.push('2'),
359 "br" => etag_string.push('3'),
360 "deflate" => etag_string.push('4'),
361 _ => (),
362 }
363
364 match HeaderValue::from_str(etag_string.as_str()) {
365 Ok(v) => return Some(v),
366 Err(err) => tracing::error!(%err),
367 }
368 } else {
369 return Some(curr_etag.clone());
370 }
371 }
372
373 None
374}
375
376pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
377 if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
378 && let Ok(if_none_match_str) = if_none_match.to_str()
379 {
380 if if_none_match_str.len() < 11 {
381 return None;
382 }
383
384 if (etag.len() == 23
385 || etag.len() == 12
386 || if_none_match_str.len() == 22
387 || if_none_match_str.len() == 11)
388 && if_none_match_str == etag
389 {
390 return Some(etag);
391 }
392
393 if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
394 Some(if_none_match_str)
395 } else {
396 None
397 }
398 } else {
399 None
400 }
401}