1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use crate::{all_or_some::AllOrSome, http::HttpResponse, method::Method};
use std::ops::Deref;

#[derive(Eq, PartialEq, Debug)]
pub struct Cors {
    allow_origin: Option<AllOrSome<String>>,
    allow_methods: Vec<Method>,
    allow_headers: Vec<String>,
    allow_credentials: bool,
    expose_headers: Vec<String>,
    max_age: Option<usize>,
    vary_origin: bool,
}

impl Cors {
    /// Create an empty `Cors`
    pub fn new() -> Self {
        Self {
            allow_origin: None,
            allow_headers: vec![],
            allow_methods: vec![],
            allow_credentials: false,
            expose_headers: vec![],
            max_age: None,
            vary_origin: false,
        }
    }

    /// Consumes the `Response` and return an altered response with origin and `vary_origin` set
    pub fn allow_origin(mut self, origin: &str) -> Self {
        self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
        self
    }

    /// Consumes the `Response` and return an altered response with origin set to "*"
    pub fn any(mut self) -> Self {
        self.allow_origin = Some(AllOrSome::All);
        self
    }

    /// Consumes the Response and set credentials
    pub fn credentials(mut self, value: bool) -> Self {
        self.allow_credentials = value;
        self
    }

    /// Consumes the CORS, set expose_headers to
    /// passed headers and returns changed CORS
    pub fn exposed_headers(mut self, headers: Vec<&str>) -> Self {
        self.expose_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
        self
    }

    /// Consumes the CORS, set allow_headers to
    /// passed headers and returns changed CORS
    pub fn allow_headers(mut self, headers: Vec<&str>) -> Self {
        self.allow_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
        self
    }

    /// Consumes the CORS, set max_age to
    /// passed value and returns changed CORS
    pub fn max_age(mut self, value: Option<usize>) -> Self {
        self.max_age = value;
        self
    }

    /// Consumes the CORS, set allow_methods to
    /// passed methods and returns changed CORS
    pub fn allow_methods(mut self, methods: Vec<Method>) -> Self {
        self.allow_methods = methods.clone();
        self
    }

    /// Merge CORS headers with an existing `rocket::Response`.
    ///
    /// This will overwrite any existing CORS headers
    pub fn merge(&self, response: &mut HttpResponse) {
        let origin = match self.allow_origin {
            None => {
                // This is not a CORS response
                return;
            }
            Some(ref origin) => origin,
        };

        let origin = match *origin {
            AllOrSome::All => "*".to_string(),
            AllOrSome::Some(ref origin) => origin.to_string(),
        };

        response.add_raw_header("Access-Control-Allow-Origin", origin);

        if self.allow_credentials {
            response.add_raw_header("Access-Control-Allow-Credentials", "true".to_string());
        }

        if !self.expose_headers.is_empty() {
            let headers: Vec<String> = self
                .expose_headers
                .iter()
                .map(|s| s.deref().to_string())
                .collect();
            let headers = headers.join(", ");

            response.add_raw_header("Access-Control-Expose-Headers", headers);
        }

        if !self.allow_headers.is_empty() {
            let headers: Vec<String> = self
                .allow_headers
                .iter()
                .map(|s| s.deref().to_string())
                .collect();
            let headers = headers.join(", ");

            response.add_raw_header("Access-Control-Allow-Headers", headers);
        }

        if !self.allow_methods.is_empty() {
            let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
            let methods = methods.join(", ");

            response.add_raw_header("Access-Control-Allow-Methods", methods);
        }

        if self.max_age.is_some() {
            let max_age = self.max_age.unwrap();
            response.add_raw_header("Access-Control-Max-Age", max_age.to_string());
        }

        if self.vary_origin {
            response.add_raw_header("Vary", "Origin".to_string());
        }
    }
}