Skip to main content

bitrouter_core/auth/
token.rs

1//! JWT signing and verification for the BitRouter protocol.
2//!
3//! Supports two web3 wallet signing schemes:
4//!
5//! - **SOL_EDDSA** — Solana-style Ed25519 over raw message bytes.
6//! - **EIP191K** — EVM-style EIP-191 prefixed secp256k1 ECDSA.
7//!
8//! Token format: `base64url(header).base64url(claims).base64url(signature)`
9
10use alloy_primitives::Signature as EvmSignature;
11use base64::Engine;
12use base64::engine::general_purpose::URL_SAFE_NO_PAD;
13use solana_signature::Signature as SolanaSignature;
14
15use crate::auth::JwtError;
16use crate::auth::chain::{Caip10, JwtAlgorithm};
17use crate::auth::claims::BitrouterClaims;
18use crate::auth::keys::MasterKeypair;
19
20/// Sign a set of claims into a JWT string using the master keypair.
21///
22/// The algorithm and signing method are determined by the chain in the claims:
23/// - Solana → SOL_EDDSA (Ed25519 over raw message)
24/// - EVM → EIP191K (EIP-191 prefixed secp256k1 ECDSA)
25pub fn sign(claims: &BitrouterClaims, keypair: &MasterKeypair) -> Result<String, JwtError> {
26    let caip10 = Caip10::parse(&claims.iss)?;
27    let alg = caip10.chain.jwt_algorithm();
28
29    // Reject tokens where the explicit `chain` field contradicts `iss`.
30    let expected_chain = caip10.chain.caip2();
31    if claims.chain != expected_chain {
32        return Err(JwtError::Verification(format!(
33            "chain mismatch: claims.chain is {}, iss implies {}",
34            claims.chain, expected_chain
35        )));
36    }
37
38    let header_b64 = URL_SAFE_NO_PAD.encode(alg.header_json().as_bytes());
39    let payload = serde_json::to_vec(claims).map_err(|e| JwtError::Signing(e.to_string()))?;
40    let payload_b64 = URL_SAFE_NO_PAD.encode(&payload);
41
42    let message = format!("{header_b64}.{payload_b64}");
43
44    let sig_bytes = match alg {
45        JwtAlgorithm::SolEdDsa => keypair.sign_ed25519(message.as_bytes()),
46        JwtAlgorithm::Eip191K => keypair.sign_eip191(message.as_bytes())?,
47    };
48
49    let sig_b64 = URL_SAFE_NO_PAD.encode(&sig_bytes);
50    Ok(format!("{message}.{sig_b64}"))
51}
52
53/// Verify a JWT string and extract the claims.
54///
55/// Determines the algorithm from the JWT header, then:
56/// - SOL_EDDSA: extracts the base58 pubkey from CAIP-10 `iss`, verifies Ed25519.
57/// - EIP191K: recovers the EVM address from the EIP-191 signature, compares
58///   with the address in CAIP-10 `iss`.
59pub fn verify(token: &str) -> Result<BitrouterClaims, JwtError> {
60    let (message, sig_b64) = token
61        .rsplit_once('.')
62        .ok_or_else(|| JwtError::MalformedToken("expected header.payload.signature".into()))?;
63
64    let sig_bytes = URL_SAFE_NO_PAD
65        .decode(sig_b64)
66        .map_err(|e| JwtError::MalformedToken(format!("bad signature encoding: {e}")))?;
67
68    // Decode claims (unverified) to determine chain.
69    let (_, payload_b64) = message
70        .split_once('.')
71        .ok_or_else(|| JwtError::MalformedToken("expected header.payload".into()))?;
72    let payload = URL_SAFE_NO_PAD
73        .decode(payload_b64)
74        .map_err(|e| JwtError::MalformedToken(format!("bad payload encoding: {e}")))?;
75    let claims: BitrouterClaims =
76        serde_json::from_slice(&payload).map_err(|e| JwtError::MalformedToken(e.to_string()))?;
77
78    // Parse algorithm from header.
79    let alg = decode_algorithm(message)?;
80
81    // Parse CAIP-10 identity from iss.
82    let caip10 = Caip10::parse(&claims.iss)?;
83
84    // Verify the algorithm matches the chain.
85    let expected_alg = caip10.chain.jwt_algorithm();
86    if alg != expected_alg {
87        return Err(JwtError::Verification(format!(
88            "algorithm mismatch: header says {alg}, chain expects {expected_alg}"
89        )));
90    }
91
92    // Ensure the CAIP-2 chain in the claims matches the chain implied by iss.
93    let expected_chain = caip10.chain.caip2();
94    if claims.chain != expected_chain {
95        return Err(JwtError::Verification(format!(
96            "chain mismatch: claims.chain is {}, iss implies {}",
97            claims.chain, expected_chain
98        )));
99    }
100
101    match alg {
102        JwtAlgorithm::SolEdDsa => {
103            verify_sol_eddsa(message.as_bytes(), &sig_bytes, &caip10.address)?;
104        }
105        JwtAlgorithm::Eip191K => {
106            verify_eip191k(message.as_bytes(), &sig_bytes, &caip10.address)?;
107        }
108    }
109
110    Ok(claims)
111}
112
113/// Decode a JWT without verifying the signature.
114///
115/// Used to extract claims before verification (e.g., to read `iss` for
116/// account lookup). **Never trust claims from this function without a
117/// subsequent `verify()` call.**
118pub fn decode_unverified(token: &str) -> Result<BitrouterClaims, JwtError> {
119    let parts: Vec<&str> = token.split('.').collect();
120    if parts.len() != 3 {
121        return Err(JwtError::MalformedToken(
122            "expected exactly 3 segments (header.payload.signature)".into(),
123        ));
124    }
125    let payload_b64 = parts[1];
126
127    let payload = URL_SAFE_NO_PAD
128        .decode(payload_b64)
129        .map_err(|e| JwtError::MalformedToken(format!("bad payload encoding: {e}")))?;
130    serde_json::from_slice(&payload).map_err(|e| JwtError::MalformedToken(e.to_string()))
131}
132
133/// Check whether a token's `exp` claim has passed.
134///
135/// Returns `Ok(())` if the token is still valid (or has no `exp`).
136/// Returns `Err(JwtError::Expired)` if the token is expired.
137pub fn check_expiration(claims: &BitrouterClaims) -> Result<(), JwtError> {
138    if let Some(exp) = claims.exp {
139        let now = std::time::SystemTime::now()
140            .duration_since(std::time::UNIX_EPOCH)
141            .map_err(|_| JwtError::Expired)?
142            .as_secs();
143        if now >= exp {
144            return Err(JwtError::Expired);
145        }
146    }
147    Ok(())
148}
149
150// ── internal helpers ──────────────────────────────────────────
151
152/// Extract the algorithm from the JWT header segment.
153fn decode_algorithm(header_dot_payload: &str) -> Result<JwtAlgorithm, JwtError> {
154    let header_b64 = header_dot_payload
155        .split_once('.')
156        .map(|(h, _)| h)
157        .ok_or_else(|| JwtError::MalformedToken("expected header.payload".into()))?;
158
159    let header_bytes = URL_SAFE_NO_PAD
160        .decode(header_b64)
161        .map_err(|e| JwtError::MalformedToken(format!("bad header encoding: {e}")))?;
162
163    #[derive(serde::Deserialize)]
164    struct Header {
165        alg: String,
166    }
167
168    let header: Header = serde_json::from_slice(&header_bytes)
169        .map_err(|e| JwtError::MalformedToken(format!("bad header JSON: {e}")))?;
170
171    JwtAlgorithm::from_header(&header.alg)
172}
173
174/// Verify a SOL_EDDSA (Ed25519) signature.
175fn verify_sol_eddsa(message: &[u8], sig_bytes: &[u8], address_b58: &str) -> Result<(), JwtError> {
176    let pubkey = crate::auth::keys::decode_solana_pubkey(address_b58)?;
177
178    let sig = SolanaSignature::try_from(sig_bytes)
179        .map_err(|_| JwtError::Verification("invalid Ed25519 signature length".into()))?;
180
181    if !sig.verify(pubkey.as_ref(), message) {
182        return Err(JwtError::Verification("invalid Ed25519 signature".into()));
183    }
184
185    Ok(())
186}
187
188/// Verify an EIP191K (EIP-191 + secp256k1) signature.
189///
190/// Recovers the signer address from the EIP-191 prefixed message and
191/// compares it with the expected address from the CAIP-10 `iss`.
192fn verify_eip191k(
193    message: &[u8],
194    sig_bytes: &[u8],
195    expected_address: &str,
196) -> Result<(), JwtError> {
197    let sig = EvmSignature::try_from(sig_bytes)
198        .map_err(|_| JwtError::Verification("invalid secp256k1 signature".into()))?;
199
200    let recovered = sig
201        .recover_address_from_msg(message)
202        .map_err(|e| JwtError::Verification(format!("ecrecover failed: {e}")))?;
203
204    let expected = expected_address
205        .parse::<alloy_primitives::Address>()
206        .map_err(|e| JwtError::InvalidCaip10(format!("invalid EVM address: {e}")))?;
207
208    if recovered != expected {
209        return Err(JwtError::AddressMismatch);
210    }
211
212    Ok(())
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::auth::chain::Chain;
219    use crate::auth::claims::TokenScope;
220    use crate::auth::keys::MasterKeypair;
221
222    fn test_claims_solana(kp: &MasterKeypair) -> BitrouterClaims {
223        let chain = Chain::solana_mainnet();
224        let caip10 = kp.caip10(&chain).expect("caip10");
225        BitrouterClaims {
226            iss: caip10.format(),
227            chain: chain.caip2(),
228            iat: Some(1_700_000_000),
229            exp: None,
230            scope: TokenScope::Api,
231            models: None,
232            budget: None,
233            budget_scope: None,
234            budget_range: None,
235        }
236    }
237
238    fn test_claims_evm(kp: &MasterKeypair) -> BitrouterClaims {
239        let chain = Chain::base();
240        let caip10 = kp.caip10(&chain).expect("caip10");
241        BitrouterClaims {
242            iss: caip10.format(),
243            chain: chain.caip2(),
244            iat: Some(1_700_000_000),
245            exp: None,
246            scope: TokenScope::Api,
247            models: None,
248            budget: None,
249            budget_scope: None,
250            budget_range: None,
251        }
252    }
253
254    #[test]
255    fn sign_and_verify_solana() {
256        let kp = MasterKeypair::generate();
257        let claims = test_claims_solana(&kp);
258        let token = sign(&claims, &kp).expect("sign");
259        let decoded = verify(&token).expect("verify");
260        assert_eq!(decoded.iss, claims.iss);
261        assert_eq!(decoded.scope, TokenScope::Api);
262    }
263
264    #[test]
265    fn sign_and_verify_evm() {
266        let kp = MasterKeypair::generate();
267        let claims = test_claims_evm(&kp);
268        let token = sign(&claims, &kp).expect("sign");
269        let decoded = verify(&token).expect("verify");
270        assert_eq!(decoded.iss, claims.iss);
271        assert_eq!(decoded.scope, TokenScope::Api);
272    }
273
274    #[test]
275    fn verify_rejects_wrong_key_solana() {
276        let kp1 = MasterKeypair::generate();
277        let kp2 = MasterKeypair::generate();
278        let claims = test_claims_solana(&kp1);
279        let token = sign(&claims, &kp1).expect("sign");
280
281        // Tamper: replace iss with kp2's address but keep kp1's signature.
282        let claims2 = test_claims_solana(&kp2);
283        let parts: Vec<&str> = token.split('.').collect();
284        let new_payload_b64 =
285            URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims2).expect("ser").as_slice());
286        let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
287        assert!(verify(&tampered).is_err());
288    }
289
290    #[test]
291    fn verify_rejects_wrong_key_evm() {
292        let kp1 = MasterKeypair::generate();
293        let kp2 = MasterKeypair::generate();
294        let claims = test_claims_evm(&kp1);
295        let token = sign(&claims, &kp1).expect("sign");
296
297        // Tamper: replace iss with kp2's address but keep kp1's signature.
298        let claims2 = test_claims_evm(&kp2);
299        let parts: Vec<&str> = token.split('.').collect();
300        let new_payload_b64 =
301            URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims2).expect("ser").as_slice());
302        let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
303        assert!(verify(&tampered).is_err());
304    }
305
306    #[test]
307    fn decode_unverified_extracts_claims() {
308        let kp = MasterKeypair::generate();
309        let claims = test_claims_solana(&kp);
310        let token = sign(&claims, &kp).expect("sign");
311        let decoded = decode_unverified(&token).expect("decode");
312        assert_eq!(decoded.iss, claims.iss);
313        assert_eq!(decoded.chain, claims.chain);
314    }
315
316    #[test]
317    fn check_expiration_passes_for_future() {
318        let claims = BitrouterClaims {
319            iss: String::new(),
320            chain: String::new(),
321            iat: None,
322            exp: Some(u64::MAX),
323            scope: TokenScope::Api,
324            models: None,
325            budget: None,
326            budget_scope: None,
327            budget_range: None,
328        };
329        check_expiration(&claims).expect("not expired");
330    }
331
332    #[test]
333    fn check_expiration_fails_for_past() {
334        let claims = BitrouterClaims {
335            iss: String::new(),
336            chain: String::new(),
337            iat: None,
338            exp: Some(1),
339            scope: TokenScope::Api,
340            models: None,
341            budget: None,
342            budget_scope: None,
343            budget_range: None,
344        };
345        assert!(check_expiration(&claims).is_err());
346    }
347
348    #[test]
349    fn check_expiration_passes_for_none() {
350        let claims = BitrouterClaims {
351            iss: String::new(),
352            chain: String::new(),
353            iat: None,
354            exp: None,
355            scope: TokenScope::Api,
356            models: None,
357            budget: None,
358            budget_scope: None,
359            budget_range: None,
360        };
361        check_expiration(&claims).expect("no exp means valid");
362    }
363
364    #[test]
365    fn token_has_three_base64url_parts() {
366        let kp = MasterKeypair::generate();
367        let claims = test_claims_solana(&kp);
368        let token = sign(&claims, &kp).expect("sign");
369        let parts: Vec<&str> = token.split('.').collect();
370        assert_eq!(parts.len(), 3);
371    }
372
373    #[test]
374    fn solana_header_is_sol_eddsa() {
375        let kp = MasterKeypair::generate();
376        let claims = test_claims_solana(&kp);
377        let token = sign(&claims, &kp).expect("sign");
378        let header_b64 = token.split('.').next().expect("header");
379        let header = URL_SAFE_NO_PAD.decode(header_b64).expect("decode");
380        let header_str = String::from_utf8(header).expect("utf8");
381        assert!(header_str.contains("SOL_EDDSA"));
382    }
383
384    #[test]
385    fn evm_header_is_eip191k() {
386        let kp = MasterKeypair::generate();
387        let claims = test_claims_evm(&kp);
388        let token = sign(&claims, &kp).expect("sign");
389        let header_b64 = token.split('.').next().expect("header");
390        let header = URL_SAFE_NO_PAD.decode(header_b64).expect("decode");
391        let header_str = String::from_utf8(header).expect("utf8");
392        assert!(header_str.contains("EIP191K"));
393    }
394
395    #[test]
396    fn malformed_token_rejected() {
397        assert!(decode_unverified("not-a-jwt").is_err());
398        assert!(decode_unverified("a.b.c.d").is_err());
399    }
400
401    #[test]
402    fn sign_rejects_chain_mismatch() {
403        let kp = MasterKeypair::generate();
404        let sol_chain = Chain::solana_mainnet();
405        let caip10 = kp.caip10(&sol_chain).expect("caip10");
406        // iss is Solana but chain field claims EVM.
407        let bad_claims = BitrouterClaims {
408            iss: caip10.format(),
409            chain: Chain::base().caip2(),
410            iat: None,
411            exp: None,
412            scope: TokenScope::Api,
413            models: None,
414            budget: None,
415            budget_scope: None,
416            budget_range: None,
417        };
418        assert!(sign(&bad_claims, &kp).is_err());
419    }
420
421    #[test]
422    fn verify_rejects_chain_mismatch_in_payload() {
423        let kp = MasterKeypair::generate();
424        // Sign a valid Solana token, then tamper the chain field in the payload.
425        let claims = test_claims_solana(&kp);
426        let token = sign(&claims, &kp).expect("sign");
427
428        let parts: Vec<&str> = token.split('.').collect();
429        // Replace chain with EVM chain while keeping Solana iss.
430        let mut tampered_claims = claims.clone();
431        tampered_claims.chain = Chain::base().caip2();
432        let new_payload_b64 = URL_SAFE_NO_PAD.encode(
433            serde_json::to_vec(&tampered_claims)
434                .expect("ser")
435                .as_slice(),
436        );
437        let tampered = format!("{}.{}.{}", parts[0], new_payload_b64, parts[2]);
438        assert!(verify(&tampered).is_err());
439    }
440}