celers_protocol/
security.rs

1//! Security utilities for protocol handling
2//!
3//! This module provides security-related utilities including content-type
4//! whitelisting, message validation, and safety checks.
5//!
6//! # Content-Type Whitelist
7//!
8//! By default, only safe serialization formats are allowed. Pickle is
9//! explicitly blocked due to security concerns (arbitrary code execution).
10//!
11//! # Example
12//!
13//! ```
14//! use celers_protocol::security::{ContentTypeWhitelist, SecurityPolicy};
15//!
16//! let policy = SecurityPolicy::strict();
17//! assert!(policy.is_content_type_allowed("application/json"));
18//! assert!(!policy.is_content_type_allowed("application/x-python-pickle"));
19//! ```
20
21use std::collections::HashSet;
22
23/// Content-type whitelist for allowed serialization formats
24#[derive(Debug, Clone)]
25pub struct ContentTypeWhitelist {
26    /// Allowed content types
27    allowed: HashSet<String>,
28    /// Blocked content types (takes precedence)
29    blocked: HashSet<String>,
30}
31
32impl Default for ContentTypeWhitelist {
33    fn default() -> Self {
34        Self::safe()
35    }
36}
37
38impl ContentTypeWhitelist {
39    /// Create a new empty whitelist
40    pub fn new() -> Self {
41        Self {
42            allowed: HashSet::new(),
43            blocked: HashSet::new(),
44        }
45    }
46
47    /// Create a whitelist with safe defaults (JSON, MessagePack)
48    pub fn safe() -> Self {
49        let mut allowed = HashSet::new();
50        allowed.insert("application/json".to_string());
51        allowed.insert("application/x-msgpack".to_string());
52        allowed.insert("application/octet-stream".to_string());
53
54        let mut blocked = HashSet::new();
55        // Block pickle - security risk (arbitrary code execution)
56        blocked.insert("application/x-python-pickle".to_string());
57        blocked.insert("application/python-pickle".to_string());
58        blocked.insert("application/x-pickle".to_string());
59
60        Self { allowed, blocked }
61    }
62
63    /// Create a permissive whitelist (allows all except blocked)
64    pub fn permissive() -> Self {
65        let mut blocked = HashSet::new();
66        // Still block pickle
67        blocked.insert("application/x-python-pickle".to_string());
68        blocked.insert("application/python-pickle".to_string());
69        blocked.insert("application/x-pickle".to_string());
70
71        Self {
72            allowed: HashSet::new(), // Empty means check blocked list only
73            blocked,
74        }
75    }
76
77    /// Create a strict whitelist (JSON only)
78    pub fn strict() -> Self {
79        let mut allowed = HashSet::new();
80        allowed.insert("application/json".to_string());
81
82        Self {
83            allowed,
84            blocked: HashSet::new(),
85        }
86    }
87
88    /// Allow a content type
89    #[must_use]
90    pub fn allow(mut self, content_type: impl Into<String>) -> Self {
91        let ct = content_type.into();
92        self.allowed.insert(ct.clone());
93        self.blocked.remove(&ct);
94        self
95    }
96
97    /// Block a content type
98    #[must_use]
99    pub fn block(mut self, content_type: impl Into<String>) -> Self {
100        let ct = content_type.into();
101        self.blocked.insert(ct.clone());
102        self.allowed.remove(&ct);
103        self
104    }
105
106    /// Check if a content type is allowed
107    pub fn is_allowed(&self, content_type: &str) -> bool {
108        // Normalize content type (lowercase, strip parameters)
109        let normalized = normalize_content_type(content_type);
110
111        // Blocked takes precedence
112        if self.blocked.contains(&normalized) {
113            return false;
114        }
115
116        // If allowed list is empty, allow anything not blocked
117        if self.allowed.is_empty() {
118            return true;
119        }
120
121        // Check allowed list
122        self.allowed.contains(&normalized)
123    }
124
125    /// Get all allowed content types
126    #[inline]
127    pub fn allowed_types(&self) -> Vec<&str> {
128        self.allowed.iter().map(|s| s.as_str()).collect()
129    }
130
131    /// Get all blocked content types
132    #[inline]
133    pub fn blocked_types(&self) -> Vec<&str> {
134        self.blocked.iter().map(|s| s.as_str()).collect()
135    }
136}
137
138/// Normalize a content type string
139fn normalize_content_type(content_type: &str) -> String {
140    // Extract main type (before any parameters like charset)
141    content_type
142        .split(';')
143        .next()
144        .unwrap_or(content_type)
145        .trim()
146        .to_lowercase()
147}
148
149/// Security validation error
150#[derive(Debug, Clone)]
151pub enum SecurityError {
152    /// Content type is not allowed
153    ContentTypeBlocked(String),
154    /// Message size exceeds limit
155    MessageTooLarge { size: usize, limit: usize },
156    /// Task name contains invalid characters
157    InvalidTaskName(String),
158    /// Potential injection detected
159    PotentialInjection(String),
160}
161
162impl std::fmt::Display for SecurityError {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        match self {
165            SecurityError::ContentTypeBlocked(ct) => {
166                write!(f, "Content type '{}' is not allowed", ct)
167            }
168            SecurityError::MessageTooLarge { size, limit } => {
169                write!(
170                    f,
171                    "Message size {} bytes exceeds limit of {} bytes",
172                    size, limit
173                )
174            }
175            SecurityError::InvalidTaskName(name) => {
176                write!(f, "Invalid task name: {}", name)
177            }
178            SecurityError::PotentialInjection(desc) => {
179                write!(f, "Potential injection detected: {}", desc)
180            }
181        }
182    }
183}
184
185impl std::error::Error for SecurityError {}
186
187/// Security policy for message handling
188#[derive(Debug, Clone)]
189pub struct SecurityPolicy {
190    /// Content type whitelist
191    pub content_types: ContentTypeWhitelist,
192    /// Maximum message size in bytes
193    pub max_message_size: usize,
194    /// Maximum task name length
195    pub max_task_name_length: usize,
196    /// Allowed task name pattern (regex-like)
197    pub task_name_pattern: Option<String>,
198    /// Enable strict validation
199    pub strict_validation: bool,
200}
201
202impl Default for SecurityPolicy {
203    fn default() -> Self {
204        Self::standard()
205    }
206}
207
208impl SecurityPolicy {
209    /// Create a standard security policy
210    pub fn standard() -> Self {
211        Self {
212            content_types: ContentTypeWhitelist::safe(),
213            max_message_size: 10 * 1024 * 1024, // 10 MB
214            max_task_name_length: 256,
215            task_name_pattern: None,
216            strict_validation: false,
217        }
218    }
219
220    /// Create a strict security policy
221    pub fn strict() -> Self {
222        Self {
223            content_types: ContentTypeWhitelist::strict(),
224            max_message_size: 1024 * 1024, // 1 MB
225            max_task_name_length: 128,
226            task_name_pattern: Some(r"^[a-zA-Z_][a-zA-Z0-9_.]*$".to_string()),
227            strict_validation: true,
228        }
229    }
230
231    /// Create a permissive security policy
232    pub fn permissive() -> Self {
233        Self {
234            content_types: ContentTypeWhitelist::permissive(),
235            max_message_size: 100 * 1024 * 1024, // 100 MB
236            max_task_name_length: 512,
237            task_name_pattern: None,
238            strict_validation: false,
239        }
240    }
241
242    /// Check if a content type is allowed
243    pub fn is_content_type_allowed(&self, content_type: &str) -> bool {
244        self.content_types.is_allowed(content_type)
245    }
246
247    /// Validate content type
248    pub fn validate_content_type(&self, content_type: &str) -> Result<(), SecurityError> {
249        if self.content_types.is_allowed(content_type) {
250            Ok(())
251        } else {
252            Err(SecurityError::ContentTypeBlocked(content_type.to_string()))
253        }
254    }
255
256    /// Validate message size
257    pub fn validate_message_size(&self, size: usize) -> Result<(), SecurityError> {
258        if size <= self.max_message_size {
259            Ok(())
260        } else {
261            Err(SecurityError::MessageTooLarge {
262                size,
263                limit: self.max_message_size,
264            })
265        }
266    }
267
268    /// Validate task name
269    pub fn validate_task_name(&self, name: &str) -> Result<(), SecurityError> {
270        // Check length
271        if name.len() > self.max_task_name_length {
272            return Err(SecurityError::InvalidTaskName(format!(
273                "Task name too long: {} > {}",
274                name.len(),
275                self.max_task_name_length
276            )));
277        }
278
279        // Check for empty name
280        if name.is_empty() {
281            return Err(SecurityError::InvalidTaskName(
282                "Task name cannot be empty".to_string(),
283            ));
284        }
285
286        // Check for null bytes
287        if name.contains('\0') {
288            return Err(SecurityError::PotentialInjection(
289                "Task name contains null bytes".to_string(),
290            ));
291        }
292
293        // In strict mode, validate pattern
294        if self.strict_validation {
295            // Simple pattern check: must start with letter/underscore,
296            // contain only alphanumeric, underscore, or dot
297            let is_valid = name.chars().enumerate().all(|(i, c)| {
298                if i == 0 {
299                    c.is_ascii_alphabetic() || c == '_'
300                } else {
301                    c.is_ascii_alphanumeric() || c == '_' || c == '.'
302                }
303            });
304
305            if !is_valid {
306                return Err(SecurityError::InvalidTaskName(format!(
307                    "Task name '{}' contains invalid characters",
308                    name
309                )));
310            }
311        }
312
313        Ok(())
314    }
315
316    /// Validate a complete message
317    pub fn validate_message(
318        &self,
319        content_type: &str,
320        body_size: usize,
321        task_name: &str,
322    ) -> Result<(), SecurityError> {
323        self.validate_content_type(content_type)?;
324        self.validate_message_size(body_size)?;
325        self.validate_task_name(task_name)?;
326        Ok(())
327    }
328
329    /// Set maximum message size
330    pub fn with_max_message_size(mut self, size: usize) -> Self {
331        self.max_message_size = size;
332        self
333    }
334
335    /// Set maximum task name length
336    pub fn with_max_task_name_length(mut self, length: usize) -> Self {
337        self.max_task_name_length = length;
338        self
339    }
340
341    /// Enable strict validation
342    pub fn with_strict_validation(mut self, strict: bool) -> Self {
343        self.strict_validation = strict;
344        self
345    }
346
347    /// Set content type whitelist
348    pub fn with_content_types(mut self, whitelist: ContentTypeWhitelist) -> Self {
349        self.content_types = whitelist;
350        self
351    }
352}
353
354/// Check if a content type is known to be unsafe
355pub fn is_unsafe_content_type(content_type: &str) -> bool {
356    let normalized = normalize_content_type(content_type);
357    matches!(
358        normalized.as_str(),
359        "application/x-python-pickle"
360            | "application/python-pickle"
361            | "application/x-pickle"
362            | "application/x-python-serialize"
363    )
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_content_type_whitelist_safe() {
372        let whitelist = ContentTypeWhitelist::safe();
373        assert!(whitelist.is_allowed("application/json"));
374        assert!(whitelist.is_allowed("application/x-msgpack"));
375        assert!(!whitelist.is_allowed("application/x-python-pickle"));
376    }
377
378    #[test]
379    fn test_content_type_whitelist_strict() {
380        let whitelist = ContentTypeWhitelist::strict();
381        assert!(whitelist.is_allowed("application/json"));
382        assert!(!whitelist.is_allowed("application/x-msgpack"));
383    }
384
385    #[test]
386    fn test_content_type_whitelist_permissive() {
387        let whitelist = ContentTypeWhitelist::permissive();
388        assert!(whitelist.is_allowed("application/json"));
389        assert!(whitelist.is_allowed("application/x-msgpack"));
390        assert!(whitelist.is_allowed("text/plain"));
391        assert!(!whitelist.is_allowed("application/x-python-pickle"));
392    }
393
394    #[test]
395    fn test_content_type_normalization() {
396        let whitelist = ContentTypeWhitelist::safe();
397        assert!(whitelist.is_allowed("application/json; charset=utf-8"));
398        assert!(whitelist.is_allowed("APPLICATION/JSON"));
399    }
400
401    #[test]
402    fn test_content_type_whitelist_allow_block() {
403        let whitelist = ContentTypeWhitelist::new()
404            .allow("text/plain")
405            .block("text/html");
406
407        assert!(whitelist.is_allowed("text/plain"));
408        assert!(!whitelist.is_allowed("text/html"));
409        assert!(!whitelist.is_allowed("application/json"));
410    }
411
412    #[test]
413    fn test_security_policy_standard() {
414        let policy = SecurityPolicy::standard();
415        assert!(policy.is_content_type_allowed("application/json"));
416        assert!(!policy.is_content_type_allowed("application/x-python-pickle"));
417    }
418
419    #[test]
420    fn test_security_policy_strict() {
421        let policy = SecurityPolicy::strict();
422        assert!(policy.is_content_type_allowed("application/json"));
423        assert!(!policy.is_content_type_allowed("application/x-msgpack"));
424    }
425
426    #[test]
427    fn test_validate_message_size() {
428        let policy = SecurityPolicy::standard().with_max_message_size(100);
429        assert!(policy.validate_message_size(50).is_ok());
430        assert!(policy.validate_message_size(100).is_ok());
431        assert!(policy.validate_message_size(101).is_err());
432    }
433
434    #[test]
435    fn test_validate_task_name() {
436        let policy = SecurityPolicy::standard();
437        assert!(policy.validate_task_name("tasks.add").is_ok());
438        assert!(policy.validate_task_name("my_task").is_ok());
439        assert!(policy.validate_task_name("").is_err());
440    }
441
442    #[test]
443    fn test_validate_task_name_strict() {
444        let policy = SecurityPolicy::strict();
445        assert!(policy.validate_task_name("tasks.add").is_ok());
446        assert!(policy.validate_task_name("_private_task").is_ok());
447        assert!(policy.validate_task_name("123_invalid").is_err());
448        assert!(policy.validate_task_name("task-with-dash").is_err());
449    }
450
451    #[test]
452    fn test_validate_task_name_null_bytes() {
453        let policy = SecurityPolicy::standard();
454        assert!(policy.validate_task_name("task\0name").is_err());
455    }
456
457    #[test]
458    fn test_validate_task_name_length() {
459        let policy = SecurityPolicy::standard().with_max_task_name_length(10);
460        assert!(policy.validate_task_name("short").is_ok());
461        assert!(policy.validate_task_name("this_is_too_long").is_err());
462    }
463
464    #[test]
465    fn test_validate_message() {
466        let policy = SecurityPolicy::standard();
467        assert!(policy
468            .validate_message("application/json", 1000, "tasks.add")
469            .is_ok());
470    }
471
472    #[test]
473    fn test_is_unsafe_content_type() {
474        assert!(is_unsafe_content_type("application/x-python-pickle"));
475        assert!(is_unsafe_content_type("application/python-pickle"));
476        assert!(!is_unsafe_content_type("application/json"));
477    }
478
479    #[test]
480    fn test_security_error_display() {
481        let err = SecurityError::ContentTypeBlocked("pickle".to_string());
482        assert!(err.to_string().contains("pickle"));
483
484        let err = SecurityError::MessageTooLarge {
485            size: 100,
486            limit: 50,
487        };
488        assert!(err.to_string().contains("100"));
489        assert!(err.to_string().contains("50"));
490
491        let err = SecurityError::InvalidTaskName("bad".to_string());
492        assert!(err.to_string().contains("bad"));
493
494        let err = SecurityError::PotentialInjection("null".to_string());
495        assert!(err.to_string().contains("null"));
496    }
497
498    #[test]
499    fn test_allowed_blocked_types() {
500        let whitelist = ContentTypeWhitelist::safe();
501        let allowed = whitelist.allowed_types();
502        let blocked = whitelist.blocked_types();
503
504        assert!(allowed.contains(&"application/json"));
505        assert!(blocked.contains(&"application/x-python-pickle"));
506    }
507}