Skip to main content

auth_framework/server/core/
common_validation.rs

1//! Common Validation Utilities
2//!
3//! This module provides shared validation functions to eliminate
4//! duplication across server modules.
5
6use crate::errors::{AuthError, Result};
7use std::collections::HashMap;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10/// Common JWT validation utilities
11pub mod jwt {
12    use super::*;
13    use jsonwebtoken::decode_header;
14
15    /// Validate JWT structure and format
16    pub fn validate_jwt_format(token: &str) -> Result<()> {
17        if token.is_empty() {
18            return Err(AuthError::validation("JWT token is empty"));
19        }
20
21        let parts: Vec<&str> = token.split('.').collect();
22        if parts.len() != 3 {
23            return Err(AuthError::validation(
24                "Invalid JWT format: must have 3 parts",
25            ));
26        }
27
28        // Validate header can be decoded
29        decode_header(token)
30            .map_err(|e| AuthError::validation(format!("Invalid JWT header: {}", e)))?;
31
32        Ok(())
33    }
34
35    /// Extract claims without signature validation (for inspection ONLY)
36    ///
37    /// # Security Warning
38    /// This function does NOT validate the JWT signature, making it vulnerable to:
39    /// - Token forgery
40    /// - Data tampering
41    /// - Man-in-the-middle attacks
42    ///
43    /// Only use for:
44    /// - Token inspection/debugging
45    /// - Extracting metadata before validation
46    /// - Non-security-critical operations
47    ///
48    /// Never use for authentication or authorization decisions!
49    pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
50        validate_jwt_format(token)?;
51
52        let parts: Vec<&str> = token.split('.').collect();
53        let payload = parts[1];
54
55        use base64::Engine as _;
56        use base64::engine::general_purpose::URL_SAFE_NO_PAD;
57
58        let decoded = URL_SAFE_NO_PAD
59            .decode(payload)
60            .map_err(|e| AuthError::validation(format!("Invalid JWT payload encoding: {}", e)))?;
61
62        let claims: serde_json::Value = serde_json::from_slice(&decoded)
63            .map_err(|e| AuthError::validation(format!("Invalid JWT payload JSON: {}", e)))?;
64
65        Ok(claims)
66    }
67
68    /// Validate JWT timestamp claims (exp, iat, nbf)
69    pub fn validate_time_claims(claims: &serde_json::Value) -> Result<()> {
70        let now = SystemTime::now()
71            .duration_since(UNIX_EPOCH)
72            .unwrap_or_default()
73            .as_secs() as i64;
74
75        // Check expiration
76        if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64())
77            && now >= exp
78        {
79            return Err(AuthError::validation("Token has expired"));
80        }
81
82        // Check not before
83        if let Some(nbf) = claims.get("nbf").and_then(|v| v.as_i64())
84            && now < nbf
85        {
86            return Err(AuthError::validation("Token not yet valid (nbf)"));
87        }
88
89        // Check issued at (reasonable bounds)
90        if let Some(iat) = claims.get("iat").and_then(|v| v.as_i64()) {
91            let max_age = 24 * 60 * 60; // 24 hours
92            if now - iat > max_age {
93                return Err(AuthError::validation("Token too old"));
94            }
95        }
96
97        Ok(())
98    }
99
100    /// Validate required JWT claims
101    pub fn validate_required_claims(claims: &serde_json::Value, required: &[&str]) -> Result<()> {
102        for claim in required {
103            if claims.get(claim).is_none() {
104                return Err(AuthError::validation(format!(
105                    "Missing required claim: {}",
106                    claim
107                )));
108            }
109        }
110        Ok(())
111    }
112}
113
114/// Common token validation utilities
115pub mod token {
116    use super::*;
117
118    /// Token type validation
119    pub fn validate_token_type(token_type: &str, allowed_types: &[&str]) -> Result<()> {
120        if !allowed_types.contains(&token_type) {
121            return Err(AuthError::validation(format!(
122                "Unsupported token type: {}",
123                token_type
124            )));
125        }
126        Ok(())
127    }
128
129    /// Validate token format (basic structure)
130    pub fn validate_token_format(token: &str, token_type: &str) -> Result<()> {
131        if token.is_empty() {
132            return Err(AuthError::validation("Token is empty"));
133        }
134
135        match token_type {
136            "urn:ietf:params:oauth:token-type:jwt" => jwt::validate_jwt_format(token),
137            "urn:ietf:params:oauth:token-type:access_token" => {
138                // Bearer token validation
139                if token.len() < 10 {
140                    return Err(AuthError::validation("Access token too short"));
141                }
142                Ok(())
143            }
144            "urn:ietf:params:oauth:token-type:refresh_token" => {
145                // Refresh token validation
146                if token.len() < 20 {
147                    return Err(AuthError::validation("Refresh token too short"));
148                }
149                Ok(())
150            }
151            _ => {
152                // Reject unrecognized token types rather than silently accepting them.
153                Err(AuthError::validation(format!(
154                    "Unsupported token type: {}",
155                    token_type
156                )))
157            }
158        }
159    }
160
161    /// Validate scope format
162    pub fn validate_scope(scope: &str) -> Result<Vec<String>> {
163        if scope.is_empty() {
164            return Ok(vec![]);
165        }
166
167        let scopes: Vec<String> = scope.split_whitespace().map(|s| s.to_string()).collect();
168
169        // Validate each scope
170        for scope in &scopes {
171            if scope.is_empty() {
172                return Err(AuthError::validation("Empty scope value"));
173            }
174
175            // Basic scope format validation
176            if !scope.chars().all(|c| {
177                c.is_alphanumeric() || c == ':' || c == '/' || c == '.' || c == '-' || c == '_'
178            }) {
179                return Err(AuthError::validation(format!(
180                    "Invalid scope format: {}",
181                    scope
182                )));
183            }
184        }
185
186        Ok(scopes)
187    }
188}
189
190/// Common client validation utilities
191pub mod client {
192    use super::*;
193
194    /// Validate client ID format
195    pub fn validate_client_id(client_id: &str) -> Result<()> {
196        if client_id.is_empty() {
197            return Err(AuthError::validation("Client ID is empty"));
198        }
199
200        if client_id.len() < 3 {
201            return Err(AuthError::validation("Client ID too short"));
202        }
203
204        if client_id.len() > 255 {
205            return Err(AuthError::validation("Client ID too long"));
206        }
207
208        // Validate character set
209        if !client_id
210            .chars()
211            .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
212        {
213            return Err(AuthError::validation(
214                "Client ID contains invalid characters",
215            ));
216        }
217
218        Ok(())
219    }
220
221    /// Validate redirect URI
222    pub fn validate_redirect_uri(uri: &str) -> Result<()> {
223        if uri.is_empty() {
224            return Err(AuthError::validation("Redirect URI is empty"));
225        }
226
227        // Must be absolute URI
228        if !uri.starts_with("http://")
229            && !uri.starts_with("https://")
230            && !uri.starts_with("custom://")
231        {
232            return Err(AuthError::validation("Redirect URI must be absolute"));
233        }
234
235        // No fragments allowed
236        if uri.contains('#') {
237            return Err(AuthError::validation(
238                "Redirect URI cannot contain fragments",
239            ));
240        }
241
242        Ok(())
243    }
244
245    /// Validate grant type
246    pub fn validate_grant_type(grant_type: &str, allowed_grants: &[&str]) -> Result<()> {
247        if !allowed_grants.contains(&grant_type) {
248            return Err(AuthError::validation(format!(
249                "Unsupported grant type: {}",
250                grant_type
251            )));
252        }
253        Ok(())
254    }
255}
256
257/// Common request validation utilities
258pub mod request {
259    use super::*;
260
261    /// Validate required parameters
262    pub fn validate_required_params(
263        params: &HashMap<String, String>,
264        required: &[&str],
265    ) -> Result<()> {
266        for param in required {
267            if !params.contains_key(*param) || params[*param].trim().is_empty() {
268                return Err(AuthError::validation(format!(
269                    "Missing parameter: {}",
270                    param
271                )));
272            }
273        }
274        Ok(())
275    }
276
277    /// Validate parameter format
278    pub fn validate_param_format(value: &str, param_name: &str, pattern: &str) -> Result<()> {
279        // Basic validation without regex for now
280        if value.is_empty() {
281            return Err(AuthError::validation(format!(
282                "Parameter {} cannot be empty",
283                param_name
284            )));
285        }
286
287        // Basic pattern checks
288        match pattern {
289            "alphanum" => {
290                if !value.chars().all(|c| c.is_alphanumeric()) {
291                    return Err(AuthError::validation(format!(
292                        "Parameter {} must be alphanumeric",
293                        param_name
294                    )));
295                }
296            }
297            _ => {
298                // For now, just check it's not empty
299                if value.trim().is_empty() {
300                    return Err(AuthError::validation(format!(
301                        "Parameter {} has invalid format",
302                        param_name
303                    )));
304                }
305            }
306        }
307
308        Ok(())
309    }
310
311    /// Validate code challenge method
312    pub fn validate_code_challenge_method(method: &str) -> Result<()> {
313        match method {
314            "plain" | "S256" => Ok(()),
315            _ => Err(AuthError::validation("Invalid code challenge method")),
316        }
317    }
318
319    /// Validate response type
320    pub fn validate_response_type(response_type: &str, allowed_types: &[&str]) -> Result<()> {
321        let types: Vec<&str> = response_type.split_whitespace().collect();
322
323        for response_type in &types {
324            if !allowed_types.contains(response_type) {
325                return Err(AuthError::validation(format!(
326                    "Unsupported response type: {}",
327                    response_type
328                )));
329            }
330        }
331
332        Ok(())
333    }
334}
335
336/// Common URL validation utilities
337pub mod url {
338    use super::*;
339
340    /// Validate URL format and accessibility
341    pub fn validate_url_format(url: &str) -> Result<()> {
342        if url.is_empty() {
343            return Err(AuthError::validation("URL is empty"));
344        }
345
346        if !url.starts_with("http://") && !url.starts_with("https://") {
347            return Err(AuthError::validation("URL must use HTTP or HTTPS scheme"));
348        }
349
350        // Basic URL parsing validation - simplified without url crate for now
351        if !url.contains("://") {
352            return Err(AuthError::validation("Invalid URL format"));
353        }
354
355        Ok(())
356    }
357
358    /// Validate HTTPS requirement
359    pub fn validate_https_required(url: &str) -> Result<()> {
360        validate_url_format(url)?;
361
362        if !url.starts_with("https://") {
363            return Err(AuthError::validation("HTTPS is required"));
364        }
365
366        Ok(())
367    }
368}
369
370/// Collects and aggregates validation errors from multiple validation operations.
371///
372/// This function takes a vector of validation results and combines any errors
373/// into a single error message. If all validations pass, returns Ok(()).
374/// If any validations fail, returns an error containing all error messages.
375///
376/// # Arguments
377///
378/// * `validations` - A vector of validation results to aggregate
379///
380/// # Returns
381///
382/// * `Ok(())` if all validations passed
383/// * `Err(AuthError)` containing aggregated error messages if any validations failed
384///
385/// # Example
386///
387/// ```rust,no_run
388/// use auth_framework::server::core::common_validation::collect_validation_errors;
389/// use auth_framework::errors::Result;
390///
391/// # fn validate_client_id(_: &str) -> Result<()> { Ok(()) }
392/// # fn validate_scope(_: &str) -> Result<()> { Ok(()) }
393/// # fn validate_redirect_uri(_: &str) -> Result<()> { Ok(()) }
394/// let validations = vec![
395///     validate_client_id("valid_client"),
396///     validate_scope("read write"),
397///     validate_redirect_uri("https://example.com/callback"),
398/// ];
399///
400/// let result = collect_validation_errors(validations);
401/// match result {
402///     Ok(()) => println!("All validations passed"),
403///     Err(e) => println!("Validation errors: {}", e),
404/// }
405/// ```
406pub fn collect_validation_errors(validations: Vec<Result<()>>) -> Result<()> {
407    let errors: Vec<String> = validations
408        .into_iter()
409        .filter_map(|result| result.err())
410        .map(|e| format!("{}", e))
411        .collect();
412
413    if errors.is_empty() {
414        Ok(())
415    } else {
416        Err(AuthError::validation(errors.join("; ")))
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use serde_json::json;
424
425    // ── JWT validation ──────────────────────────────────────────────────
426
427    #[test]
428    fn test_validate_jwt_format_empty() {
429        assert!(jwt::validate_jwt_format("").is_err());
430    }
431
432    #[test]
433    fn test_validate_jwt_format_wrong_parts() {
434        assert!(jwt::validate_jwt_format("one.two").is_err());
435        assert!(jwt::validate_jwt_format("a.b.c.d").is_err());
436    }
437
438    #[test]
439    fn test_validate_required_claims_missing() {
440        let claims = json!({"sub": "user1"});
441        assert!(jwt::validate_required_claims(&claims, &["sub"]).is_ok());
442        assert!(jwt::validate_required_claims(&claims, &["aud"]).is_err());
443    }
444
445    #[test]
446    fn test_validate_time_claims_expired() {
447        let claims = json!({"exp": 1000000});
448        assert!(jwt::validate_time_claims(&claims).is_err());
449    }
450
451    #[test]
452    fn test_validate_time_claims_future_nbf() {
453        let far_future = (SystemTime::now()
454            .duration_since(UNIX_EPOCH)
455            .unwrap()
456            .as_secs()
457            + 999999) as i64;
458        let claims = json!({"nbf": far_future});
459        assert!(jwt::validate_time_claims(&claims).is_err());
460    }
461
462    #[test]
463    fn test_validate_time_claims_valid_no_claims() {
464        let claims = json!({});
465        assert!(jwt::validate_time_claims(&claims).is_ok());
466    }
467
468    // ── Token validation ────────────────────────────────────────────────
469
470    #[test]
471    fn test_validate_token_type_success() {
472        assert!(token::validate_token_type("bearer", &["bearer", "dpop"]).is_ok());
473    }
474
475    #[test]
476    fn test_validate_token_type_unsupported() {
477        assert!(token::validate_token_type("mac", &["bearer"]).is_err());
478    }
479
480    #[test]
481    fn test_validate_token_format_empty() {
482        assert!(token::validate_token_format("", "anything").is_err());
483    }
484
485    #[test]
486    fn test_validate_token_format_access_token_too_short() {
487        assert!(
488            token::validate_token_format("short", "urn:ietf:params:oauth:token-type:access_token")
489                .is_err()
490        );
491    }
492
493    #[test]
494    fn test_validate_token_format_refresh_token_too_short() {
495        assert!(
496            token::validate_token_format(
497                "shorttoken",
498                "urn:ietf:params:oauth:token-type:refresh_token"
499            )
500            .is_err()
501        );
502    }
503
504    #[test]
505    fn test_validate_scope_empty() {
506        let scopes = token::validate_scope("").unwrap();
507        assert!(scopes.is_empty());
508    }
509
510    #[test]
511    fn test_validate_scope_valid() {
512        let scopes = token::validate_scope("read write openid").unwrap();
513        assert_eq!(scopes, vec!["read", "write", "openid"]);
514    }
515
516    #[test]
517    fn test_validate_scope_invalid_chars() {
518        assert!(token::validate_scope("read <script>").is_err());
519    }
520
521    // ── Client validation ───────────────────────────────────────────────
522
523    #[test]
524    fn test_validate_client_id_valid() {
525        assert!(client::validate_client_id("my-client.app_01").is_ok());
526    }
527
528    #[test]
529    fn test_validate_client_id_empty() {
530        assert!(client::validate_client_id("").is_err());
531    }
532
533    #[test]
534    fn test_validate_client_id_too_short() {
535        assert!(client::validate_client_id("ab").is_err());
536    }
537
538    #[test]
539    fn test_validate_client_id_too_long() {
540        let long_id = "a".repeat(256);
541        assert!(client::validate_client_id(&long_id).is_err());
542    }
543
544    #[test]
545    fn test_validate_client_id_invalid_chars() {
546        assert!(client::validate_client_id("my client!").is_err());
547    }
548
549    #[test]
550    fn test_validate_redirect_uri_valid() {
551        assert!(client::validate_redirect_uri("https://example.com/callback").is_ok());
552        assert!(client::validate_redirect_uri("http://localhost:8080/cb").is_ok());
553        assert!(client::validate_redirect_uri("custom://app/callback").is_ok());
554    }
555
556    #[test]
557    fn test_validate_redirect_uri_empty() {
558        assert!(client::validate_redirect_uri("").is_err());
559    }
560
561    #[test]
562    fn test_validate_redirect_uri_not_absolute() {
563        assert!(client::validate_redirect_uri("/callback").is_err());
564    }
565
566    #[test]
567    fn test_validate_redirect_uri_with_fragment() {
568        assert!(client::validate_redirect_uri("https://example.com/cb#section").is_err());
569    }
570
571    #[test]
572    fn test_validate_grant_type_success() {
573        assert!(
574            client::validate_grant_type(
575                "authorization_code",
576                &["authorization_code", "refresh_token"]
577            )
578            .is_ok()
579        );
580    }
581
582    #[test]
583    fn test_validate_grant_type_unsupported() {
584        assert!(client::validate_grant_type("implicit", &["authorization_code"]).is_err());
585    }
586
587    // ── Request validation ──────────────────────────────────────────────
588
589    #[test]
590    fn test_validate_required_params() {
591        let mut params = HashMap::new();
592        params.insert("code".to_string(), "abc123".to_string());
593        assert!(request::validate_required_params(&params, &["code"]).is_ok());
594        assert!(request::validate_required_params(&params, &["code", "state"]).is_err());
595    }
596
597    #[test]
598    fn test_validate_required_params_empty_value() {
599        let mut params = HashMap::new();
600        params.insert("code".to_string(), "  ".to_string());
601        assert!(request::validate_required_params(&params, &["code"]).is_err());
602    }
603
604    #[test]
605    fn test_validate_param_format_alphanum() {
606        assert!(request::validate_param_format("abc123", "nonce", "alphanum").is_ok());
607        assert!(request::validate_param_format("abc-123", "nonce", "alphanum").is_err());
608    }
609
610    #[test]
611    fn test_validate_param_format_empty() {
612        assert!(request::validate_param_format("", "nonce", "alphanum").is_err());
613    }
614
615    #[test]
616    fn test_validate_code_challenge_method() {
617        assert!(request::validate_code_challenge_method("S256").is_ok());
618        assert!(request::validate_code_challenge_method("plain").is_ok());
619        assert!(request::validate_code_challenge_method("S512").is_err());
620    }
621
622    #[test]
623    fn test_validate_response_type() {
624        assert!(request::validate_response_type("code", &["code", "token"]).is_ok());
625        assert!(request::validate_response_type("id_token", &["code"]).is_err());
626    }
627
628    // ── URL validation ──────────────────────────────────────────────────
629
630    #[test]
631    fn test_validate_url_format_valid() {
632        assert!(url::validate_url_format("https://example.com").is_ok());
633        assert!(url::validate_url_format("http://localhost:8080").is_ok());
634    }
635
636    #[test]
637    fn test_validate_url_format_empty() {
638        assert!(url::validate_url_format("").is_err());
639    }
640
641    #[test]
642    fn test_validate_url_format_no_scheme() {
643        assert!(url::validate_url_format("example.com").is_err());
644    }
645
646    #[test]
647    fn test_validate_https_required() {
648        assert!(url::validate_https_required("https://example.com").is_ok());
649        assert!(url::validate_https_required("http://example.com").is_err());
650    }
651
652    // ── collect_validation_errors ────────────────────────────────────────
653
654    #[test]
655    fn test_collect_validation_errors_all_ok() {
656        let validations = vec![Ok(()), Ok(())];
657        assert!(collect_validation_errors(validations).is_ok());
658    }
659
660    #[test]
661    fn test_collect_validation_errors_some_fail() {
662        let validations = vec![
663            Ok(()),
664            Err(AuthError::validation("err1")),
665            Err(AuthError::validation("err2")),
666        ];
667        let err = collect_validation_errors(validations).unwrap_err();
668        let msg = format!("{}", err);
669        assert!(msg.contains("err1"));
670        assert!(msg.contains("err2"));
671    }
672}