ic_pluto/
cors.rs

1use crate::{all_or_some::AllOrSome, http::HttpResponse, method::Method};
2use std::ops::Deref;
3
4#[derive(Eq, PartialEq, Debug)]
5pub struct Cors {
6    allow_origin: Option<AllOrSome<String>>,
7    allow_methods: Vec<Method>,
8    allow_headers: Vec<String>,
9    allow_credentials: bool,
10    expose_headers: Vec<String>,
11    max_age: Option<usize>,
12    vary_origin: bool,
13}
14
15impl Cors {
16    /// Create an empty `Cors`
17    pub fn new() -> Self {
18        Self {
19            allow_origin: None,
20            allow_headers: vec![],
21            allow_methods: vec![],
22            allow_credentials: false,
23            expose_headers: vec![],
24            max_age: None,
25            vary_origin: false,
26        }
27    }
28
29    /// Consumes the `Response` and return an altered response with origin and `vary_origin` set
30    pub fn allow_origin(mut self, origin: &str) -> Self {
31        self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
32        self
33    }
34
35    /// Consumes the `Response` and return an altered response with origin set to "*"
36    pub fn any(mut self) -> Self {
37        self.allow_origin = Some(AllOrSome::All);
38        self
39    }
40
41    /// Consumes the Response and set credentials
42    pub fn credentials(mut self, value: bool) -> Self {
43        self.allow_credentials = value;
44        self
45    }
46
47    /// Consumes the CORS, set expose_headers to
48    /// passed headers and returns changed CORS
49    pub fn exposed_headers(mut self, headers: Vec<&str>) -> Self {
50        self.expose_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
51        self
52    }
53
54    /// Consumes the CORS, set allow_headers to
55    /// passed headers and returns changed CORS
56    pub fn allow_headers(mut self, headers: Vec<&str>) -> Self {
57        self.allow_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
58        self
59    }
60
61    /// Consumes the CORS, set max_age to
62    /// passed value and returns changed CORS
63    pub fn max_age(mut self, value: Option<usize>) -> Self {
64        self.max_age = value;
65        self
66    }
67
68    /// Consumes the CORS, set allow_methods to
69    /// passed methods and returns changed CORS
70    pub fn allow_methods(mut self, methods: Vec<Method>) -> Self {
71        self.allow_methods = methods.clone();
72        self
73    }
74
75    /// Merge CORS headers with an existing `rocket::Response`.
76    ///
77    /// This will overwrite any existing CORS headers
78    pub fn merge(&self, response: &mut HttpResponse) {
79        let origin = match self.allow_origin {
80            None => {
81                // This is not a CORS response
82                return;
83            }
84            Some(ref origin) => origin,
85        };
86
87        let origin = match *origin {
88            AllOrSome::All => "*".to_string(),
89            AllOrSome::Some(ref origin) => origin.to_string(),
90        };
91
92        response.add_raw_header("Access-Control-Allow-Origin", origin);
93
94        if self.allow_credentials {
95            response.add_raw_header("Access-Control-Allow-Credentials", "true".to_string());
96        }
97
98        if !self.expose_headers.is_empty() {
99            let headers: Vec<String> = self
100                .expose_headers
101                .iter()
102                .map(|s| s.deref().to_string())
103                .collect();
104            let headers = headers.join(", ");
105
106            response.add_raw_header("Access-Control-Expose-Headers", headers);
107        }
108
109        if !self.allow_headers.is_empty() {
110            let headers: Vec<String> = self
111                .allow_headers
112                .iter()
113                .map(|s| s.deref().to_string())
114                .collect();
115            let headers = headers.join(", ");
116
117            response.add_raw_header("Access-Control-Allow-Headers", headers);
118        }
119
120        if !self.allow_methods.is_empty() {
121            let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
122            let methods = methods.join(", ");
123
124            response.add_raw_header("Access-Control-Allow-Methods", methods);
125        }
126
127        if self.max_age.is_some() {
128            let max_age = self.max_age.unwrap();
129            response.add_raw_header("Access-Control-Max-Age", max_age.to_string());
130        }
131
132        if self.vary_origin {
133            response.add_raw_header("Vary", "Origin".to_string());
134        }
135    }
136}