Skip to main content

auth_framework/protocols/
sd_jwt.rs

1//! SD-JWT (Selective Disclosure JWT) implementation.
2//!
3//! Implements the IETF SD-JWT specification (draft-ietf-oauth-selective-disclosure-jwt)
4//! for creating JWTs whose claims can be selectively disclosed by the holder.
5//!
6//! # Architecture
7//!
8//! - **Issuer**: Creates an SD-JWT with selectively disclosable claims hashed into
9//!   the `_sd` array. Each claim becomes a separate disclosure.
10//! - **Holder**: Receives the full SD-JWT and can present a subset of disclosures
11//!   to a verifier, revealing only the claims they choose.
12//! - **Verifier**: Validates the JWT signature and reconstructs disclosed claims
13//!   by matching disclosure hashes against the `_sd` array.
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! use auth_framework::protocols::sd_jwt::{SdJwtIssuer, SdJwtConfig};
19//!
20//! let config = SdJwtConfig::default();
21//! let issuer = SdJwtIssuer::new(config);
22//!
23//! let mut claims = serde_json::Map::new();
24//! claims.insert("sub".into(), serde_json::json!("user-42"));
25//! claims.insert("email".into(), serde_json::json!("user@example.com"));
26//!
27//! // "email" is selectively disclosable; "sub" stays in the clear
28//! let sd_jwt = issuer.issue(
29//!     &claims,
30//!     &["email"],
31//!     "signing-secret-key",
32//! ).unwrap();
33//! ```
34
35use crate::errors::{AuthError, Result};
36use base64::engine::general_purpose::URL_SAFE_NO_PAD;
37use base64::Engine;
38use serde::{Deserialize, Serialize};
39use sha2::{Digest, Sha256};
40use std::collections::HashMap;
41
42/// Hash algorithm used for disclosure digests.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SdHashAlgorithm {
45    /// SHA-256 (default, recommended).
46    #[serde(rename = "sha-256")]
47    Sha256,
48}
49
50impl SdHashAlgorithm {
51    /// Return the `_sd_alg` string value.
52    pub fn as_str(&self) -> &'static str {
53        match self {
54            Self::Sha256 => "sha-256",
55        }
56    }
57
58    /// Compute the digest of `input` using this algorithm.
59    fn digest(&self, input: &[u8]) -> Vec<u8> {
60        match self {
61            Self::Sha256 => Sha256::digest(input).to_vec(),
62        }
63    }
64}
65
66impl Default for SdHashAlgorithm {
67    fn default() -> Self {
68        Self::Sha256
69    }
70}
71
72/// Configuration for SD-JWT operations.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct SdJwtConfig {
75    /// Hash algorithm for disclosure digests.
76    pub hash_algorithm: SdHashAlgorithm,
77    /// JWT signing algorithm.
78    pub signing_algorithm: jsonwebtoken::Algorithm,
79    /// Issuer claim value.
80    pub issuer: String,
81    /// Token lifetime in seconds.
82    pub lifetime_secs: u64,
83    /// Salt length in bytes (minimum 16 recommended by spec).
84    pub salt_length: usize,
85}
86
87impl Default for SdJwtConfig {
88    fn default() -> Self {
89        Self {
90            hash_algorithm: SdHashAlgorithm::default(),
91            signing_algorithm: jsonwebtoken::Algorithm::HS256,
92            issuer: "auth-framework".to_string(),
93            lifetime_secs: 3600,
94            salt_length: 16,
95        }
96    }
97}
98
99/// A single disclosure: the base64url-encoded `[salt, claim_name, claim_value]` array.
100#[derive(Debug, Clone)]
101pub struct Disclosure {
102    /// The base64url-encoded disclosure string.
103    pub encoded: String,
104    /// The claim name this disclosure reveals.
105    pub claim_name: String,
106    /// The claim value.
107    pub claim_value: serde_json::Value,
108    /// The hash digest of the encoded disclosure (for inclusion in `_sd`).
109    pub digest: String,
110}
111
112/// The issued SD-JWT: a compact JWT, the tilde-separated disclosures, and
113/// an optional key-binding JWT.
114#[derive(Debug, Clone)]
115pub struct SdJwt {
116    /// The signed JWT containing `_sd` digests.
117    pub jwt: String,
118    /// All disclosures produced by the issuer.
119    pub disclosures: Vec<Disclosure>,
120    /// Optional holder key-binding JWT.
121    pub key_binding_jwt: Option<String>,
122}
123
124impl SdJwt {
125    /// Serialize to the SD-JWT compact format: `<JWT>~<Disclosure1>~...~<DisclosureN>~[KB-JWT]`.
126    pub fn serialize(&self) -> String {
127        let mut out = self.jwt.clone();
128        for d in &self.disclosures {
129            out.push('~');
130            out.push_str(&d.encoded);
131        }
132        out.push('~');
133        if let Some(ref kb) = self.key_binding_jwt {
134            out.push_str(kb);
135        }
136        out
137    }
138
139    /// Create a presentation with only the selected claim names disclosed.
140    pub fn present(&self, claims_to_disclose: &[&str]) -> String {
141        let mut out = self.jwt.clone();
142        for d in &self.disclosures {
143            if claims_to_disclose.contains(&d.claim_name.as_str()) {
144                out.push('~');
145                out.push_str(&d.encoded);
146            }
147        }
148        out.push('~');
149        if let Some(ref kb) = self.key_binding_jwt {
150            out.push_str(kb);
151        }
152        out
153    }
154}
155
156/// SD-JWT issuer: creates SD-JWTs with selectively disclosable claims.
157pub struct SdJwtIssuer {
158    config: SdJwtConfig,
159}
160
161impl SdJwtIssuer {
162    /// Create a new issuer with the given configuration.
163    pub fn new(config: SdJwtConfig) -> Self {
164        Self { config }
165    }
166
167    /// Generate a cryptographically random salt.
168    fn generate_salt(&self) -> Result<String> {
169        let mut salt = vec![0u8; self.config.salt_length];
170        ring::rand::SecureRandom::fill(
171            &ring::rand::SystemRandom::new(),
172            &mut salt,
173        )
174        .map_err(|_| AuthError::crypto("Failed to generate random salt"))?;
175        Ok(URL_SAFE_NO_PAD.encode(&salt))
176    }
177
178    /// Build a disclosure for a single claim and return its digest.
179    fn create_disclosure(
180        &self,
181        claim_name: &str,
182        claim_value: &serde_json::Value,
183    ) -> Result<Disclosure> {
184        let salt = self.generate_salt()?;
185        let array = serde_json::json!([salt, claim_name, claim_value]);
186        let encoded = URL_SAFE_NO_PAD.encode(array.to_string().as_bytes());
187        let hash = self.config.hash_algorithm.digest(encoded.as_bytes());
188        let digest = URL_SAFE_NO_PAD.encode(&hash);
189
190        Ok(Disclosure {
191            encoded,
192            claim_name: claim_name.to_string(),
193            claim_value: claim_value.clone(),
194            digest,
195        })
196    }
197
198    /// Issue an SD-JWT.
199    ///
200    /// * `claims` — all claims to include in the token.
201    /// * `sd_claims` — names of claims that should be selectively disclosable.
202    ///   Claims not listed here are included in the JWT payload in the clear.
203    /// * `signing_key` — the symmetric key (for HMAC algorithms) or PEM-encoded
204    ///   private key (for RSA/EC algorithms).
205    pub fn issue(
206        &self,
207        claims: &serde_json::Map<String, serde_json::Value>,
208        sd_claims: &[&str],
209        signing_key: &str,
210    ) -> Result<SdJwt> {
211        if claims.is_empty() {
212            return Err(AuthError::validation("Claims map cannot be empty"));
213        }
214
215        let mut payload = serde_json::Map::new();
216        let mut disclosures = Vec::new();
217        let mut sd_digests: Vec<serde_json::Value> = Vec::new();
218
219        // Separate plaintext claims from selectively-disclosable claims.
220        for (name, value) in claims {
221            if sd_claims.contains(&name.as_str()) {
222                let disclosure = self.create_disclosure(name, value)?;
223                sd_digests.push(serde_json::Value::String(disclosure.digest.clone()));
224                disclosures.push(disclosure);
225            } else {
226                payload.insert(name.clone(), value.clone());
227            }
228        }
229
230        // Add standard JWT claims.
231        let now = chrono::Utc::now().timestamp() as u64;
232        payload.insert("iss".to_string(), serde_json::json!(self.config.issuer));
233        payload.insert("iat".to_string(), serde_json::json!(now));
234        payload.insert(
235            "exp".to_string(),
236            serde_json::json!(now + self.config.lifetime_secs),
237        );
238
239        // Add the `_sd` array and `_sd_alg`.
240        if !sd_digests.is_empty() {
241            payload.insert("_sd".to_string(), serde_json::Value::Array(sd_digests));
242            payload.insert(
243                "_sd_alg".to_string(),
244                serde_json::json!(self.config.hash_algorithm.as_str()),
245            );
246        }
247
248        // Sign the JWT.
249        let header = jsonwebtoken::Header::new(self.config.signing_algorithm);
250        let key = jsonwebtoken::EncodingKey::from_secret(signing_key.as_bytes());
251        let jwt = jsonwebtoken::encode(&header, &payload, &key)
252            .map_err(|e| AuthError::crypto(format!("SD-JWT signing failed: {e}")))?;
253
254        Ok(SdJwt {
255            jwt,
256            disclosures,
257            key_binding_jwt: None,
258        })
259    }
260}
261
262/// SD-JWT verifier: validates SD-JWTs and reconstructs disclosed claims.
263pub struct SdJwtVerifier {
264    config: SdJwtConfig,
265}
266
267impl SdJwtVerifier {
268    /// Create a new verifier.
269    pub fn new(config: SdJwtConfig) -> Self {
270        Self { config }
271    }
272
273    /// Parse a serialized SD-JWT string into its components.
274    pub fn parse(input: &str) -> Result<(String, Vec<String>, Option<String>)> {
275        let parts: Vec<&str> = input.split('~').collect();
276        if parts.len() < 2 {
277            return Err(AuthError::validation(
278                "Invalid SD-JWT format: must contain at least JWT~",
279            ));
280        }
281
282        let jwt = parts[0].to_string();
283        let last = *parts.last().unwrap();
284
285        // If the last part is empty, there is no key-binding JWT.
286        // If the last part looks like a JWT (has dots), treat it as KB-JWT.
287        let (disclosure_parts, kb_jwt) = if last.is_empty() {
288            (&parts[1..parts.len() - 1], None)
289        } else if last.chars().filter(|&c| c == '.').count() == 2 {
290            (
291                &parts[1..parts.len() - 1],
292                Some(last.to_string()),
293            )
294        } else {
295            (&parts[1..], None)
296        };
297
298        let disclosures = disclosure_parts
299            .iter()
300            .filter(|s| !s.is_empty())
301            .map(|s| s.to_string())
302            .collect();
303
304        Ok((jwt, disclosures, kb_jwt))
305    }
306
307    /// Verify an SD-JWT and return the disclosed claims.
308    ///
309    /// * `sd_jwt_str` — the compact SD-JWT string.
310    /// * `verification_key` — the symmetric key or public key for signature verification.
311    pub fn verify(
312        &self,
313        sd_jwt_str: &str,
314        verification_key: &str,
315    ) -> Result<VerifiedSdJwt> {
316        let (jwt, disclosure_strings, kb_jwt) = Self::parse(sd_jwt_str)?;
317
318        // Verify JWT signature and decode payload.
319        let key = jsonwebtoken::DecodingKey::from_secret(verification_key.as_bytes());
320        let mut validation = jsonwebtoken::Validation::new(self.config.signing_algorithm);
321        validation.set_required_spec_claims::<&str>(&[]);
322        validation.validate_exp = true;
323        validation.set_issuer(&[&self.config.issuer]);
324
325        let token_data = jsonwebtoken::decode::<serde_json::Map<String, serde_json::Value>>(
326            &jwt,
327            &key,
328            &validation,
329        )
330        .map_err(|e| AuthError::token(format!("SD-JWT signature verification failed: {e}")))?;
331
332        let mut payload = token_data.claims;
333
334        // Extract `_sd` digests and `_sd_alg`.
335        let sd_digests: Vec<String> = payload
336            .remove("_sd")
337            .map(|v| {
338                v.as_array()
339                    .unwrap_or(&vec![])
340                    .iter()
341                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
342                    .collect()
343            })
344            .unwrap_or_default();
345
346        let _sd_alg = payload.remove("_sd_alg");
347
348        // Process disclosures: decode, hash, and match against `_sd`.
349        let mut disclosed_claims = HashMap::new();
350        for disclosure_str in &disclosure_strings {
351            let decoded_bytes = URL_SAFE_NO_PAD
352                .decode(disclosure_str.as_bytes())
353                .map_err(|e| {
354                    AuthError::validation(format!("Invalid disclosure encoding: {e}"))
355                })?;
356
357            let disclosure_array: serde_json::Value =
358                serde_json::from_slice(&decoded_bytes).map_err(|e| {
359                    AuthError::validation(format!("Invalid disclosure JSON: {e}"))
360                })?;
361
362            let arr = disclosure_array.as_array().ok_or_else(|| {
363                AuthError::validation("Disclosure must be a JSON array")
364            })?;
365
366            if arr.len() != 3 {
367                return Err(AuthError::validation(
368                    "Disclosure array must have exactly 3 elements [salt, name, value]",
369                ));
370            }
371
372            let claim_name = arr[1].as_str().ok_or_else(|| {
373                AuthError::validation("Disclosure claim name must be a string")
374            })?;
375            let claim_value = &arr[2];
376
377            // Verify the disclosure hash is in the `_sd` array.
378            let hash = self.config.hash_algorithm.digest(disclosure_str.as_bytes());
379            let digest = URL_SAFE_NO_PAD.encode(&hash);
380
381            if !sd_digests.contains(&digest) {
382                return Err(AuthError::validation(format!(
383                    "Disclosure for '{}' does not match any _sd digest",
384                    claim_name,
385                )));
386            }
387
388            disclosed_claims.insert(claim_name.to_string(), claim_value.clone());
389        }
390
391        Ok(VerifiedSdJwt {
392            plaintext_claims: payload,
393            disclosed_claims,
394            key_binding_jwt: kb_jwt,
395        })
396    }
397}
398
399/// Result of verifying an SD-JWT.
400#[derive(Debug, Clone)]
401pub struct VerifiedSdJwt {
402    /// Claims that were in the JWT payload in the clear.
403    pub plaintext_claims: serde_json::Map<String, serde_json::Value>,
404    /// Claims reconstructed from the presented disclosures.
405    pub disclosed_claims: HashMap<String, serde_json::Value>,
406    /// The optional key-binding JWT, if present.
407    pub key_binding_jwt: Option<String>,
408}
409
410impl VerifiedSdJwt {
411    /// Get a claim by name, checking disclosed claims first, then plaintext.
412    pub fn get_claim(&self, name: &str) -> Option<&serde_json::Value> {
413        self.disclosed_claims
414            .get(name)
415            .or_else(|| self.plaintext_claims.get(name))
416    }
417
418    /// Merge all claims into a single map.
419    pub fn all_claims(&self) -> serde_json::Map<String, serde_json::Value> {
420        let mut merged = self.plaintext_claims.clone();
421        for (k, v) in &self.disclosed_claims {
422            merged.insert(k.clone(), v.clone());
423        }
424        merged
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    const TEST_KEY: &str = "test-signing-key-at-least-256-bits-long!!";
433
434    fn test_config() -> SdJwtConfig {
435        SdJwtConfig {
436            lifetime_secs: 3600,
437            ..SdJwtConfig::default()
438        }
439    }
440
441    fn sample_claims() -> serde_json::Map<String, serde_json::Value> {
442        let mut claims = serde_json::Map::new();
443        claims.insert("sub".into(), serde_json::json!("user-42"));
444        claims.insert("email".into(), serde_json::json!("user@example.com"));
445        claims.insert("name".into(), serde_json::json!("Alice"));
446        claims.insert(
447            "address".into(),
448            serde_json::json!({"street": "123 Main St", "city": "Springfield"}),
449        );
450        claims
451    }
452
453    #[test]
454    fn test_issue_and_serialize() {
455        let issuer = SdJwtIssuer::new(test_config());
456        let claims = sample_claims();
457        let sd_jwt = issuer.issue(&claims, &["email", "address"], TEST_KEY).unwrap();
458
459        assert!(!sd_jwt.jwt.is_empty());
460        assert_eq!(sd_jwt.disclosures.len(), 2);
461
462        let serialized = sd_jwt.serialize();
463        // JWT + 2 disclosures + trailing tilde
464        assert_eq!(serialized.matches('~').count(), 3);
465    }
466
467    #[test]
468    fn test_issue_no_sd_claims() {
469        let issuer = SdJwtIssuer::new(test_config());
470        let claims = sample_claims();
471        let sd_jwt = issuer.issue(&claims, &[], TEST_KEY).unwrap();
472
473        assert!(sd_jwt.disclosures.is_empty());
474        let serialized = sd_jwt.serialize();
475        assert!(serialized.ends_with('~'));
476    }
477
478    #[test]
479    fn test_full_disclosure_roundtrip() {
480        let config = test_config();
481        let issuer = SdJwtIssuer::new(config.clone());
482        let verifier = SdJwtVerifier::new(config);
483        let claims = sample_claims();
484
485        let sd_jwt = issuer.issue(&claims, &["email", "name"], TEST_KEY).unwrap();
486        let serialized = sd_jwt.serialize();
487
488        let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
489
490        assert_eq!(verified.get_claim("sub").unwrap(), "user-42");
491        assert_eq!(verified.get_claim("email").unwrap(), "user@example.com");
492        assert_eq!(verified.get_claim("name").unwrap(), "Alice");
493    }
494
495    #[test]
496    fn test_selective_disclosure() {
497        let config = test_config();
498        let issuer = SdJwtIssuer::new(config.clone());
499        let verifier = SdJwtVerifier::new(config);
500        let claims = sample_claims();
501
502        let sd_jwt = issuer
503            .issue(&claims, &["email", "name", "address"], TEST_KEY)
504            .unwrap();
505
506        // Present only "email", omitting "name" and "address"
507        let presentation = sd_jwt.present(&["email"]);
508
509        let verified = verifier.verify(&presentation, TEST_KEY).unwrap();
510
511        // "sub" is plaintext — always visible
512        assert_eq!(verified.get_claim("sub").unwrap(), "user-42");
513        // "email" was disclosed
514        assert_eq!(verified.get_claim("email").unwrap(), "user@example.com");
515        // "name" and "address" were NOT disclosed
516        assert!(verified.get_claim("name").is_none());
517        assert!(verified.get_claim("address").is_none());
518    }
519
520    #[test]
521    fn test_all_claims_merged() {
522        let config = test_config();
523        let issuer = SdJwtIssuer::new(config.clone());
524        let verifier = SdJwtVerifier::new(config);
525        let claims = sample_claims();
526
527        let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
528        let serialized = sd_jwt.serialize();
529
530        let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
531        let merged = verified.all_claims();
532
533        assert!(merged.contains_key("sub"));
534        assert!(merged.contains_key("email"));
535        assert!(merged.contains_key("iss"));
536        assert!(merged.contains_key("iat"));
537        assert!(merged.contains_key("exp"));
538    }
539
540    #[test]
541    fn test_reject_empty_claims() {
542        let issuer = SdJwtIssuer::new(test_config());
543        let claims = serde_json::Map::new();
544        assert!(issuer.issue(&claims, &[], TEST_KEY).is_err());
545    }
546
547    #[test]
548    fn test_reject_wrong_key() {
549        let config = test_config();
550        let issuer = SdJwtIssuer::new(config.clone());
551        let verifier = SdJwtVerifier::new(config);
552        let claims = sample_claims();
553
554        let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
555        let serialized = sd_jwt.serialize();
556
557        assert!(verifier.verify(&serialized, "wrong-key-wrong-key-wrong-key!!!").is_err());
558    }
559
560    #[test]
561    fn test_reject_forged_disclosure() {
562        let config = test_config();
563        let issuer = SdJwtIssuer::new(config.clone());
564        let verifier = SdJwtVerifier::new(config);
565        let claims = sample_claims();
566
567        let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
568
569        // Forge a disclosure that isn't in the _sd array
570        let forged = serde_json::json!(["fakesalt", "role", "admin"]);
571        let forged_encoded = URL_SAFE_NO_PAD.encode(forged.to_string().as_bytes());
572        let forged_sd_jwt = format!("{}~{}~", sd_jwt.jwt, forged_encoded);
573
574        assert!(verifier.verify(&forged_sd_jwt, TEST_KEY).is_err());
575    }
576
577    #[test]
578    fn test_parse_components() {
579        let input = "eyJ0eXAi.payload.sig~disc1~disc2~";
580        let (jwt, disclosures, kb) = SdJwtVerifier::parse(input).unwrap();
581        assert_eq!(jwt, "eyJ0eXAi.payload.sig");
582        assert_eq!(disclosures.len(), 2);
583        assert!(kb.is_none());
584    }
585
586    #[test]
587    fn test_parse_with_kb_jwt() {
588        let input = "eyJ0eXAi.payload.sig~disc1~header.payload.signature";
589        let (jwt, disclosures, kb) = SdJwtVerifier::parse(input).unwrap();
590        assert_eq!(jwt, "eyJ0eXAi.payload.sig");
591        assert_eq!(disclosures.len(), 1);
592        assert_eq!(kb.unwrap(), "header.payload.signature");
593    }
594
595    #[test]
596    fn test_disclosure_uniqueness() {
597        let issuer = SdJwtIssuer::new(test_config());
598        let claims = sample_claims();
599
600        let sd_jwt1 = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
601        let sd_jwt2 = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
602
603        // Different salts produce different disclosures
604        assert_ne!(sd_jwt1.disclosures[0].encoded, sd_jwt2.disclosures[0].encoded);
605        assert_ne!(sd_jwt1.disclosures[0].digest, sd_jwt2.disclosures[0].digest);
606    }
607
608    #[test]
609    fn test_complex_claim_value() {
610        let config = test_config();
611        let issuer = SdJwtIssuer::new(config.clone());
612        let verifier = SdJwtVerifier::new(config);
613
614        let mut claims = serde_json::Map::new();
615        claims.insert("sub".into(), serde_json::json!("user-1"));
616        claims.insert(
617            "address".into(),
618            serde_json::json!({
619                "street": "123 Main St",
620                "city": "Springfield",
621                "zip": "62701"
622            }),
623        );
624
625        let sd_jwt = issuer.issue(&claims, &["address"], TEST_KEY).unwrap();
626        let serialized = sd_jwt.serialize();
627        let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
628
629        let addr = verified.get_claim("address").unwrap();
630        assert_eq!(addr["city"], "Springfield");
631        assert_eq!(addr["zip"], "62701");
632    }
633}