actix_security_core/http/security/
channel.rs1use std::future::{ready, Ready};
45use std::sync::Arc;
46
47use actix_service::{Service, Transform};
48use actix_web::{
49 body::EitherBody,
50 dev::{ServiceRequest, ServiceResponse},
51 http::{header, StatusCode},
52 Error, HttpResponse,
53};
54use futures_util::future::LocalBoxFuture;
55
56use super::ant_matcher::AntMatcher;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum ChannelRequirement {
61 Secure,
63 Insecure,
65 #[default]
67 Any,
68}
69
70#[derive(Debug, Clone)]
72pub struct PortMapper {
73 http_port: u16,
74 https_port: u16,
75}
76
77impl Default for PortMapper {
78 fn default() -> Self {
79 Self {
80 http_port: 80,
81 https_port: 443,
82 }
83 }
84}
85
86impl PortMapper {
87 pub fn new(http_port: u16, https_port: u16) -> Self {
89 Self {
90 http_port,
91 https_port,
92 }
93 }
94
95 pub fn get_https_port(&self) -> u16 {
97 self.https_port
98 }
99
100 pub fn get_http_port(&self) -> u16 {
102 self.http_port
103 }
104}
105
106#[derive(Debug, Clone)]
108struct ChannelRule {
109 matcher: AntMatcher,
110 requirement: ChannelRequirement,
111}
112
113#[derive(Debug, Clone)]
115pub struct ChannelSecurityConfig {
116 rules: Vec<ChannelRule>,
117 port_mapper: PortMapper,
118 default_requirement: ChannelRequirement,
119 redirect_status: StatusCode,
120 preserve_host: bool,
121}
122
123impl Default for ChannelSecurityConfig {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl ChannelSecurityConfig {
130 pub fn new() -> Self {
132 Self {
133 rules: Vec::new(),
134 port_mapper: PortMapper::default(),
135 default_requirement: ChannelRequirement::Any,
136 redirect_status: StatusCode::MOVED_PERMANENTLY,
137 preserve_host: true,
138 }
139 }
140
141 pub fn require_https_everywhere() -> Self {
143 Self::new().default_requirement(ChannelRequirement::Secure)
144 }
145
146 pub fn require_https(mut self, patterns: &[&str]) -> Self {
148 for pattern in patterns {
149 self.rules.push(ChannelRule {
150 matcher: AntMatcher::new(pattern),
151 requirement: ChannelRequirement::Secure,
152 });
153 }
154 self
155 }
156
157 pub fn allow_http(mut self, patterns: &[&str]) -> Self {
159 for pattern in patterns {
160 self.rules.push(ChannelRule {
161 matcher: AntMatcher::new(pattern),
162 requirement: ChannelRequirement::Insecure,
163 });
164 }
165 self
166 }
167
168 pub fn allow_any(mut self, patterns: &[&str]) -> Self {
170 for pattern in patterns {
171 self.rules.push(ChannelRule {
172 matcher: AntMatcher::new(pattern),
173 requirement: ChannelRequirement::Any,
174 });
175 }
176 self
177 }
178
179 pub fn default_requirement(mut self, requirement: ChannelRequirement) -> Self {
181 self.default_requirement = requirement;
182 self
183 }
184
185 pub fn port_mapper(mut self, http_port: u16, https_port: u16) -> Self {
187 self.port_mapper = PortMapper::new(http_port, https_port);
188 self
189 }
190
191 pub fn redirect_status(mut self, status: StatusCode) -> Self {
199 self.redirect_status = status;
200 self
201 }
202
203 pub fn temporary_redirect(self) -> Self {
205 self.redirect_status(StatusCode::FOUND)
206 }
207
208 pub fn permanent_redirect_preserve_method(self) -> Self {
210 self.redirect_status(StatusCode::PERMANENT_REDIRECT)
211 }
212
213 pub fn preserve_host(mut self, preserve: bool) -> Self {
215 self.preserve_host = preserve;
216 self
217 }
218
219 fn get_requirement(&self, path: &str) -> ChannelRequirement {
221 for rule in &self.rules {
222 if rule.matcher.matches(path) {
223 return rule.requirement;
224 }
225 }
226 self.default_requirement
227 }
228}
229
230#[derive(Clone)]
232pub struct ChannelSecurity {
233 config: Arc<ChannelSecurityConfig>,
234}
235
236impl ChannelSecurity {
237 pub fn new(config: ChannelSecurityConfig) -> Self {
239 Self {
240 config: Arc::new(config),
241 }
242 }
243
244 pub fn https_everywhere() -> Self {
246 Self::new(ChannelSecurityConfig::require_https_everywhere())
247 }
248
249 pub fn config(&self) -> &ChannelSecurityConfig {
251 &self.config
252 }
253}
254
255impl<S, B> Transform<S, ServiceRequest> for ChannelSecurity
256where
257 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
258 B: 'static,
259{
260 type Response = ServiceResponse<EitherBody<B>>;
261 type Error = Error;
262 type Transform = ChannelSecurityService<S>;
263 type InitError = ();
264 type Future = Ready<Result<Self::Transform, Self::InitError>>;
265
266 fn new_transform(&self, service: S) -> Self::Future {
267 ready(Ok(ChannelSecurityService {
268 service,
269 config: Arc::clone(&self.config),
270 }))
271 }
272}
273
274pub struct ChannelSecurityService<S> {
276 service: S,
277 config: Arc<ChannelSecurityConfig>,
278}
279
280impl<S, B> Service<ServiceRequest> for ChannelSecurityService<S>
281where
282 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
283 B: 'static,
284{
285 type Response = ServiceResponse<EitherBody<B>>;
286 type Error = Error;
287 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
288
289 fn poll_ready(
290 &self,
291 cx: &mut std::task::Context<'_>,
292 ) -> std::task::Poll<Result<(), Self::Error>> {
293 self.service.poll_ready(cx)
294 }
295
296 fn call(&self, req: ServiceRequest) -> Self::Future {
297 let path = req.path().to_string();
298 let requirement = self.config.get_requirement(&path);
299
300 let is_secure = req.connection_info().scheme() == "https";
302
303 let redirect_url = match requirement {
305 ChannelRequirement::Secure if !is_secure => Some(self.build_redirect_url(&req, true)),
306 ChannelRequirement::Insecure if is_secure => Some(self.build_redirect_url(&req, false)),
307 _ => None,
308 };
309
310 if let Some(url) = redirect_url {
311 let response = HttpResponse::build(self.config.redirect_status)
312 .insert_header((header::LOCATION, url))
313 .finish()
314 .map_into_right_body();
315
316 let (http_req, _) = req.into_parts();
317 return Box::pin(async move { Ok(ServiceResponse::new(http_req, response)) });
318 }
319
320 let fut = self.service.call(req);
322 Box::pin(async move {
323 let res = fut.await?;
324 Ok(res.map_into_left_body())
325 })
326 }
327}
328
329impl<S> ChannelSecurityService<S> {
330 fn build_redirect_url(&self, req: &ServiceRequest, to_https: bool) -> String {
332 let conn_info = req.connection_info();
333 let scheme = if to_https { "https" } else { "http" };
334
335 let host = if self.config.preserve_host {
337 conn_info.host().to_string()
338 } else {
339 conn_info
341 .host()
342 .split(':')
343 .next()
344 .unwrap_or("localhost")
345 .to_string()
346 };
347
348 let port = if to_https {
350 self.config.port_mapper.get_https_port()
351 } else {
352 self.config.port_mapper.get_http_port()
353 };
354
355 let path_and_query = req
357 .uri()
358 .path_and_query()
359 .map(|pq| pq.as_str())
360 .unwrap_or("/");
361
362 let host_without_port = host.split(':').next().unwrap_or(&host);
364
365 let port_str = if (to_https && port == 443) || (!to_https && port == 80) {
367 String::new()
368 } else {
369 format!(":{}", port)
370 };
371
372 format!(
373 "{}://{}{}{}",
374 scheme, host_without_port, port_str, path_and_query
375 )
376 }
377}
378
379pub trait ChannelSecurityExt {
381 fn requires_secure(self) -> Self;
383
384 fn requires_insecure(self) -> Self;
386}
387
388#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_channel_requirement() {
398 assert_eq!(ChannelRequirement::default(), ChannelRequirement::Any);
399 }
400
401 #[test]
402 fn test_port_mapper_default() {
403 let mapper = PortMapper::default();
404 assert_eq!(mapper.get_http_port(), 80);
405 assert_eq!(mapper.get_https_port(), 443);
406 }
407
408 #[test]
409 fn test_port_mapper_custom() {
410 let mapper = PortMapper::new(8080, 8443);
411 assert_eq!(mapper.get_http_port(), 8080);
412 assert_eq!(mapper.get_https_port(), 8443);
413 }
414
415 #[test]
416 fn test_config_default() {
417 let config = ChannelSecurityConfig::new();
418 assert_eq!(config.default_requirement, ChannelRequirement::Any);
419 assert_eq!(config.redirect_status, StatusCode::MOVED_PERMANENTLY);
420 }
421
422 #[test]
423 fn test_config_require_https() {
424 let config = ChannelSecurityConfig::new().require_https(&["/login", "/api/**"]);
425
426 assert_eq!(config.get_requirement("/login"), ChannelRequirement::Secure);
427 assert_eq!(
428 config.get_requirement("/api/users"),
429 ChannelRequirement::Secure
430 );
431 assert_eq!(config.get_requirement("/public"), ChannelRequirement::Any);
432 }
433
434 #[test]
435 fn test_config_allow_http() {
436 let config = ChannelSecurityConfig::new()
437 .default_requirement(ChannelRequirement::Secure)
438 .allow_http(&["/health", "/public/**"]);
439
440 assert_eq!(
441 config.get_requirement("/health"),
442 ChannelRequirement::Insecure
443 );
444 assert_eq!(
445 config.get_requirement("/public/images/logo.png"),
446 ChannelRequirement::Insecure
447 );
448 assert_eq!(
449 config.get_requirement("/api/users"),
450 ChannelRequirement::Secure
451 );
452 }
453
454 #[test]
455 fn test_config_https_everywhere() {
456 let config = ChannelSecurityConfig::require_https_everywhere();
457 assert_eq!(config.default_requirement, ChannelRequirement::Secure);
458 }
459
460 #[test]
461 fn test_config_redirect_status() {
462 let config = ChannelSecurityConfig::new().temporary_redirect();
463 assert_eq!(config.redirect_status, StatusCode::FOUND);
464
465 let config = ChannelSecurityConfig::new().permanent_redirect_preserve_method();
466 assert_eq!(config.redirect_status, StatusCode::PERMANENT_REDIRECT);
467 }
468
469 #[test]
470 fn test_config_port_mapper() {
471 let config = ChannelSecurityConfig::new().port_mapper(8080, 8443);
472
473 assert_eq!(config.port_mapper.get_http_port(), 8080);
474 assert_eq!(config.port_mapper.get_https_port(), 8443);
475 }
476
477 #[test]
478 fn test_channel_security_creation() {
479 let cs = ChannelSecurity::https_everywhere();
480 assert_eq!(cs.config().default_requirement, ChannelRequirement::Secure);
481 }
482
483 #[test]
484 fn test_mixed_rules() {
485 let config = ChannelSecurityConfig::new()
486 .require_https(&["/admin/**", "/login"])
487 .allow_any(&["/public/**"])
488 .allow_http(&["/health"]);
489
490 assert_eq!(
491 config.get_requirement("/admin/dashboard"),
492 ChannelRequirement::Secure
493 );
494 assert_eq!(config.get_requirement("/login"), ChannelRequirement::Secure);
495 assert_eq!(
496 config.get_requirement("/public/css/style.css"),
497 ChannelRequirement::Any
498 );
499 assert_eq!(
500 config.get_requirement("/health"),
501 ChannelRequirement::Insecure
502 );
503 }
504}