1use base64::Engine;
6use hmac::{Hmac, Mac};
7use jerrycan_core::{Error, Result};
8use serde::{Serialize, de::DeserializeOwned};
9use sha2::Sha256;
10
11type HmacSha256 = Hmac<Sha256>;
12
13fn b64(bytes: &[u8]) -> String {
14 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
15}
16
17fn unb64(s: &str) -> Result<Vec<u8>> {
18 base64::engine::general_purpose::URL_SAFE_NO_PAD
19 .decode(s)
20 .map_err(|_| Error::unauthorized())
21}
22
23fn sign(message: &str, key: &[u8]) -> String {
24 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
25 mac.update(message.as_bytes());
26 b64(&mac.finalize().into_bytes())
27}
28
29pub fn encode<T: Serialize>(claims: &T, key: &[u8]) -> Result<String> {
32 let header = b64(br#"{"alg":"HS256","typ":"JWT"}"#);
33 let payload_json =
34 serde_json::to_vec(claims).map_err(|e| Error::internal(format!("jwt serialize: {e}")))?;
35 let payload = b64(&payload_json);
36 let message = format!("{header}.{payload}");
37 let signature = sign(&message, key);
38 Ok(format!("{message}.{signature}"))
39}
40
41pub fn decode<T: DeserializeOwned>(token: &str, key: &[u8]) -> Result<T> {
44 let parts: Vec<&str> = token.split('.').collect();
45 if parts.len() != 3 {
46 return Err(Error::unauthorized());
47 }
48 let message = format!("{}.{}", parts[0], parts[1]);
49 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
50 mac.update(message.as_bytes());
51 let provided = unb64(parts[2])?;
52 mac.verify_slice(&provided)
53 .map_err(|_| Error::unauthorized())?;
54
55 let payload = unb64(parts[1])?;
56 if let Ok(map) = serde_json::from_slice::<serde_json::Value>(&payload)
58 && let Some(exp) = map.get("exp").and_then(|v| v.as_u64())
59 && exp <= now_unix()
60 {
61 return Err(Error::unauthorized());
62 }
63 serde_json::from_slice(&payload).map_err(|_| Error::unauthorized())
64}
65
66fn now_unix() -> u64 {
67 std::time::SystemTime::now()
68 .duration_since(std::time::UNIX_EPOCH)
69 .map(|d| d.as_secs())
70 .unwrap_or(0)
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use serde::Deserialize;
77
78 #[derive(Serialize, Deserialize, PartialEq, Debug)]
79 struct Claims {
80 sub: String,
81 role: String,
82 exp: u64,
83 }
84
85 fn key() -> [u8; 32] {
86 *crate::derive_key(b"a-very-long-development-secret-string!!", "jwt")
87 }
88
89 #[test]
90 fn encode_then_decode_round_trips() {
91 let token = encode(
92 &Claims {
93 sub: "u1".into(),
94 role: "admin".into(),
95 exp: 9999999999,
96 },
97 &key(),
98 )
99 .unwrap();
100 assert_eq!(token.split('.').count(), 3, "header.payload.signature");
101 let claims: Claims = decode(&token, &key()).unwrap();
102 assert_eq!(
103 claims,
104 Claims {
105 sub: "u1".into(),
106 role: "admin".into(),
107 exp: 9999999999
108 }
109 );
110 }
111
112 #[test]
113 fn a_tampered_payload_fails_signature_verification() {
114 let token = encode(
115 &Claims {
116 sub: "u1".into(),
117 role: "user".into(),
118 exp: 9999999999,
119 },
120 &key(),
121 )
122 .unwrap();
123 let mut parts: Vec<&str> = token.split('.').collect();
124 let forged = base64::engine::general_purpose::URL_SAFE_NO_PAD
126 .encode(br#"{"sub":"u1","role":"admin","exp":9999999999}"#);
127 parts[1] = &forged;
128 let tampered = parts.join(".");
129 assert!(decode::<Claims>(&tampered, &key()).is_err());
130 }
131
132 #[test]
133 fn a_wrong_key_is_rejected() {
134 let token = encode(
135 &Claims {
136 sub: "u1".into(),
137 role: "user".into(),
138 exp: 9999999999,
139 },
140 &key(),
141 )
142 .unwrap();
143 let other = *crate::derive_key(b"different-secret-of-at-least-32-bytes!!", "jwt");
144 assert!(decode::<Claims>(&token, &other).is_err());
145 }
146
147 #[test]
148 fn expired_tokens_are_rejected() {
149 let token = encode(
150 &Claims {
151 sub: "u1".into(),
152 role: "user".into(),
153 exp: 1,
154 },
155 &key(),
156 )
157 .unwrap();
158 let err = decode::<Claims>(&token, &key()).unwrap_err();
159 assert_eq!(err.code(), "JC0401");
160 }
161}