http_tunnel_common/
validation.rs

1//! Input validation for security-critical data
2//!
3//! This module provides validation functions for user-supplied input to prevent
4//! injection attacks, log poisoning, and system crashes from malformed data.
5
6use once_cell::sync::Lazy;
7use regex::Regex;
8use thiserror::Error;
9
10/// Regex for validating tunnel IDs (12 lowercase alphanumeric characters)
11static TUNNEL_ID_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"^[a-z0-9]{12}$").unwrap());
12
13/// Regex for validating request IDs (req_ prefix + UUID format)
14static REQUEST_ID_REGEX: Lazy<Regex> = Lazy::new(|| {
15    Regex::new(r"^req_[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$").unwrap()
16});
17
18/// Regex for validating connection IDs (AWS API Gateway format)
19static CONNECTION_ID_REGEX: Lazy<Regex> =
20    Lazy::new(|| Regex::new(r"^[A-Za-z0-9_=-]{1,128}$").unwrap());
21
22/// Maximum length for HTTP header values
23pub const MAX_HEADER_VALUE_LENGTH: usize = 8192;
24
25/// Maximum length for HTTP paths
26pub const MAX_PATH_LENGTH: usize = 2048;
27
28/// Validation errors
29#[derive(Debug, Error)]
30pub enum ValidationError {
31    #[error("Invalid tunnel ID format: {0}")]
32    InvalidTunnelId(String),
33
34    #[error("Invalid request ID format: {0}")]
35    InvalidRequestId(String),
36
37    #[error("Invalid connection ID format: {0}")]
38    InvalidConnectionId(String),
39
40    #[error("Path too long: {0} bytes (max: {1})")]
41    PathTooLong(usize, usize),
42
43    #[error("Header value too long: {0} bytes (max: {1})")]
44    HeaderValueTooLong(usize, usize),
45
46    #[error("Invalid header value contains control characters")]
47    InvalidHeaderValue,
48}
49
50/// Validate tunnel ID format
51///
52/// Tunnel IDs must be exactly 12 lowercase alphanumeric characters.
53///
54/// # Examples
55///
56/// ```
57/// use http_tunnel_common::validation::validate_tunnel_id;
58///
59/// assert!(validate_tunnel_id("abc123def456").is_ok());
60/// assert!(validate_tunnel_id("INVALID").is_err());
61/// assert!(validate_tunnel_id("abc123").is_err()); // too short
62/// ```
63pub fn validate_tunnel_id(id: &str) -> Result<(), ValidationError> {
64    if !TUNNEL_ID_REGEX.is_match(id) {
65        return Err(ValidationError::InvalidTunnelId(
66            id.chars().take(50).collect::<String>(), // Limit error message
67        ));
68    }
69    Ok(())
70}
71
72/// Validate request ID format
73///
74/// Request IDs must start with "req_" followed by a UUID.
75///
76/// # Examples
77///
78/// ```
79/// use http_tunnel_common::validation::validate_request_id;
80///
81/// assert!(validate_request_id("req_550e8400-e29b-41d4-a716-446655440000").is_ok());
82/// assert!(validate_request_id("invalid").is_err());
83/// ```
84pub fn validate_request_id(id: &str) -> Result<(), ValidationError> {
85    if !REQUEST_ID_REGEX.is_match(id) {
86        return Err(ValidationError::InvalidRequestId(
87            id.chars().take(50).collect::<String>(), // Limit error message
88        ));
89    }
90    Ok(())
91}
92
93/// Validate connection ID format
94///
95/// Connection IDs are AWS API Gateway WebSocket connection IDs.
96pub fn validate_connection_id(id: &str) -> Result<(), ValidationError> {
97    if !CONNECTION_ID_REGEX.is_match(id) {
98        return Err(ValidationError::InvalidConnectionId(
99            id.chars().take(50).collect::<String>(), // Limit error message
100        ));
101    }
102    Ok(())
103}
104
105/// Validate and sanitize HTTP path
106///
107/// - Removes control characters
108/// - Enforces length limits
109/// - Ensures path starts with /
110pub fn validate_path(path: &str) -> Result<String, ValidationError> {
111    // Check length
112    if path.len() > MAX_PATH_LENGTH {
113        return Err(ValidationError::PathTooLong(path.len(), MAX_PATH_LENGTH));
114    }
115
116    // Remove control characters and ensure valid UTF-8
117    let sanitized: String = path
118        .chars()
119        .filter(|c| !c.is_control() || *c == '\t')
120        .collect();
121
122    // Ensure path starts with /
123    if sanitized.is_empty() {
124        Ok("/".to_string())
125    } else if sanitized.starts_with('/') {
126        Ok(sanitized)
127    } else {
128        Ok(format!("/{}", sanitized))
129    }
130}
131
132/// Sanitize HTTP header value
133///
134/// - Removes dangerous control characters (except tab)
135/// - Enforces length limits
136/// - Returns sanitized value
137pub fn sanitize_header_value(value: &str) -> Result<String, ValidationError> {
138    // Check length
139    if value.len() > MAX_HEADER_VALUE_LENGTH {
140        return Err(ValidationError::HeaderValueTooLong(
141            value.len(),
142            MAX_HEADER_VALUE_LENGTH,
143        ));
144    }
145
146    // Remove control characters except tab (which is allowed in HTTP headers)
147    let sanitized: String = value
148        .chars()
149        .filter(|c| !c.is_control() || *c == '\t')
150        .collect();
151
152    Ok(sanitized)
153}
154
155/// Sanitize header name
156///
157/// Header names must be ASCII and contain no control characters.
158pub fn sanitize_header_name(name: &str) -> Result<String, ValidationError> {
159    // Header names should be ASCII
160    if !name.is_ascii() {
161        return Err(ValidationError::InvalidHeaderValue);
162    }
163
164    // Remove control characters
165    let sanitized: String = name.chars().filter(|c| !c.is_control()).collect();
166
167    if sanitized.is_empty() {
168        return Err(ValidationError::InvalidHeaderValue);
169    }
170
171    Ok(sanitized.to_lowercase())
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_validate_tunnel_id_valid() {
180        assert!(validate_tunnel_id("abc123def456").is_ok());
181        assert!(validate_tunnel_id("000000000000").is_ok());
182        assert!(validate_tunnel_id("zzz999yyy888").is_ok());
183    }
184
185    #[test]
186    fn test_validate_tunnel_id_invalid() {
187        assert!(validate_tunnel_id("ABC123").is_err()); // uppercase
188        assert!(validate_tunnel_id("abc123").is_err()); // too short
189        assert!(validate_tunnel_id("abc123def456extra").is_err()); // too long
190        assert!(validate_tunnel_id("abc-123-def").is_err()); // special chars
191        assert!(validate_tunnel_id("").is_err()); // empty
192    }
193
194    #[test]
195    fn test_validate_request_id_valid() {
196        assert!(validate_request_id("req_550e8400-e29b-41d4-a716-446655440000").is_ok());
197        assert!(validate_request_id("req_00000000-0000-0000-0000-000000000000").is_ok());
198    }
199
200    #[test]
201    fn test_validate_request_id_invalid() {
202        assert!(validate_request_id("invalid").is_err());
203        assert!(validate_request_id("req_12345").is_err());
204        assert!(validate_request_id("550e8400-e29b-41d4-a716-446655440000").is_err()); // no prefix
205        assert!(validate_request_id("").is_err());
206    }
207
208    #[test]
209    fn test_validate_connection_id() {
210        assert!(validate_connection_id("abc123XYZ").is_ok());
211        assert!(validate_connection_id("test-conn_id=123").is_ok());
212        assert!(validate_connection_id("").is_err());
213        assert!(validate_connection_id("a".repeat(129).as_str()).is_err()); // too long
214    }
215
216    #[test]
217    fn test_validate_path() {
218        assert_eq!(validate_path("/foo/bar").unwrap(), "/foo/bar");
219        assert_eq!(validate_path("foo/bar").unwrap(), "/foo/bar");
220        assert_eq!(validate_path("").unwrap(), "/");
221
222        // Control characters removed
223        let path_with_controls = "/foo\x00/bar\n/baz";
224        let sanitized = validate_path(path_with_controls).unwrap();
225        assert!(!sanitized.contains('\x00'));
226        assert!(!sanitized.contains('\n'));
227
228        // Too long
229        let long_path = "/".to_string() + &"a".repeat(3000);
230        assert!(validate_path(&long_path).is_err());
231    }
232
233    #[test]
234    fn test_sanitize_header_value() {
235        assert_eq!(
236            sanitize_header_value("normal value").unwrap(),
237            "normal value"
238        );
239        assert_eq!(
240            sanitize_header_value("value\twith\ttabs").unwrap(),
241            "value\twith\ttabs"
242        );
243
244        // Control characters removed
245        let value_with_controls = "value\x00with\nnull\rand\rcr";
246        let sanitized = sanitize_header_value(value_with_controls).unwrap();
247        assert!(!sanitized.contains('\x00'));
248        assert!(!sanitized.contains('\n'));
249        assert!(!sanitized.contains('\r'));
250
251        // Too long
252        let long_value = "a".repeat(10000);
253        assert!(sanitize_header_value(&long_value).is_err());
254    }
255
256    #[test]
257    fn test_sanitize_header_name() {
258        assert_eq!(
259            sanitize_header_name("Content-Type").unwrap(),
260            "content-type"
261        );
262        assert_eq!(
263            sanitize_header_name("X-Custom-Header").unwrap(),
264            "x-custom-header"
265        );
266
267        // Control characters removed
268        assert!(sanitize_header_name("header\nname").is_ok());
269
270        // Non-ASCII rejected
271        assert!(sanitize_header_name("headerâ„¢").is_err());
272    }
273}