1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use actix_web::{
    body::MessageBody,
    dev::{ServiceRequest, ServiceResponse},
    web::Redirect,
    Error, Responder,
};

use crate::middleware_from_fn::Next;

/// A function middleware to redirect traffic away from `www.` if it's present.
///
/// # Examples
///
/// ```
/// # use actix_web::App;
/// use actix_web_lab::middleware::{from_fn, redirect_to_non_www};
///
/// App::new().wrap(from_fn(redirect_to_non_www))
///     # ;
/// ```
pub async fn redirect_to_non_www(
    req: ServiceRequest,
    next: Next<impl MessageBody + 'static>,
) -> Result<ServiceResponse<impl MessageBody>, Error> {
    #![allow(clippy::await_holding_refcell_ref)] // RefCell is dropped before await

    let (req, pl) = req.into_parts();
    let conn_info = req.connection_info();

    if let Some(host_no_www) = conn_info.host().strip_prefix("www.") {
        let scheme = conn_info.scheme();
        let path = req.uri().path();
        let uri = format!("{scheme}://{host_no_www}{path}");

        let res = Redirect::to(uri).respond_to(&req);

        drop(conn_info);
        return Ok(ServiceResponse::new(req, res).map_into_right_body());
    }

    drop(conn_info);
    let req = ServiceRequest::from_parts(req, pl);
    Ok(next.call(req).await?.map_into_left_body())
}

#[cfg(test)]
mod tests {
    use actix_web::{
        dev::ServiceFactory,
        http::{header, StatusCode},
        test, web, App, HttpResponse,
    };

    use super::*;
    use crate::middleware::from_fn;

    fn test_app() -> App<
        impl ServiceFactory<
            ServiceRequest,
            Response = ServiceResponse<impl MessageBody>,
            Config = (),
            InitError = (),
            Error = Error,
        >,
    > {
        App::new().wrap(from_fn(redirect_to_non_www)).route(
            "/",
            web::get().to(|| async { HttpResponse::Ok().body("content") }),
        )
    }

    #[actix_web::test]
    async fn redirect_non_www() {
        let app = test::init_service(test_app()).await;

        let req = test::TestRequest::get()
            .uri("http://www.localhost/")
            .to_request();
        let res = test::call_service(&app, req).await;
        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);

        let loc = res.headers().get(header::LOCATION);
        assert!(loc.is_some());
        assert!(!loc.unwrap().as_bytes().starts_with(b"http://www."));

        let body = test::read_body(res).await;
        assert!(body.is_empty());
    }

    #[actix_web::test]
    async fn do_not_redirect_already_non_www() {
        let app = test::init_service(test_app()).await;

        let req = test::TestRequest::default()
            .uri("http://localhost/")
            .to_request();
        let res = test::call_service(&app, req).await;
        assert_eq!(res.status(), StatusCode::OK);

        let loc = res.headers().get(header::LOCATION);
        assert!(loc.is_none());

        let body = test::read_body(res).await;
        assert_eq!(body, "content");
    }
}