1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
//! Default response headers
use http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use http::{HeaderMap, HttpTryFrom};

use error::Result;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Response};

/// `Middleware` for setting default response headers.
///
/// This middleware does not set header if response headers already contains it.
///
/// ```rust
/// # extern crate actix_web;
/// use actix_web::{http, middleware, App, HttpResponse};
///
/// fn main() {
///     let app = App::new()
///         .middleware(
///             middleware::DefaultHeaders::new()
///                 .header("X-Version", "0.2"))
///         .resource("/test", |r| {
///              r.method(http::Method::GET).f(|_| HttpResponse::Ok());
///              r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed());
///         })
///         .finish();
/// }
/// ```
pub struct DefaultHeaders {
    ct: bool,
    headers: HeaderMap,
}

impl Default for DefaultHeaders {
    fn default() -> Self {
        DefaultHeaders {
            ct: false,
            headers: HeaderMap::new(),
        }
    }
}

impl DefaultHeaders {
    /// Construct `DefaultHeaders` middleware.
    pub fn new() -> DefaultHeaders {
        DefaultHeaders::default()
    }

    /// Set a header.
    #[inline]
    #[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))]
    pub fn header<K, V>(mut self, key: K, value: V) -> Self
    where
        HeaderName: HttpTryFrom<K>,
        HeaderValue: HttpTryFrom<V>,
    {
        match HeaderName::try_from(key) {
            Ok(key) => match HeaderValue::try_from(value) {
                Ok(value) => {
                    self.headers.append(key, value);
                }
                Err(_) => panic!("Can not create header value"),
            },
            Err(_) => panic!("Can not create header name"),
        }
        self
    }

    /// Set *CONTENT-TYPE* header if response does not contain this header.
    pub fn content_type(mut self) -> Self {
        self.ct = true;
        self
    }
}

impl<S> Middleware<S> for DefaultHeaders {
    fn response(
        &self, _: &mut HttpRequest<S>, mut resp: HttpResponse,
    ) -> Result<Response> {
        for (key, value) in self.headers.iter() {
            if !resp.headers().contains_key(key) {
                resp.headers_mut().insert(key, value.clone());
            }
        }
        // default content-type
        if self.ct && !resp.headers().contains_key(CONTENT_TYPE) {
            resp.headers_mut().insert(
                CONTENT_TYPE,
                HeaderValue::from_static("application/octet-stream"),
            );
        }
        Ok(Response::Done(resp))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use http::header::CONTENT_TYPE;

    #[test]
    fn test_default_headers() {
        let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001");

        let mut req = HttpRequest::default();

        let resp = HttpResponse::Ok().finish();
        let resp = match mw.response(&mut req, resp) {
            Ok(Response::Done(resp)) => resp,
            _ => panic!(),
        };
        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");

        let resp = HttpResponse::Ok()
            .header(CONTENT_TYPE, "0002")
            .finish();
        let resp = match mw.response(&mut req, resp) {
            Ok(Response::Done(resp)) => resp,
            _ => panic!(),
        };
        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
    }
}