iota_sdk_types/crypto/
zklogin.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use super::SimpleSignature;
6use crate::{checkpoint::EpochId, u256::U256};
7
8/// A zklogin authenticator
9///
10/// # BCS
11///
12/// The BCS serialized form for this type is defined by the following ABNF:
13///
14/// ```text
15/// zklogin-bcs = bytes             ; contents are defined by <zklogin-authenticator>
16/// zklogin     = zklogin-flag
17///               zklogin-inputs
18///               u64               ; max epoch
19///               simple-signature    
20/// ```
21///
22/// Note: Due to historical reasons, signatures are serialized slightly
23/// different from the majority of the types in IOTA. In particular if a
24/// signature is ever embedded in another structure it generally is serialized
25/// as `bytes` meaning it has a length prefix that defines the length of
26/// the completely serialized signature.
27#[derive(Debug, Clone, PartialEq, Eq)]
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
30pub struct ZkLoginAuthenticator {
31    /// Zklogin proof and inputs required to perform proof verification.
32    pub inputs: ZkLoginInputs,
33    /// Maximum epoch for which the proof is valid.
34    #[cfg_attr(feature = "schemars", schemars(with = "crate::_schemars::U64"))]
35    pub max_epoch: EpochId,
36    /// User signature with the pubkey attested to by the provided proof.
37    pub signature: SimpleSignature,
38}
39
40/// A zklogin groth16 proof and the required inputs to perform proof
41/// verification.
42///
43/// # BCS
44///
45/// The BCS serialized form for this type is defined by the following ABNF:
46///
47/// ```text
48/// zklogin-inputs = zklogin-proof
49///                  zklogin-claim
50///                  string              ; base64url-unpadded encoded JwtHeader
51///                  bn254-field-element ; address_seed
52/// ```
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ZkLoginInputs {
55    proof_points: ZkLoginProof,
56    iss_base64_details: ZkLoginClaim,
57    header_base64: String,
58    // Validated types
59    jwt_header: JwtHeader,
60    jwk_id: JwkId,
61    public_identifier: ZkLoginPublicIdentifier,
62}
63
64impl ZkLoginInputs {
65    #[cfg(feature = "serde")]
66    #[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
67    pub fn new(
68        proof_points: ZkLoginProof,
69        iss_base64_details: ZkLoginClaim,
70        header_base64: String,
71        address_seed: Bn254FieldElement,
72    ) -> Result<Self, InvalidZkLoginAuthenticatorError> {
73        let iss = {
74            const ISS: &str = "iss";
75
76            let iss = iss_base64_details.verify_extended_claim(ISS)?;
77
78            if iss.len() > 255 {
79                return Err(InvalidZkLoginAuthenticatorError::new(
80                    "invalid iss: too long",
81                ));
82            }
83            iss
84        };
85
86        let jwt_header = JwtHeader::from_base64(&header_base64)?;
87        let jwk_id = JwkId {
88            iss: iss.clone(),
89            kid: jwt_header.kid.clone(),
90        };
91
92        let public_identifier = ZkLoginPublicIdentifier { iss, address_seed };
93
94        Ok(Self {
95            proof_points,
96            iss_base64_details,
97            header_base64,
98            jwt_header,
99            jwk_id,
100            public_identifier,
101        })
102    }
103
104    pub fn proof_points(&self) -> &ZkLoginProof {
105        &self.proof_points
106    }
107
108    pub fn iss_base64_details(&self) -> &ZkLoginClaim {
109        &self.iss_base64_details
110    }
111
112    pub fn header_base64(&self) -> &str {
113        &self.header_base64
114    }
115
116    pub fn address_seed(&self) -> &Bn254FieldElement {
117        &self.public_identifier.address_seed
118    }
119
120    pub fn jwk_id(&self) -> &JwkId {
121        &self.jwk_id
122    }
123
124    pub fn iss(&self) -> &str {
125        &self.public_identifier.iss
126    }
127
128    pub fn public_identifier(&self) -> &ZkLoginPublicIdentifier {
129        &self.public_identifier
130    }
131}
132
133#[cfg(feature = "schemars")]
134impl schemars::JsonSchema for ZkLoginInputs {
135    fn schema_name() -> String {
136        "ZkLoginInputs".to_owned()
137    }
138
139    fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
140        #[derive(schemars::JsonSchema)]
141        #[expect(unused)]
142        struct Inputs {
143            proof_points: ZkLoginProof,
144            iss_base64_details: ZkLoginClaim,
145            header_base64: String,
146            address_seed: Bn254FieldElement,
147        }
148
149        Inputs::json_schema(generator)
150    }
151}
152
153#[cfg(feature = "proptest")]
154impl proptest::arbitrary::Arbitrary for ZkLoginInputs {
155    type Parameters = ();
156    type Strategy = proptest::strategy::BoxedStrategy<Self>;
157
158    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
159        use proptest::prelude::*;
160
161        (any::<ZkLoginProof>(), any::<Bn254FieldElement>())
162            .prop_map(|(proof_points, address_seed)| {
163                // TODO implement Arbitrary for real for ZkLoginClaim and header_base64 values
164                let iss_base64_details = ZkLoginClaim {
165                    value: "wiaXNzIjoiaHR0cHM6Ly9pZC50d2l0Y2gudHYvb2F1dGgyIiw".to_owned(),
166                    index_mod_4: 2,
167                };
168                let header_base64 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ".to_owned();
169                Self::new(
170                    proof_points,
171                    iss_base64_details,
172                    header_base64,
173                    address_seed,
174                )
175                .unwrap()
176            })
177            .boxed()
178    }
179}
180
181/// A claim of the iss in a zklogin proof
182///
183/// # BCS
184///
185/// The BCS serialized form for this type is defined by the following ABNF:
186///
187/// ```text
188/// zklogin-claim = string u8
189/// ```
190#[derive(Debug, Clone, PartialEq, Eq)]
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
193#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
194pub struct ZkLoginClaim {
195    pub value: String,
196    pub index_mod_4: u8,
197}
198
199#[derive(Debug)]
200pub struct InvalidZkLoginAuthenticatorError(String);
201
202#[cfg(feature = "serde")]
203#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
204impl InvalidZkLoginAuthenticatorError {
205    fn new<T: Into<String>>(err: T) -> Self {
206        Self(err.into())
207    }
208}
209
210impl std::fmt::Display for InvalidZkLoginAuthenticatorError {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        write!(f, "invalid zklogin claim: {}", self.0)
213    }
214}
215
216impl std::error::Error for InvalidZkLoginAuthenticatorError {}
217
218#[cfg(feature = "serde")]
219#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
220impl ZkLoginClaim {
221    fn verify_extended_claim(
222        &self,
223        expected_key: &str,
224    ) -> Result<String, InvalidZkLoginAuthenticatorError> {
225        /// Map a base64 string to a bit array by taking each char's index and
226        /// convert it to binary form with one bit per u8 element in the
227        /// output. Returns InvalidZkLoginClaimError if one of the characters is
228        /// not in the base64 charset.
229        fn base64_to_bitarray(input: &str) -> Result<Vec<u8>, InvalidZkLoginAuthenticatorError> {
230            use itertools::Itertools;
231
232            const BASE64_URL_CHARSET: &str =
233                "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
234
235            input
236                .chars()
237                .map(|c| {
238                    BASE64_URL_CHARSET
239                        .find(c)
240                        .map(|index| index as u8)
241                        .map(|index| (0..6).rev().map(move |i| (index >> i) & 1))
242                        .ok_or_else(|| {
243                            InvalidZkLoginAuthenticatorError::new(
244                                "base64_to_bitarray invalid input",
245                            )
246                        })
247                })
248                .flatten_ok()
249                .collect()
250        }
251
252        /// Convert a bitarray (each bit is represented by a u8) to a byte array
253        /// by taking each 8 bits as a byte in big-endian format.
254        fn bitarray_to_bytearray(bits: &[u8]) -> Result<Vec<u8>, InvalidZkLoginAuthenticatorError> {
255            if !bits.len().is_multiple_of(8) {
256                return Err(InvalidZkLoginAuthenticatorError::new(
257                    "bitarray_to_bytearray invalid input",
258                ));
259            }
260            Ok(bits
261                .chunks(8)
262                .map(|chunk| {
263                    let mut byte = 0u8;
264                    for (i, bit) in chunk.iter().rev().enumerate() {
265                        byte |= bit << i;
266                    }
267                    byte
268                })
269                .collect())
270        }
271
272        /// Parse the base64 string, add paddings based on offset, and convert
273        /// to a bytearray.
274        fn decode_base64_url(
275            s: &str,
276            index_mod_4: &u8,
277        ) -> Result<String, InvalidZkLoginAuthenticatorError> {
278            if s.len() < 2 {
279                return Err(InvalidZkLoginAuthenticatorError::new(
280                    "Base64 string smaller than 2",
281                ));
282            }
283            let mut bits = base64_to_bitarray(s)?;
284            match index_mod_4 {
285                0 => {}
286                1 => {
287                    bits.drain(..2);
288                }
289                2 => {
290                    bits.drain(..4);
291                }
292                _ => {
293                    return Err(InvalidZkLoginAuthenticatorError::new(
294                        "Invalid first_char_offset",
295                    ));
296                }
297            }
298
299            let last_char_offset = (index_mod_4 + s.len() as u8 - 1) % 4;
300            match last_char_offset {
301                3 => {}
302                2 => {
303                    bits.drain(bits.len() - 2..);
304                }
305                1 => {
306                    bits.drain(bits.len() - 4..);
307                }
308                _ => {
309                    return Err(InvalidZkLoginAuthenticatorError::new(
310                        "Invalid last_char_offset",
311                    ));
312                }
313            }
314
315            if bits.len() % 8 != 0 {
316                return Err(InvalidZkLoginAuthenticatorError::new("Invalid bits length"));
317            }
318
319            Ok(std::str::from_utf8(&bitarray_to_bytearray(&bits)?)
320                .map_err(|_| InvalidZkLoginAuthenticatorError::new("Invalid UTF8 string"))?
321                .to_owned())
322        }
323
324        let extended_claim = decode_base64_url(&self.value, &self.index_mod_4)?;
325
326        // Last character of each extracted_claim must be '}' or ','
327        if !(extended_claim.ends_with('}') || extended_claim.ends_with(',')) {
328            return Err(InvalidZkLoginAuthenticatorError::new(
329                "Invalid extended claim",
330            ));
331        }
332
333        let json_str = format!("{{{}}}", &extended_claim[..extended_claim.len() - 1]);
334
335        serde_json::from_str::<serde_json::Value>(&json_str)
336            .map_err(|e| InvalidZkLoginAuthenticatorError::new(e.to_string()))?
337            .as_object_mut()
338            .and_then(|o| o.get_mut(expected_key))
339            .map(serde_json::Value::take)
340            .and_then(|v| match v {
341                serde_json::Value::String(s) => Some(s),
342                _ => None,
343            })
344            .ok_or_else(|| InvalidZkLoginAuthenticatorError::new("invalid extended claim"))
345    }
346}
347
348/// Struct that represents a standard JWT header according to
349/// https://openid.net/specs/openid-connect-core-1_0.html
350#[derive(Debug, Clone, PartialEq, Eq)]
351struct JwtHeader {
352    alg: String,
353    kid: String,
354    typ: Option<String>,
355}
356
357impl JwtHeader {
358    #[cfg(feature = "serde")]
359    fn from_base64(s: &str) -> Result<Self, InvalidZkLoginAuthenticatorError> {
360        use base64ct::{Base64UrlUnpadded, Encoding};
361
362        #[derive(serde::Serialize, serde::Deserialize)]
363        struct Header {
364            alg: String,
365            kid: String,
366            #[serde(skip_serializing_if = "Option::is_none")]
367            typ: Option<String>,
368        }
369
370        let header_bytes = Base64UrlUnpadded::decode_vec(s)
371            .map_err(|e| InvalidZkLoginAuthenticatorError::new(format!("invalid base64: {e}")))?;
372        let Header { alg, kid, typ } = serde_json::from_slice(&header_bytes)
373            .map_err(|e| InvalidZkLoginAuthenticatorError::new(format!("invalid json: {e}")))?;
374        if alg != "RS256" {
375            return Err(InvalidZkLoginAuthenticatorError::new(
376                "jwt alg must be RS256",
377            ));
378        }
379        Ok(Self { alg, kid, typ })
380    }
381}
382
383/// A zklogin groth16 proof
384///
385/// # BCS
386///
387/// The BCS serialized form for this type is defined by the following ABNF:
388///
389/// ```text
390/// zklogin-proof = circom-g1 circom-g2 circom-g1
391/// ```
392#[derive(Debug, Clone, PartialEq, Eq)]
393#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
394#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
395#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
396pub struct ZkLoginProof {
397    pub a: CircomG1,
398    pub b: CircomG2,
399    pub c: CircomG1,
400}
401
402/// A G1 point
403///
404/// This represents the canonical decimal representation of the projective
405/// coordinates in Fq.
406///
407/// # BCS
408///
409/// The BCS serialized form for this type is defined by the following ABNF:
410///
411/// ```text
412/// circom-g1 = %x03 3(bn254-field-element)
413/// ```
414#[derive(Clone, Debug, PartialEq, Eq)]
415#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
416#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
417pub struct CircomG1(pub [Bn254FieldElement; 3]);
418
419/// A G2 point
420///
421/// This represents the canonical decimal representation of the coefficients of
422/// the projective coordinates in Fq2.
423///
424/// # BCS
425///
426/// The BCS serialized form for this type is defined by the following ABNF:
427///
428/// ```text
429/// circom-g2 = %x03 3(%x02 2(bn254-field-element))
430/// ```
431#[derive(Clone, Debug, PartialEq, Eq)]
432#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
433#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
434pub struct CircomG2(pub [[Bn254FieldElement; 2]; 3]);
435
436/// Public Key equivalent for Zklogin authenticators
437///
438/// A `ZkLoginPublicIdentifier` is the equivalent of a public key for other
439/// account authenticators, and contains the information required to derive the
440/// onchain account [`Address`] for a Zklogin authenticator.
441///
442/// ## Note
443///
444/// Due to a historical bug that was introduced in the IOTA Typescript SDK when
445/// the zklogin authenticator was first introduced, there are now possibly two
446/// "valid" addresses for each zklogin authenticator depending on the
447/// bit-pattern of the `address_seed` value.
448///
449/// The original bug incorrectly derived a zklogin's address by stripping any
450/// leading zero-bytes that could have been present in the 32-byte length
451/// `address_seed` value prior to hashing, leading to a different derived
452/// address. This incorrectly derived address was presented to users of various
453/// wallets, leading them to sending funds to these addresses that they couldn't
454/// access. Instead of letting these users lose any assets that were sent to
455/// these addresses, the IOTA network decided to change the protocol to allow
456/// for a zklogin authenticator who's `address_seed` value had leading
457/// zero-bytes be authorized to sign for both the addresses derived from both
458/// the unpadded and padded `address_seed` value.
459///
460/// # BCS
461///
462/// The BCS serialized form for this type is defined by the following ABNF:
463///
464/// ```text
465/// zklogin-public-identifier-bcs = bytes ; where the contents are defined by
466///                                       ; <zklogin-public-identifier>
467///
468/// zklogin-public-identifier = zklogin-public-identifier-iss
469///                             address-seed
470///
471/// zklogin-public-identifier-unpadded = zklogin-public-identifier-iss
472///                                      address-seed-unpadded
473///
474/// ; The iss, or issuer, is a utf8 string that is less than 255 bytes long
475/// ; and is serialized with the iss's length in bytes as a u8 followed by
476/// ; the bytes of the iss
477/// zklogin-public-identifier-iss = u8 *255(OCTET)
478///
479/// ; A Bn254FieldElement serialized as a 32-byte big-endian value
480/// address-seed = 32(OCTET)
481///
482/// ; A Bn254FieldElement serialized as a 32-byte big-endian value
483/// ; with any leading zero bytes stripped
484/// address-seed-unpadded = %x00 / %x01-ff *31(OCTET)
485/// ```
486///
487/// [`Address`]: crate::Address
488#[derive(Clone, Debug, PartialEq, Eq)]
489#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
490#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
491pub struct ZkLoginPublicIdentifier {
492    iss: String,
493    address_seed: Bn254FieldElement,
494}
495
496impl ZkLoginPublicIdentifier {
497    pub fn new(iss: String, address_seed: Bn254FieldElement) -> Option<Self> {
498        if iss.len() > 255 {
499            None
500        } else {
501            Some(Self { iss, address_seed })
502        }
503    }
504
505    pub fn iss(&self) -> &str {
506        &self.iss
507    }
508
509    pub fn address_seed(&self) -> &Bn254FieldElement {
510        &self.address_seed
511    }
512}
513
514/// A JSON Web Key
515///
516/// Struct that contains info for a JWK. A list of them for different kids can
517/// be retrieved from the JWK endpoint (e.g. <https://www.googleapis.com/oauth2/v3/certs>).
518/// The JWK is used to verify the JWT token.
519///
520/// # BCS
521///
522/// The BCS serialized form for this type is defined by the following ABNF:
523///
524/// ```text
525/// jwk = string string string string
526/// ```
527#[derive(Clone, Debug, PartialEq, Eq, Hash)]
528#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
529#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
530#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
531pub struct Jwk {
532    /// Key type parameter, <https://datatracker.ietf.org/doc/html/rfc7517#section-4.1>
533    pub kty: String,
534    /// RSA public exponent, <https://datatracker.ietf.org/doc/html/rfc7517#section-9.3>
535    pub e: String,
536    /// RSA modulus, <https://datatracker.ietf.org/doc/html/rfc7517#section-9.3>
537    pub n: String,
538    /// Algorithm parameter, <https://datatracker.ietf.org/doc/html/rfc7517#section-4.4>
539    pub alg: String,
540}
541
542/// Key to uniquely identify a JWK
543///
544/// # BCS
545///
546/// The BCS serialized form for this type is defined by the following ABNF:
547///
548/// ```text
549/// jwk-id = string string
550/// ```
551#[derive(Clone, Debug, PartialEq, Eq, Hash)]
552#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
553#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
554#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
555pub struct JwkId {
556    /// The issuer or identity of the OIDC provider.
557    pub iss: String,
558    /// A key id use to uniquely identify a key from an OIDC provider.
559    pub kid: String,
560}
561
562/// A point on the BN254 elliptic curve.
563///
564/// This is a 32-byte, or 256-bit, value that is generally represented as
565/// radix10 when a human-readable display format is needed, and is represented
566/// as a 32-byte big-endian value while in memory.
567///
568/// # BCS
569///
570/// The BCS serialized form for this type is defined by the following ABNF:
571///
572/// ```text
573/// bn254-field-element = *DIGIT ; which is then interpreted as a radix10 encoded 32-byte value
574/// ```
575#[derive(Clone, Debug, Default, PartialEq, Eq)]
576#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
577#[cfg_attr(feature = "proptest", derive(test_strategy::Arbitrary))]
578pub struct Bn254FieldElement(
579    #[cfg_attr(feature = "schemars", schemars(with = "crate::_schemars::U256"))] [u8; 32],
580);
581
582impl Bn254FieldElement {
583    pub const fn new(bytes: [u8; 32]) -> Self {
584        Self(bytes)
585    }
586
587    pub const fn from_str_radix_10(s: &str) -> Result<Self, Bn254FieldElementParseError> {
588        let u256 = match U256::from_str_radix(s, 10) {
589            Ok(u256) => u256,
590            Err(e) => return Err(Bn254FieldElementParseError(e)),
591        };
592        let be = u256.to_be();
593        Ok(Self(*be.digits()))
594    }
595
596    pub fn unpadded(&self) -> &[u8] {
597        let mut buf = self.0.as_slice();
598
599        while !buf.is_empty() && buf[0] == 0 {
600            buf = &buf[1..];
601        }
602
603        // If the value is '0' then just return a slice of length 1 of the final byte
604        if buf.is_empty() { &self.0[31..] } else { buf }
605    }
606
607    pub fn padded(&self) -> &[u8] {
608        &self.0
609    }
610}
611
612impl std::fmt::Display for Bn254FieldElement {
613    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614        let u256 = U256::from_be(U256::from_digits(self.0));
615        let radix10 = u256.to_str_radix(10);
616        f.write_str(&radix10)
617    }
618}
619
620#[derive(Debug)]
621pub struct Bn254FieldElementParseError(bnum::errors::ParseIntError);
622
623impl std::fmt::Display for Bn254FieldElementParseError {
624    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
625        write!(f, "unable to parse radix10 encoded value {}", self.0)
626    }
627}
628
629impl std::error::Error for Bn254FieldElementParseError {}
630
631impl std::str::FromStr for Bn254FieldElement {
632    type Err = Bn254FieldElementParseError;
633
634    fn from_str(s: &str) -> Result<Self, Self::Err> {
635        let u256 = U256::from_str_radix(s, 10).map_err(Bn254FieldElementParseError)?;
636        let be = u256.to_be();
637        Ok(Self(*be.digits()))
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use std::str::FromStr;
644
645    use num_bigint::BigUint;
646    use proptest::prelude::*;
647    use test_strategy::proptest;
648    #[cfg(target_arch = "wasm32")]
649    use wasm_bindgen_test::wasm_bindgen_test as test;
650
651    use super::Bn254FieldElement;
652
653    #[test]
654    fn unpadded_slice() {
655        let seed = Bn254FieldElement([0; 32]);
656        let zero: [u8; 1] = [0];
657        assert_eq!(seed.unpadded(), zero.as_slice());
658
659        let mut seed = Bn254FieldElement([1; 32]);
660        seed.0[0] = 0;
661        assert_eq!(seed.unpadded(), [1; 31].as_slice());
662    }
663
664    #[proptest]
665    fn dont_crash_on_large_inputs(
666        #[strategy(proptest::collection::vec(any::<u8>(), 33..1024))] bytes: Vec<u8>,
667    ) {
668        let big_int = BigUint::from_bytes_be(&bytes);
669        let radix10 = big_int.to_str_radix(10);
670
671        // doesn't crash
672        let _ = Bn254FieldElement::from_str(&radix10);
673    }
674
675    #[proptest]
676    fn valid_address_seeds(
677        #[strategy(proptest::collection::vec(any::<u8>(), 1..=32))] bytes: Vec<u8>,
678    ) {
679        let big_int = BigUint::from_bytes_be(&bytes);
680        let radix10 = big_int.to_str_radix(10);
681
682        let seed = Bn254FieldElement::from_str(&radix10).unwrap();
683        assert_eq!(radix10, seed.to_string());
684        // Ensure unpadded doesn't crash
685        seed.unpadded();
686    }
687}
688
689#[cfg(feature = "serde")]
690#[cfg_attr(doc_cfg, doc(cfg(feature = "serde")))]
691mod serialization {
692    use std::borrow::Cow;
693
694    use serde::{Deserialize, Deserializer, Serialize, Serializer};
695    use serde_with::{Bytes, DeserializeAs, SerializeAs};
696
697    use super::*;
698    use crate::{SignatureScheme, crypto::SignatureFromBytesError};
699
700    // Serialized format is: iss_bytes_len || iss_bytes ||
701    // padded_32_byte_address_seed.
702    impl Serialize for ZkLoginPublicIdentifier {
703        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
704        where
705            S: Serializer,
706        {
707            if serializer.is_human_readable() {
708                #[derive(serde::Serialize)]
709                struct Readable<'a> {
710                    iss: &'a str,
711                    address_seed: &'a Bn254FieldElement,
712                }
713                let readable = Readable {
714                    iss: &self.iss,
715                    address_seed: &self.address_seed,
716                };
717                readable.serialize(serializer)
718            } else {
719                let mut buf = Vec::new();
720                let iss_bytes = self.iss.as_bytes();
721                buf.push(iss_bytes.len() as u8);
722                buf.extend(iss_bytes);
723
724                buf.extend(&self.address_seed.0);
725
726                serializer.serialize_bytes(&buf)
727            }
728        }
729    }
730
731    impl<'de> Deserialize<'de> for ZkLoginPublicIdentifier {
732        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
733        where
734            D: Deserializer<'de>,
735        {
736            if deserializer.is_human_readable() {
737                #[derive(serde::Deserialize)]
738                struct Readable {
739                    iss: String,
740                    address_seed: Bn254FieldElement,
741                }
742
743                let Readable { iss, address_seed } = Deserialize::deserialize(deserializer)?;
744                Self::new(iss, address_seed)
745                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
746            } else {
747                let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
748                let iss_len = *bytes
749                    .first()
750                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
751                let iss_bytes = bytes
752                    .get(1..(1 + iss_len as usize))
753                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
754                let iss = std::str::from_utf8(iss_bytes).map_err(serde::de::Error::custom)?;
755                let address_seed_bytes = bytes
756                    .get((1 + iss_len as usize)..)
757                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))?;
758
759                let address_seed = <[u8; 32]>::try_from(address_seed_bytes)
760                    .map_err(serde::de::Error::custom)
761                    .map(Bn254FieldElement)?;
762
763                Self::new(iss.into(), address_seed)
764                    .ok_or_else(|| serde::de::Error::custom("invalid zklogin public identifier"))
765            }
766        }
767    }
768
769    #[derive(serde::Serialize)]
770    struct AuthenticatorRef<'a> {
771        inputs: &'a ZkLoginInputs,
772        #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
773        max_epoch: EpochId,
774        signature: &'a SimpleSignature,
775    }
776
777    #[derive(serde::Deserialize)]
778    struct Authenticator {
779        inputs: ZkLoginInputs,
780        #[cfg_attr(feature = "serde", serde(with = "crate::_serde::ReadableDisplay"))]
781        max_epoch: EpochId,
782        signature: SimpleSignature,
783    }
784
785    impl Serialize for ZkLoginAuthenticator {
786        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
787        where
788            S: Serializer,
789        {
790            if serializer.is_human_readable() {
791                let authenticator_ref = AuthenticatorRef {
792                    inputs: &self.inputs,
793                    max_epoch: self.max_epoch,
794                    signature: &self.signature,
795                };
796
797                authenticator_ref.serialize(serializer)
798            } else {
799                let bytes = self.to_bytes();
800                serializer.serialize_bytes(&bytes)
801            }
802        }
803    }
804
805    impl<'de> Deserialize<'de> for ZkLoginAuthenticator {
806        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
807        where
808            D: Deserializer<'de>,
809        {
810            if deserializer.is_human_readable() {
811                let Authenticator {
812                    inputs,
813                    max_epoch,
814                    signature,
815                } = Authenticator::deserialize(deserializer)?;
816                Ok(Self {
817                    inputs,
818                    max_epoch,
819                    signature,
820                })
821            } else {
822                let bytes: Cow<'de, [u8]> = Bytes::deserialize_as(deserializer)?;
823                Self::from_serialized_bytes(bytes).map_err(serde::de::Error::custom)
824            }
825        }
826    }
827
828    impl ZkLoginAuthenticator {
829        pub(crate) fn to_bytes(&self) -> Vec<u8> {
830            let authenticator_ref = AuthenticatorRef {
831                inputs: &self.inputs,
832                max_epoch: self.max_epoch,
833                signature: &self.signature,
834            };
835
836            let mut buf = Vec::new();
837            buf.push(SignatureScheme::ZkLogin as u8);
838
839            bcs::serialize_into(&mut buf, &authenticator_ref).expect("serialization cannot fail");
840            buf
841        }
842
843        pub fn from_serialized_bytes(
844            bytes: impl AsRef<[u8]>,
845        ) -> Result<Self, SignatureFromBytesError> {
846            let bytes = bytes.as_ref();
847            let flag =
848                SignatureScheme::from_byte(*bytes.first().ok_or_else(|| {
849                    SignatureFromBytesError::new("missing signature scheme flag")
850                })?)
851                .map_err(SignatureFromBytesError::new)?;
852            if flag != SignatureScheme::ZkLogin {
853                return Err(SignatureFromBytesError::new("invalid zklogin flag"));
854            }
855            let bcs_bytes = &bytes[1..];
856
857            let Authenticator {
858                inputs,
859                max_epoch,
860                signature,
861            } = bcs::from_bytes(bcs_bytes).map_err(SignatureFromBytesError::new)?;
862            Ok(Self {
863                inputs,
864                max_epoch,
865                signature,
866            })
867        }
868    }
869
870    impl Serialize for ZkLoginInputs {
871        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
872        where
873            S: Serializer,
874        {
875            #[derive(serde::Serialize)]
876            struct Inputs<'a> {
877                proof_points: &'a ZkLoginProof,
878                iss_base64_details: &'a ZkLoginClaim,
879                header_base64: &'a str,
880                address_seed: &'a Bn254FieldElement,
881            }
882
883            Inputs {
884                proof_points: self.proof_points(),
885                iss_base64_details: self.iss_base64_details(),
886                header_base64: self.header_base64(),
887                address_seed: self.address_seed(),
888            }
889            .serialize(serializer)
890        }
891    }
892
893    impl<'de> Deserialize<'de> for ZkLoginInputs {
894        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
895        where
896            D: Deserializer<'de>,
897        {
898            #[derive(serde::Deserialize)]
899            struct Inputs {
900                proof_points: ZkLoginProof,
901                iss_base64_details: ZkLoginClaim,
902                header_base64: String,
903                address_seed: Bn254FieldElement,
904            }
905
906            let Inputs {
907                proof_points,
908                iss_base64_details,
909                header_base64,
910                address_seed,
911            } = Inputs::deserialize(deserializer)?;
912            Self::new(
913                proof_points,
914                iss_base64_details,
915                header_base64,
916                address_seed,
917            )
918            .map_err(serde::de::Error::custom)
919        }
920    }
921
922    // AddressSeed's serialized format is as a radix10 encoded string
923    impl Serialize for Bn254FieldElement {
924        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
925        where
926            S: serde::Serializer,
927        {
928            serde_with::DisplayFromStr::serialize_as(self, serializer)
929        }
930    }
931
932    impl<'de> Deserialize<'de> for Bn254FieldElement {
933        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
934        where
935            D: Deserializer<'de>,
936        {
937            serde_with::DisplayFromStr::deserialize_as(deserializer)
938        }
939    }
940
941    impl Serialize for CircomG1 {
942        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
943        where
944            S: serde::Serializer,
945        {
946            use serde::ser::SerializeSeq;
947            let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
948            for element in &self.0 {
949                seq.serialize_element(element)?;
950            }
951            seq.end()
952        }
953    }
954
955    impl<'de> Deserialize<'de> for CircomG1 {
956        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
957        where
958            D: Deserializer<'de>,
959        {
960            let inner = <Vec<_>>::deserialize(deserializer)?;
961            Ok(Self(inner.try_into().map_err(|_| {
962                serde::de::Error::custom("expected array of length 3")
963            })?))
964        }
965    }
966
967    impl Serialize for CircomG2 {
968        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
969        where
970            S: serde::Serializer,
971        {
972            use serde::ser::SerializeSeq;
973
974            struct Inner<'a>(&'a [Bn254FieldElement; 2]);
975
976            impl Serialize for Inner<'_> {
977                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
978                where
979                    S: serde::Serializer,
980                {
981                    let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
982                    for element in self.0 {
983                        seq.serialize_element(element)?;
984                    }
985                    seq.end()
986                }
987            }
988
989            let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
990            for element in &self.0 {
991                seq.serialize_element(&Inner(element))?;
992            }
993            seq.end()
994        }
995    }
996
997    impl<'de> Deserialize<'de> for CircomG2 {
998        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
999        where
1000            D: Deserializer<'de>,
1001        {
1002            let vecs = <Vec<Vec<Bn254FieldElement>>>::deserialize(deserializer)?;
1003            let mut inner: [[Bn254FieldElement; 2]; 3] = Default::default();
1004
1005            if vecs.len() != 3 {
1006                return Err(serde::de::Error::custom(
1007                    "vector of three vectors each being a vector of two strings",
1008                ));
1009            }
1010
1011            for (i, v) in vecs.into_iter().enumerate() {
1012                if v.len() != 2 {
1013                    return Err(serde::de::Error::custom(
1014                        "vector of three vectors each being a vector of two strings",
1015                    ));
1016                }
1017
1018                for (j, point) in v.into_iter().enumerate() {
1019                    inner[i][j] = point;
1020                }
1021            }
1022
1023            Ok(Self(inner))
1024        }
1025    }
1026}