wae_authentication/jwt/
codec.rs1use base64::{Engine as _, engine::general_purpose};
4use hmac::{Hmac, Mac as _};
5use serde::{Deserialize, Serialize};
6use sha2::{Sha256, Sha384, Sha512};
7use std::fmt;
8use wae_types::{WaeError, WaeErrorKind};
9
10#[derive(Debug)]
12pub enum JwtCodecError {
13 InvalidFormat,
15
16 Base64Error(base64::DecodeError),
18
19 JsonError(serde_json::Error),
21
22 InvalidSignature,
24
25 InvalidAlgorithm,
27
28 KeyError,
30}
31
32impl fmt::Display for JwtCodecError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 JwtCodecError::InvalidFormat => write!(f, "invalid token format"),
36 JwtCodecError::Base64Error(e) => write!(f, "base64 decode error: {}", e),
37 JwtCodecError::JsonError(e) => write!(f, "json error: {}", e),
38 JwtCodecError::InvalidSignature => write!(f, "invalid signature"),
39 JwtCodecError::InvalidAlgorithm => write!(f, "invalid algorithm"),
40 JwtCodecError::KeyError => write!(f, "key error"),
41 }
42 }
43}
44
45impl std::error::Error for JwtCodecError {}
46
47impl From<base64::DecodeError> for JwtCodecError {
48 fn from(err: base64::DecodeError) -> Self {
49 JwtCodecError::Base64Error(err)
50 }
51}
52
53impl From<serde_json::Error> for JwtCodecError {
54 fn from(err: serde_json::Error) -> Self {
55 JwtCodecError::JsonError(err)
56 }
57}
58
59impl From<JwtCodecError> for WaeError {
60 fn from(err: JwtCodecError) -> Self {
61 match err {
62 JwtCodecError::InvalidFormat => WaeError::invalid_token("malformed token"),
63 JwtCodecError::Base64Error(_) => WaeError::invalid_token("invalid base64"),
64 JwtCodecError::JsonError(_) => WaeError::invalid_token("invalid json"),
65 JwtCodecError::InvalidSignature => WaeError::invalid_signature(),
66 JwtCodecError::InvalidAlgorithm => WaeError::new(WaeErrorKind::InvalidAlgorithm),
67 JwtCodecError::KeyError => WaeError::new(WaeErrorKind::KeyError),
68 }
69 }
70}
71
72pub type JwtCodecResult<T> = Result<T, JwtCodecError>;
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct JwtHeader {
78 pub alg: String,
80 pub typ: String,
82}
83
84impl JwtHeader {
85 pub fn new(alg: impl Into<String>) -> Self {
87 Self { alg: alg.into(), typ: "JWT".to_string() }
88 }
89}
90
91pub fn base64url_encode(input: &[u8]) -> String {
93 general_purpose::URL_SAFE_NO_PAD.encode(input)
94}
95
96pub fn base64url_decode(input: &str) -> JwtCodecResult<Vec<u8>> {
98 Ok(general_purpose::URL_SAFE_NO_PAD.decode(input)?)
99}
100
101pub fn hmac_sign(algorithm: &str, secret: &[u8], data: &[u8]) -> JwtCodecResult<Vec<u8>> {
103 match algorithm {
104 "HS256" => {
105 let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
106 mac.update(data);
107 Ok(mac.finalize().into_bytes().to_vec())
108 }
109 "HS384" => {
110 let mut mac = Hmac::<Sha384>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
111 mac.update(data);
112 Ok(mac.finalize().into_bytes().to_vec())
113 }
114 "HS512" => {
115 let mut mac = Hmac::<Sha512>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
116 mac.update(data);
117 Ok(mac.finalize().into_bytes().to_vec())
118 }
119 _ => Err(JwtCodecError::InvalidAlgorithm),
120 }
121}
122
123pub fn hmac_verify(algorithm: &str, secret: &[u8], data: &[u8], signature: &[u8]) -> JwtCodecResult<bool> {
125 match algorithm {
126 "HS256" => {
127 let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
128 mac.update(data);
129 mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
130 Ok(true)
131 }
132 "HS384" => {
133 let mut mac = Hmac::<Sha384>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
134 mac.update(data);
135 mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
136 Ok(true)
137 }
138 "HS512" => {
139 let mut mac = Hmac::<Sha512>::new_from_slice(secret).map_err(|_| JwtCodecError::KeyError)?;
140 mac.update(data);
141 mac.verify_slice(signature).map_err(|_| JwtCodecError::InvalidSignature)?;
142 Ok(true)
143 }
144 _ => Err(JwtCodecError::InvalidAlgorithm),
145 }
146}
147
148pub fn encode_jwt<T: Serialize>(header: &JwtHeader, claims: &T, secret: &[u8]) -> JwtCodecResult<String> {
150 let header_json = serde_json::to_string(header)?;
151 let claims_json = serde_json::to_string(claims)?;
152
153 let header_b64 = base64url_encode(header_json.as_bytes());
154 let claims_b64 = base64url_encode(claims_json.as_bytes());
155
156 let message = format!("{}.{}", header_b64, claims_b64);
157 let signature = hmac_sign(&header.alg, secret, message.as_bytes())?;
158 let signature_b64 = base64url_encode(&signature);
159
160 Ok(format!("{}.{}", message, signature_b64))
161}
162
163pub fn decode_jwt<T: for<'de> Deserialize<'de>>(token: &str, secret: &[u8], validate_signature: bool) -> JwtCodecResult<T> {
165 let parts: Vec<&str> = token.split('.').collect();
166 if parts.len() != 3 {
167 return Err(JwtCodecError::InvalidFormat);
168 }
169
170 let header_b64 = parts[0];
171 let claims_b64 = parts[1];
172 let signature_b64 = parts[2];
173
174 let header_bytes = base64url_decode(header_b64)?;
175 let header: JwtHeader = serde_json::from_slice(&header_bytes)?;
176
177 let claims_bytes = base64url_decode(claims_b64)?;
178 let claims: T = serde_json::from_slice(&claims_bytes)?;
179
180 if validate_signature {
181 let message = format!("{}.{}", header_b64, claims_b64);
182 let signature = base64url_decode(signature_b64)?;
183 hmac_verify(&header.alg, secret, message.as_bytes(), &signature)?;
184 }
185
186 Ok(claims)
187}