1use alloy_primitives::Signature as EvmSignature;
11use base64::Engine;
12use base64::engine::general_purpose::URL_SAFE_NO_PAD;
13use solana_signature::Signature as SolanaSignature;
14
15use crate::auth::JwtError;
16use crate::auth::chain::{Caip10, JwtAlgorithm};
17use crate::auth::claims::BitrouterClaims;
18use crate::auth::keys::MasterKeypair;
19
20pub fn sign(claims: &BitrouterClaims, keypair: &MasterKeypair) -> Result<String, JwtError> {
26 let caip10 = Caip10::parse(&claims.iss)?;
27 let alg = caip10.chain.jwt_algorithm();
28
29 let expected_chain = caip10.chain.caip2();
31 if claims.chain != expected_chain {
32 return Err(JwtError::Verification(format!(
33 "chain mismatch: claims.chain is {}, iss implies {}",
34 claims.chain, expected_chain
35 )));
36 }
37
38 let header_b64 = URL_SAFE_NO_PAD.encode(alg.header_json().as_bytes());
39 let payload = serde_json::to_vec(claims).map_err(|e| JwtError::Signing(e.to_string()))?;
40 let payload_b64 = URL_SAFE_NO_PAD.encode(&payload);
41
42 let message = format!("{header_b64}.{payload_b64}");
43
44 let sig_bytes = match alg {
45 JwtAlgorithm::SolEdDsa => keypair.sign_ed25519(message.as_bytes()),
46 JwtAlgorithm::Eip191K => keypair.sign_eip191(message.as_bytes())?,
47 };
48
49 let sig_b64 = URL_SAFE_NO_PAD.encode(&sig_bytes);
50 Ok(format!("{message}.{sig_b64}"))
51}
52
53pub fn verify(token: &str) -> Result<BitrouterClaims, JwtError> {
60 let (message, sig_b64) = token
61 .rsplit_once('.')
62 .ok_or_else(|| JwtError::MalformedToken("expected header.payload.signature".into()))?;
63
64 let sig_bytes = URL_SAFE_NO_PAD
65 .decode(sig_b64)
66 .map_err(|e| JwtError::MalformedToken(format!("bad signature encoding: {e}")))?;
67
68 let (_, payload_b64) = message
70 .split_once('.')
71 .ok_or_else(|| JwtError::MalformedToken("expected header.payload".into()))?;
72 let payload = URL_SAFE_NO_PAD
73 .decode(payload_b64)
74 .map_err(|e| JwtError::MalformedToken(format!("bad payload encoding: {e}")))?;
75 let claims: BitrouterClaims =
76 serde_json::from_slice(&payload).map_err(|e| JwtError::MalformedToken(e.to_string()))?;
77
78 let alg = decode_algorithm(message)?;
80
81 let caip10 = Caip10::parse(&claims.iss)?;
83
84 let expected_alg = caip10.chain.jwt_algorithm();
86 if alg != expected_alg {
87 return Err(JwtError::Verification(format!(
88 "algorithm mismatch: header says {alg}, chain expects {expected_alg}"
89 )));
90 }
91
92 let expected_chain = caip10.chain.caip2();
94 if claims.chain != expected_chain {
95 return Err(JwtError::Verification(format!(
96 "chain mismatch: claims.chain is {}, iss implies {}",
97 claims.chain, expected_chain
98 )));
99 }
100
101 match alg {
102 JwtAlgorithm::SolEdDsa => {
103 verify_sol_eddsa(message.as_bytes(), &sig_bytes, &caip10.address)?;
104 }
105 JwtAlgorithm::Eip191K => {
106 verify_eip191k(message.as_bytes(), &sig_bytes, &caip10.address)?;
107 }
108 }
109
110 Ok(claims)
111}
112
113pub fn decode_unverified(token: &str) -> Result<BitrouterClaims, JwtError> {
119 let parts: Vec<&str> = token.split('.').collect();
120 if parts.len() != 3 {
121 return Err(JwtError::MalformedToken(
122 "expected exactly 3 segments (header.payload.signature)".into(),
123 ));
124 }
125 let payload_b64 = parts[1];
126
127 let payload = URL_SAFE_NO_PAD
128 .decode(payload_b64)
129 .map_err(|e| JwtError::MalformedToken(format!("bad payload encoding: {e}")))?;
130 serde_json::from_slice(&payload).map_err(|e| JwtError::MalformedToken(e.to_string()))
131}
132
133pub fn check_expiration(claims: &BitrouterClaims) -> Result<(), JwtError> {
138 if let Some(exp) = claims.exp {
139 let now = std::time::SystemTime::now()
140 .duration_since(std::time::UNIX_EPOCH)
141 .map_err(|_| JwtError::Expired)?
142 .as_secs();
143 if now >= exp {
144 return Err(JwtError::Expired);
145 }
146 }
147 Ok(())
148}
149
150fn decode_algorithm(header_dot_payload: &str) -> Result<JwtAlgorithm, JwtError> {
154 let header_b64 = header_dot_payload
155 .split_once('.')
156 .map(|(h, _)| h)
157 .ok_or_else(|| JwtError::MalformedToken("expected header.payload".into()))?;
158
159 let header_bytes = URL_SAFE_NO_PAD
160 .decode(header_b64)
161 .map_err(|e| JwtError::MalformedToken(format!("bad header encoding: {e}")))?;
162
163 #[derive(serde::Deserialize)]
164 struct Header {
165 alg: String,
166 }
167
168 let header: Header = serde_json::from_slice(&header_bytes)
169 .map_err(|e| JwtError::MalformedToken(format!("bad header JSON: {e}")))?;
170
171 JwtAlgorithm::from_header(&header.alg)
172}
173
174fn verify_sol_eddsa(message: &[u8], sig_bytes: &[u8], address_b58: &str) -> Result<(), JwtError> {
176 let pubkey = crate::auth::keys::decode_solana_pubkey(address_b58)?;
177
178 let sig = SolanaSignature::try_from(sig_bytes)
179 .map_err(|_| JwtError::Verification("invalid Ed25519 signature length".into()))?;
180
181 if !sig.verify(pubkey.as_ref(), message) {
182 return Err(JwtError::Verification("invalid Ed25519 signature".into()));
183 }
184
185 Ok(())
186}
187
188fn verify_eip191k(
193 message: &[u8],
194 sig_bytes: &[u8],
195 expected_address: &str,
196) -> Result<(), JwtError> {
197 let sig = EvmSignature::try_from(sig_bytes)
198 .map_err(|_| JwtError::Verification("invalid secp256k1 signature".into()))?;
199
200 let recovered = sig
201 .recover_address_from_msg(message)
202 .map_err(|e| JwtError::Verification(format!("ecrecover failed: {e}")))?;
203
204 let expected = expected_address
205 .parse::<alloy_primitives::Address>()
206 .map_err(|e| JwtError::InvalidCaip10(format!("invalid EVM address: {e}")))?;
207
208 if recovered != expected {
209 return Err(JwtError::AddressMismatch);
210 }
211
212 Ok(())
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::auth::chain::Chain;
219 use crate::auth::claims::TokenScope;
220 use crate::auth::keys::MasterKeypair;
221
222 fn test_claims_solana(kp: &MasterKeypair) -> BitrouterClaims {
223 let chain = Chain::solana_mainnet();
224 let caip10 = kp.caip10(&chain).expect("caip10");
225 BitrouterClaims {
226 iss: caip10.format(),
227 chain: chain.caip2(),
228 iat: Some(1_700_000_000),
229 exp: None,
230 scope: TokenScope::Api,
231 models: None,
232 budget: None,
233 budget_scope: None,
234 budget_range: None,
235 }
236 }
237
238 fn test_claims_evm(kp: &MasterKeypair) -> BitrouterClaims {
239 let chain = Chain::base();
240 let caip10 = kp.caip10(&chain).expect("caip10");
241 BitrouterClaims {
242 iss: caip10.format(),
243 chain: chain.caip2(),
244 iat: Some(1_700_000_000),
245 exp: None,
246 scope: TokenScope::Api,
247 models: None,
248 budget: None,
249 budget_scope: None,
250 budget_range: None,
251 }
252 }
253
254 #[test]
255 fn sign_and_verify_solana() {
256 let kp = MasterKeypair::generate();
257 let claims = test_claims_solana(&kp);
258 let token = sign(&claims, &kp).expect("sign");
259 let decoded = verify(&token).expect("verify");
260 assert_eq!(decoded.iss, claims.iss);
261 assert_eq!(decoded.scope, TokenScope::Api);
262 }
263
264 #[test]
265 fn sign_and_verify_evm() {
266 let kp = MasterKeypair::generate();
267 let claims = test_claims_evm(&kp);
268 let token = sign(&claims, &kp).expect("sign");
269 let decoded = verify(&token).expect("verify");
270 assert_eq!(decoded.iss, claims.iss);
271 assert_eq!(decoded.scope, TokenScope::Api);
272 }
273
274 #[test]
275 fn verify_rejects_wrong_key_solana() {
276 let kp1 = MasterKeypair::generate();
277 let kp2 = MasterKeypair::generate();
278 let claims = test_claims_solana(&kp1);
279 let token = sign(&claims, &kp1).expect("sign");
280
281 let claims2 = test_claims_solana(&kp2);
283 let parts: Vec<&str> = token.split('.').collect();
284 let new_payload_b64 =
285 URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims2).expect("ser").as_slice());
286 let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
287 assert!(verify(&tampered).is_err());
288 }
289
290 #[test]
291 fn verify_rejects_wrong_key_evm() {
292 let kp1 = MasterKeypair::generate();
293 let kp2 = MasterKeypair::generate();
294 let claims = test_claims_evm(&kp1);
295 let token = sign(&claims, &kp1).expect("sign");
296
297 let claims2 = test_claims_evm(&kp2);
299 let parts: Vec<&str> = token.split('.').collect();
300 let new_payload_b64 =
301 URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims2).expect("ser").as_slice());
302 let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
303 assert!(verify(&tampered).is_err());
304 }
305
306 #[test]
307 fn decode_unverified_extracts_claims() {
308 let kp = MasterKeypair::generate();
309 let claims = test_claims_solana(&kp);
310 let token = sign(&claims, &kp).expect("sign");
311 let decoded = decode_unverified(&token).expect("decode");
312 assert_eq!(decoded.iss, claims.iss);
313 assert_eq!(decoded.chain, claims.chain);
314 }
315
316 #[test]
317 fn check_expiration_passes_for_future() {
318 let claims = BitrouterClaims {
319 iss: String::new(),
320 chain: String::new(),
321 iat: None,
322 exp: Some(u64::MAX),
323 scope: TokenScope::Api,
324 models: None,
325 budget: None,
326 budget_scope: None,
327 budget_range: None,
328 };
329 check_expiration(&claims).expect("not expired");
330 }
331
332 #[test]
333 fn check_expiration_fails_for_past() {
334 let claims = BitrouterClaims {
335 iss: String::new(),
336 chain: String::new(),
337 iat: None,
338 exp: Some(1),
339 scope: TokenScope::Api,
340 models: None,
341 budget: None,
342 budget_scope: None,
343 budget_range: None,
344 };
345 assert!(check_expiration(&claims).is_err());
346 }
347
348 #[test]
349 fn check_expiration_passes_for_none() {
350 let claims = BitrouterClaims {
351 iss: String::new(),
352 chain: String::new(),
353 iat: None,
354 exp: None,
355 scope: TokenScope::Api,
356 models: None,
357 budget: None,
358 budget_scope: None,
359 budget_range: None,
360 };
361 check_expiration(&claims).expect("no exp means valid");
362 }
363
364 #[test]
365 fn token_has_three_base64url_parts() {
366 let kp = MasterKeypair::generate();
367 let claims = test_claims_solana(&kp);
368 let token = sign(&claims, &kp).expect("sign");
369 let parts: Vec<&str> = token.split('.').collect();
370 assert_eq!(parts.len(), 3);
371 }
372
373 #[test]
374 fn solana_header_is_sol_eddsa() {
375 let kp = MasterKeypair::generate();
376 let claims = test_claims_solana(&kp);
377 let token = sign(&claims, &kp).expect("sign");
378 let header_b64 = token.split('.').next().expect("header");
379 let header = URL_SAFE_NO_PAD.decode(header_b64).expect("decode");
380 let header_str = String::from_utf8(header).expect("utf8");
381 assert!(header_str.contains("SOL_EDDSA"));
382 }
383
384 #[test]
385 fn evm_header_is_eip191k() {
386 let kp = MasterKeypair::generate();
387 let claims = test_claims_evm(&kp);
388 let token = sign(&claims, &kp).expect("sign");
389 let header_b64 = token.split('.').next().expect("header");
390 let header = URL_SAFE_NO_PAD.decode(header_b64).expect("decode");
391 let header_str = String::from_utf8(header).expect("utf8");
392 assert!(header_str.contains("EIP191K"));
393 }
394
395 #[test]
396 fn malformed_token_rejected() {
397 assert!(decode_unverified("not-a-jwt").is_err());
398 assert!(decode_unverified("a.b.c.d").is_err());
399 }
400
401 #[test]
402 fn sign_rejects_chain_mismatch() {
403 let kp = MasterKeypair::generate();
404 let sol_chain = Chain::solana_mainnet();
405 let caip10 = kp.caip10(&sol_chain).expect("caip10");
406 let bad_claims = BitrouterClaims {
408 iss: caip10.format(),
409 chain: Chain::base().caip2(),
410 iat: None,
411 exp: None,
412 scope: TokenScope::Api,
413 models: None,
414 budget: None,
415 budget_scope: None,
416 budget_range: None,
417 };
418 assert!(sign(&bad_claims, &kp).is_err());
419 }
420
421 #[test]
422 fn verify_rejects_chain_mismatch_in_payload() {
423 let kp = MasterKeypair::generate();
424 let claims = test_claims_solana(&kp);
426 let token = sign(&claims, &kp).expect("sign");
427
428 let parts: Vec<&str> = token.split('.').collect();
429 let mut tampered_claims = claims.clone();
431 tampered_claims.chain = Chain::base().caip2();
432 let new_payload_b64 = URL_SAFE_NO_PAD.encode(
433 serde_json::to_vec(&tampered_claims)
434 .expect("ser")
435 .as_slice(),
436 );
437 let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
438 assert!(verify(&tampered).is_err());
439 }
440}