Skip to main content

http_security_headers/policy/
csp.rs

1//! Content-Security-Policy (CSP) header configuration.
2//!
3//! CSP helps prevent cross-site scripting (XSS), clickjacking, and other code injection
4//! attacks by specifying which dynamic resources are allowed to load.
5
6use crate::error::{Error, Result};
7use std::collections::HashMap;
8
9/// Content-Security-Policy configuration.
10///
11/// # Examples
12///
13/// ```
14/// use http_security_headers::ContentSecurityPolicy;
15///
16/// let csp = ContentSecurityPolicy::new()
17///     .default_src(vec!["'self'"])
18///     .script_src(vec!["'self'", "'unsafe-inline'"])
19///     .style_src(vec!["'self'", "https://fonts.googleapis.com"])
20///     .img_src(vec!["'self'", "data:", "https:"]);
21/// ```
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ContentSecurityPolicy {
24    directives: HashMap<String, Vec<String>>,
25}
26
27impl ContentSecurityPolicy {
28    /// Creates a new empty CSP policy.
29    pub fn new() -> Self {
30        Self {
31            directives: HashMap::new(),
32        }
33    }
34
35    /// Sets the `default-src` directive.
36    ///
37    /// This serves as a fallback for other fetch directives.
38    pub fn default_src<I, S>(mut self, sources: I) -> Self
39    where
40        I: IntoIterator<Item = S>,
41        S: Into<String>,
42    {
43        self.set_directive("default-src", sources);
44        self
45    }
46
47    /// Sets the `script-src` directive.
48    ///
49    /// Specifies valid sources for JavaScript.
50    pub fn script_src<I, S>(mut self, sources: I) -> Self
51    where
52        I: IntoIterator<Item = S>,
53        S: Into<String>,
54    {
55        self.set_directive("script-src", sources);
56        self
57    }
58
59    /// Sets the `style-src` directive.
60    ///
61    /// Specifies valid sources for stylesheets.
62    pub fn style_src<I, S>(mut self, sources: I) -> Self
63    where
64        I: IntoIterator<Item = S>,
65        S: Into<String>,
66    {
67        self.set_directive("style-src", sources);
68        self
69    }
70
71    /// Sets the `img-src` directive.
72    ///
73    /// Specifies valid sources for images.
74    pub fn img_src<I, S>(mut self, sources: I) -> Self
75    where
76        I: IntoIterator<Item = S>,
77        S: Into<String>,
78    {
79        self.set_directive("img-src", sources);
80        self
81    }
82
83    /// Sets the `font-src` directive.
84    ///
85    /// Specifies valid sources for fonts.
86    pub fn font_src<I, S>(mut self, sources: I) -> Self
87    where
88        I: IntoIterator<Item = S>,
89        S: Into<String>,
90    {
91        self.set_directive("font-src", sources);
92        self
93    }
94
95    /// Sets the `connect-src` directive.
96    ///
97    /// Restricts URLs that can be loaded using script interfaces (fetch, XHR, WebSocket, etc.).
98    pub fn connect_src<I, S>(mut self, sources: I) -> Self
99    where
100        I: IntoIterator<Item = S>,
101        S: Into<String>,
102    {
103        self.set_directive("connect-src", sources);
104        self
105    }
106
107    /// Sets the `object-src` directive.
108    ///
109    /// Specifies valid sources for `<object>`, `<embed>`, and `<applet>` elements.
110    pub fn object_src<I, S>(mut self, sources: I) -> Self
111    where
112        I: IntoIterator<Item = S>,
113        S: Into<String>,
114    {
115        self.set_directive("object-src", sources);
116        self
117    }
118
119    /// Sets the `frame-src` directive.
120    ///
121    /// Specifies valid sources for nested browsing contexts loaded using `<frame>` and `<iframe>`.
122    pub fn frame_src<I, S>(mut self, sources: I) -> Self
123    where
124        I: IntoIterator<Item = S>,
125        S: Into<String>,
126    {
127        self.set_directive("frame-src", sources);
128        self
129    }
130
131    /// Sets the `base-uri` directive.
132    ///
133    /// Restricts the URLs that can be used in a document's `<base>` element.
134    pub fn base_uri<I, S>(mut self, sources: I) -> Self
135    where
136        I: IntoIterator<Item = S>,
137        S: Into<String>,
138    {
139        self.set_directive("base-uri", sources);
140        self
141    }
142
143    /// Sets the `form-action` directive.
144    ///
145    /// Restricts the URLs which can be used as the target of form submissions.
146    pub fn form_action<I, S>(mut self, sources: I) -> Self
147    where
148        I: IntoIterator<Item = S>,
149        S: Into<String>,
150    {
151        self.set_directive("form-action", sources);
152        self
153    }
154
155    /// Sets the `frame-ancestors` directive.
156    ///
157    /// Specifies valid parents that may embed a page using `<frame>`, `<iframe>`, etc.
158    pub fn frame_ancestors<I, S>(mut self, sources: I) -> Self
159    where
160        I: IntoIterator<Item = S>,
161        S: Into<String>,
162    {
163        self.set_directive("frame-ancestors", sources);
164        self
165    }
166
167    /// Sets the `upgrade-insecure-requests` directive (valueless).
168    ///
169    /// Instructs browsers to upgrade all insecure requests to HTTPS.
170    pub fn upgrade_insecure_requests(mut self) -> Self {
171        self.directives
172            .insert("upgrade-insecure-requests".to_string(), vec![]);
173        self
174    }
175
176    /// Sets the `block-all-mixed-content` directive (valueless).
177    ///
178    /// Prevents loading any mixed content (HTTP resources on HTTPS pages).
179    pub fn block_all_mixed_content(mut self) -> Self {
180        self.directives
181            .insert("block-all-mixed-content".to_string(), vec![]);
182        self
183    }
184
185    /// Sets a custom directive.
186    ///
187    /// This allows setting directives not covered by the convenience methods.
188    pub fn directive<I, S>(mut self, name: &str, sources: I) -> Self
189    where
190        I: IntoIterator<Item = S>,
191        S: Into<String>,
192    {
193        self.set_directive(name, sources);
194        self
195    }
196
197    /// Helper method to set a directive.
198    fn set_directive<I, S>(&mut self, name: &str, sources: I)
199    where
200        I: IntoIterator<Item = S>,
201        S: Into<String>,
202    {
203        let sources: Vec<String> = sources.into_iter().map(|s| s.into()).collect();
204        self.directives.insert(name.to_string(), sources);
205    }
206
207    /// Converts the policy to its header value string.
208    pub fn to_header_value(&self) -> Result<String> {
209        if self.directives.is_empty() {
210            return Err(Error::InvalidCsp("CSP policy is empty".to_string()));
211        }
212
213        let mut parts = Vec::new();
214        let mut keys: Vec<&String> = self.directives.keys().collect();
215        keys.sort();
216
217        for directive in keys {
218            let sources = &self.directives[directive];
219            if sources.is_empty() {
220                // Valueless directives (upgrade-insecure-requests, block-all-mixed-content)
221                parts.push(directive.clone());
222            } else {
223                parts.push(format!("{} {}", directive, sources.join(" ")));
224            }
225        }
226
227        Ok(parts.join("; "))
228    }
229
230    /// Parses a CSP policy from a header value string.
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// use http_security_headers::ContentSecurityPolicy;
236    ///
237    /// let csp = ContentSecurityPolicy::parse("default-src 'self'; script-src 'unsafe-inline'").unwrap();
238    /// ```
239    pub fn parse(value: &str) -> Result<Self> {
240        let mut csp = Self::new();
241
242        for directive_str in value.split(';').map(|s| s.trim()) {
243            if directive_str.is_empty() {
244                continue;
245            }
246
247            let parts: Vec<&str> = directive_str.split_whitespace().collect();
248            if parts.is_empty() {
249                continue;
250            }
251
252            let directive_name = parts[0];
253            let sources: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
254
255            csp.directives.insert(directive_name.to_string(), sources);
256        }
257
258        if csp.directives.is_empty() {
259            return Err(Error::InvalidCsp("No directives found".to_string()));
260        }
261
262        Ok(csp)
263    }
264}
265
266impl Default for ContentSecurityPolicy {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272impl std::fmt::Display for ContentSecurityPolicy {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        write!(f, "{}", self.to_header_value().unwrap_or_default())
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_new() {
284        let csp = ContentSecurityPolicy::new();
285        assert!(csp.directives.is_empty());
286    }
287
288    #[test]
289    fn test_builder() {
290        let csp = ContentSecurityPolicy::new()
291            .default_src(vec!["'self'"])
292            .script_src(vec!["'self'", "'unsafe-inline'"])
293            .style_src(vec!["'self'", "https://fonts.googleapis.com"]);
294
295        assert_eq!(csp.directives.len(), 3);
296        assert_eq!(csp.directives.get("default-src").unwrap(), &vec!["'self'"]);
297        assert_eq!(
298            csp.directives.get("script-src").unwrap(),
299            &vec!["'self'", "'unsafe-inline'"]
300        );
301    }
302
303    #[test]
304    fn test_to_header_value() {
305        let csp = ContentSecurityPolicy::new()
306            .default_src(vec!["'self'"])
307            .script_src(vec!["'self'", "'unsafe-inline'"]);
308
309        let header = csp.to_header_value().unwrap();
310        assert!(header.contains("default-src 'self'"));
311        assert!(header.contains("script-src 'self' 'unsafe-inline'"));
312    }
313
314    #[test]
315    fn test_valueless_directives() {
316        let csp = ContentSecurityPolicy::new()
317            .default_src(vec!["'self'"])
318            .upgrade_insecure_requests();
319
320        let header = csp.to_header_value().unwrap();
321        assert!(header.contains("upgrade-insecure-requests"));
322        assert!(header.contains("default-src 'self'"));
323    }
324
325    #[test]
326    fn test_empty_policy_error() {
327        let csp = ContentSecurityPolicy::new();
328        assert!(csp.to_header_value().is_err());
329    }
330
331    #[test]
332    fn test_parse() {
333        let csp =
334            ContentSecurityPolicy::parse("default-src 'self'; script-src 'unsafe-inline'")
335                .unwrap();
336
337        assert_eq!(csp.directives.len(), 2);
338        assert_eq!(csp.directives.get("default-src").unwrap(), &vec!["'self'"]);
339        assert_eq!(
340            csp.directives.get("script-src").unwrap(),
341            &vec!["'unsafe-inline'"]
342        );
343    }
344
345    #[test]
346    fn test_parse_empty() {
347        assert!(ContentSecurityPolicy::parse("").is_err());
348        assert!(ContentSecurityPolicy::parse("   ").is_err());
349    }
350
351    #[test]
352    fn test_custom_directive() {
353        let csp = ContentSecurityPolicy::new()
354            .directive("worker-src", vec!["'self'", "blob:"]);
355
356        assert_eq!(
357            csp.directives.get("worker-src").unwrap(),
358            &vec!["'self'", "blob:"]
359        );
360    }
361}