1use base64::Engine;
2use jsonwebtoken::{decode_header, Algorithm, DecodingKey};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fmt;
7
8use crate::common::error::JwtError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TokenType {
13 IdToken,
15 AccessToken,
17 Unknown,
19 None,
21}
22
23impl fmt::Display for TokenType {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 TokenType::IdToken => write!(f, "ID token"),
27 TokenType::AccessToken => write!(f, "Access token"),
28 TokenType::Unknown => write!(f, "Unknown token"),
29 TokenType::None => write!(f, "No token"),
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct JwtHeader {
37 pub alg: String,
39 pub kid: String,
41 #[serde(rename = "typ")]
43 pub token_type: Option<String>,
44 #[serde(flatten)]
46 pub additional_headers: HashMap<String, Value>,
47}
48
49pub struct TokenParser;
51
52impl TokenParser {
55 pub fn parse_token_header(token: &str) -> Result<JwtHeader, JwtError> {
57 if token.is_empty() {
59 return Err(JwtError::MissingToken);
60 }
61
62 if !token.contains('.') || token.matches('.').count() != 2 {
64 return Err(JwtError::ParseError {
65 part: Some("token".to_string()),
66 error: "Invalid token format: expected 3 parts separated by dots".to_string(),
67 });
68 }
69
70 let header = decode_header(token).map_err(|e| JwtError::ParseError {
72 part: Some("header".to_string()),
73 error: format!("Failed to decode header: {}", e),
74 })?;
75
76 let kid = header.kid.ok_or_else(|| JwtError::ParseError {
78 part: Some("header".to_string()),
79 error: "Missing 'kid' in token header".to_string(),
80 })?;
81
82 if header.alg != Algorithm::RS256 {
84 return Err(JwtError::InvalidClaim {
85 claim: "alg".to_string(),
86 reason: format!("Unsupported algorithm: {:?}, expected RS256", header.alg),
87 value: Some(format!("{:?}", header.alg)),
88 });
89 }
90
91 let alg = match header.alg {
93 Algorithm::RS256 => "RS256".to_string(),
94 Algorithm::RS384 => "RS384".to_string(),
95 Algorithm::RS512 => "RS512".to_string(),
96 Algorithm::HS256 => "HS256".to_string(),
97 Algorithm::HS384 => "HS384".to_string(),
98 Algorithm::HS512 => "HS512".to_string(),
99 Algorithm::ES256 => "ES256".to_string(),
100 Algorithm::ES384 => "ES384".to_string(),
101 _ => {
102 return Err(JwtError::ParseError {
103 part: Some("header".to_string()),
104 error: "Unsupported algorithm".to_string(),
105 })
106 }
107 };
108
109 let token_type = header.typ.clone();
111
112 let mut additional_headers = HashMap::new();
114
115 if let Some(cty) = &header.cty {
117 additional_headers.insert("cty".to_string(), Value::String(cty.clone()));
118 }
119
120 let jwt_header = JwtHeader {
122 alg,
123 kid,
124 token_type,
125 additional_headers,
126 };
127
128 Ok(jwt_header)
129 }
130
131 pub fn parse_token_claims<T: for<'de> Deserialize<'de>>(
133 token: &str,
134 key: &DecodingKey,
135 validation: &jsonwebtoken::Validation,
136 ) -> Result<T, JwtError> {
137 if token.is_empty() {
139 return Err(JwtError::MissingToken);
140 }
141
142 if !token.contains('.') || token.matches('.').count() != 2 {
144 return Err(JwtError::ParseError {
145 part: Some("token".to_string()),
146 error: "Invalid token format: expected 3 parts separated by dots".to_string(),
147 });
148 }
149
150 let token_data = jsonwebtoken::decode::<T>(token, key, validation).map_err(|e| {
152 use jsonwebtoken::errors::ErrorKind;
154 match e.kind() {
155 ErrorKind::ExpiredSignature => {
156 if let Ok(exp) = Self::extract_claim_from_token::<u64>(token, "exp") {
158 let now = std::time::SystemTime::now()
159 .duration_since(std::time::UNIX_EPOCH)
160 .unwrap_or_default()
161 .as_secs();
162
163 JwtError::ExpiredToken {
164 exp: Some(exp),
165 current_time: Some(now),
166 }
167 } else {
168 JwtError::ExpiredToken {
169 exp: None,
170 current_time: None,
171 }
172 }
173 }
174 ErrorKind::InvalidSignature => JwtError::InvalidSignature,
175 ErrorKind::InvalidIssuer => {
176 if let Ok(iss) = Self::extract_claim_from_token::<String>(token, "iss") {
178 JwtError::InvalidIssuer {
179 expected: validation
180 .iss
181 .as_ref()
182 .and_then(|iss_set| iss_set.iter().next())
183 .cloned()
184 .unwrap_or_default(),
185 actual: iss,
186 }
187 } else {
188 JwtError::InvalidClaim {
189 claim: "iss".to_string(),
190 reason: "Invalid issuer".to_string(),
191 value: None,
192 }
193 }
194 }
195 ErrorKind::InvalidAudience => JwtError::InvalidClaim {
196 claim: "aud".to_string(),
197 reason: "Invalid audience".to_string(),
198 value: None,
199 },
200 ErrorKind::InvalidSubject => JwtError::InvalidClaim {
201 claim: "sub".to_string(),
202 reason: "Invalid subject".to_string(),
203 value: None,
204 },
205 ErrorKind::ImmatureSignature => {
206 if let Ok(nbf) = Self::extract_claim_from_token::<u64>(token, "nbf") {
208 let now = std::time::SystemTime::now()
209 .duration_since(std::time::UNIX_EPOCH)
210 .unwrap_or_default()
211 .as_secs();
212
213 JwtError::TokenNotYetValid {
214 nbf: Some(nbf),
215 current_time: Some(now),
216 }
217 } else {
218 JwtError::TokenNotYetValid {
219 nbf: None,
220 current_time: None,
221 }
222 }
223 }
224 ErrorKind::InvalidAlgorithm => JwtError::InvalidClaim {
225 claim: "alg".to_string(),
226 reason: "Invalid algorithm".to_string(),
227 value: None,
228 },
229 _ => JwtError::ParseError {
230 part: Some("claims".to_string()),
231 error: format!("Failed to decode token: {}", e),
232 },
233 }
234 })?;
235
236 Ok(token_data.claims)
237 }
238
239 pub fn extract_claim_from_token<T: for<'de> Deserialize<'de>>(
243 token: &str,
244 claim_name: &str,
245 ) -> Result<T, JwtError> {
246 let parts: Vec<&str> = token.split('.').collect();
248 if parts.len() != 3 {
249 return Err(JwtError::ParseError {
250 part: Some("token".to_string()),
251 error: "Invalid token format: expected 3 parts separated by dots".to_string(),
252 });
253 }
254
255 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
257 .decode(parts[1])
258 .map_err(|e| JwtError::ParseError {
259 part: Some("payload".to_string()),
260 error: format!("Invalid base64 in payload: {}", e),
261 })?;
262
263 let payload: serde_json::Value =
265 serde_json::from_slice(&payload).map_err(|e| JwtError::ParseError {
266 part: Some("payload".to_string()),
267 error: format!("Invalid JSON in payload: {}", e),
268 })?;
269
270 let claim = payload
272 .get(claim_name)
273 .ok_or_else(|| JwtError::ParseError {
274 part: Some("payload".to_string()),
275 error: format!("Claim '{}' not found in payload", claim_name),
276 })?;
277
278 serde_json::from_value(claim.clone()).map_err(|e| JwtError::ParseError {
280 part: Some("payload".to_string()),
281 error: format!("Failed to deserialize claim '{}': {}", claim_name, e),
282 })
283 }
284
285 pub fn extract_issuer(token: &str) -> Result<String, JwtError> {
298 Self::extract_claim_from_token(token, "iss")
299 }
300}