Skip to main content

nimble_http/
cors.rs

1//! cors.rs
2
3use hyper::{header, Response, Body};
4use std::collections::HashSet;
5use std::time::Duration;
6
7#[derive(Clone, Default)]
8pub struct Cors {
9    pub allow_origins: Option<HashSet<String>>,
10    pub allow_methods: HashSet<String>,
11    pub allow_headers: HashSet<String>,
12    pub expose_headers: HashSet<String>,
13    pub allow_credentials: bool,
14    pub max_age: Option<Duration>,
15}
16
17impl Cors {
18    pub fn new() -> Self {
19        Self::default()
20    }
21    
22    pub fn allow_origins<T, I>(mut self, origins: I) -> Self
23    where
24        T: AsRef<str> + Into<String>,
25        I: IntoIterator<Item = T>,
26    {
27        self.allow_origins = Some(origins.into_iter().map(|s| s.into()).collect());
28        self
29    }
30    
31    pub fn allow_methods<T, I>(mut self, methods: I) -> Self
32    where
33        T: AsRef<str> + Into<String>,
34        I: IntoIterator<Item = T>,
35    {
36        self.allow_methods = methods.into_iter().map(|s| s.into()).collect();
37        self
38    }
39    
40    pub fn allow_headers<T, I>(mut self, headers: I) -> Self
41    where
42        T: AsRef<str> + Into<String>,
43        I: IntoIterator<Item = T>,
44    {
45        self.allow_headers = headers.into_iter().map(|s| s.into()).collect();
46        self
47    }
48    
49    pub fn allow_credentials(mut self, allow: bool) -> Self {
50        self.allow_credentials = allow;
51        self
52    }
53    
54    pub fn max_age<T: Into<Duration>>(mut self, duration: T) -> Self {
55        self.max_age = Some(duration.into());
56        self
57    }
58    
59    /// 处理 OPTIONS 预检请求
60    pub fn handle_preflight(&self, origin: Option<&str>) -> Response<Body> {  // 改这里
61        let mut response = Response::new(Body::empty());
62        let headers = response.headers_mut();
63        
64        // 设置 Allow-Origin
65        match &self.allow_origins {
66            Some(origins) => {
67                if let Some(origin) = origin {
68                    if origins.contains(origin) {
69                        headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
70                    }
71                }
72            }
73            None => {
74                if let Some(origin) = origin {
75                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
76                } else {
77                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
78                }
79            }
80        }
81        
82        // 设置 Allow-Methods
83        if !self.allow_methods.is_empty() {
84            let methods = self.allow_methods.iter()
85                .map(|s| s.as_str())
86                .collect::<Vec<_>>()
87                .join(", ");
88            headers.insert(header::ACCESS_CONTROL_ALLOW_METHODS, methods.parse().unwrap());
89        }
90        
91        // 设置 Allow-Headers
92        if !self.allow_headers.is_empty() {
93            let headers_str = self.allow_headers.iter()
94                .map(|s| s.as_str())
95                .collect::<Vec<_>>()
96                .join(", ");
97            headers.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, headers_str.parse().unwrap());
98        }
99        
100        // 设置 Max-Age
101        if let Some(max_age) = self.max_age {
102            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.as_secs().to_string().parse().unwrap());
103        }
104        
105        // 设置 Credentials
106        if self.allow_credentials {
107            headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
108        }
109        
110        response
111    }
112    
113    /// 给正常响应添加 CORS 头
114    pub fn apply_headers(&self, response: &mut Response<Body>, origin: Option<&str>) {  // 改这里
115        let headers = response.headers_mut();
116        
117        // 设置 Allow-Origin
118        match &self.allow_origins {
119            Some(origins) => {
120                if let Some(origin) = origin {
121                    if origins.contains(origin) {
122                        headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
123                    }
124                }
125            }
126            None => {
127                if let Some(origin) = origin {
128                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
129                } else {
130                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
131                }
132            }
133        }
134        
135        // 设置 Expose-Headers
136        if !self.expose_headers.is_empty() {
137            let expose = self.expose_headers.iter()
138                .map(|s| s.as_str())
139                .collect::<Vec<_>>()
140                .join(", ");
141            headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.parse().unwrap());
142        }
143        
144        // 设置 Credentials
145        if self.allow_credentials {
146            headers.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
147        }
148    }
149}