1use crate::{all_or_some::AllOrSome, http::HttpResponse, method::Method};
2use std::ops::Deref;
3
4#[derive(Eq, PartialEq, Debug)]
5pub struct Cors {
6 allow_origin: Option<AllOrSome<String>>,
7 allow_methods: Vec<Method>,
8 allow_headers: Vec<String>,
9 allow_credentials: bool,
10 expose_headers: Vec<String>,
11 max_age: Option<usize>,
12 vary_origin: bool,
13}
14
15impl Cors {
16 pub fn new() -> Self {
18 Self {
19 allow_origin: None,
20 allow_headers: vec![],
21 allow_methods: vec![],
22 allow_credentials: false,
23 expose_headers: vec![],
24 max_age: None,
25 vary_origin: false,
26 }
27 }
28
29 pub fn allow_origin(mut self, origin: &str) -> Self {
31 self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
32 self
33 }
34
35 pub fn any(mut self) -> Self {
37 self.allow_origin = Some(AllOrSome::All);
38 self
39 }
40
41 pub fn credentials(mut self, value: bool) -> Self {
43 self.allow_credentials = value;
44 self
45 }
46
47 pub fn exposed_headers(mut self, headers: Vec<&str>) -> Self {
50 self.expose_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
51 self
52 }
53
54 pub fn allow_headers(mut self, headers: Vec<&str>) -> Self {
57 self.allow_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
58 self
59 }
60
61 pub fn max_age(mut self, value: Option<usize>) -> Self {
64 self.max_age = value;
65 self
66 }
67
68 pub fn allow_methods(mut self, methods: Vec<Method>) -> Self {
71 self.allow_methods = methods.clone();
72 self
73 }
74
75 pub fn merge(&self, response: &mut HttpResponse) {
79 let origin = match self.allow_origin {
80 None => {
81 return;
83 }
84 Some(ref origin) => origin,
85 };
86
87 let origin = match *origin {
88 AllOrSome::All => "*".to_string(),
89 AllOrSome::Some(ref origin) => origin.to_string(),
90 };
91
92 response.add_raw_header("Access-Control-Allow-Origin", origin);
93
94 if self.allow_credentials {
95 response.add_raw_header("Access-Control-Allow-Credentials", "true".to_string());
96 }
97
98 if !self.expose_headers.is_empty() {
99 let headers: Vec<String> = self
100 .expose_headers
101 .iter()
102 .map(|s| s.deref().to_string())
103 .collect();
104 let headers = headers.join(", ");
105
106 response.add_raw_header("Access-Control-Expose-Headers", headers);
107 }
108
109 if !self.allow_headers.is_empty() {
110 let headers: Vec<String> = self
111 .allow_headers
112 .iter()
113 .map(|s| s.deref().to_string())
114 .collect();
115 let headers = headers.join(", ");
116
117 response.add_raw_header("Access-Control-Allow-Headers", headers);
118 }
119
120 if !self.allow_methods.is_empty() {
121 let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
122 let methods = methods.join(", ");
123
124 response.add_raw_header("Access-Control-Allow-Methods", methods);
125 }
126
127 if self.max_age.is_some() {
128 let max_age = self.max_age.unwrap();
129 response.add_raw_header("Access-Control-Max-Age", max_age.to_string());
130 }
131
132 if self.vary_origin {
133 response.add_raw_header("Vary", "Origin".to_string());
134 }
135 }
136}