jwt_verify/common/
token.rs

1use base64::Engine;
2use jsonwebtoken::{decode_header, Algorithm, DecodingKey};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fmt;
7
8use crate::common::error::JwtError;
9
10/// JWT token types
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TokenType {
13    /// ID token
14    IdToken,
15    /// Access token
16    AccessToken,
17    /// Unknown token type
18    Unknown,
19    /// No token provided
20    None,
21}
22
23impl fmt::Display for TokenType {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            TokenType::IdToken => write!(f, "ID token"),
27            TokenType::AccessToken => write!(f, "Access token"),
28            TokenType::Unknown => write!(f, "Unknown token"),
29            TokenType::None => write!(f, "No token"),
30        }
31    }
32}
33
34/// JWT header
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct JwtHeader {
37    /// Algorithm
38    pub alg: String,
39    /// Key ID
40    pub kid: String,
41    /// Token type
42    #[serde(rename = "typ")]
43    pub token_type: Option<String>,
44    /// Additional headers
45    #[serde(flatten)]
46    pub additional_headers: HashMap<String, Value>,
47}
48
49/// Token parser
50pub struct TokenParser;
51
52// TokenPayload moved to cognito/token.rs
53
54impl TokenParser {
55    /// Parse JWT header
56    pub fn parse_token_header(token: &str) -> Result<JwtHeader, JwtError> {
57        // Check if token is empty
58        if token.is_empty() {
59            return Err(JwtError::MissingToken);
60        }
61
62        // Validate token format (should have 3 parts separated by dots)
63        if !token.contains('.') || token.matches('.').count() != 2 {
64            return Err(JwtError::ParseError {
65                part: Some("token".to_string()),
66                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
67            });
68        }
69
70        // Decode header
71        let header = decode_header(token).map_err(|e| JwtError::ParseError {
72            part: Some("header".to_string()),
73            error: format!("Failed to decode header: {}", e),
74        })?;
75
76        // Extract kid
77        let kid = header.kid.ok_or_else(|| JwtError::ParseError {
78            part: Some("header".to_string()),
79            error: "Missing 'kid' in token header".to_string(),
80        })?;
81
82        // Validate algorithm (Cognito uses RS256)
83        if header.alg != Algorithm::RS256 {
84            return Err(JwtError::InvalidClaim {
85                claim: "alg".to_string(),
86                reason: format!("Unsupported algorithm: {:?}, expected RS256", header.alg),
87                value: Some(format!("{:?}", header.alg)),
88            });
89        }
90
91        // Extract algorithm
92        let alg = match header.alg {
93            Algorithm::RS256 => "RS256".to_string(),
94            Algorithm::RS384 => "RS384".to_string(),
95            Algorithm::RS512 => "RS512".to_string(),
96            Algorithm::HS256 => "HS256".to_string(),
97            Algorithm::HS384 => "HS384".to_string(),
98            Algorithm::HS512 => "HS512".to_string(),
99            Algorithm::ES256 => "ES256".to_string(),
100            Algorithm::ES384 => "ES384".to_string(),
101            _ => {
102                return Err(JwtError::ParseError {
103                    part: Some("header".to_string()),
104                    error: "Unsupported algorithm".to_string(),
105                })
106            }
107        };
108
109        // Extract token type from header if available
110        let token_type = header.typ.clone();
111
112        // Create header with additional headers
113        let mut additional_headers = HashMap::new();
114
115        // Add any other fields from the header that we're not explicitly handling
116        if let Some(cty) = &header.cty {
117            additional_headers.insert("cty".to_string(), Value::String(cty.clone()));
118        }
119
120        // Create header
121        let jwt_header = JwtHeader {
122            alg,
123            kid,
124            token_type,
125            additional_headers,
126        };
127
128        Ok(jwt_header)
129    }
130
131    /// Parse token claims
132    pub fn parse_token_claims<T: for<'de> Deserialize<'de>>(
133        token: &str,
134        key: &DecodingKey,
135        validation: &jsonwebtoken::Validation,
136    ) -> Result<T, JwtError> {
137        // Check if token is empty
138        if token.is_empty() {
139            return Err(JwtError::MissingToken);
140        }
141
142        // Validate token format (should have 3 parts separated by dots)
143        if !token.contains('.') || token.matches('.').count() != 2 {
144            return Err(JwtError::ParseError {
145                part: Some("token".to_string()),
146                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
147            });
148        }
149
150        // Decode and validate the token
151        let token_data = jsonwebtoken::decode::<T>(token, key, validation).map_err(|e| {
152            // Convert jsonwebtoken errors to our custom error types with more context
153            use jsonwebtoken::errors::ErrorKind;
154            match e.kind() {
155                ErrorKind::ExpiredSignature => {
156                    // Try to extract the expiration time from the token
157                    if let Ok(exp) = Self::extract_claim_from_token::<u64>(token, "exp") {
158                        let now = std::time::SystemTime::now()
159                            .duration_since(std::time::UNIX_EPOCH)
160                            .unwrap_or_default()
161                            .as_secs();
162
163                        JwtError::ExpiredToken {
164                            exp: Some(exp),
165                            current_time: Some(now),
166                        }
167                    } else {
168                        JwtError::ExpiredToken {
169                            exp: None,
170                            current_time: None,
171                        }
172                    }
173                }
174                ErrorKind::InvalidSignature => JwtError::InvalidSignature,
175                ErrorKind::InvalidIssuer => {
176                    // Try to extract the issuer from the token
177                    if let Ok(iss) = Self::extract_claim_from_token::<String>(token, "iss") {
178                        JwtError::InvalidIssuer {
179                            expected: validation
180                                .iss
181                                .as_ref()
182                                .and_then(|iss_set| iss_set.iter().next())
183                                .cloned()
184                                .unwrap_or_default(),
185                            actual: iss,
186                        }
187                    } else {
188                        JwtError::InvalidClaim {
189                            claim: "iss".to_string(),
190                            reason: "Invalid issuer".to_string(),
191                            value: None,
192                        }
193                    }
194                }
195                ErrorKind::InvalidAudience => JwtError::InvalidClaim {
196                    claim: "aud".to_string(),
197                    reason: "Invalid audience".to_string(),
198                    value: None,
199                },
200                ErrorKind::InvalidSubject => JwtError::InvalidClaim {
201                    claim: "sub".to_string(),
202                    reason: "Invalid subject".to_string(),
203                    value: None,
204                },
205                ErrorKind::ImmatureSignature => {
206                    // Try to extract the not before time from the token
207                    if let Ok(nbf) = Self::extract_claim_from_token::<u64>(token, "nbf") {
208                        let now = std::time::SystemTime::now()
209                            .duration_since(std::time::UNIX_EPOCH)
210                            .unwrap_or_default()
211                            .as_secs();
212
213                        JwtError::TokenNotYetValid {
214                            nbf: Some(nbf),
215                            current_time: Some(now),
216                        }
217                    } else {
218                        JwtError::TokenNotYetValid {
219                            nbf: None,
220                            current_time: None,
221                        }
222                    }
223                }
224                ErrorKind::InvalidAlgorithm => JwtError::InvalidClaim {
225                    claim: "alg".to_string(),
226                    reason: "Invalid algorithm".to_string(),
227                    value: None,
228                },
229                _ => JwtError::ParseError {
230                    part: Some("claims".to_string()),
231                    error: format!("Failed to decode token: {}", e),
232                },
233            }
234        })?;
235
236        Ok(token_data.claims)
237    }
238
239    // parse_token_payload moved to cognito/token.rs
240
241    /// Extract a specific claim from a token without validating the signature
242    pub fn extract_claim_from_token<T: for<'de> Deserialize<'de>>(
243        token: &str,
244        claim_name: &str,
245    ) -> Result<T, JwtError> {
246        // Split the token
247        let parts: Vec<&str> = token.split('.').collect();
248        if parts.len() != 3 {
249            return Err(JwtError::ParseError {
250                part: Some("token".to_string()),
251                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
252            });
253        }
254
255        // Decode the payload (second part)
256        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
257            .decode(parts[1])
258            .map_err(|e| JwtError::ParseError {
259                part: Some("payload".to_string()),
260                error: format!("Invalid base64 in payload: {}", e),
261            })?;
262
263        // Parse the payload
264        let payload: serde_json::Value =
265            serde_json::from_slice(&payload).map_err(|e| JwtError::ParseError {
266                part: Some("payload".to_string()),
267                error: format!("Invalid JSON in payload: {}", e),
268            })?;
269
270        // Extract the claim
271        let claim = payload
272            .get(claim_name)
273            .ok_or_else(|| JwtError::ParseError {
274                part: Some("payload".to_string()),
275                error: format!("Claim '{}' not found in payload", claim_name),
276            })?;
277
278        // Deserialize the claim
279        serde_json::from_value(claim.clone()).map_err(|e| JwtError::ParseError {
280            part: Some("payload".to_string()),
281            error: format!("Failed to deserialize claim '{}': {}", claim_name, e),
282        })
283    }
284
285    /// Extract the issuer from a token without validating the signature
286    ///
287    /// This is a convenience method that extracts the issuer claim from a token.
288    ///
289    /// # Parameters
290    ///
291    /// * `token` - The JWT token
292    ///
293    /// # Returns
294    ///
295    /// Returns a `Result` containing the issuer if successful, or a `JwtError`
296    /// if the issuer could not be extracted.
297    pub fn extract_issuer(token: &str) -> Result<String, JwtError> {
298        Self::extract_claim_from_token(token, "iss")
299    }
300}