1#![deny(unsafe_code)]
3
4use axum::http::HeaderValue;
5use regex::RegexSet;
6use std::collections::HashMap;
7use std::fmt::{Debug, Display, Formatter};
8
9#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy, Ord, PartialOrd)]
10pub enum CspDirectiveType {
11 BaseUri,
12 ChildSrc,
13 ConnectSrc,
14 DefaultSrc,
15 FencedFrameSrc,
17 FontSrc,
18 FormAction,
19 FrameAncestors,
20 FrameSrc,
21 ImgSrc,
22 ManifestSrc,
23 MediaSrc,
24 NavigateTo,
26 ObjectSrc,
27 PrefetchSrc,
28 ReportTo,
30 ReportUri,
32 RequireTrustedTypesFor,
34 Sandbox,
35 ScriptSource,
36 ScriptSourceAttr,
37 ScriptSourceElem,
38 StyleSource,
39 StyleSourceAttr,
40 StyleSourceElem,
41 TrustedTypes,
43 UpgradeInsecureRequests,
44 WorkerSource,
45}
46
47impl AsRef<str> for CspDirectiveType {
48 fn as_ref(&self) -> &str {
49 match self {
50 CspDirectiveType::BaseUri => "base-uri",
51 CspDirectiveType::ChildSrc => "child-src",
52 CspDirectiveType::ConnectSrc => "connect-src",
53 CspDirectiveType::DefaultSrc => "default-src",
54 CspDirectiveType::FencedFrameSrc => "fenced-frame-src",
56 CspDirectiveType::FontSrc => "font-src",
57 CspDirectiveType::FormAction => "form-action",
58 CspDirectiveType::FrameAncestors => "frame-ancestors",
59 CspDirectiveType::FrameSrc => "frame-src",
60 CspDirectiveType::ImgSrc => "img-src",
61 CspDirectiveType::ManifestSrc => "manifest-src",
62 CspDirectiveType::MediaSrc => "media-src",
63 CspDirectiveType::NavigateTo => "navigate-to",
65 CspDirectiveType::ObjectSrc => "object-src",
66 CspDirectiveType::PrefetchSrc => "prefetch-src",
67 CspDirectiveType::ReportTo => "report-to",
69 CspDirectiveType::ReportUri => "report-uri",
71 CspDirectiveType::RequireTrustedTypesFor => "require-trusted-types-for",
73 CspDirectiveType::Sandbox => "sandbox",
74 CspDirectiveType::ScriptSourceAttr => "script-src-attr",
75 CspDirectiveType::ScriptSourceElem => "script-src-elem",
76 CspDirectiveType::ScriptSource => "script-src",
77 CspDirectiveType::StyleSourceAttr => "style-src-attr",
78 CspDirectiveType::StyleSourceElem => "style-src-elem",
79 CspDirectiveType::StyleSource => "style-src",
80 CspDirectiveType::TrustedTypes => "trusted-types",
82 CspDirectiveType::UpgradeInsecureRequests => "upgrade-insecure-requests",
83 CspDirectiveType::WorkerSource => "worker-src",
84 }
85 }
86}
87
88impl Display for CspDirectiveType {
89 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", self.as_ref())
91 }
92}
93
94impl From<CspDirectiveType> for String {
95 fn from(input: CspDirectiveType) -> String {
96 input.as_ref().to_string()
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct CspDirective {
102 pub directive_type: CspDirectiveType,
103 pub values: Vec<CspValue>,
104}
105
106impl CspDirective {
107 #[must_use]
108 pub fn from(directive_type: CspDirectiveType, values: Vec<CspValue>) -> Self {
109 Self {
110 directive_type,
111 values,
112 }
113 }
114
115 pub fn default_self() -> Self {
117 Self {
118 directive_type: CspDirectiveType::DefaultSrc,
119 values: vec![CspValue::SelfSite],
120 }
121 }
122}
123
124impl Display for CspDirective {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 let space = if self.values.is_empty() { "" } else { " " };
127 f.write_fmt(format_args!(
128 "{}{}{}",
129 self.directive_type.as_ref(),
130 space,
131 self.values
132 .iter()
133 .map(|v| String::from(v.to_owned()))
134 .collect::<Vec<String>>()
135 .join(" ")
136 ))
137 }
138}
139
140impl From<CspDirective> for HeaderValue {
141 fn from(input: CspDirective) -> HeaderValue {
142 match HeaderValue::from_str(&input.to_string()) {
143 Ok(val) => val,
144 Err(e) => panic!("Failed to build HeaderValue from CspDirective: {}", e),
145 }
146 }
147}
148
149#[derive(Clone, Debug)]
151pub struct CspUrlMatcher {
152 pub matcher: RegexSet,
153 pub directives: Vec<CspDirective>,
154}
155
156impl CspUrlMatcher {
157 #[must_use]
158 pub fn new(matcher: RegexSet) -> Self {
159 Self {
160 matcher,
161 directives: vec![],
162 }
163 }
164 pub fn with_directive(&mut self, directive: CspDirective) -> &mut Self {
165 self.directives.push(directive);
166 self
167 }
168
169 pub fn is_match(&self, text: &str) -> bool {
171 self.matcher.is_match(text)
172 }
173
174 pub fn default_all_self() -> Self {
176 Self {
177 matcher: RegexSet::new([r#".*"#]).unwrap(),
178 directives: vec![CspDirective {
179 directive_type: CspDirectiveType::DefaultSrc,
180 values: vec![CspValue::SelfSite],
181 }],
182 }
183 }
184
185 pub fn default_self(matcher: RegexSet) -> Self {
187 Self {
188 matcher,
189 directives: vec![CspDirective {
190 directive_type: CspDirectiveType::DefaultSrc,
191 values: vec![CspValue::SelfSite],
192 }],
193 }
194 }
195}
196
197impl From<CspUrlMatcher> for HeaderValue {
199 fn from(input: CspUrlMatcher) -> HeaderValue {
200 let mut res = String::new();
201 for directive in input.directives {
202 res.push_str(directive.directive_type.as_ref());
203 for val in directive.values {
204 res.push_str(&format!(" {}", String::from(val)));
205 }
206 res.push_str("; ");
207 }
208 HeaderValue::from_str(res.trim()).unwrap()
209 }
210}
211
212#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
213pub enum CspValue {
215 None,
216 SelfSite,
218 StrictDynamic,
219 ReportSample,
220
221 UnsafeInline,
222 UnsafeEval,
223 WasmUnsafeEval,
224 UnsafeHashes,
225 UnsafeAllowRedirects,
227 Host {
228 value: String,
229 },
230 SchemeHttps,
231 SchemeHttp,
232 SchemeData,
233 SchemeOther {
234 value: String,
235 },
236 Nonce {
237 value: String,
238 },
239 Sha256 {
240 value: String,
241 },
242 Sha384 {
243 value: String,
244 },
245 Sha512 {
246 value: String,
247 },
248}
249
250impl From<CspValue> for String {
251 fn from(input: CspValue) -> String {
252 match input {
253 CspValue::None => "'none'".to_string(),
254 CspValue::SelfSite => "'self'".to_string(),
255 CspValue::StrictDynamic => "'strict-dynamic'".to_string(),
256 CspValue::ReportSample => "'report-sample'".to_string(),
257 CspValue::UnsafeInline => "'unsafe-inline'".to_string(),
258 CspValue::UnsafeEval => "'unsafe-eval'".to_string(),
259 CspValue::WasmUnsafeEval => "'wasm-unsafe-eval'".to_string(),
260 CspValue::UnsafeHashes => "'unsafe-hashes'".to_string(),
261 CspValue::UnsafeAllowRedirects => "'unsafe-allow-redirects'".to_string(),
262 CspValue::SchemeHttps => "https:".to_string(),
263 CspValue::SchemeHttp => "http:".to_string(),
264 CspValue::SchemeData => "data:".to_string(),
265 CspValue::Host { value } | CspValue::SchemeOther { value } => value.to_string(),
266 CspValue::Nonce { value } => format!("nonce-{value}"),
267 CspValue::Sha256 { value } => format!("sha256-{value}"),
268 CspValue::Sha384 { value } => format!("sha384-{value}"),
269 CspValue::Sha512 { value } => format!("sha512-{value}"),
270 }
271 }
272}
273
274#[derive(Clone, Debug, Default)]
275pub struct CspHeaderBuilder {
277 pub directive_map: HashMap<CspDirectiveType, Vec<CspValue>>,
278}
279
280impl CspHeaderBuilder {
281 pub fn new() -> Self {
282 Self {
283 directive_map: HashMap::new(),
284 }
285 }
286
287 pub fn add(mut self, directive: CspDirectiveType, values: Vec<CspValue>) -> Self {
288 self.directive_map.entry(directive).or_default();
289
290 values.into_iter().for_each(|val| {
291 if !self.directive_map.get(&directive).unwrap().contains(&val) {
292 self.directive_map.get_mut(&directive).unwrap().push(val);
293 }
294 });
295 self
296 }
297
298 pub fn finish(self) -> HeaderValue {
299 let mut keys = self
300 .directive_map
301 .keys()
302 .collect::<Vec<&CspDirectiveType>>();
303 keys.sort();
304
305 let directive_strings: Vec<String> = keys
306 .iter()
307 .map(|directive| {
308 let mut directive_string = String::new();
309 directive_string.push_str(&format!(" {}", directive));
310 let mut values = match self.directive_map.get(directive) {
311 Some(val) => val.to_owned(),
312 None => vec![],
313 };
314 values.sort();
315 values.into_iter().for_each(|val| {
316 directive_string.push_str(&format!(" {}", String::from(val)));
317 });
318 directive_string.trim().to_string()
319 })
320 .collect();
321
322 HeaderValue::from_str(&directive_strings.join("; "))
323 .expect("Failed to build header value from directive strings")
324 }
325}