Skip to main content

actix_security_core/http/security/
channel.rs

1//! Channel Security Module
2//!
3//! Provides channel security enforcement including HTTPS redirection
4//! and port mapping. Similar to Spring Security's channel security.
5//!
6//! # Features
7//!
8//! - **HTTPS Enforcement**: Redirect HTTP requests to HTTPS
9//! - **Port Mapping**: Configure custom HTTP/HTTPS port mappings
10//! - **Path-based Rules**: Apply different security to different paths
11//! - **Flexible Configuration**: Customize redirect behavior
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use actix_security::http::security::channel::{ChannelSecurity, ChannelSecurityConfig};
17//! use actix_web::{App, HttpServer};
18//!
19//! let channel_security = ChannelSecurity::new(
20//!     ChannelSecurityConfig::new()
21//!         .require_https(&["/login", "/api/**"])
22//!         .allow_http(&["/health", "/public/**"])
23//! );
24//!
25//! HttpServer::new(move || {
26//!     App::new()
27//!         .wrap(channel_security.clone())
28//!         // ... routes
29//! })
30//! .bind("0.0.0.0:80")?
31//! .bind_rustls("0.0.0.0:443", config)?
32//! .run()
33//! .await
34//! ```
35//!
36//! # Spring Equivalent
37//!
38//! ```java
39//! http.requiresChannel()
40//!     .requestMatchers("/login", "/api/**").requiresSecure()
41//!     .requestMatchers("/public/**").requiresInsecure();
42//! ```
43
44use std::future::{ready, Ready};
45use std::sync::Arc;
46
47use actix_service::{Service, Transform};
48use actix_web::{
49    body::EitherBody,
50    dev::{ServiceRequest, ServiceResponse},
51    http::{header, StatusCode},
52    Error, HttpResponse,
53};
54use futures_util::future::LocalBoxFuture;
55
56use super::ant_matcher::AntMatcher;
57
58/// Channel security requirement
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60pub enum ChannelRequirement {
61    /// Require HTTPS
62    Secure,
63    /// Require HTTP (insecure)
64    Insecure,
65    /// Allow any channel
66    #[default]
67    Any,
68}
69
70/// Port mapping for HTTP/HTTPS redirects
71#[derive(Debug, Clone)]
72pub struct PortMapper {
73    http_port: u16,
74    https_port: u16,
75}
76
77impl Default for PortMapper {
78    fn default() -> Self {
79        Self {
80            http_port: 80,
81            https_port: 443,
82        }
83    }
84}
85
86impl PortMapper {
87    /// Create a new port mapper with custom ports
88    pub fn new(http_port: u16, https_port: u16) -> Self {
89        Self {
90            http_port,
91            https_port,
92        }
93    }
94
95    /// Get the HTTPS port for redirect
96    pub fn get_https_port(&self) -> u16 {
97        self.https_port
98    }
99
100    /// Get the HTTP port for redirect
101    pub fn get_http_port(&self) -> u16 {
102        self.http_port
103    }
104}
105
106/// Rule for channel security
107#[derive(Debug, Clone)]
108struct ChannelRule {
109    matcher: AntMatcher,
110    requirement: ChannelRequirement,
111}
112
113/// Configuration for channel security
114#[derive(Debug, Clone)]
115pub struct ChannelSecurityConfig {
116    rules: Vec<ChannelRule>,
117    port_mapper: PortMapper,
118    default_requirement: ChannelRequirement,
119    redirect_status: StatusCode,
120    preserve_host: bool,
121}
122
123impl Default for ChannelSecurityConfig {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl ChannelSecurityConfig {
130    /// Create a new configuration
131    pub fn new() -> Self {
132        Self {
133            rules: Vec::new(),
134            port_mapper: PortMapper::default(),
135            default_requirement: ChannelRequirement::Any,
136            redirect_status: StatusCode::MOVED_PERMANENTLY,
137            preserve_host: true,
138        }
139    }
140
141    /// Require HTTPS for all paths (strict mode)
142    pub fn require_https_everywhere() -> Self {
143        Self::new().default_requirement(ChannelRequirement::Secure)
144    }
145
146    /// Add paths that require HTTPS
147    pub fn require_https(mut self, patterns: &[&str]) -> Self {
148        for pattern in patterns {
149            self.rules.push(ChannelRule {
150                matcher: AntMatcher::new(pattern),
151                requirement: ChannelRequirement::Secure,
152            });
153        }
154        self
155    }
156
157    /// Add paths that allow HTTP
158    pub fn allow_http(mut self, patterns: &[&str]) -> Self {
159        for pattern in patterns {
160            self.rules.push(ChannelRule {
161                matcher: AntMatcher::new(pattern),
162                requirement: ChannelRequirement::Insecure,
163            });
164        }
165        self
166    }
167
168    /// Add paths that allow any channel
169    pub fn allow_any(mut self, patterns: &[&str]) -> Self {
170        for pattern in patterns {
171            self.rules.push(ChannelRule {
172                matcher: AntMatcher::new(pattern),
173                requirement: ChannelRequirement::Any,
174            });
175        }
176        self
177    }
178
179    /// Set the default requirement for paths not matching any rule
180    pub fn default_requirement(mut self, requirement: ChannelRequirement) -> Self {
181        self.default_requirement = requirement;
182        self
183    }
184
185    /// Set custom port mapping
186    pub fn port_mapper(mut self, http_port: u16, https_port: u16) -> Self {
187        self.port_mapper = PortMapper::new(http_port, https_port);
188        self
189    }
190
191    /// Set the HTTP redirect status code (default: 301 Moved Permanently)
192    ///
193    /// Common values:
194    /// - 301: Moved Permanently (cached by browsers)
195    /// - 302: Found (temporary redirect)
196    /// - 307: Temporary Redirect (preserves method)
197    /// - 308: Permanent Redirect (preserves method, cached)
198    pub fn redirect_status(mut self, status: StatusCode) -> Self {
199        self.redirect_status = status;
200        self
201    }
202
203    /// Use temporary redirect (302 Found)
204    pub fn temporary_redirect(self) -> Self {
205        self.redirect_status(StatusCode::FOUND)
206    }
207
208    /// Use permanent redirect that preserves HTTP method (308)
209    pub fn permanent_redirect_preserve_method(self) -> Self {
210        self.redirect_status(StatusCode::PERMANENT_REDIRECT)
211    }
212
213    /// Set whether to preserve the Host header in redirects
214    pub fn preserve_host(mut self, preserve: bool) -> Self {
215        self.preserve_host = preserve;
216        self
217    }
218
219    /// Get the requirement for a given path
220    fn get_requirement(&self, path: &str) -> ChannelRequirement {
221        for rule in &self.rules {
222            if rule.matcher.matches(path) {
223                return rule.requirement;
224            }
225        }
226        self.default_requirement
227    }
228}
229
230/// Channel security middleware
231#[derive(Clone)]
232pub struct ChannelSecurity {
233    config: Arc<ChannelSecurityConfig>,
234}
235
236impl ChannelSecurity {
237    /// Create a new channel security middleware
238    pub fn new(config: ChannelSecurityConfig) -> Self {
239        Self {
240            config: Arc::new(config),
241        }
242    }
243
244    /// Create with HTTPS required everywhere
245    pub fn https_everywhere() -> Self {
246        Self::new(ChannelSecurityConfig::require_https_everywhere())
247    }
248
249    /// Get the configuration
250    pub fn config(&self) -> &ChannelSecurityConfig {
251        &self.config
252    }
253}
254
255impl<S, B> Transform<S, ServiceRequest> for ChannelSecurity
256where
257    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
258    B: 'static,
259{
260    type Response = ServiceResponse<EitherBody<B>>;
261    type Error = Error;
262    type Transform = ChannelSecurityService<S>;
263    type InitError = ();
264    type Future = Ready<Result<Self::Transform, Self::InitError>>;
265
266    fn new_transform(&self, service: S) -> Self::Future {
267        ready(Ok(ChannelSecurityService {
268            service,
269            config: Arc::clone(&self.config),
270        }))
271    }
272}
273
274/// Channel security service
275pub struct ChannelSecurityService<S> {
276    service: S,
277    config: Arc<ChannelSecurityConfig>,
278}
279
280impl<S, B> Service<ServiceRequest> for ChannelSecurityService<S>
281where
282    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
283    B: 'static,
284{
285    type Response = ServiceResponse<EitherBody<B>>;
286    type Error = Error;
287    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
288
289    fn poll_ready(
290        &self,
291        cx: &mut std::task::Context<'_>,
292    ) -> std::task::Poll<Result<(), Self::Error>> {
293        self.service.poll_ready(cx)
294    }
295
296    fn call(&self, req: ServiceRequest) -> Self::Future {
297        let path = req.path().to_string();
298        let requirement = self.config.get_requirement(&path);
299
300        // Determine if request is secure
301        let is_secure = req.connection_info().scheme() == "https";
302
303        // Check if redirect is needed
304        let redirect_url = match requirement {
305            ChannelRequirement::Secure if !is_secure => Some(self.build_redirect_url(&req, true)),
306            ChannelRequirement::Insecure if is_secure => Some(self.build_redirect_url(&req, false)),
307            _ => None,
308        };
309
310        if let Some(url) = redirect_url {
311            let response = HttpResponse::build(self.config.redirect_status)
312                .insert_header((header::LOCATION, url))
313                .finish()
314                .map_into_right_body();
315
316            let (http_req, _) = req.into_parts();
317            return Box::pin(async move { Ok(ServiceResponse::new(http_req, response)) });
318        }
319
320        // No redirect needed, proceed with request
321        let fut = self.service.call(req);
322        Box::pin(async move {
323            let res = fut.await?;
324            Ok(res.map_into_left_body())
325        })
326    }
327}
328
329impl<S> ChannelSecurityService<S> {
330    /// Build the redirect URL
331    fn build_redirect_url(&self, req: &ServiceRequest, to_https: bool) -> String {
332        let conn_info = req.connection_info();
333        let scheme = if to_https { "https" } else { "http" };
334
335        // Get host
336        let host = if self.config.preserve_host {
337            conn_info.host().to_string()
338        } else {
339            // Strip port from host if present
340            conn_info
341                .host()
342                .split(':')
343                .next()
344                .unwrap_or("localhost")
345                .to_string()
346        };
347
348        // Get appropriate port
349        let port = if to_https {
350            self.config.port_mapper.get_https_port()
351        } else {
352            self.config.port_mapper.get_http_port()
353        };
354
355        // Build URL
356        let path_and_query = req
357            .uri()
358            .path_and_query()
359            .map(|pq| pq.as_str())
360            .unwrap_or("/");
361
362        // Strip existing port from host
363        let host_without_port = host.split(':').next().unwrap_or(&host);
364
365        // Only include port if non-standard
366        let port_str = if (to_https && port == 443) || (!to_https && port == 80) {
367            String::new()
368        } else {
369            format!(":{}", port)
370        };
371
372        format!(
373            "{}://{}{}{}",
374            scheme, host_without_port, port_str, path_and_query
375        )
376    }
377}
378
379/// Helper trait for building channel security rules
380pub trait ChannelSecurityExt {
381    /// Require HTTPS for this path pattern
382    fn requires_secure(self) -> Self;
383
384    /// Allow HTTP for this path pattern
385    fn requires_insecure(self) -> Self;
386}
387
388// ============================================================================
389// Tests
390// ============================================================================
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_channel_requirement() {
398        assert_eq!(ChannelRequirement::default(), ChannelRequirement::Any);
399    }
400
401    #[test]
402    fn test_port_mapper_default() {
403        let mapper = PortMapper::default();
404        assert_eq!(mapper.get_http_port(), 80);
405        assert_eq!(mapper.get_https_port(), 443);
406    }
407
408    #[test]
409    fn test_port_mapper_custom() {
410        let mapper = PortMapper::new(8080, 8443);
411        assert_eq!(mapper.get_http_port(), 8080);
412        assert_eq!(mapper.get_https_port(), 8443);
413    }
414
415    #[test]
416    fn test_config_default() {
417        let config = ChannelSecurityConfig::new();
418        assert_eq!(config.default_requirement, ChannelRequirement::Any);
419        assert_eq!(config.redirect_status, StatusCode::MOVED_PERMANENTLY);
420    }
421
422    #[test]
423    fn test_config_require_https() {
424        let config = ChannelSecurityConfig::new().require_https(&["/login", "/api/**"]);
425
426        assert_eq!(config.get_requirement("/login"), ChannelRequirement::Secure);
427        assert_eq!(
428            config.get_requirement("/api/users"),
429            ChannelRequirement::Secure
430        );
431        assert_eq!(config.get_requirement("/public"), ChannelRequirement::Any);
432    }
433
434    #[test]
435    fn test_config_allow_http() {
436        let config = ChannelSecurityConfig::new()
437            .default_requirement(ChannelRequirement::Secure)
438            .allow_http(&["/health", "/public/**"]);
439
440        assert_eq!(
441            config.get_requirement("/health"),
442            ChannelRequirement::Insecure
443        );
444        assert_eq!(
445            config.get_requirement("/public/images/logo.png"),
446            ChannelRequirement::Insecure
447        );
448        assert_eq!(
449            config.get_requirement("/api/users"),
450            ChannelRequirement::Secure
451        );
452    }
453
454    #[test]
455    fn test_config_https_everywhere() {
456        let config = ChannelSecurityConfig::require_https_everywhere();
457        assert_eq!(config.default_requirement, ChannelRequirement::Secure);
458    }
459
460    #[test]
461    fn test_config_redirect_status() {
462        let config = ChannelSecurityConfig::new().temporary_redirect();
463        assert_eq!(config.redirect_status, StatusCode::FOUND);
464
465        let config = ChannelSecurityConfig::new().permanent_redirect_preserve_method();
466        assert_eq!(config.redirect_status, StatusCode::PERMANENT_REDIRECT);
467    }
468
469    #[test]
470    fn test_config_port_mapper() {
471        let config = ChannelSecurityConfig::new().port_mapper(8080, 8443);
472
473        assert_eq!(config.port_mapper.get_http_port(), 8080);
474        assert_eq!(config.port_mapper.get_https_port(), 8443);
475    }
476
477    #[test]
478    fn test_channel_security_creation() {
479        let cs = ChannelSecurity::https_everywhere();
480        assert_eq!(cs.config().default_requirement, ChannelRequirement::Secure);
481    }
482
483    #[test]
484    fn test_mixed_rules() {
485        let config = ChannelSecurityConfig::new()
486            .require_https(&["/admin/**", "/login"])
487            .allow_any(&["/public/**"])
488            .allow_http(&["/health"]);
489
490        assert_eq!(
491            config.get_requirement("/admin/dashboard"),
492            ChannelRequirement::Secure
493        );
494        assert_eq!(config.get_requirement("/login"), ChannelRequirement::Secure);
495        assert_eq!(
496            config.get_requirement("/public/css/style.css"),
497            ChannelRequirement::Any
498        );
499        assert_eq!(
500            config.get_requirement("/health"),
501            ChannelRequirement::Insecure
502        );
503    }
504}