celers_protocol/
security.rs1use std::collections::HashSet;
22
23#[derive(Debug, Clone)]
25pub struct ContentTypeWhitelist {
26 allowed: HashSet<String>,
28 blocked: HashSet<String>,
30}
31
32impl Default for ContentTypeWhitelist {
33 fn default() -> Self {
34 Self::safe()
35 }
36}
37
38impl ContentTypeWhitelist {
39 pub fn new() -> Self {
41 Self {
42 allowed: HashSet::new(),
43 blocked: HashSet::new(),
44 }
45 }
46
47 pub fn safe() -> Self {
49 let mut allowed = HashSet::new();
50 allowed.insert("application/json".to_string());
51 allowed.insert("application/x-msgpack".to_string());
52 allowed.insert("application/octet-stream".to_string());
53
54 let mut blocked = HashSet::new();
55 blocked.insert("application/x-python-pickle".to_string());
57 blocked.insert("application/python-pickle".to_string());
58 blocked.insert("application/x-pickle".to_string());
59
60 Self { allowed, blocked }
61 }
62
63 pub fn permissive() -> Self {
65 let mut blocked = HashSet::new();
66 blocked.insert("application/x-python-pickle".to_string());
68 blocked.insert("application/python-pickle".to_string());
69 blocked.insert("application/x-pickle".to_string());
70
71 Self {
72 allowed: HashSet::new(), blocked,
74 }
75 }
76
77 pub fn strict() -> Self {
79 let mut allowed = HashSet::new();
80 allowed.insert("application/json".to_string());
81
82 Self {
83 allowed,
84 blocked: HashSet::new(),
85 }
86 }
87
88 #[must_use]
90 pub fn allow(mut self, content_type: impl Into<String>) -> Self {
91 let ct = content_type.into();
92 self.allowed.insert(ct.clone());
93 self.blocked.remove(&ct);
94 self
95 }
96
97 #[must_use]
99 pub fn block(mut self, content_type: impl Into<String>) -> Self {
100 let ct = content_type.into();
101 self.blocked.insert(ct.clone());
102 self.allowed.remove(&ct);
103 self
104 }
105
106 pub fn is_allowed(&self, content_type: &str) -> bool {
108 let normalized = normalize_content_type(content_type);
110
111 if self.blocked.contains(&normalized) {
113 return false;
114 }
115
116 if self.allowed.is_empty() {
118 return true;
119 }
120
121 self.allowed.contains(&normalized)
123 }
124
125 #[inline]
127 pub fn allowed_types(&self) -> Vec<&str> {
128 self.allowed.iter().map(|s| s.as_str()).collect()
129 }
130
131 #[inline]
133 pub fn blocked_types(&self) -> Vec<&str> {
134 self.blocked.iter().map(|s| s.as_str()).collect()
135 }
136}
137
138fn normalize_content_type(content_type: &str) -> String {
140 content_type
142 .split(';')
143 .next()
144 .unwrap_or(content_type)
145 .trim()
146 .to_lowercase()
147}
148
149#[derive(Debug, Clone)]
151pub enum SecurityError {
152 ContentTypeBlocked(String),
154 MessageTooLarge { size: usize, limit: usize },
156 InvalidTaskName(String),
158 PotentialInjection(String),
160}
161
162impl std::fmt::Display for SecurityError {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 match self {
165 SecurityError::ContentTypeBlocked(ct) => {
166 write!(f, "Content type '{}' is not allowed", ct)
167 }
168 SecurityError::MessageTooLarge { size, limit } => {
169 write!(
170 f,
171 "Message size {} bytes exceeds limit of {} bytes",
172 size, limit
173 )
174 }
175 SecurityError::InvalidTaskName(name) => {
176 write!(f, "Invalid task name: {}", name)
177 }
178 SecurityError::PotentialInjection(desc) => {
179 write!(f, "Potential injection detected: {}", desc)
180 }
181 }
182 }
183}
184
185impl std::error::Error for SecurityError {}
186
187#[derive(Debug, Clone)]
189pub struct SecurityPolicy {
190 pub content_types: ContentTypeWhitelist,
192 pub max_message_size: usize,
194 pub max_task_name_length: usize,
196 pub task_name_pattern: Option<String>,
198 pub strict_validation: bool,
200}
201
202impl Default for SecurityPolicy {
203 fn default() -> Self {
204 Self::standard()
205 }
206}
207
208impl SecurityPolicy {
209 pub fn standard() -> Self {
211 Self {
212 content_types: ContentTypeWhitelist::safe(),
213 max_message_size: 10 * 1024 * 1024, max_task_name_length: 256,
215 task_name_pattern: None,
216 strict_validation: false,
217 }
218 }
219
220 pub fn strict() -> Self {
222 Self {
223 content_types: ContentTypeWhitelist::strict(),
224 max_message_size: 1024 * 1024, max_task_name_length: 128,
226 task_name_pattern: Some(r"^[a-zA-Z_][a-zA-Z0-9_.]*$".to_string()),
227 strict_validation: true,
228 }
229 }
230
231 pub fn permissive() -> Self {
233 Self {
234 content_types: ContentTypeWhitelist::permissive(),
235 max_message_size: 100 * 1024 * 1024, max_task_name_length: 512,
237 task_name_pattern: None,
238 strict_validation: false,
239 }
240 }
241
242 pub fn is_content_type_allowed(&self, content_type: &str) -> bool {
244 self.content_types.is_allowed(content_type)
245 }
246
247 pub fn validate_content_type(&self, content_type: &str) -> Result<(), SecurityError> {
249 if self.content_types.is_allowed(content_type) {
250 Ok(())
251 } else {
252 Err(SecurityError::ContentTypeBlocked(content_type.to_string()))
253 }
254 }
255
256 pub fn validate_message_size(&self, size: usize) -> Result<(), SecurityError> {
258 if size <= self.max_message_size {
259 Ok(())
260 } else {
261 Err(SecurityError::MessageTooLarge {
262 size,
263 limit: self.max_message_size,
264 })
265 }
266 }
267
268 pub fn validate_task_name(&self, name: &str) -> Result<(), SecurityError> {
270 if name.len() > self.max_task_name_length {
272 return Err(SecurityError::InvalidTaskName(format!(
273 "Task name too long: {} > {}",
274 name.len(),
275 self.max_task_name_length
276 )));
277 }
278
279 if name.is_empty() {
281 return Err(SecurityError::InvalidTaskName(
282 "Task name cannot be empty".to_string(),
283 ));
284 }
285
286 if name.contains('\0') {
288 return Err(SecurityError::PotentialInjection(
289 "Task name contains null bytes".to_string(),
290 ));
291 }
292
293 if self.strict_validation {
295 let is_valid = name.chars().enumerate().all(|(i, c)| {
298 if i == 0 {
299 c.is_ascii_alphabetic() || c == '_'
300 } else {
301 c.is_ascii_alphanumeric() || c == '_' || c == '.'
302 }
303 });
304
305 if !is_valid {
306 return Err(SecurityError::InvalidTaskName(format!(
307 "Task name '{}' contains invalid characters",
308 name
309 )));
310 }
311 }
312
313 Ok(())
314 }
315
316 pub fn validate_message(
318 &self,
319 content_type: &str,
320 body_size: usize,
321 task_name: &str,
322 ) -> Result<(), SecurityError> {
323 self.validate_content_type(content_type)?;
324 self.validate_message_size(body_size)?;
325 self.validate_task_name(task_name)?;
326 Ok(())
327 }
328
329 pub fn with_max_message_size(mut self, size: usize) -> Self {
331 self.max_message_size = size;
332 self
333 }
334
335 pub fn with_max_task_name_length(mut self, length: usize) -> Self {
337 self.max_task_name_length = length;
338 self
339 }
340
341 pub fn with_strict_validation(mut self, strict: bool) -> Self {
343 self.strict_validation = strict;
344 self
345 }
346
347 pub fn with_content_types(mut self, whitelist: ContentTypeWhitelist) -> Self {
349 self.content_types = whitelist;
350 self
351 }
352}
353
354pub fn is_unsafe_content_type(content_type: &str) -> bool {
356 let normalized = normalize_content_type(content_type);
357 matches!(
358 normalized.as_str(),
359 "application/x-python-pickle"
360 | "application/python-pickle"
361 | "application/x-pickle"
362 | "application/x-python-serialize"
363 )
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_content_type_whitelist_safe() {
372 let whitelist = ContentTypeWhitelist::safe();
373 assert!(whitelist.is_allowed("application/json"));
374 assert!(whitelist.is_allowed("application/x-msgpack"));
375 assert!(!whitelist.is_allowed("application/x-python-pickle"));
376 }
377
378 #[test]
379 fn test_content_type_whitelist_strict() {
380 let whitelist = ContentTypeWhitelist::strict();
381 assert!(whitelist.is_allowed("application/json"));
382 assert!(!whitelist.is_allowed("application/x-msgpack"));
383 }
384
385 #[test]
386 fn test_content_type_whitelist_permissive() {
387 let whitelist = ContentTypeWhitelist::permissive();
388 assert!(whitelist.is_allowed("application/json"));
389 assert!(whitelist.is_allowed("application/x-msgpack"));
390 assert!(whitelist.is_allowed("text/plain"));
391 assert!(!whitelist.is_allowed("application/x-python-pickle"));
392 }
393
394 #[test]
395 fn test_content_type_normalization() {
396 let whitelist = ContentTypeWhitelist::safe();
397 assert!(whitelist.is_allowed("application/json; charset=utf-8"));
398 assert!(whitelist.is_allowed("APPLICATION/JSON"));
399 }
400
401 #[test]
402 fn test_content_type_whitelist_allow_block() {
403 let whitelist = ContentTypeWhitelist::new()
404 .allow("text/plain")
405 .block("text/html");
406
407 assert!(whitelist.is_allowed("text/plain"));
408 assert!(!whitelist.is_allowed("text/html"));
409 assert!(!whitelist.is_allowed("application/json"));
410 }
411
412 #[test]
413 fn test_security_policy_standard() {
414 let policy = SecurityPolicy::standard();
415 assert!(policy.is_content_type_allowed("application/json"));
416 assert!(!policy.is_content_type_allowed("application/x-python-pickle"));
417 }
418
419 #[test]
420 fn test_security_policy_strict() {
421 let policy = SecurityPolicy::strict();
422 assert!(policy.is_content_type_allowed("application/json"));
423 assert!(!policy.is_content_type_allowed("application/x-msgpack"));
424 }
425
426 #[test]
427 fn test_validate_message_size() {
428 let policy = SecurityPolicy::standard().with_max_message_size(100);
429 assert!(policy.validate_message_size(50).is_ok());
430 assert!(policy.validate_message_size(100).is_ok());
431 assert!(policy.validate_message_size(101).is_err());
432 }
433
434 #[test]
435 fn test_validate_task_name() {
436 let policy = SecurityPolicy::standard();
437 assert!(policy.validate_task_name("tasks.add").is_ok());
438 assert!(policy.validate_task_name("my_task").is_ok());
439 assert!(policy.validate_task_name("").is_err());
440 }
441
442 #[test]
443 fn test_validate_task_name_strict() {
444 let policy = SecurityPolicy::strict();
445 assert!(policy.validate_task_name("tasks.add").is_ok());
446 assert!(policy.validate_task_name("_private_task").is_ok());
447 assert!(policy.validate_task_name("123_invalid").is_err());
448 assert!(policy.validate_task_name("task-with-dash").is_err());
449 }
450
451 #[test]
452 fn test_validate_task_name_null_bytes() {
453 let policy = SecurityPolicy::standard();
454 assert!(policy.validate_task_name("task\0name").is_err());
455 }
456
457 #[test]
458 fn test_validate_task_name_length() {
459 let policy = SecurityPolicy::standard().with_max_task_name_length(10);
460 assert!(policy.validate_task_name("short").is_ok());
461 assert!(policy.validate_task_name("this_is_too_long").is_err());
462 }
463
464 #[test]
465 fn test_validate_message() {
466 let policy = SecurityPolicy::standard();
467 assert!(policy
468 .validate_message("application/json", 1000, "tasks.add")
469 .is_ok());
470 }
471
472 #[test]
473 fn test_is_unsafe_content_type() {
474 assert!(is_unsafe_content_type("application/x-python-pickle"));
475 assert!(is_unsafe_content_type("application/python-pickle"));
476 assert!(!is_unsafe_content_type("application/json"));
477 }
478
479 #[test]
480 fn test_security_error_display() {
481 let err = SecurityError::ContentTypeBlocked("pickle".to_string());
482 assert!(err.to_string().contains("pickle"));
483
484 let err = SecurityError::MessageTooLarge {
485 size: 100,
486 limit: 50,
487 };
488 assert!(err.to_string().contains("100"));
489 assert!(err.to_string().contains("50"));
490
491 let err = SecurityError::InvalidTaskName("bad".to_string());
492 assert!(err.to_string().contains("bad"));
493
494 let err = SecurityError::PotentialInjection("null".to_string());
495 assert!(err.to_string().contains("null"));
496 }
497
498 #[test]
499 fn test_allowed_blocked_types() {
500 let whitelist = ContentTypeWhitelist::safe();
501 let allowed = whitelist.allowed_types();
502 let blocked = whitelist.blocked_types();
503
504 assert!(allowed.contains(&"application/json"));
505 assert!(blocked.contains(&"application/x-python-pickle"));
506 }
507}