Skip to main content

robinpath_modules/modules/
jwt_mod.rs

1use robinpath::{RobinPath, Value};
2
3pub fn register(rp: &mut RobinPath) {
4    rp.register_builtin("jwt.sign", |args, _| {
5        let payload = args.first().cloned().unwrap_or(Value::Null);
6        let secret = args.get(1).map(|v| v.to_display_string()).unwrap_or_default();
7        let options = args.get(2).cloned().unwrap_or(Value::Null);
8
9        let algorithm = if let Value::Object(ref opts) = options {
10            opts.get("algorithm")
11                .map(|v| v.to_display_string())
12                .unwrap_or_else(|| "HS256".to_string())
13        } else {
14            "HS256".to_string()
15        };
16
17        let expires_in = if let Value::Object(ref opts) = options {
18            opts.get("expiresIn").map(|v| v.to_number() as u64)
19        } else {
20            None
21        };
22
23        // Build header
24        let mut header = indexmap::IndexMap::new();
25        header.insert("alg".to_string(), Value::String(algorithm.clone()));
26        header.insert("typ".to_string(), Value::String("JWT".to_string()));
27
28        // Possibly add exp to payload
29        let mut payload_map = if let Value::Object(obj) = &payload {
30            obj.clone()
31        } else {
32            indexmap::IndexMap::new()
33        };
34
35        if let Some(exp_secs) = expires_in {
36            let now = std::time::SystemTime::now()
37                .duration_since(std::time::UNIX_EPOCH)
38                .unwrap_or_default()
39                .as_secs();
40            payload_map.insert("exp".to_string(), Value::Number((now + exp_secs) as f64));
41        }
42
43        let header_json: serde_json::Value = Value::Object(header).into();
44        let payload_json: serde_json::Value = Value::Object(payload_map).into();
45
46        let header_b64 = base64url_encode(&serde_json::to_vec(&header_json).unwrap_or_default());
47        let payload_b64 = base64url_encode(&serde_json::to_vec(&payload_json).unwrap_or_default());
48
49        let signing_input = format!("{}.{}", header_b64, payload_b64);
50        let signature = hmac_sign(&algorithm, signing_input.as_bytes(), secret.as_bytes());
51        let sig_b64 = base64url_encode(&signature);
52
53        Ok(Value::String(format!("{}.{}", signing_input, sig_b64)))
54    });
55
56    rp.register_builtin("jwt.verify", |args, _| {
57        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
58        let secret = args.get(1).map(|v| v.to_display_string()).unwrap_or_default();
59
60        let parts: Vec<&str> = token.split('.').collect();
61        if parts.len() != 3 {
62            return Err("jwt.verify: invalid token format".to_string());
63        }
64
65        let header_bytes = base64url_decode(parts[0])
66            .map_err(|e| format!("jwt.verify: invalid header: {}", e))?;
67        let header: serde_json::Value = serde_json::from_slice(&header_bytes)
68            .map_err(|e| format!("jwt.verify: invalid header JSON: {}", e))?;
69
70        let algorithm = header
71            .get("alg")
72            .and_then(|v| v.as_str())
73            .unwrap_or("HS256")
74            .to_string();
75
76        let signing_input = format!("{}.{}", parts[0], parts[1]);
77        let expected_sig = hmac_sign(&algorithm, signing_input.as_bytes(), secret.as_bytes());
78        let expected_b64 = base64url_encode(&expected_sig);
79
80        if expected_b64 != parts[2] {
81            return Err("jwt.verify: invalid signature".to_string());
82        }
83
84        let payload_bytes = base64url_decode(parts[1])
85            .map_err(|e| format!("jwt.verify: invalid payload: {}", e))?;
86        let payload_json: serde_json::Value = serde_json::from_slice(&payload_bytes)
87            .map_err(|e| format!("jwt.verify: invalid payload JSON: {}", e))?;
88
89        // Check expiration
90        if let Some(exp) = payload_json.get("exp").and_then(|v| v.as_f64()) {
91            let now = std::time::SystemTime::now()
92                .duration_since(std::time::UNIX_EPOCH)
93                .unwrap_or_default()
94                .as_secs() as f64;
95            if now > exp {
96                return Err("jwt.verify: token expired".to_string());
97            }
98        }
99
100        Ok(Value::from(payload_json))
101    });
102
103    rp.register_builtin("jwt.decode", |args, _| {
104        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
105        let parts: Vec<&str> = token.split('.').collect();
106        if parts.len() != 3 {
107            return Err("jwt.decode: invalid token format".to_string());
108        }
109
110        let header_bytes = base64url_decode(parts[0])
111            .map_err(|e| format!("jwt.decode: invalid header: {}", e))?;
112        let payload_bytes = base64url_decode(parts[1])
113            .map_err(|e| format!("jwt.decode: invalid payload: {}", e))?;
114
115        let header_json: serde_json::Value = serde_json::from_slice(&header_bytes)
116            .map_err(|e| format!("jwt.decode: {}", e))?;
117        let payload_json: serde_json::Value = serde_json::from_slice(&payload_bytes)
118            .map_err(|e| format!("jwt.decode: {}", e))?;
119
120        let mut result = indexmap::IndexMap::new();
121        result.insert("header".to_string(), Value::from(header_json));
122        result.insert("payload".to_string(), Value::from(payload_json));
123        result.insert("signature".to_string(), Value::String(parts[2].to_string()));
124        Ok(Value::Object(result))
125    });
126
127    rp.register_builtin("jwt.getHeader", |args, _| {
128        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
129        let parts: Vec<&str> = token.split('.').collect();
130        if parts.is_empty() {
131            return Err("jwt.getHeader: invalid token".to_string());
132        }
133        let header_bytes = base64url_decode(parts[0])
134            .map_err(|e| format!("jwt.getHeader: {}", e))?;
135        let header_json: serde_json::Value = serde_json::from_slice(&header_bytes)
136            .map_err(|e| format!("jwt.getHeader: {}", e))?;
137        Ok(Value::from(header_json))
138    });
139
140    rp.register_builtin("jwt.getPayload", |args, _| {
141        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
142        let parts: Vec<&str> = token.split('.').collect();
143        if parts.len() < 2 {
144            return Err("jwt.getPayload: invalid token".to_string());
145        }
146        let payload_bytes = base64url_decode(parts[1])
147            .map_err(|e| format!("jwt.getPayload: {}", e))?;
148        let payload_json: serde_json::Value = serde_json::from_slice(&payload_bytes)
149            .map_err(|e| format!("jwt.getPayload: {}", e))?;
150        Ok(Value::from(payload_json))
151    });
152
153    rp.register_builtin("jwt.isExpired", |args, _| {
154        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
155        let parts: Vec<&str> = token.split('.').collect();
156        if parts.len() < 2 {
157            return Ok(Value::Bool(true));
158        }
159        let payload_bytes = match base64url_decode(parts[1]) {
160            Ok(b) => b,
161            Err(_) => return Ok(Value::Bool(true)),
162        };
163        let payload_json: serde_json::Value = match serde_json::from_slice(&payload_bytes) {
164            Ok(j) => j,
165            Err(_) => return Ok(Value::Bool(true)),
166        };
167
168        if let Some(exp) = payload_json.get("exp").and_then(|v| v.as_f64()) {
169            let now = std::time::SystemTime::now()
170                .duration_since(std::time::UNIX_EPOCH)
171                .unwrap_or_default()
172                .as_secs() as f64;
173            Ok(Value::Bool(now > exp))
174        } else {
175            // No exp claim means it doesn't expire
176            Ok(Value::Bool(false))
177        }
178    });
179
180    rp.register_builtin("jwt.getExpiration", |args, _| {
181        let token = args.first().map(|v| v.to_display_string()).unwrap_or_default();
182        let parts: Vec<&str> = token.split('.').collect();
183        if parts.len() < 2 {
184            return Ok(Value::Null);
185        }
186        let payload_bytes = match base64url_decode(parts[1]) {
187            Ok(b) => b,
188            Err(_) => return Ok(Value::Null),
189        };
190        let payload_json: serde_json::Value = match serde_json::from_slice(&payload_bytes) {
191            Ok(j) => j,
192            Err(_) => return Ok(Value::Null),
193        };
194        match payload_json.get("exp").and_then(|v| v.as_f64()) {
195            Some(exp) => Ok(Value::Number(exp)),
196            None => Ok(Value::Null),
197        }
198    });
199}
200
201fn base64url_encode(data: &[u8]) -> String {
202    use base64::Engine;
203    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
204}
205
206fn base64url_decode(s: &str) -> Result<Vec<u8>, String> {
207    use base64::Engine;
208    base64::engine::general_purpose::URL_SAFE_NO_PAD
209        .decode(s)
210        .map_err(|e| e.to_string())
211}
212
213fn hmac_sign(algorithm: &str, data: &[u8], key: &[u8]) -> Vec<u8> {
214    use hmac::{Hmac, Mac};
215
216    match algorithm {
217        "HS256" => {
218            type HmacSha256 = Hmac<sha2::Sha256>;
219            let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key");
220            mac.update(data);
221            mac.finalize().into_bytes().to_vec()
222        }
223        "HS384" => {
224            type HmacSha384 = Hmac<sha2::Sha384>;
225            let mut mac = HmacSha384::new_from_slice(key).expect("HMAC key");
226            mac.update(data);
227            mac.finalize().into_bytes().to_vec()
228        }
229        "HS512" => {
230            type HmacSha512 = Hmac<sha2::Sha512>;
231            let mut mac = HmacSha512::new_from_slice(key).expect("HMAC key");
232            mac.update(data);
233            mac.finalize().into_bytes().to_vec()
234        }
235        _ => {
236            // Default to HS256
237            type HmacSha256 = Hmac<sha2::Sha256>;
238            let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key");
239            mac.update(data);
240            mac.finalize().into_bytes().to_vec()
241        }
242    }
243}