actix_middleware_etag/
lib.rs

1#![deny(missing_docs)]
2#![deny(unsafe_code)]
3//! # Actix Middleware - ETag
4//!
5//! To avoid sending unnecessary bodies downstream, this middleware handles comparing If-None-Match headers
6//! to the calculated hash of the body of the GET request.
7//! Inspired by Node's [express framework](http://expressjs.com/en/api.html#etag.options.table) and how it does ETag calculation, this middleware behaves in a similar fashion.
8//!
9//! First hash the resulting body, then base64 encode the hash and set this as the ETag header for the GET request.
10//!
11//! This does not save CPU resources on server side, since the body is still being calculated.
12//!
13//! Beware: This middleware does not look at headers, so if you need to refresh your headers even if body is exactly the same, use something else
14//! (or better yet, add a PR on this repo adding a sane way to adhere to headers as well)
15use std::pin::Pin;
16
17use actix_service::{forward_ready, Service, Transform};
18use actix_web::body::{BodySize, BoxBody, EitherBody, MessageBody, None as BodyNone};
19use actix_web::dev::{ServiceRequest, ServiceResponse};
20use actix_web::http::header::{ETag, EntityTag, Header, IfNoneMatch, TryIntoHeaderPair};
21use actix_web::http::Method;
22use actix_web::web::Bytes;
23use actix_web::{HttpMessage, HttpResponse};
24use base64::Engine;
25use core::fmt::Write;
26use futures::{
27    future::{ok, Ready},
28    Future,
29};
30use xxhash_rust::xxh3::xxh3_128;
31
32///
33/// This should be loaded as the last middleware, as in, first in the sequence of wrap()
34/// Actix loads middlewares in bottom up fashion, and we want to have the resulting body from processing the entire request
35
36/// # Examples
37/// ```no_run
38/// use actix_web::{web, App, HttpServer, HttpResponse, Error};
39/// use actix_middleware_etag::{Etag};
40///
41///
42/// #[actix_web::main]
43/// async fn main() -> std::io::Result<()> {
44///     HttpServer::new(move ||
45///             App::new()
46///             // Add etag headers to your actix application. Calculating the hash of your GET bodies and putting the base64 hash in the ETag header
47///             .wrap(Etag::default())
48///             .default_service(web::to(|| HttpResponse::Ok())))
49///         .bind(("127.0.0.1", 8080))?
50///         .run()
51///         .await
52/// }
53/// ```
54#[derive(Debug, Default)]
55pub struct Etag {
56    /// If true, always generate a strong ETag instead of a weak one.
57    pub force_strong_etag: bool,
58}
59
60impl<S, B> Transform<S, ServiceRequest> for Etag
61where
62    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
63    S::Future: 'static,
64    B: MessageBody + 'static,
65{
66    type Response = ServiceResponse<EitherBody<BoxBody>>;
67    type Error = actix_web::Error;
68    type Transform = EtagMiddleware<S>;
69    type InitError = ();
70    type Future = Ready<Result<Self::Transform, Self::InitError>>;
71
72    fn new_transform(&self, service: S) -> Self::Future {
73        ok(EtagMiddleware {
74            service,
75            force_strong_etag: self.force_strong_etag,
76        })
77    }
78}
79type Buffer = str_buf::StrBuf<62>;
80///
81/// The service holder for the transform that should happen
82pub struct EtagMiddleware<S> {
83    service: S,
84    force_strong_etag: bool,
85}
86
87impl<S, B> Service<ServiceRequest> for EtagMiddleware<S>
88where
89    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
90    S::Future: 'static,
91    B: MessageBody + 'static,
92{
93    type Response = ServiceResponse<EitherBody<BoxBody>>;
94    type Error = actix_web::Error;
95    #[allow(clippy::type_complexity)]
96    type Future =
97        Pin<Box<dyn Future<Output = Result<ServiceResponse<EitherBody<BoxBody>>, Self::Error>>>>;
98    forward_ready!(service);
99
100    fn call(&self, req: ServiceRequest) -> Self::Future {
101        let request_etag_header: Option<IfNoneMatch> = req.get_header();
102        let method = req.method().clone();
103        let fut = self.service.call(req);
104        let force_strong_etag = self.force_strong_etag;
105        Box::pin(async move {
106            let res: ServiceResponse<B> = fut.await?;
107            match method {
108                Method::GET => {
109                    let mut modified = true;
110                    let mut payload: Option<Bytes> = None;
111                    let mut res = res.map_body(|_h, body| match body.size() {
112                        BodySize::Sized(_size) => {
113                            let bytes = body.try_into_bytes().unwrap_or_else(|_| Bytes::new());
114                            payload = Some(bytes.clone());
115                            bytes.clone().boxed()
116                        }
117                        _ => body.boxed(),
118                    });
119                    if let Some(bytes) = payload {
120                        let custom_etag = res.response().headers().get(ETag::name());
121                        let tag = match custom_etag.and_then(|etag| etag.to_str().ok()) {
122                            Some(custom_etag) => EntityTag::new_strong(custom_etag.to_owned()),
123                            None => {
124                                let response_hash = xxh3_128(&bytes);
125                                let base64 = base64::prelude::BASE64_URL_SAFE
126                                    .encode(response_hash.to_le_bytes());
127                                let mut buff = Buffer::new();
128                                let _ = write!(buff, "{:x}-{}", bytes.len(), base64);
129                                if force_strong_etag {
130                                    EntityTag::new_strong(buff.to_string())
131                                } else {
132                                    EntityTag::new_weak(buff.to_string())
133                                }
134                            }
135                        };
136
137                        if let Some(request_etag_header) = request_etag_header {
138                            if request_etag_header == IfNoneMatch::Any
139                                || request_etag_header.to_string() == tag.to_string()
140                            {
141                                modified = false
142                            }
143                        }
144                        if modified {
145                            if let Ok((name, value)) = ETag(tag.clone()).try_into_pair() {
146                                res.headers_mut().insert(name, value);
147                            }
148                        }
149                    }
150
151                    Ok(match modified {
152                        false => res
153                            .into_response(HttpResponse::NotModified().body(BodyNone::new()))
154                            .map_into_right_body(),
155                        true => res.map_into_left_body(),
156                    })
157                }
158                _ => Ok(res.map_into_boxed_body().map_into_left_body()),
159            }
160        })
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::io::Read;
167
168    use super::*;
169    use actix_service::IntoService;
170    use actix_web::http::header::{ETag, EntityTag, Header, HeaderName};
171    use actix_web::{
172        http::StatusCode,
173        test::{call_service, init_service, TestRequest},
174        web, App, Responder,
175    };
176
177    async fn index() -> impl Responder {
178        HttpResponse::Ok().body("abcd")
179    }
180
181    async fn image() -> impl Responder {
182        HttpResponse::Ok()
183            .content_type("image/png")
184            .body(&include_bytes!("assets/favicon.ico")[..])
185    }
186
187    #[actix_web::test]
188    async fn test_generates_etag() {
189        let srv = |req: ServiceRequest| {
190            ok(req.into_response(HttpResponse::build(StatusCode::OK).body("abc")))
191        };
192        let etag_service = Etag::default();
193        let srv = etag_service
194            .new_transform(srv.into_service())
195            .await
196            .unwrap();
197
198        let req = TestRequest::default().to_srv_request();
199        let res = srv.call(req).await;
200        if let Ok(response) = res {
201            assert_eq!(response.status(), StatusCode::OK);
202            let headers = response.headers();
203            let etag = HeaderName::from_lowercase(b"etag").unwrap();
204            let etag = headers.get(etag);
205            assert_eq!(
206                etag.unwrap().to_str().unwrap(),
207                r#"W/"3-UDkviZRfr3iFYTpztlqwBg==""#
208            );
209        } else {
210            panic!("No response was generated!");
211        }
212    }
213
214    #[actix_web::test]
215    async fn test_any_data_matches_wildcard_etag() {
216        let mut app = init_service(
217            App::new()
218                .wrap(Etag::default())
219                .route("/", web::get().to(index)),
220        )
221        .await;
222
223        let match_header = IfNoneMatch::Any;
224        let req = TestRequest::default()
225            .append_header(match_header)
226            .to_request();
227        let res = call_service(&mut app, req).await;
228        assert_eq!(res.status(), StatusCode::NOT_MODIFIED)
229    }
230
231    #[actix_web::test]
232    async fn test_generates_etag_on_changes() {
233        let mut app = init_service(
234            App::new()
235                .wrap(Etag::default())
236                .route("/", web::get().to(index)),
237        )
238        .await;
239        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak(
240            "3-UDkviZRfr3iFYTpztlqwBg==".to_string(),
241        )]);
242        let req = TestRequest::default()
243            .append_header(match_header)
244            .to_request();
245        let res = call_service(&mut app, req).await;
246        let etag = res.headers().get(ETag::name()).unwrap();
247        assert_eq!(etag.to_str().unwrap(), r#"W/"4-PTWx0eye5xvCkPo9OGBrjQ==""#);
248        assert!(res.status().is_success());
249    }
250
251    #[actix_web::test]
252    async fn test_body_gets_preserved() {
253        let mut app = init_service(
254            App::new()
255                .wrap(Etag::default())
256                .route("/", web::get().to(index)),
257        )
258        .await;
259        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak(
260            "UDkviZRfr3iFYTpztlqwBg==".to_string(),
261        )]);
262        let req = TestRequest::default()
263            .append_header(match_header)
264            .to_request();
265        let res = call_service(&mut app, req).await;
266        assert!(res.status().is_success());
267        let body = res.into_body();
268        let body: Bytes = body.try_into_bytes().unwrap();
269        let example: Bytes = Bytes::from("abcd");
270        assert!(example.bytes().zip(body).all(|(a, b)| a.unwrap() == b));
271    }
272
273    #[actix_web::test]
274    async fn test_favicon_generates_correct_status_coded_on_etag_match() {
275        let mut app = init_service(
276            App::new()
277                .wrap(Etag::default())
278                .route("/", web::get().to(image)),
279        )
280        .await;
281        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak(
282            "3aee-m0RKLkLoLS6kJ1N8xt0D5A==".to_string(),
283        )]);
284        let req = TestRequest::default()
285            .append_header(match_header)
286            .to_request();
287        let res = call_service(&mut app, req).await;
288        assert_eq!(res.status(), StatusCode::NOT_MODIFIED);
289        assert_eq!(res.into_body().size(), BodySize::None);
290    }
291
292    #[actix_web::test]
293    async fn test_favicon_data_works() {
294        let mut app = init_service(
295            App::new()
296                .wrap(Etag::default())
297                .route("/", web::get().to(image)),
298        )
299        .await;
300
301        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak(
302            "UDkviZRfr3iFYTpztlqwBg==".to_string(),
303        )]);
304        let req = TestRequest::default()
305            .append_header(match_header)
306            .to_request();
307        let res = call_service(&mut app, req).await;
308
309        let etag = res.headers().get(ETag::name()).unwrap();
310        assert_eq!(
311            etag.to_str().unwrap(),
312            r#"W/"3aee-m0RKLkLoLS6kJ1N8xt0D5A==""#
313        )
314    }
315
316    #[actix_web::test]
317    async fn does_not_add_etag_header_to_post_request() {
318        let mut app = init_service(
319            App::new()
320                .wrap(Etag::default())
321                .route("/", web::post().to(image)),
322        )
323        .await;
324
325        let req = TestRequest::default().method(Method::POST).to_request();
326        let res = call_service(&mut app, req).await;
327
328        assert_eq!(res.headers().get(ETag::name()), None)
329    }
330
331    #[actix_web::test]
332    async fn still_empty_body_when_compress_middleware_is_added() {
333        let mut app = init_service(
334            App::new()
335                .wrap(Etag::default())
336                .wrap(actix_web::middleware::Compress::default())
337                .route("/", web::get().to(image)),
338        )
339        .await;
340        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak(
341            "3aee-m0RKLkLoLS6kJ1N8xt0D5A==".to_string(),
342        )]);
343        let req = TestRequest::default()
344            .append_header(match_header)
345            .append_header(("Accept-Encoding", "gzip"))
346            .to_request();
347        let res = call_service(&mut app, req).await;
348
349        assert_eq!(res.status(), StatusCode::NOT_MODIFIED);
350        assert_eq!(res.into_body().size(), BodySize::None);
351    }
352
353    #[actix_web::test]
354    async fn test_explicit_etag_matches_if_none_match() {
355        let mut app = init_service(App::new().wrap(Etag::default()).route(
356            "/",
357            web::get().to(|| async {
358                HttpResponse::Ok()
359                    .insert_header((ETag::name(), "123"))
360                    .body("Response with a custom ETag")
361            }),
362        ))
363        .await;
364        let match_header = IfNoneMatch::Items(vec![EntityTag::new_strong("123".to_string())]);
365
366        let req = TestRequest::default()
367            .append_header(match_header)
368            .to_request();
369        let res = call_service(&mut app, req).await;
370
371        assert_eq!(res.status(), StatusCode::NOT_MODIFIED);
372        assert_eq!(res.into_body().size(), BodySize::None);
373    }
374
375    #[actix_web::test]
376    async fn test_explicit_etag_does_not_match_if_none_match() {
377        let mut app = init_service(App::new().wrap(Etag::default()).route(
378            "/",
379            web::get().to(|| async {
380                HttpResponse::Ok()
381                    .insert_header((ETag::name(), "123"))
382                    .body("Response with a custom ETag")
383            }),
384        ))
385        .await;
386        let match_header = IfNoneMatch::Items(vec![EntityTag::new_weak("124".to_string())]);
387
388        let req = TestRequest::default()
389            .append_header(match_header)
390            .to_request();
391        let res = call_service(&mut app, req).await;
392
393        let etag_header = res.headers().get(ETag::name()).unwrap();
394
395        assert_eq!(res.status(), StatusCode::OK);
396        assert_eq!(etag_header.to_str().unwrap(), "\"123\"");
397    }
398
399    #[actix_web::test]
400    async fn test_force_strong_etag() {
401        let srv = |req: ServiceRequest| {
402            ok(req.into_response(HttpResponse::build(StatusCode::OK).body("abc")))
403        };
404        let etag_service = Etag{
405            force_strong_etag: true,
406        };
407        let srv = etag_service
408            .new_transform(srv.into_service())
409            .await
410            .unwrap();
411
412        let req = TestRequest::default().to_srv_request();
413        let res = srv.call(req).await;
414        if let Ok(response) = res {
415            assert_eq!(response.status(), StatusCode::OK);
416            let headers = response.headers();
417            let etag = HeaderName::from_lowercase(b"etag").unwrap();
418            let etag = headers.get(etag);
419            assert_eq!(
420                etag.unwrap().to_str().unwrap(),
421                r#""3-UDkviZRfr3iFYTpztlqwBg==""#
422            );
423        } else {
424            panic!("No response was generated!");
425        }
426    }
427}