ssi_sd_jwt/
lib.rs

1//! Selective Disclosure for JWTs ([SD-JWT]).
2//!
3//! [SD-JWT]: <https://datatracker.ietf.org/doc/draft-ietf-oauth-selective-disclosure-jwt/>
4//!
5//! # Usage
6//!
7//! Contrarily to regular JWTs or JWSs that can be verified directly after
8//! being decoded, SD-JWTs claims need to be revealed before being validated.
9//! The standard path looks like this:
10//! ```text
11//! ┌───────┐                     ┌──────────────┐                            ┌───────────────┐
12//! │       │                     │              │                            │               │
13//! │ SdJwt │ ─► SdJwt::decode ─► │ DecodedSdJwt │ ─► DecodedSdJwt::reveal ─► │ RevealedSdJwt │
14//! │       │                     │              │                            │               │
15//! └───────┘                     └──────────────┘                            └───────────────┘
16//! ```
17//!
18//! The base SD-JWT type is [`SdJwt`] (or [`SdJwtBuf`] if you want to own the
19//! SD-JWT). The [`SdJwt::decode`] function decodes the SD-JWT header, payload
20//! and disclosures into a [`DecodedSdJwt`]. At this point the payload claims
21//! are still concealed and cannot be validated. The [`DecodedSdJwt::reveal`]
22//! function uses the disclosures to reveal the disclosed claims and discard
23//! the non-disclosed claims. The result is a [`RevealedSdJwt`] containing the
24//! revealed JWT, and a set of JSON pointers ([`JsonPointerBuf`]) mapping each
25//! revealed claim to its disclosure. The [`RevealedSdJwt::verify`] function
26//! can then be used to verify the JWT as usual.
27//!
28//! Alternatively, if you don't care about the byproducts of decoding and
29//! revealing the claims, a [`SdJwt::decode_reveal_verify`] function is provided
30//! to decode, reveal and verify the claims directly.
31#![warn(missing_docs)]
32use rand::{CryptoRng, RngCore};
33use serde::{de::DeserializeOwned, Deserialize, Serialize};
34use serde_json::Value;
35use ssi_claims_core::{
36    DateTimeProvider, ProofValidationError, ResolverProvider, SignatureError, ValidateClaims,
37    Verification,
38};
39use ssi_core::BytesBuf;
40use ssi_jwk::JWKResolver;
41use ssi_jws::{DecodedJws, Jws, JwsPayload, JwsSignature, JwsSigner, ValidateJwsHeader};
42use ssi_jwt::{AnyClaims, ClaimSet, DecodedJwt, JWTClaims};
43use std::{
44    borrow::{Borrow, Cow},
45    collections::BTreeMap,
46    fmt::{self, Write},
47    ops::Deref,
48    str::FromStr,
49};
50
51pub use ssi_core::{json_pointer, JsonPointer, JsonPointerBuf};
52
53pub(crate) mod utils;
54use utils::is_url_safe_base64_char;
55
56mod digest;
57pub use digest::*;
58
59mod decode;
60pub use decode::*;
61
62mod disclosure;
63pub use disclosure::*;
64
65mod conceal;
66pub use conceal::*;
67
68mod reveal;
69pub use reveal::*;
70
71const SD_CLAIM_NAME: &str = "_sd";
72const SD_ALG_CLAIM_NAME: &str = "_sd_alg";
73const ARRAY_CLAIM_ITEM_PROPERTY_NAME: &str = "...";
74
75/// Invalid SD-JWT error.
76#[derive(Debug, thiserror::Error)]
77#[error("invalid SD-JWT")]
78pub struct InvalidSdJwt<T = String>(pub T);
79
80impl<T: ?Sized + ToOwned> InvalidSdJwt<&T> {
81    /// Takes ownership of the inner value.
82    pub fn into_owned(self) -> InvalidSdJwt<T::Owned> {
83        InvalidSdJwt(self.0.to_owned())
84    }
85}
86
87/// Creates a new static SD-JWT reference from a string literal.
88#[macro_export]
89#[collapse_debuginfo(no)]
90macro_rules! sd_jwt {
91    ($value:literal) => {
92        match $crate::SdJwt::from_str_const($value) {
93            Ok(value) => value,
94            Err(_) => panic!("invalid SD-JWT"),
95        }
96    };
97}
98
99/// SD-JWT in compact form.
100///
101/// # Grammar
102///
103/// ```abnf
104/// ALPHA = %x41-5A / %x61-7A ; A-Z / a-z
105/// DIGIT = %x30-39 ; 0-9
106/// BASE64URL = 1*(ALPHA / DIGIT / "-" / "_")
107/// JWT = BASE64URL "." BASE64URL "." BASE64URL
108/// DISCLOSURE = BASE64URL
109/// SD-JWT = JWT "~" *[DISCLOSURE "~"]
110/// ```
111#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
112pub struct SdJwt([u8]);
113
114impl SdJwt {
115    /// Parses the given `input` as an SD-JWT.
116    ///
117    /// Returns an error if it is not a valid SD-JWT.
118    pub fn new<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<&Self, InvalidSdJwt<&T>> {
119        let bytes = input.as_ref();
120        if Self::validate(bytes) {
121            Ok(unsafe { Self::new_unchecked(bytes) })
122        } else {
123            Err(InvalidSdJwt(input))
124        }
125    }
126
127    /// Parses the given `input` string as an SD-JWT.
128    ///
129    /// Returns an error if it is not a valid SD-JWT.
130    pub const fn from_str_const(input: &str) -> Result<&Self, InvalidSdJwt<&str>> {
131        let bytes = input.as_bytes();
132        if Self::validate(bytes) {
133            Ok(unsafe { Self::new_unchecked(bytes) })
134        } else {
135            Err(InvalidSdJwt(input))
136        }
137    }
138
139    /// Checks that the given input is a SD-JWT.
140    pub const fn validate(bytes: &[u8]) -> bool {
141        let mut i = 0;
142
143        // Find the first `~`.
144        loop {
145            if i >= bytes.len() {
146                return false;
147            }
148
149            if bytes[i] == b'~' {
150                break;
151            }
152
153            i += 1
154        }
155
156        // Validate the JWS.
157        if !Jws::validate_range(bytes, 0, i) {
158            return false;
159        }
160
161        // Parse disclosures.
162        loop {
163            // Skip the `~`
164            i += 1;
165
166            // No more disclosures.
167            if i >= bytes.len() {
168                break true;
169            }
170
171            loop {
172                if i >= bytes.len() {
173                    // Missing terminating `~`.
174                    return false;
175                }
176
177                // End of disclosure.
178                if bytes[i] == b'~' {
179                    break;
180                }
181
182                // Not a disclosure.
183                if !is_url_safe_base64_char(bytes[i]) {
184                    return false;
185                }
186
187                i += 1
188            }
189        }
190    }
191
192    /// Creates a new SD-JWT from the given `input` without validation.
193    ///
194    /// # Safety
195    ///
196    /// The input value **must** be a valid SD-JWT.
197    pub const unsafe fn new_unchecked(input: &[u8]) -> &Self {
198        std::mem::transmute(input)
199    }
200
201    /// Returns the underlying bytes of the SD-JWT.
202    pub fn as_bytes(&self) -> &[u8] {
203        &self.0
204    }
205
206    /// Returns this SD-JWT as a string.
207    pub fn as_str(&self) -> &str {
208        unsafe {
209            // SAFETY: SD-JWT are valid UTF-8 strings by definition.
210            std::str::from_utf8_unchecked(&self.0)
211        }
212    }
213
214    /// Returns the byte-position just after the issuer-signed JWT.
215    fn jwt_end(&self) -> usize {
216        self.0.iter().copied().position(|c| c == b'~').unwrap()
217    }
218
219    /// Returns the issuer-signed JWT.
220    pub fn jwt(&self) -> &Jws {
221        unsafe {
222            // SAFETY: we already validated the SD-JWT and know it
223            // starts with a valid JWT.
224            Jws::new_unchecked(&self.0[..self.jwt_end()])
225        }
226    }
227
228    /// Returns an iterator over the disclosures of the SD-JWT.
229    pub fn disclosures(&self) -> Disclosures {
230        Disclosures {
231            bytes: &self.0,
232            offset: self.jwt_end() + 1,
233        }
234    }
235
236    /// Returns references to each part of this SD-JWT.
237    pub fn parts(&self) -> PartsRef {
238        PartsRef {
239            jwt: self.jwt(),
240            disclosures: self.disclosures().collect(),
241        }
242    }
243
244    /// Decode a compact SD-JWT.
245    pub fn decode(&self) -> Result<DecodedSdJwt, DecodeError> {
246        self.parts().decode()
247    }
248
249    /// Decodes and reveals the SD-JWT.
250    pub fn decode_reveal<T: DeserializeOwned>(&self) -> Result<RevealedSdJwt<T>, RevealError> {
251        self.parts().decode_reveal()
252    }
253
254    /// Decodes and reveals the SD-JWT.
255    pub fn decode_reveal_any(&self) -> Result<RevealedSdJwt, RevealError> {
256        self.parts().decode_reveal_any()
257    }
258
259    /// Decode a compact SD-JWT.
260    pub async fn decode_verify_concealed<P>(
261        &self,
262        params: P,
263    ) -> Result<(DecodedSdJwt, Verification), ProofValidationError>
264    where
265        P: ResolverProvider<Resolver: JWKResolver>,
266    {
267        self.parts().decode_verify_concealed(params).await
268    }
269
270    /// Decodes, reveals and verify a compact SD-JWT.
271    ///
272    /// Only the registered JWT claims will be validated.
273    /// If you need to validate custom claims, use the
274    /// [`Self::decode_reveal_verify`] method with `T` defining the custom
275    /// claims.
276    ///
277    /// Returns the decoded JWT with the verification status.
278    pub async fn decode_reveal_verify_any<P>(
279        &self,
280        params: P,
281    ) -> Result<(RevealedSdJwt, Verification), ProofValidationError>
282    where
283        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
284    {
285        self.parts().decode_reveal_verify_any(params).await
286    }
287
288    /// Decodes, reveals and verify a compact SD-JWT.
289    ///
290    /// The type parameter `T` corresponds to the set of private JWT claims
291    /// contained in the encoded SD-JWT. If you don't know what value to use
292    /// for this parameter, you can use the [`Self::decode_reveal_verify_any`]
293    /// function instead.
294    ///
295    /// Returns the decoded JWT with the verification status.
296    pub async fn decode_reveal_verify<T, P>(
297        &self,
298        params: P,
299    ) -> Result<(RevealedSdJwt<T>, Verification), ProofValidationError>
300    where
301        T: ClaimSet + DeserializeOwned + ValidateClaims<P, JwsSignature>,
302        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
303    {
304        self.parts().decode_reveal_verify(params).await
305    }
306}
307
308impl AsRef<str> for SdJwt {
309    fn as_ref(&self) -> &str {
310        self.as_str()
311    }
312}
313
314impl AsRef<[u8]> for SdJwt {
315    fn as_ref(&self) -> &[u8] {
316        self.as_bytes()
317    }
318}
319
320impl fmt::Display for SdJwt {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        self.as_str().fmt(f)
323    }
324}
325
326impl fmt::Debug for SdJwt {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        self.as_str().fmt(f)
329    }
330}
331
332impl serde::Serialize for SdJwt {
333    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
334    where
335        S: serde::Serializer,
336    {
337        self.as_str().serialize(serializer)
338    }
339}
340
341impl<'de> serde::Deserialize<'de> for &'de SdJwt {
342    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
343    where
344        D: serde::Deserializer<'de>,
345    {
346        SdJwt::new(<&'de str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
347    }
348}
349
350/// Owned SD-JWT.
351#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
352pub struct SdJwtBuf(Vec<u8>);
353
354impl SdJwtBuf {
355    /// Creates a new owned SD-JWT.
356    pub fn new<B: BytesBuf>(bytes: B) -> Result<Self, InvalidSdJwt<B>> {
357        if SdJwt::validate(bytes.as_ref()) {
358            Ok(Self(bytes.into()))
359        } else {
360            Err(InvalidSdJwt(bytes))
361        }
362    }
363
364    /// Creates a new owned SD-JWT without validating the input bytes.
365    ///
366    /// # Safety
367    ///
368    /// The input `bytes` **must** represent an SD-JWT.
369    pub unsafe fn new_unchecked(bytes: Vec<u8>) -> Self {
370        Self(bytes)
371    }
372
373    /// Conceals and sign the given claims.
374    pub async fn conceal_and_sign(
375        claims: &JWTClaims<impl Serialize>,
376        sd_alg: SdAlg,
377        pointers: &[impl Borrow<JsonPointer>],
378        signer: impl JwsSigner,
379    ) -> Result<Self, SignatureError> {
380        DecodedSdJwt::conceal_and_sign(claims, sd_alg, pointers, signer)
381            .await
382            .map(DecodedSdJwt::into_encoded)
383    }
384
385    /// Conceals and sign the given claims.
386    pub async fn conceal_and_sign_with(
387        claims: &JWTClaims<impl Serialize>,
388        sd_alg: SdAlg,
389        pointers: &[impl Borrow<JsonPointer>],
390        signer: impl JwsSigner,
391        rng: impl CryptoRng + RngCore,
392    ) -> Result<Self, SignatureError> {
393        DecodedSdJwt::conceal_and_sign_with(claims, sd_alg, pointers, signer, rng)
394            .await
395            .map(DecodedSdJwt::into_encoded)
396    }
397
398    /// Borrows the SD-JWT.
399    pub fn as_sd_jwt(&self) -> &SdJwt {
400        unsafe { SdJwt::new_unchecked(&self.0) }
401    }
402
403    /// Turns this SD-JWT into a byte string.
404    pub fn into_bytes(self) -> Vec<u8> {
405        self.0
406    }
407
408    /// Turns this SD-JWT into a string.
409    pub fn into_string(self) -> String {
410        unsafe {
411            // SAFETY: SD-JWTs are valid UTF-8 strings.
412            String::from_utf8_unchecked(self.0)
413        }
414    }
415}
416
417impl Deref for SdJwtBuf {
418    type Target = SdJwt;
419
420    fn deref(&self) -> &Self::Target {
421        self.as_sd_jwt()
422    }
423}
424
425impl Borrow<SdJwt> for SdJwtBuf {
426    fn borrow(&self) -> &SdJwt {
427        self.as_sd_jwt()
428    }
429}
430
431impl AsRef<SdJwt> for SdJwtBuf {
432    fn as_ref(&self) -> &SdJwt {
433        self.as_sd_jwt()
434    }
435}
436
437impl AsRef<str> for SdJwtBuf {
438    fn as_ref(&self) -> &str {
439        self.as_str()
440    }
441}
442
443impl AsRef<[u8]> for SdJwtBuf {
444    fn as_ref(&self) -> &[u8] {
445        self.as_bytes()
446    }
447}
448
449impl fmt::Display for SdJwtBuf {
450    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451        self.as_str().fmt(f)
452    }
453}
454
455impl fmt::Debug for SdJwtBuf {
456    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457        self.as_str().fmt(f)
458    }
459}
460
461impl FromStr for SdJwtBuf {
462    type Err = InvalidSdJwt;
463
464    fn from_str(s: &str) -> Result<Self, Self::Err> {
465        Self::new(s.to_owned())
466    }
467}
468
469impl serde::Serialize for SdJwtBuf {
470    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
471    where
472        S: serde::Serializer,
473    {
474        self.as_str().serialize(serializer)
475    }
476}
477
478impl<'de> serde::Deserialize<'de> for SdJwtBuf {
479    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
480    where
481        D: serde::Deserializer<'de>,
482    {
483        String::deserialize(deserializer)?
484            .parse()
485            .map_err(serde::de::Error::custom)
486    }
487}
488
489/// Iterator over the disclosures of an SD-JWT.
490pub struct Disclosures<'a> {
491    /// SD-JWT bytes.
492    bytes: &'a [u8],
493
494    /// Offset of the beginning of the next disclosure (if any).
495    offset: usize,
496}
497
498impl<'a> Iterator for Disclosures<'a> {
499    type Item = &'a Disclosure;
500
501    fn next(&mut self) -> Option<Self::Item> {
502        let mut i = self.offset;
503
504        while i < self.bytes.len() {
505            if self.bytes[i] == b'~' {
506                let disclosure = unsafe {
507                    // SAFETY: we already validated the SD-JWT and know
508                    // it is composed of valid disclosures.
509                    Disclosure::new_unchecked(&self.bytes[self.offset..i])
510                };
511
512                self.offset = i + 1;
513                return Some(disclosure);
514            }
515
516            i += 1
517        }
518
519        None
520    }
521}
522
523/// SD-JWT components to be presented for decoding and validation whether coming
524/// from a compact representation, enveloping JWT, etc.
525#[derive(Debug, PartialEq)]
526pub struct PartsRef<'a> {
527    /// JWT who's claims can be selectively disclosed.
528    pub jwt: &'a Jws,
529
530    /// Disclosures for associated JWT
531    pub disclosures: Vec<&'a Disclosure>,
532}
533
534impl<'a> PartsRef<'a> {
535    /// Creates a new `PartsRef`.
536    pub fn new(jwt: &'a Jws, disclosures: Vec<&'a Disclosure>) -> Self {
537        Self { jwt, disclosures }
538    }
539
540    /// Decodes and reveals the SD-JWT.
541    pub fn decode_reveal<T: DeserializeOwned>(self) -> Result<RevealedSdJwt<'a, T>, RevealError> {
542        let decoded = self.decode()?;
543        decoded.reveal()
544    }
545
546    /// Decodes and reveals the SD-JWT.
547    pub fn decode_reveal_any(self) -> Result<RevealedSdJwt<'a>, RevealError> {
548        let decoded = self.decode()?;
549        decoded.reveal_any()
550    }
551
552    /// Decode a compact SD-JWT.
553    pub async fn decode_verify_concealed<P>(
554        self,
555        params: P,
556    ) -> Result<(DecodedSdJwt<'a>, Verification), ProofValidationError>
557    where
558        P: ResolverProvider<Resolver: JWKResolver>,
559    {
560        let decoded = self.decode().map_err(ProofValidationError::input_data)?;
561        let verification = decoded.verify_concealed(params).await?;
562        Ok((decoded, verification))
563    }
564
565    /// Decodes, reveals and verify a compact SD-JWT.
566    ///
567    /// Only the registered JWT claims will be validated.
568    /// If you need to validate custom claims, use the
569    /// [`Self::decode_reveal_verify`] method with `T` defining the custom
570    /// claims.
571    ///
572    /// Returns the decoded JWT with the verification status.
573    pub async fn decode_reveal_verify_any<P>(
574        self,
575        params: P,
576    ) -> Result<(RevealedSdJwt<'a>, Verification), ProofValidationError>
577    where
578        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
579    {
580        let decoded = self.decode().map_err(ProofValidationError::input_data)?;
581        decoded.reveal_verify_any(params).await
582    }
583
584    /// Decodes, reveals and verify a compact SD-JWT.
585    ///
586    /// The type parameter `T` corresponds to the set of private JWT claims
587    /// contained in the encoded SD-JWT. If you don't know what value to use
588    /// for this parameter, you can use the [`Self::decode_reveal_verify_any`]
589    /// function instead.
590    ///
591    /// Returns the decoded JWT with the verification status.
592    pub async fn decode_reveal_verify<T, P>(
593        self,
594        params: P,
595    ) -> Result<(RevealedSdJwt<'a, T>, Verification), ProofValidationError>
596    where
597        T: ClaimSet + DeserializeOwned + ValidateClaims<P, JwsSignature>,
598        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
599    {
600        let decoded = self.decode().map_err(ProofValidationError::input_data)?;
601        decoded.reveal_verify(params).await
602    }
603}
604
605impl fmt::Display for PartsRef<'_> {
606    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
607        self.jwt.fmt(f)?;
608        f.write_char('~')?;
609
610        for d in &self.disclosures {
611            d.fmt(f)?;
612            f.write_char('~')?;
613        }
614
615        Ok(())
616    }
617}
618
619/// Undisclosed SD-JWT payload.
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct SdJwtPayload {
622    /// Hash algorithm used by the Issuer to generate the digests.
623    #[serde(rename = "_sd_alg")]
624    pub sd_alg: SdAlg,
625
626    /// Other claims.
627    #[serde(flatten)]
628    pub claims: serde_json::Map<String, Value>,
629}
630
631impl JwsPayload for SdJwtPayload {
632    fn payload_bytes(&self) -> Cow<[u8]> {
633        Cow::Owned(serde_json::to_vec(self).unwrap())
634    }
635}
636
637impl<E> ValidateJwsHeader<E> for SdJwtPayload {}
638
639impl<E, P> ValidateClaims<E, P> for SdJwtPayload {}
640
641/// Decoded SD-JWT.
642pub struct DecodedSdJwt<'a> {
643    /// JWT who's claims can be selectively disclosed.
644    pub jwt: DecodedJws<'a, SdJwtPayload>,
645
646    /// Disclosures for associated JWT.
647    pub disclosures: Vec<DecodedDisclosure<'a>>,
648}
649
650impl<'a> DecodedSdJwt<'a> {
651    /// Verifies the decoded SD-JWT without revealing the concealed claims.
652    ///
653    /// No revealing the claims means only the registered JWT claims will be
654    /// validated.
655    pub async fn verify_concealed<P>(&self, params: P) -> Result<Verification, ProofValidationError>
656    where
657        P: ResolverProvider<Resolver: JWKResolver>,
658    {
659        self.jwt.verify(params).await
660    }
661
662    /// Verifies the decoded SD-JWT after revealing the claims.
663    ///
664    /// Only the registered JWT claims will be validated.
665    /// If you need to validate custom claims, use the [`Self::reveal_verify`]
666    /// method with `T` defining the custom claims.
667    ///
668    /// Returns the decoded JWT with the verification status.
669    pub async fn reveal_verify_any<P>(
670        self,
671        params: P,
672    ) -> Result<(RevealedSdJwt<'a>, Verification), ProofValidationError>
673    where
674        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
675    {
676        let revealed = self
677            .reveal_any()
678            .map_err(ProofValidationError::input_data)?;
679        let verification = revealed.verify(params).await?;
680        Ok((revealed, verification))
681    }
682
683    /// Verifies the decoded SD-JWT after revealing the claims.
684    ///
685    /// The type parameter `T` corresponds to the set of private JWT claims.
686    /// If you don't know what value to use for this parameter, you can use the
687    /// [`Self::reveal_verify_any`] function instead.
688    ///
689    /// The `T` type parameter is the type of private claims.
690    pub async fn reveal_verify<T, P>(
691        self,
692        params: P,
693    ) -> Result<(RevealedSdJwt<'a, T>, Verification), ProofValidationError>
694    where
695        T: ClaimSet + DeserializeOwned + ValidateClaims<P, JwsSignature>,
696        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
697    {
698        let revealed = self
699            .reveal::<T>()
700            .map_err(ProofValidationError::input_data)?;
701        let verification = revealed.verify(params).await?;
702        Ok((revealed, verification))
703    }
704}
705
706impl DecodedSdJwt<'static> {
707    /// Conceal and sign the given claims.
708    pub async fn conceal_and_sign(
709        claims: &JWTClaims<impl Serialize>,
710        sd_alg: SdAlg,
711        pointers: &[impl Borrow<JsonPointer>],
712        signer: impl JwsSigner,
713    ) -> Result<Self, SignatureError> {
714        let (payload, disclosures) =
715            SdJwtPayload::conceal(claims, sd_alg, pointers).map_err(SignatureError::other)?;
716
717        Ok(Self {
718            jwt: signer.sign_into_decoded(payload).await?,
719            disclosures,
720        })
721    }
722
723    /// Conceal and sign the given claims with a custom rng.
724    pub async fn conceal_and_sign_with(
725        claims: &JWTClaims<impl Serialize>,
726        sd_alg: SdAlg,
727        pointers: &[impl Borrow<JsonPointer>],
728        signer: impl JwsSigner,
729        rng: impl CryptoRng + RngCore,
730    ) -> Result<Self, SignatureError> {
731        let (payload, disclosures) = SdJwtPayload::conceal_with(claims, sd_alg, pointers, rng)
732            .map_err(SignatureError::other)?;
733
734        Ok(Self {
735            jwt: signer.sign_into_decoded(payload).await?,
736            disclosures,
737        })
738    }
739
740    /// Encodes the SD-JWT.
741    pub fn into_encoded(self) -> SdJwtBuf {
742        let mut bytes = self.jwt.into_encoded().into_bytes();
743        bytes.push(b'~');
744
745        for d in self.disclosures {
746            bytes.extend_from_slice(d.encoded.as_bytes());
747            bytes.push(b'~');
748        }
749
750        unsafe {
751            // SAFETY: we just constructed those bytes following the SD-JWT
752            // syntax.
753            SdJwtBuf::new_unchecked(bytes)
754        }
755    }
756}
757
758/// Revealed SD-JWT.
759///
760/// This is similar to a [`DecodedSdJwt`] but with the JWT claims revealed.
761/// You can use this type to access the revealed claims, and filter the
762/// disclosures.
763#[derive(Debug, Clone)]
764pub struct RevealedSdJwt<'a, T = AnyClaims> {
765    /// Decoded JWT.
766    ///
767    /// The JWT bytes still contain the concealed SD-JWT claims, but the
768    /// decoded payload is revealed.
769    pub jwt: DecodedJwt<'a, T>,
770
771    /// Disclosures bound to their JSON pointers.
772    pub disclosures: BTreeMap<JsonPointerBuf, DecodedDisclosure<'a>>,
773}
774
775impl<'a, T> RevealedSdJwt<'a, T> {
776    /// Returns a reference to the revealed JWT claims.
777    pub fn claims(&self) -> &JWTClaims<T> {
778        &self.jwt.signing_bytes.payload
779    }
780
781    /// Turns this SD-JWT into its revealed JWT claims.
782    pub fn into_claims(self) -> JWTClaims<T> {
783        self.jwt.signing_bytes.payload
784    }
785
786    /// Verifies the SD-JWT, validating the revealed claims.
787    pub async fn verify<P>(&self, params: P) -> Result<Verification, ProofValidationError>
788    where
789        T: ClaimSet + ValidateClaims<P, JwsSignature>,
790        P: ResolverProvider<Resolver: JWKResolver> + DateTimeProvider,
791    {
792        self.jwt.verify(params).await
793    }
794
795    /// Removes all the disclosures.
796    pub fn clear(&mut self) {
797        self.disclosures.clear()
798    }
799
800    /// Removes all the disclosures.
801    pub fn cleared(mut self) -> Self {
802        self.clear();
803        self
804    }
805
806    /// Filter the disclosures, leaving only the ones targeting the given
807    /// JSON pointers.
808    ///
809    /// Returns a map containing the filtered-out disclosures and their
810    /// pointers.
811    pub fn retain(
812        &mut self,
813        pointers: &[impl Borrow<JsonPointer>],
814    ) -> BTreeMap<JsonPointerBuf, DecodedDisclosure<'a>> {
815        let mut disclosures = BTreeMap::new();
816
817        for p in pointers {
818            if let Some((p, d)) = self.disclosures.remove_entry(p.borrow()) {
819                disclosures.insert(p, d);
820            }
821        }
822
823        std::mem::swap(&mut disclosures, &mut self.disclosures);
824        disclosures
825    }
826
827    /// Filter the disclosures, leaving only the ones targeting the given
828    /// JSON pointers.
829    ///
830    /// Returns a map containing the filtered-out disclosures and their
831    /// pointers.
832    pub fn retaining(mut self, pointers: &[impl Borrow<JsonPointer>]) -> Self {
833        self.retain(pointers);
834        self
835    }
836
837    /// Filter the disclosures, removing the ones targeting the given JSON
838    /// pointers.
839    ///
840    /// Returns a map containing the filtered-out disclosures and their
841    /// pointers.
842    pub fn reject(
843        &mut self,
844        pointers: &[impl Borrow<JsonPointer>],
845    ) -> BTreeMap<JsonPointerBuf, DecodedDisclosure<'a>> {
846        let mut disclosures = BTreeMap::new();
847
848        for p in pointers {
849            if let Some((p, d)) = self.disclosures.remove_entry(p.borrow()) {
850                disclosures.insert(p, d);
851            }
852        }
853
854        disclosures
855    }
856
857    /// Filter the disclosures, removing the ones targeting the given JSON
858    /// pointers.
859    pub fn rejecting(mut self, pointers: &[impl Borrow<JsonPointer>]) -> Self {
860        self.reject(pointers);
861        self
862    }
863
864    /// Encodes the SD-JWT, re-concealing the claims.
865    pub fn into_encoded(self) -> SdJwtBuf {
866        let mut bytes = self.jwt.into_encoded().into_bytes();
867        bytes.push(b'~');
868
869        for d in self.disclosures.into_values() {
870            bytes.extend_from_slice(d.encoded.as_bytes());
871            bytes.push(b'~');
872        }
873
874        unsafe {
875            // SAFETY: we just constructed those bytes following the SD-JWT
876            // syntax.
877            SdJwtBuf::new_unchecked(bytes)
878        }
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885
886    const ENCODED: &str = concat!(
887        "eyJhbGciOiAiRVMyNTYifQ.eyJfc2QiOiBbIkM5aW5wNllvUmFFWFI0Mjd6WUpQN1Fya",
888        "zFXSF84YmR3T0FfWVVyVW5HUVUiLCAiS3VldDF5QWEwSElRdlluT1ZkNTloY1ZpTzlVZ",
889        "zZKMmtTZnFZUkJlb3d2RSIsICJNTWxkT0ZGekIyZDB1bWxtcFRJYUdlcmhXZFVfUHBZZ",
890        "kx2S2hoX2ZfOWFZIiwgIlg2WkFZT0lJMnZQTjQwVjd4RXhad1Z3ejd5Um1MTmNWd3Q1R",
891        "Ew4Ukx2NGciLCAiWTM0em1JbzBRTExPdGRNcFhHd2pCZ0x2cjE3eUVoaFlUMEZHb2ZSL",
892        "WFJRSIsICJmeUdwMFdUd3dQdjJKRFFsbjFsU2lhZW9iWnNNV0ExMGJRNTk4OS05RFRzI",
893        "iwgIm9tbUZBaWNWVDhMR0hDQjB1eXd4N2ZZdW8zTUhZS08xNWN6LVJaRVlNNVEiLCAic",
894        "zBCS1lzTFd4UVFlVTh0VmxsdE03TUtzSVJUckVJYTFQa0ptcXhCQmY1VSJdLCAiaXNzI",
895        "jogImh0dHBzOi8vZXhhbXBsZS5jb20vaXNzdWVyIiwgImlhdCI6IDE2ODMwMDAwMDAsI",
896        "CJleHAiOiAxODgzMDAwMDAwLCAiYWRkcmVzcyI6IHsiX3NkIjogWyI2YVVoelloWjdTS",
897        "jFrVm1hZ1FBTzN1MkVUTjJDQzFhSGhlWnBLbmFGMF9FIiwgIkF6TGxGb2JrSjJ4aWF1c",
898        "FJFUHlvSnotOS1OU2xkQjZDZ2pyN2ZVeW9IemciLCAiUHp6Y1Z1MHFiTXVCR1NqdWxmZ",
899        "Xd6a2VzRDl6dXRPRXhuNUVXTndrclEtayIsICJiMkRrdzBqY0lGOXJHZzhfUEY4WmN2b",
900        "mNXN3p3Wmo1cnlCV3ZYZnJwemVrIiwgImNQWUpISVo4VnUtZjlDQ3lWdWIyVWZnRWs4a",
901        "nZ2WGV6d0sxcF9KbmVlWFEiLCAiZ2xUM2hyU1U3ZlNXZ3dGNVVEWm1Xd0JUdzMyZ25Vb",
902        "GRJaGk4aEdWQ2FWNCIsICJydkpkNmlxNlQ1ZWptc0JNb0d3dU5YaDlxQUFGQVRBY2k0M",
903        "G9pZEVlVnNBIiwgInVOSG9XWWhYc1poVkpDTkUyRHF5LXpxdDd0NjlnSkt5NVFhRnY3R",
904        "3JNWDQiXX0sICJfc2RfYWxnIjogInNoYS0yNTYifQ.rFsowW-KSZe7EITlWsGajR9nnG",
905        "BLlQ78qgtdGIZg3FZuZnxtapP0H8CUMnffJAwPQJmGnpFpulTkLWHiI1kMmw~WyJHMDJ",
906        "OU3JRZmpGWFE3SW8wOXN5YWpBIiwgInJlZ2lvbiIsICJcdTZlMmZcdTUzM2EiXQ~WyJs",
907        "a2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgImNvdW50cnkiLCAiSlAiXQ~"
908    );
909
910    const JWT: &str = concat!(
911        "eyJhbGciOiAiRVMyNTYifQ.eyJfc2QiOiBbIkM5aW5wNllvUmFFWFI0Mjd6WUpQN1Fya",
912        "zFXSF84YmR3T0FfWVVyVW5HUVUiLCAiS3VldDF5QWEwSElRdlluT1ZkNTloY1ZpTzlVZ",
913        "zZKMmtTZnFZUkJlb3d2RSIsICJNTWxkT0ZGekIyZDB1bWxtcFRJYUdlcmhXZFVfUHBZZ",
914        "kx2S2hoX2ZfOWFZIiwgIlg2WkFZT0lJMnZQTjQwVjd4RXhad1Z3ejd5Um1MTmNWd3Q1R",
915        "Ew4Ukx2NGciLCAiWTM0em1JbzBRTExPdGRNcFhHd2pCZ0x2cjE3eUVoaFlUMEZHb2ZSL",
916        "WFJRSIsICJmeUdwMFdUd3dQdjJKRFFsbjFsU2lhZW9iWnNNV0ExMGJRNTk4OS05RFRzI",
917        "iwgIm9tbUZBaWNWVDhMR0hDQjB1eXd4N2ZZdW8zTUhZS08xNWN6LVJaRVlNNVEiLCAic",
918        "zBCS1lzTFd4UVFlVTh0VmxsdE03TUtzSVJUckVJYTFQa0ptcXhCQmY1VSJdLCAiaXNzI",
919        "jogImh0dHBzOi8vZXhhbXBsZS5jb20vaXNzdWVyIiwgImlhdCI6IDE2ODMwMDAwMDAsI",
920        "CJleHAiOiAxODgzMDAwMDAwLCAiYWRkcmVzcyI6IHsiX3NkIjogWyI2YVVoelloWjdTS",
921        "jFrVm1hZ1FBTzN1MkVUTjJDQzFhSGhlWnBLbmFGMF9FIiwgIkF6TGxGb2JrSjJ4aWF1c",
922        "FJFUHlvSnotOS1OU2xkQjZDZ2pyN2ZVeW9IemciLCAiUHp6Y1Z1MHFiTXVCR1NqdWxmZ",
923        "Xd6a2VzRDl6dXRPRXhuNUVXTndrclEtayIsICJiMkRrdzBqY0lGOXJHZzhfUEY4WmN2b",
924        "mNXN3p3Wmo1cnlCV3ZYZnJwemVrIiwgImNQWUpISVo4VnUtZjlDQ3lWdWIyVWZnRWs4a",
925        "nZ2WGV6d0sxcF9KbmVlWFEiLCAiZ2xUM2hyU1U3ZlNXZ3dGNVVEWm1Xd0JUdzMyZ25Vb",
926        "GRJaGk4aEdWQ2FWNCIsICJydkpkNmlxNlQ1ZWptc0JNb0d3dU5YaDlxQUFGQVRBY2k0M",
927        "G9pZEVlVnNBIiwgInVOSG9XWWhYc1poVkpDTkUyRHF5LXpxdDd0NjlnSkt5NVFhRnY3R",
928        "3JNWDQiXX0sICJfc2RfYWxnIjogInNoYS0yNTYifQ.rFsowW-KSZe7EITlWsGajR9nnG",
929        "BLlQ78qgtdGIZg3FZuZnxtapP0H8CUMnffJAwPQJmGnpFpulTkLWHiI1kMmw"
930    );
931
932    const DISCLOSURE_0: &str =
933        "WyJHMDJOU3JRZmpGWFE3SW8wOXN5YWpBIiwgInJlZ2lvbiIsICJcdTZlMmZcdTUzM2EiXQ";
934    const DISCLOSURE_1: &str = "WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgImNvdW50cnkiLCAiSlAiXQ";
935
936    #[test]
937    fn deserialize() {
938        assert_eq!(
939            SdJwt::new(ENCODED).unwrap().parts(),
940            PartsRef::new(
941                Jws::new(JWT).unwrap(),
942                vec![
943                    Disclosure::new(DISCLOSURE_0).unwrap(),
944                    Disclosure::new(DISCLOSURE_1).unwrap()
945                ]
946            )
947        )
948    }
949
950    #[test]
951    fn deserialize_fails_with_emtpy() {
952        assert!(SdJwt::new("").is_err())
953    }
954
955    #[test]
956    fn serialize_parts() {
957        assert_eq!(
958            PartsRef::new(
959                Jws::new(JWT).unwrap(),
960                vec![
961                    Disclosure::new(DISCLOSURE_0).unwrap(),
962                    Disclosure::new(DISCLOSURE_1).unwrap()
963                ]
964            )
965            .to_string(),
966            ENCODED,
967        )
968    }
969}