Skip to main content

aptos_sdk/account/
keyless.rs

1//! Keyless (OIDC-based) account support.
2
3use crate::account::account::{Account, AuthenticationKey};
4use crate::crypto::{Ed25519PrivateKey, Ed25519PublicKey, KEYLESS_SCHEME};
5use crate::error::{AptosError, AptosResult};
6use crate::types::AccountAddress;
7use async_trait::async_trait;
8use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
9use rand::RngCore;
10use serde::{Deserialize, Serialize};
11use sha3::{Digest, Sha3_256};
12use std::fmt;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use url::Url;
15
16// Re-export JwkSet for use with from_jwt_with_jwks and refresh_proof_with_jwks
17pub use jsonwebtoken::jwk::JwkSet;
18
19/// Keyless signature payload for transaction authentication.
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct KeylessSignature {
22    /// Ephemeral public key bytes.
23    pub ephemeral_public_key: Vec<u8>,
24    /// Signature produced by the ephemeral key.
25    pub ephemeral_signature: Vec<u8>,
26    /// Zero-knowledge proof bytes.
27    pub proof: Vec<u8>,
28}
29
30impl KeylessSignature {
31    /// Serializes the signature using BCS.
32    ///
33    /// # Errors
34    ///
35    /// Returns an error if BCS serialization fails.
36    pub fn to_bcs(&self) -> AptosResult<Vec<u8>> {
37        aptos_bcs::to_bytes(self).map_err(AptosError::bcs)
38    }
39}
40
41/// Short-lived key pair used for keyless signing.
42#[derive(Clone)]
43pub struct EphemeralKeyPair {
44    private_key: Ed25519PrivateKey,
45    public_key: Ed25519PublicKey,
46    expiry: SystemTime,
47    nonce: String,
48}
49
50impl EphemeralKeyPair {
51    /// Generates a new ephemeral key pair with the given expiry (in seconds).
52    pub fn generate(expiry_secs: u64) -> Self {
53        let private_key = Ed25519PrivateKey::generate();
54        let public_key = private_key.public_key();
55        let nonce = {
56            let mut bytes = [0u8; 16];
57            rand::rngs::OsRng.fill_bytes(&mut bytes);
58            hex::encode(bytes)
59        };
60        Self {
61            private_key,
62            public_key,
63            expiry: SystemTime::now() + Duration::from_secs(expiry_secs),
64            nonce,
65        }
66    }
67
68    /// Returns true if the key pair has expired.
69    pub fn is_expired(&self) -> bool {
70        SystemTime::now() >= self.expiry
71    }
72
73    /// Returns the nonce associated with this key pair.
74    pub fn nonce(&self) -> &str {
75        &self.nonce
76    }
77
78    /// Returns the public key.
79    pub fn public_key(&self) -> &Ed25519PublicKey {
80        &self.public_key
81    }
82}
83
84impl fmt::Debug for EphemeralKeyPair {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_struct("EphemeralKeyPair")
87            .field("public_key", &self.public_key)
88            .field("expiry", &self.expiry)
89            .field("nonce", &self.nonce)
90            .finish_non_exhaustive()
91    }
92}
93
94/// Supported OIDC providers.
95#[derive(Clone, Debug, PartialEq, Eq)]
96pub enum OidcProvider {
97    /// Google identity provider.
98    Google,
99    /// Apple identity provider.
100    Apple,
101    /// Microsoft identity provider.
102    Microsoft,
103    /// Custom OIDC provider.
104    Custom {
105        /// Issuer URL.
106        issuer: String,
107        /// JWKS URL.
108        jwks_url: String,
109    },
110}
111
112impl OidcProvider {
113    /// Returns the issuer URL.
114    pub fn issuer(&self) -> &str {
115        match self {
116            OidcProvider::Google => "https://accounts.google.com",
117            OidcProvider::Apple => "https://appleid.apple.com",
118            OidcProvider::Microsoft => "https://login.microsoftonline.com/common/v2.0",
119            OidcProvider::Custom { issuer, .. } => issuer,
120        }
121    }
122
123    /// Returns the JWKS URL.
124    pub fn jwks_url(&self) -> &str {
125        match self {
126            OidcProvider::Google => "https://www.googleapis.com/oauth2/v3/certs",
127            OidcProvider::Apple => "https://appleid.apple.com/auth/keys",
128            OidcProvider::Microsoft => {
129                "https://login.microsoftonline.com/common/discovery/v2.0/keys"
130            }
131            OidcProvider::Custom { jwks_url, .. } => jwks_url,
132        }
133    }
134
135    /// Infers a provider from an issuer URL.
136    pub fn from_issuer(issuer: &str) -> Self {
137        match issuer {
138            "https://accounts.google.com" => OidcProvider::Google,
139            "https://appleid.apple.com" => OidcProvider::Apple,
140            "https://login.microsoftonline.com/common/v2.0" => OidcProvider::Microsoft,
141            _ => OidcProvider::Custom {
142                issuer: issuer.to_string(),
143                jwks_url: format!("{issuer}/.well-known/jwks.json"),
144            },
145        }
146    }
147}
148
149/// Pepper bytes used in keyless address derivation.
150#[derive(Clone, Debug, PartialEq, Eq)]
151pub struct Pepper(Vec<u8>);
152
153impl Pepper {
154    /// Creates a new pepper from raw bytes.
155    pub fn new(bytes: Vec<u8>) -> Self {
156        Self(bytes)
157    }
158
159    /// Returns the pepper as bytes.
160    pub fn as_bytes(&self) -> &[u8] {
161        &self.0
162    }
163
164    /// Creates a pepper from hex.
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if the hex string is invalid or cannot be decoded.
169    pub fn from_hex(hex_str: &str) -> AptosResult<Self> {
170        let hex_str = hex_str.strip_prefix("0x").unwrap_or(hex_str);
171        Ok(Self(hex::decode(hex_str)?))
172    }
173
174    /// Returns the pepper as hex.
175    pub fn to_hex(&self) -> String {
176        format!("0x{}", hex::encode(&self.0))
177    }
178}
179
180/// Zero-knowledge proof bytes.
181#[derive(Clone, Debug, PartialEq, Eq)]
182pub struct ZkProof(Vec<u8>);
183
184impl ZkProof {
185    /// Creates a new proof from raw bytes.
186    pub fn new(bytes: Vec<u8>) -> Self {
187        Self(bytes)
188    }
189
190    /// Returns the proof as bytes.
191    pub fn as_bytes(&self) -> &[u8] {
192        &self.0
193    }
194
195    /// Creates a proof from hex.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the hex string is invalid or cannot be decoded.
200    pub fn from_hex(hex_str: &str) -> AptosResult<Self> {
201        let hex_str = hex_str.strip_prefix("0x").unwrap_or(hex_str);
202        Ok(Self(hex::decode(hex_str)?))
203    }
204
205    /// Returns the proof as hex.
206    pub fn to_hex(&self) -> String {
207        format!("0x{}", hex::encode(&self.0))
208    }
209}
210
211/// Service for obtaining pepper values.
212#[async_trait]
213pub trait PepperService: Send + Sync {
214    /// Fetches the pepper for a JWT.
215    async fn get_pepper(&self, jwt: &str) -> AptosResult<Pepper>;
216}
217
218/// Service for generating zero-knowledge proofs.
219#[async_trait]
220pub trait ProverService: Send + Sync {
221    /// Generates the proof for keyless authentication.
222    async fn generate_proof(
223        &self,
224        jwt: &str,
225        ephemeral_key: &EphemeralKeyPair,
226        pepper: &Pepper,
227    ) -> AptosResult<ZkProof>;
228}
229
230/// HTTP pepper service client.
231#[derive(Clone, Debug)]
232pub struct HttpPepperService {
233    url: Url,
234    client: reqwest::Client,
235}
236
237impl HttpPepperService {
238    /// Creates a new HTTP pepper service client.
239    pub fn new(url: Url) -> Self {
240        Self {
241            url,
242            client: reqwest::Client::new(),
243        }
244    }
245}
246
247#[derive(Serialize)]
248struct PepperRequest<'a> {
249    jwt: &'a str,
250}
251
252#[derive(Deserialize)]
253struct PepperResponse {
254    pepper: String,
255}
256
257#[async_trait]
258impl PepperService for HttpPepperService {
259    async fn get_pepper(&self, jwt: &str) -> AptosResult<Pepper> {
260        let response = self
261            .client
262            .post(self.url.clone())
263            .json(&PepperRequest { jwt })
264            .send()
265            .await?
266            .error_for_status()?;
267
268        let payload: PepperResponse = response.json().await?;
269        Pepper::from_hex(&payload.pepper)
270    }
271}
272
273/// HTTP prover service client.
274#[derive(Clone, Debug)]
275pub struct HttpProverService {
276    url: Url,
277    client: reqwest::Client,
278}
279
280impl HttpProverService {
281    /// Creates a new HTTP prover service client.
282    pub fn new(url: Url) -> Self {
283        Self {
284            url,
285            client: reqwest::Client::new(),
286        }
287    }
288}
289
290#[derive(Serialize)]
291struct ProverRequest<'a> {
292    jwt: &'a str,
293    ephemeral_public_key: String,
294    nonce: &'a str,
295    pepper: String,
296}
297
298#[derive(Deserialize)]
299struct ProverResponse {
300    proof: String,
301}
302
303#[async_trait]
304impl ProverService for HttpProverService {
305    async fn generate_proof(
306        &self,
307        jwt: &str,
308        ephemeral_key: &EphemeralKeyPair,
309        pepper: &Pepper,
310    ) -> AptosResult<ZkProof> {
311        let request = ProverRequest {
312            jwt,
313            ephemeral_public_key: format!("0x{}", hex::encode(ephemeral_key.public_key.to_bytes())),
314            nonce: ephemeral_key.nonce(),
315            pepper: pepper.to_hex(),
316        };
317
318        let response = self
319            .client
320            .post(self.url.clone())
321            .json(&request)
322            .send()
323            .await?
324            .error_for_status()?;
325
326        let payload: ProverResponse = response.json().await?;
327        ZkProof::from_hex(&payload.proof)
328    }
329}
330
331/// Account authenticated via OIDC.
332pub struct KeylessAccount {
333    ephemeral_key: EphemeralKeyPair,
334    provider: OidcProvider,
335    issuer: String,
336    audience: String,
337    user_id: String,
338    pepper: Pepper,
339    proof: ZkProof,
340    address: AccountAddress,
341    auth_key: AuthenticationKey,
342    jwt_expiration: Option<SystemTime>,
343}
344
345impl KeylessAccount {
346    /// Creates a keyless account from an OIDC JWT token.
347    ///
348    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint
349    /// before extracting claims and creating the account.
350    ///
351    /// # Network Requests
352    ///
353    /// This method makes HTTP requests to:
354    /// - The OIDC provider's JWKS endpoint to fetch signing keys
355    /// - The pepper service to obtain the pepper
356    /// - The prover service to generate a ZK proof
357    ///
358    /// For more control over network calls and caching, use [`Self::from_jwt_with_jwks`]
359    /// with pre-fetched JWKS.
360    ///
361    /// # Errors
362    ///
363    /// This function will return an error if:
364    /// - The JWT signature verification fails
365    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
366    /// - The JWT nonce doesn't match the ephemeral key's nonce
367    /// - The JWT is expired
368    /// - The JWKS cannot be fetched from the provider (network timeout, DNS failure,
369    ///   connection errors, HTTP errors, or invalid JWKS response)
370    /// - The pepper service fails to return a pepper
371    /// - The prover service fails to generate a proof
372    pub async fn from_jwt(
373        jwt: &str,
374        ephemeral_key: EphemeralKeyPair,
375        pepper_service: &dyn PepperService,
376        prover_service: &dyn ProverService,
377    ) -> AptosResult<Self> {
378        // First, decode without verification to get the issuer for JWKS lookup
379        let unverified_claims = decode_claims_unverified(jwt)?;
380        let issuer = unverified_claims
381            .iss
382            .as_ref()
383            .ok_or_else(|| AptosError::InvalidJwt("missing iss claim".into()))?;
384
385        // Determine provider and fetch JWKS
386        let provider = OidcProvider::from_issuer(issuer);
387        let client = reqwest::Client::builder()
388            .timeout(JWKS_FETCH_TIMEOUT)
389            .build()
390            .map_err(|e| AptosError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
391        let jwks = fetch_jwks(&client, provider.jwks_url()).await?;
392
393        // Now verify and decode the JWT properly
394        let claims = decode_and_verify_jwt(jwt, &jwks)?;
395        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
396
397        if nonce != ephemeral_key.nonce() {
398            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
399        }
400
401        let pepper = pepper_service.get_pepper(jwt).await?;
402        let proof = prover_service
403            .generate_proof(jwt, &ephemeral_key, &pepper)
404            .await?;
405
406        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
407        let auth_key = AuthenticationKey::new(address.to_bytes());
408
409        Ok(Self {
410            provider: OidcProvider::from_issuer(&issuer),
411            issuer,
412            audience,
413            user_id,
414            pepper,
415            proof,
416            address,
417            auth_key,
418            jwt_expiration: exp,
419            ephemeral_key,
420        })
421    }
422
423    /// Creates a keyless account from a JWT with pre-fetched JWKS.
424    ///
425    /// This method is useful when you want to:
426    /// - Cache the JWKS to avoid repeated network requests
427    /// - Have more control over HTTP client configuration
428    /// - Implement custom caching strategies based on HTTP cache headers
429    ///
430    /// # Errors
431    ///
432    /// This function will return an error if:
433    /// - The JWT signature verification fails
434    /// - The JWT cannot be decoded or is missing required claims (iss, aud, sub, nonce)
435    /// - The JWT nonce doesn't match the ephemeral key's nonce
436    /// - The JWT is expired
437    /// - The pepper service fails to return a pepper
438    /// - The prover service fails to generate a proof
439    pub async fn from_jwt_with_jwks(
440        jwt: &str,
441        jwks: &JwkSet,
442        ephemeral_key: EphemeralKeyPair,
443        pepper_service: &dyn PepperService,
444        prover_service: &dyn ProverService,
445    ) -> AptosResult<Self> {
446        // Verify and decode the JWT using the provided JWKS
447        let claims = decode_and_verify_jwt(jwt, jwks)?;
448        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
449
450        if nonce != ephemeral_key.nonce() {
451            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
452        }
453
454        let pepper = pepper_service.get_pepper(jwt).await?;
455        let proof = prover_service
456            .generate_proof(jwt, &ephemeral_key, &pepper)
457            .await?;
458
459        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
460        let auth_key = AuthenticationKey::new(address.to_bytes());
461
462        Ok(Self {
463            provider: OidcProvider::from_issuer(&issuer),
464            issuer,
465            audience,
466            user_id,
467            pepper,
468            proof,
469            address,
470            auth_key,
471            jwt_expiration: exp,
472            ephemeral_key,
473        })
474    }
475
476    /// Returns the OIDC provider.
477    pub fn provider(&self) -> &OidcProvider {
478        &self.provider
479    }
480
481    /// Returns the issuer.
482    pub fn issuer(&self) -> &str {
483        &self.issuer
484    }
485
486    /// Returns the audience.
487    pub fn audience(&self) -> &str {
488        &self.audience
489    }
490
491    /// Returns the user identifier (sub claim).
492    pub fn user_id(&self) -> &str {
493        &self.user_id
494    }
495
496    /// Returns the proof.
497    pub fn proof(&self) -> &ZkProof {
498        &self.proof
499    }
500
501    /// Returns true if the account is still valid.
502    pub fn is_valid(&self) -> bool {
503        if self.ephemeral_key.is_expired() {
504            return false;
505        }
506
507        match self.jwt_expiration {
508            Some(exp) => SystemTime::now() < exp,
509            None => true,
510        }
511    }
512
513    /// Refreshes the proof using a new JWT.
514    ///
515    /// This method verifies the JWT signature using the OIDC provider's JWKS endpoint.
516    ///
517    /// # Network Requests
518    ///
519    /// This method makes HTTP requests to fetch the JWKS from the OIDC provider.
520    /// For more control over network calls and caching, use [`Self::refresh_proof_with_jwks`].
521    ///
522    /// # Errors
523    ///
524    /// Returns an error if:
525    /// - The JWKS cannot be fetched (network timeout, DNS failure, connection errors)
526    /// - The JWT signature verification fails
527    /// - The JWT cannot be decoded
528    /// - The JWT nonce does not match the ephemeral key
529    /// - The JWT identity does not match the account
530    /// - The prover service fails to generate a new proof
531    pub async fn refresh_proof(
532        &mut self,
533        jwt: &str,
534        prover_service: &dyn ProverService,
535    ) -> AptosResult<()> {
536        // Fetch JWKS and verify JWT
537        let client = reqwest::Client::builder()
538            .timeout(JWKS_FETCH_TIMEOUT)
539            .build()
540            .map_err(|e| AptosError::InvalidJwt(format!("failed to create HTTP client: {e}")))?;
541        let jwks = fetch_jwks(&client, self.provider.jwks_url()).await?;
542        self.refresh_proof_with_jwks(jwt, &jwks, prover_service)
543            .await
544    }
545
546    /// Refreshes the proof using a new JWT with pre-fetched JWKS.
547    ///
548    /// This method is useful for caching the JWKS or using a custom HTTP client.
549    ///
550    /// # Errors
551    ///
552    /// Returns an error if:
553    /// - The JWT signature verification fails
554    /// - The JWT cannot be decoded
555    /// - The JWT nonce does not match the ephemeral key
556    /// - The JWT identity does not match the account
557    /// - The prover service fails to generate a new proof
558    pub async fn refresh_proof_with_jwks(
559        &mut self,
560        jwt: &str,
561        jwks: &JwkSet,
562        prover_service: &dyn ProverService,
563    ) -> AptosResult<()> {
564        let claims = decode_and_verify_jwt(jwt, jwks)?;
565        let (issuer, audience, user_id, exp, nonce) = extract_claims(&claims)?;
566
567        if nonce != self.ephemeral_key.nonce() {
568            return Err(AptosError::InvalidJwt("JWT nonce mismatch".into()));
569        }
570
571        if issuer != self.issuer || audience != self.audience || user_id != self.user_id {
572            return Err(AptosError::InvalidJwt(
573                "JWT identity does not match account".into(),
574            ));
575        }
576
577        let proof = prover_service
578            .generate_proof(jwt, &self.ephemeral_key, &self.pepper)
579            .await?;
580        self.proof = proof;
581        self.jwt_expiration = exp;
582        Ok(())
583    }
584
585    /// Signs a message and returns the structured keyless signature.
586    pub fn sign_keyless(&self, message: &[u8]) -> KeylessSignature {
587        let signature = self.ephemeral_key.private_key.sign(message).to_bytes();
588        KeylessSignature {
589            ephemeral_public_key: self.ephemeral_key.public_key.to_bytes().to_vec(),
590            ephemeral_signature: signature.to_vec(),
591            proof: self.proof.as_bytes().to_vec(),
592        }
593    }
594
595    /// Creates a keyless account from pre-verified JWT claims.
596    ///
597    /// This is useful for testing or when JWT verification is handled externally.
598    /// The caller is responsible for ensuring the JWT was properly verified.
599    ///
600    /// # Errors
601    ///
602    /// This function will return an error if:
603    /// - The nonce doesn't match the ephemeral key's nonce
604    /// - The pepper service fails to return a pepper
605    /// - The prover service fails to generate a proof
606    #[doc(hidden)]
607    #[allow(clippy::too_many_arguments)]
608    pub async fn from_verified_claims(
609        issuer: String,
610        audience: String,
611        user_id: String,
612        nonce: String,
613        exp: Option<SystemTime>,
614        ephemeral_key: EphemeralKeyPair,
615        pepper_service: &dyn PepperService,
616        prover_service: &dyn ProverService,
617        jwt_for_services: &str,
618    ) -> AptosResult<Self> {
619        if nonce != ephemeral_key.nonce() {
620            return Err(AptosError::InvalidJwt("nonce mismatch".into()));
621        }
622
623        let pepper = pepper_service.get_pepper(jwt_for_services).await?;
624        let proof = prover_service
625            .generate_proof(jwt_for_services, &ephemeral_key, &pepper)
626            .await?;
627
628        let address = derive_keyless_address(&issuer, &audience, &user_id, &pepper);
629        let auth_key = AuthenticationKey::new(address.to_bytes());
630
631        Ok(Self {
632            provider: OidcProvider::from_issuer(&issuer),
633            issuer,
634            audience,
635            user_id,
636            pepper,
637            proof,
638            address,
639            auth_key,
640            jwt_expiration: exp,
641            ephemeral_key,
642        })
643    }
644}
645
646impl Account for KeylessAccount {
647    fn address(&self) -> AccountAddress {
648        self.address
649    }
650
651    fn authentication_key(&self) -> AuthenticationKey {
652        self.auth_key
653    }
654
655    fn sign(&self, message: &[u8]) -> crate::error::AptosResult<Vec<u8>> {
656        let signature = self.sign_keyless(message);
657        signature
658            .to_bcs()
659            .map_err(|e| crate::error::AptosError::Bcs(e.to_string()))
660    }
661
662    fn public_key_bytes(&self) -> Vec<u8> {
663        self.ephemeral_key.public_key.to_bytes().to_vec()
664    }
665
666    fn signature_scheme(&self) -> u8 {
667        KEYLESS_SCHEME
668    }
669}
670
671impl fmt::Debug for KeylessAccount {
672    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
673        f.debug_struct("KeylessAccount")
674            .field("address", &self.address)
675            .field("provider", &self.provider)
676            .field("issuer", &self.issuer)
677            .field("audience", &self.audience)
678            .field("user_id", &self.user_id)
679            .finish_non_exhaustive()
680    }
681}
682
683#[derive(Debug, Deserialize)]
684struct JwtClaims {
685    iss: Option<String>,
686    aud: Option<AudClaim>,
687    sub: Option<String>,
688    exp: Option<u64>,
689    nonce: Option<String>,
690}
691
692#[derive(Debug, Deserialize)]
693#[serde(untagged)]
694enum AudClaim {
695    Single(String),
696    Multiple(Vec<String>),
697}
698
699impl AudClaim {
700    fn first(&self) -> Option<&str> {
701        match self {
702            AudClaim::Single(value) => Some(value.as_str()),
703            AudClaim::Multiple(values) => values.first().map(std::string::String::as_str),
704        }
705    }
706}
707
708/// Default timeout for JWKS fetch requests (10 seconds).
709const JWKS_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
710
711/// Fetches the JWKS (JSON Web Key Set) from an OIDC provider.
712///
713/// # Errors
714///
715/// Returns an error if:
716/// - The JWKS cannot be fetched (network timeouts, DNS resolution failures,
717///   TLS/connection errors, or HTTP errors)
718/// - The JWKS endpoint returns a non-success status code
719/// - The response cannot be parsed as valid JWKS JSON
720async fn fetch_jwks(client: &reqwest::Client, jwks_url: &str) -> AptosResult<JwkSet> {
721    // Note: timeout is configured on the client, not per-request
722    let response = client.get(jwks_url).send().await?;
723
724    if !response.status().is_success() {
725        return Err(AptosError::InvalidJwt(format!(
726            "JWKS endpoint returned status: {}",
727            response.status()
728        )));
729    }
730
731    let jwks: JwkSet = response.json().await?;
732    Ok(jwks)
733}
734
735/// Decodes and verifies a JWT using the provided JWKS.
736///
737/// This function:
738/// 1. Extracts the `kid` (key ID) from the JWT header
739/// 2. Finds the matching key in the JWKS
740/// 3. Verifies the signature and decodes the claims
741///
742/// # Errors
743///
744/// Returns an error if:
745/// - The JWT header cannot be decoded
746/// - No matching key is found in the JWKS
747/// - The signature verification fails
748/// - The claims cannot be decoded
749fn decode_and_verify_jwt(jwt: &str, jwks: &JwkSet) -> AptosResult<JwtClaims> {
750    // Decode header to get the key ID
751    let header = decode_header(jwt)
752        .map_err(|e| AptosError::InvalidJwt(format!("failed to decode JWT header: {e}")))?;
753
754    let kid = header
755        .kid
756        .as_ref()
757        .ok_or_else(|| AptosError::InvalidJwt("JWT header missing 'kid' field".into()))?;
758
759    // Find the matching key in the JWKS
760    let signing_key = jwks.find(kid).ok_or_else(|| {
761        AptosError::InvalidJwt("no matching key found for provided key identifier".into())
762    })?;
763
764    // Create decoding key from JWK
765    let decoding_key = DecodingKey::from_jwk(signing_key)
766        .map_err(|e| AptosError::InvalidJwt(format!("failed to create decoding key: {e}")))?;
767
768    // Determine the algorithm strictly from the JWK to prevent algorithm substitution attacks
769    let jwk_alg = signing_key
770        .common
771        .key_algorithm
772        .ok_or_else(|| AptosError::InvalidJwt("JWK missing 'alg' (key_algorithm) field".into()))?;
773
774    let algorithm = match jwk_alg {
775        // RSA algorithms
776        jsonwebtoken::jwk::KeyAlgorithm::RS256 => Algorithm::RS256,
777        jsonwebtoken::jwk::KeyAlgorithm::RS384 => Algorithm::RS384,
778        jsonwebtoken::jwk::KeyAlgorithm::RS512 => Algorithm::RS512,
779        // RSA-PSS algorithms
780        jsonwebtoken::jwk::KeyAlgorithm::PS256 => Algorithm::PS256,
781        jsonwebtoken::jwk::KeyAlgorithm::PS384 => Algorithm::PS384,
782        jsonwebtoken::jwk::KeyAlgorithm::PS512 => Algorithm::PS512,
783        // ECDSA algorithms
784        jsonwebtoken::jwk::KeyAlgorithm::ES256 => Algorithm::ES256,
785        jsonwebtoken::jwk::KeyAlgorithm::ES384 => Algorithm::ES384,
786        // EdDSA algorithm
787        jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Algorithm::EdDSA,
788        _ => {
789            return Err(AptosError::InvalidJwt(format!(
790                "unsupported JWK algorithm: {jwk_alg:?}"
791            )));
792        }
793    };
794
795    // Ensure the JWT header algorithm matches the JWK algorithm to prevent substitution
796    if header.alg != algorithm {
797        return Err(AptosError::InvalidJwt(format!(
798            "JWT header algorithm ({:?}) does not match JWK algorithm ({:?})",
799            header.alg, algorithm
800        )));
801    }
802
803    // Configure validation - we'll validate exp ourselves with more detailed errors
804    let mut validation = Validation::new(algorithm);
805    validation.validate_exp = false;
806    validation.validate_aud = false; // We'll check aud after decoding
807    validation.set_required_spec_claims::<String>(&[]);
808
809    let data = decode::<JwtClaims>(jwt, &decoding_key, &validation)
810        .map_err(|e| AptosError::InvalidJwt(format!("JWT verification failed: {e}")))?;
811
812    Ok(data.claims)
813}
814
815/// Decodes JWT claims without signature verification.
816///
817/// This is used only to extract the issuer (and other metadata) before we know
818/// which JWKS endpoint to fetch. This is safe because:
819/// 1. The extracted issuer is only used to determine which JWKS endpoint to fetch.
820/// 2. The JWT is fully verified immediately afterwards using `decode_and_verify_jwt`.
821/// 3. No security decisions are made based on these unverified claims.
822fn decode_claims_unverified(jwt: &str) -> AptosResult<JwtClaims> {
823    // Use dangerous decode only for initial issuer extraction to select the JWKS.
824    // The JWT is not trusted at this point: no authorization decisions are made
825    // based on these unverified claims, and the token is fully verified (including
826    // signature and claims validation) in `decode_and_verify_jwt` right after the
827    // appropriate JWKS has been fetched.
828    let data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(jwt)
829        .map_err(|e| AptosError::InvalidJwt(format!("failed to decode JWT claims: {e}")))?;
830    Ok(data.claims)
831}
832
833fn extract_claims(
834    claims: &JwtClaims,
835) -> AptosResult<(String, String, String, Option<SystemTime>, String)> {
836    let issuer = claims
837        .iss
838        .clone()
839        .ok_or_else(|| AptosError::InvalidJwt("missing iss claim".into()))?;
840    let audience = claims
841        .aud
842        .as_ref()
843        .and_then(|aud| aud.first())
844        .map(std::string::ToString::to_string)
845        .ok_or_else(|| AptosError::InvalidJwt("missing aud claim".into()))?;
846    let user_id = claims
847        .sub
848        .clone()
849        .ok_or_else(|| AptosError::InvalidJwt("missing sub claim".into()))?;
850    let nonce = claims
851        .nonce
852        .clone()
853        .ok_or_else(|| AptosError::InvalidJwt("missing nonce claim".into()))?;
854
855    let exp_time = claims.exp.map(|exp| UNIX_EPOCH + Duration::from_secs(exp));
856    if let Some(exp) = exp_time
857        && SystemTime::now() >= exp
858    {
859        let exp_secs = claims.exp.unwrap_or(0);
860        return Err(AptosError::InvalidJwt(format!(
861            "JWT is expired (exp: {exp_secs} seconds since UNIX_EPOCH)"
862        )));
863    }
864
865    Ok((issuer, audience, user_id, exp_time, nonce))
866}
867
868fn derive_keyless_address(
869    issuer: &str,
870    audience: &str,
871    user_id: &str,
872    pepper: &Pepper,
873) -> AccountAddress {
874    let issuer_hash = sha3_256_bytes(issuer.as_bytes());
875    let audience_hash = sha3_256_bytes(audience.as_bytes());
876    let user_hash = sha3_256_bytes(user_id.as_bytes());
877
878    let mut hasher = Sha3_256::new();
879    hasher.update(issuer_hash);
880    hasher.update(audience_hash);
881    hasher.update(user_hash);
882    hasher.update(pepper.as_bytes());
883    hasher.update([KEYLESS_SCHEME]);
884    let result = hasher.finalize();
885
886    let mut address = [0u8; 32];
887    address.copy_from_slice(&result);
888    AccountAddress::new(address)
889}
890
891fn sha3_256_bytes(data: &[u8]) -> [u8; 32] {
892    let mut hasher = Sha3_256::new();
893    hasher.update(data);
894    let result = hasher.finalize();
895    let mut output = [0u8; 32];
896    output.copy_from_slice(&result);
897    output
898}
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903    use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
904
905    struct StaticPepperService {
906        pepper: Pepper,
907    }
908
909    #[async_trait]
910    impl PepperService for StaticPepperService {
911        async fn get_pepper(&self, _jwt: &str) -> AptosResult<Pepper> {
912            Ok(self.pepper.clone())
913        }
914    }
915
916    struct StaticProverService {
917        proof: ZkProof,
918    }
919
920    #[async_trait]
921    impl ProverService for StaticProverService {
922        async fn generate_proof(
923            &self,
924            _jwt: &str,
925            _ephemeral_key: &EphemeralKeyPair,
926            _pepper: &Pepper,
927        ) -> AptosResult<ZkProof> {
928            Ok(self.proof.clone())
929        }
930    }
931
932    #[derive(Serialize, Deserialize)]
933    struct TestClaims {
934        iss: String,
935        aud: String,
936        sub: String,
937        exp: u64,
938        nonce: String,
939    }
940
941    #[tokio::test]
942    async fn test_keyless_account_creation() {
943        let ephemeral = EphemeralKeyPair::generate(3600);
944        let now = SystemTime::now()
945            .duration_since(UNIX_EPOCH)
946            .expect("time went backwards")
947            .as_secs();
948
949        // Create a test JWT for the services (they don't validate it)
950        let claims = TestClaims {
951            iss: "https://accounts.google.com".to_string(),
952            aud: "client-id".to_string(),
953            sub: "user-123".to_string(),
954            exp: now + 3600,
955            nonce: ephemeral.nonce().to_string(),
956        };
957
958        let jwt = encode(
959            &Header::new(Algorithm::HS256),
960            &claims,
961            &EncodingKey::from_secret(b"secret"),
962        )
963        .unwrap();
964
965        let pepper_service = StaticPepperService {
966            pepper: Pepper::new(vec![1, 2, 3, 4]),
967        };
968        let prover_service = StaticProverService {
969            proof: ZkProof::new(vec![9, 9, 9]),
970        };
971
972        // Use from_verified_claims for unit testing since we can't mock JWKS
973        let exp_time = UNIX_EPOCH + std::time::Duration::from_secs(now + 3600);
974        let account = KeylessAccount::from_verified_claims(
975            "https://accounts.google.com".to_string(),
976            "client-id".to_string(),
977            "user-123".to_string(),
978            ephemeral.nonce().to_string(),
979            Some(exp_time),
980            ephemeral,
981            &pepper_service,
982            &prover_service,
983            &jwt,
984        )
985        .await
986        .unwrap();
987
988        assert_eq!(account.issuer(), "https://accounts.google.com");
989        assert_eq!(account.audience(), "client-id");
990        assert_eq!(account.user_id(), "user-123");
991        assert!(account.is_valid());
992        assert!(!account.address().is_zero());
993    }
994
995    #[tokio::test]
996    async fn test_keyless_account_nonce_mismatch() {
997        let ephemeral = EphemeralKeyPair::generate(3600);
998        let now = SystemTime::now()
999            .duration_since(UNIX_EPOCH)
1000            .expect("time went backwards")
1001            .as_secs();
1002
1003        let claims = TestClaims {
1004            iss: "https://accounts.google.com".to_string(),
1005            aud: "client-id".to_string(),
1006            sub: "user-123".to_string(),
1007            exp: now + 3600,
1008            nonce: ephemeral.nonce().to_string(),
1009        };
1010
1011        let jwt = encode(
1012            &Header::new(Algorithm::HS256),
1013            &claims,
1014            &EncodingKey::from_secret(b"secret"),
1015        )
1016        .unwrap();
1017
1018        let pepper_service = StaticPepperService {
1019            pepper: Pepper::new(vec![1, 2, 3, 4]),
1020        };
1021        let prover_service = StaticProverService {
1022            proof: ZkProof::new(vec![9, 9, 9]),
1023        };
1024
1025        // Use a different nonce to trigger mismatch
1026        let result = KeylessAccount::from_verified_claims(
1027            "https://accounts.google.com".to_string(),
1028            "client-id".to_string(),
1029            "user-123".to_string(),
1030            "wrong-nonce".to_string(), // This doesn't match ephemeral.nonce()
1031            None,
1032            ephemeral,
1033            &pepper_service,
1034            &prover_service,
1035            &jwt,
1036        )
1037        .await;
1038
1039        assert!(result.is_err());
1040        assert!(matches!(result, Err(AptosError::InvalidJwt(_))));
1041    }
1042
1043    #[test]
1044    fn test_decode_claims_unverified() {
1045        let now = SystemTime::now()
1046            .duration_since(UNIX_EPOCH)
1047            .expect("time went backwards")
1048            .as_secs();
1049
1050        let claims = TestClaims {
1051            iss: "https://accounts.google.com".to_string(),
1052            aud: "test-aud".to_string(),
1053            sub: "test-sub".to_string(),
1054            exp: now + 3600,
1055            nonce: "test-nonce".to_string(),
1056        };
1057
1058        let jwt = encode(
1059            &Header::new(Algorithm::HS256),
1060            &claims,
1061            &EncodingKey::from_secret(b"secret"),
1062        )
1063        .unwrap();
1064
1065        let decoded = decode_claims_unverified(&jwt).unwrap();
1066        assert_eq!(decoded.iss.unwrap(), "https://accounts.google.com");
1067        assert_eq!(decoded.sub.unwrap(), "test-sub");
1068        assert_eq!(decoded.nonce.unwrap(), "test-nonce");
1069    }
1070
1071    #[test]
1072    fn test_oidc_provider_detection() {
1073        assert!(matches!(
1074            OidcProvider::from_issuer("https://accounts.google.com"),
1075            OidcProvider::Google
1076        ));
1077        assert!(matches!(
1078            OidcProvider::from_issuer("https://appleid.apple.com"),
1079            OidcProvider::Apple
1080        ));
1081        assert!(matches!(
1082            OidcProvider::from_issuer("https://unknown.example.com"),
1083            OidcProvider::Custom { .. }
1084        ));
1085    }
1086
1087    #[test]
1088    fn test_decode_and_verify_jwt_missing_kid() {
1089        // Create a JWT without a kid in the header
1090        let now = SystemTime::now()
1091            .duration_since(UNIX_EPOCH)
1092            .expect("time went backwards")
1093            .as_secs();
1094
1095        let claims = TestClaims {
1096            iss: "https://accounts.google.com".to_string(),
1097            aud: "test-aud".to_string(),
1098            sub: "test-sub".to_string(),
1099            exp: now + 3600,
1100            nonce: "test-nonce".to_string(),
1101        };
1102
1103        // HS256 JWT without kid
1104        let jwt = encode(
1105            &Header::new(Algorithm::HS256),
1106            &claims,
1107            &EncodingKey::from_secret(b"secret"),
1108        )
1109        .unwrap();
1110
1111        // Empty JWKS
1112        let jwks = JwkSet { keys: vec![] };
1113
1114        let result = decode_and_verify_jwt(&jwt, &jwks);
1115        assert!(result.is_err());
1116        let err = result.unwrap_err();
1117        assert!(
1118            matches!(&err, AptosError::InvalidJwt(msg) if msg.contains("kid")),
1119            "Expected error about missing kid, got: {err:?}"
1120        );
1121    }
1122
1123    #[test]
1124    fn test_decode_and_verify_jwt_no_matching_key() {
1125        let now = SystemTime::now()
1126            .duration_since(UNIX_EPOCH)
1127            .expect("time went backwards")
1128            .as_secs();
1129
1130        let claims = TestClaims {
1131            iss: "https://accounts.google.com".to_string(),
1132            aud: "test-aud".to_string(),
1133            sub: "test-sub".to_string(),
1134            exp: now + 3600,
1135            nonce: "test-nonce".to_string(),
1136        };
1137
1138        // Create JWT with a kid in header (using HS256 for encoding)
1139        let mut header = Header::new(Algorithm::HS256);
1140        header.kid = Some("test-kid-123".to_string());
1141
1142        let jwt = encode(&header, &claims, &EncodingKey::from_secret(b"secret")).unwrap();
1143
1144        // Empty JWKS - no matching key
1145        let jwks = JwkSet { keys: vec![] };
1146
1147        let result = decode_and_verify_jwt(&jwt, &jwks);
1148        assert!(result.is_err());
1149        let err = result.unwrap_err();
1150        assert!(
1151            matches!(&err, AptosError::InvalidJwt(msg) if msg.contains("no matching key")),
1152            "Expected error about no matching key, got: {err:?}"
1153        );
1154    }
1155
1156    #[test]
1157    fn test_decode_and_verify_jwt_invalid_jwt_format() {
1158        let jwks = JwkSet { keys: vec![] };
1159
1160        // Completely invalid JWT
1161        let result = decode_and_verify_jwt("not-a-valid-jwt", &jwks);
1162        assert!(result.is_err());
1163
1164        // JWT with invalid base64
1165        let result = decode_and_verify_jwt("aaa.bbb.ccc", &jwks);
1166        assert!(result.is_err());
1167    }
1168}