actix_security_core/http/security/
channel.rs1use std::future::{ready, Ready};
45use std::sync::Arc;
46
47use actix_service::{Service, Transform};
48use actix_web::{
49 body::EitherBody,
50 dev::{ServiceRequest, ServiceResponse},
51 http::{header, StatusCode},
52 Error, HttpResponse,
53};
54use futures_util::future::LocalBoxFuture;
55
56use super::ant_matcher::AntMatcher;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum ChannelRequirement {
61 Secure,
63 Insecure,
65 #[default]
67 Any,
68}
69
70#[derive(Debug, Clone)]
72pub struct PortMapper {
73 http_port: u16,
74 https_port: u16,
75}
76
77impl Default for PortMapper {
78 fn default() -> Self {
79 Self {
80 http_port: 80,
81 https_port: 443,
82 }
83 }
84}
85
86impl PortMapper {
87 pub fn new(http_port: u16, https_port: u16) -> Self {
89 Self {
90 http_port,
91 https_port,
92 }
93 }
94
95 pub fn get_https_port(&self) -> u16 {
97 self.https_port
98 }
99
100 pub fn get_http_port(&self) -> u16 {
102 self.http_port
103 }
104}
105
106#[derive(Debug, Clone)]
108struct ChannelRule {
109 matcher: AntMatcher,
110 requirement: ChannelRequirement,
111}
112
113#[derive(Debug, Clone)]
115pub struct ChannelSecurityConfig {
116 rules: Vec<ChannelRule>,
117 port_mapper: PortMapper,
118 default_requirement: ChannelRequirement,
119 redirect_status: StatusCode,
120 preserve_host: bool,
121}
122
123impl Default for ChannelSecurityConfig {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl ChannelSecurityConfig {
130 pub fn new() -> Self {
132 Self {
133 rules: Vec::new(),
134 port_mapper: PortMapper::default(),
135 default_requirement: ChannelRequirement::Any,
136 redirect_status: StatusCode::MOVED_PERMANENTLY,
137 preserve_host: true,
138 }
139 }
140
141 pub fn require_https_everywhere() -> Self {
143 Self::new().default_requirement(ChannelRequirement::Secure)
144 }
145
146 pub fn require_https(mut self, patterns: &[&str]) -> Self {
148 for pattern in patterns {
149 self.rules.push(ChannelRule {
150 matcher: AntMatcher::new(pattern),
151 requirement: ChannelRequirement::Secure,
152 });
153 }
154 self
155 }
156
157 pub fn allow_http(mut self, patterns: &[&str]) -> Self {
159 for pattern in patterns {
160 self.rules.push(ChannelRule {
161 matcher: AntMatcher::new(pattern),
162 requirement: ChannelRequirement::Insecure,
163 });
164 }
165 self
166 }
167
168 pub fn allow_any(mut self, patterns: &[&str]) -> Self {
170 for pattern in patterns {
171 self.rules.push(ChannelRule {
172 matcher: AntMatcher::new(pattern),
173 requirement: ChannelRequirement::Any,
174 });
175 }
176 self
177 }
178
179 pub fn default_requirement(mut self, requirement: ChannelRequirement) -> Self {
181 self.default_requirement = requirement;
182 self
183 }
184
185 pub fn port_mapper(mut self, http_port: u16, https_port: u16) -> Self {
187 self.port_mapper = PortMapper::new(http_port, https_port);
188 self
189 }
190
191 pub fn redirect_status(mut self, status: StatusCode) -> Self {
199 self.redirect_status = status;
200 self
201 }
202
203 pub fn temporary_redirect(self) -> Self {
205 self.redirect_status(StatusCode::FOUND)
206 }
207
208 pub fn permanent_redirect_preserve_method(self) -> Self {
210 self.redirect_status(StatusCode::PERMANENT_REDIRECT)
211 }
212
213 pub fn preserve_host(mut self, preserve: bool) -> Self {
215 self.preserve_host = preserve;
216 self
217 }
218
219 fn get_requirement(&self, path: &str) -> ChannelRequirement {
221 for rule in &self.rules {
222 if rule.matcher.matches(path) {
223 return rule.requirement;
224 }
225 }
226 self.default_requirement
227 }
228}
229
230#[derive(Clone)]
232pub struct ChannelSecurity {
233 config: Arc<ChannelSecurityConfig>,
234}
235
236impl ChannelSecurity {
237 pub fn new(config: ChannelSecurityConfig) -> Self {
239 Self {
240 config: Arc::new(config),
241 }
242 }
243
244 pub fn https_everywhere() -> Self {
246 Self::new(ChannelSecurityConfig::require_https_everywhere())
247 }
248
249 pub fn config(&self) -> &ChannelSecurityConfig {
251 &self.config
252 }
253}
254
255impl<S, B> Transform<S, ServiceRequest> for ChannelSecurity
256where
257 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
258 B: 'static,
259{
260 type Response = ServiceResponse<EitherBody<B>>;
261 type Error = Error;
262 type Transform = ChannelSecurityService<S>;
263 type InitError = ();
264 type Future = Ready<Result<Self::Transform, Self::InitError>>;
265
266 fn new_transform(&self, service: S) -> Self::Future {
267 ready(Ok(ChannelSecurityService {
268 service,
269 config: Arc::clone(&self.config),
270 }))
271 }
272}
273
274pub struct ChannelSecurityService<S> {
276 service: S,
277 config: Arc<ChannelSecurityConfig>,
278}
279
280impl<S, B> Service<ServiceRequest> for ChannelSecurityService<S>
281where
282 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
283 B: 'static,
284{
285 type Response = ServiceResponse<EitherBody<B>>;
286 type Error = Error;
287 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
288
289 fn poll_ready(
290 &self,
291 cx: &mut std::task::Context<'_>,
292 ) -> std::task::Poll<Result<(), Self::Error>> {
293 self.service.poll_ready(cx)
294 }
295
296 fn call(&self, req: ServiceRequest) -> Self::Future {
297 let path = req.path().to_string();
298 let requirement = self.config.get_requirement(&path);
299
300 let is_secure = req.connection_info().scheme() == "https";
302
303 let redirect_url = match requirement {
305 ChannelRequirement::Secure if !is_secure => {
306 Some(self.build_redirect_url(&req, true))
307 }
308 ChannelRequirement::Insecure if is_secure => {
309 Some(self.build_redirect_url(&req, false))
310 }
311 _ => None,
312 };
313
314 if let Some(url) = redirect_url {
315 let response = HttpResponse::build(self.config.redirect_status)
316 .insert_header((header::LOCATION, url))
317 .finish()
318 .map_into_right_body();
319
320 let (http_req, _) = req.into_parts();
321 return Box::pin(async move { Ok(ServiceResponse::new(http_req, response)) });
322 }
323
324 let fut = self.service.call(req);
326 Box::pin(async move {
327 let res = fut.await?;
328 Ok(res.map_into_left_body())
329 })
330 }
331}
332
333impl<S> ChannelSecurityService<S> {
334 fn build_redirect_url(&self, req: &ServiceRequest, to_https: bool) -> String {
336 let conn_info = req.connection_info();
337 let scheme = if to_https { "https" } else { "http" };
338
339 let host = if self.config.preserve_host {
341 conn_info.host().to_string()
342 } else {
343 conn_info
345 .host()
346 .split(':')
347 .next()
348 .unwrap_or("localhost")
349 .to_string()
350 };
351
352 let port = if to_https {
354 self.config.port_mapper.get_https_port()
355 } else {
356 self.config.port_mapper.get_http_port()
357 };
358
359 let path_and_query = req.uri().path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
361
362 let host_without_port = host.split(':').next().unwrap_or(&host);
364
365 let port_str = if (to_https && port == 443) || (!to_https && port == 80) {
367 String::new()
368 } else {
369 format!(":{}", port)
370 };
371
372 format!("{}://{}{}{}", scheme, host_without_port, port_str, path_and_query)
373 }
374}
375
376pub trait ChannelSecurityExt {
378 fn requires_secure(self) -> Self;
380
381 fn requires_insecure(self) -> Self;
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_channel_requirement() {
395 assert_eq!(ChannelRequirement::default(), ChannelRequirement::Any);
396 }
397
398 #[test]
399 fn test_port_mapper_default() {
400 let mapper = PortMapper::default();
401 assert_eq!(mapper.get_http_port(), 80);
402 assert_eq!(mapper.get_https_port(), 443);
403 }
404
405 #[test]
406 fn test_port_mapper_custom() {
407 let mapper = PortMapper::new(8080, 8443);
408 assert_eq!(mapper.get_http_port(), 8080);
409 assert_eq!(mapper.get_https_port(), 8443);
410 }
411
412 #[test]
413 fn test_config_default() {
414 let config = ChannelSecurityConfig::new();
415 assert_eq!(config.default_requirement, ChannelRequirement::Any);
416 assert_eq!(config.redirect_status, StatusCode::MOVED_PERMANENTLY);
417 }
418
419 #[test]
420 fn test_config_require_https() {
421 let config = ChannelSecurityConfig::new()
422 .require_https(&["/login", "/api/**"]);
423
424 assert_eq!(
425 config.get_requirement("/login"),
426 ChannelRequirement::Secure
427 );
428 assert_eq!(
429 config.get_requirement("/api/users"),
430 ChannelRequirement::Secure
431 );
432 assert_eq!(
433 config.get_requirement("/public"),
434 ChannelRequirement::Any
435 );
436 }
437
438 #[test]
439 fn test_config_allow_http() {
440 let config = ChannelSecurityConfig::new()
441 .default_requirement(ChannelRequirement::Secure)
442 .allow_http(&["/health", "/public/**"]);
443
444 assert_eq!(
445 config.get_requirement("/health"),
446 ChannelRequirement::Insecure
447 );
448 assert_eq!(
449 config.get_requirement("/public/images/logo.png"),
450 ChannelRequirement::Insecure
451 );
452 assert_eq!(
453 config.get_requirement("/api/users"),
454 ChannelRequirement::Secure
455 );
456 }
457
458 #[test]
459 fn test_config_https_everywhere() {
460 let config = ChannelSecurityConfig::require_https_everywhere();
461 assert_eq!(config.default_requirement, ChannelRequirement::Secure);
462 }
463
464 #[test]
465 fn test_config_redirect_status() {
466 let config = ChannelSecurityConfig::new()
467 .temporary_redirect();
468 assert_eq!(config.redirect_status, StatusCode::FOUND);
469
470 let config = ChannelSecurityConfig::new()
471 .permanent_redirect_preserve_method();
472 assert_eq!(config.redirect_status, StatusCode::PERMANENT_REDIRECT);
473 }
474
475 #[test]
476 fn test_config_port_mapper() {
477 let config = ChannelSecurityConfig::new()
478 .port_mapper(8080, 8443);
479
480 assert_eq!(config.port_mapper.get_http_port(), 8080);
481 assert_eq!(config.port_mapper.get_https_port(), 8443);
482 }
483
484 #[test]
485 fn test_channel_security_creation() {
486 let cs = ChannelSecurity::https_everywhere();
487 assert_eq!(
488 cs.config().default_requirement,
489 ChannelRequirement::Secure
490 );
491 }
492
493 #[test]
494 fn test_mixed_rules() {
495 let config = ChannelSecurityConfig::new()
496 .require_https(&["/admin/**", "/login"])
497 .allow_any(&["/public/**"])
498 .allow_http(&["/health"]);
499
500 assert_eq!(
501 config.get_requirement("/admin/dashboard"),
502 ChannelRequirement::Secure
503 );
504 assert_eq!(
505 config.get_requirement("/login"),
506 ChannelRequirement::Secure
507 );
508 assert_eq!(
509 config.get_requirement("/public/css/style.css"),
510 ChannelRequirement::Any
511 );
512 assert_eq!(
513 config.get_requirement("/health"),
514 ChannelRequirement::Insecure
515 );
516 }
517}