1use crate::errors::{AuthError, Result};
36use base64::engine::general_purpose::URL_SAFE_NO_PAD;
37use base64::Engine;
38use serde::{Deserialize, Serialize};
39use sha2::{Digest, Sha256};
40use std::collections::HashMap;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SdHashAlgorithm {
45 #[serde(rename = "sha-256")]
47 Sha256,
48}
49
50impl SdHashAlgorithm {
51 pub fn as_str(&self) -> &'static str {
53 match self {
54 Self::Sha256 => "sha-256",
55 }
56 }
57
58 fn digest(&self, input: &[u8]) -> Vec<u8> {
60 match self {
61 Self::Sha256 => Sha256::digest(input).to_vec(),
62 }
63 }
64}
65
66impl Default for SdHashAlgorithm {
67 fn default() -> Self {
68 Self::Sha256
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct SdJwtConfig {
75 pub hash_algorithm: SdHashAlgorithm,
77 pub signing_algorithm: jsonwebtoken::Algorithm,
79 pub issuer: String,
81 pub lifetime_secs: u64,
83 pub salt_length: usize,
85}
86
87impl Default for SdJwtConfig {
88 fn default() -> Self {
89 Self {
90 hash_algorithm: SdHashAlgorithm::default(),
91 signing_algorithm: jsonwebtoken::Algorithm::HS256,
92 issuer: "auth-framework".to_string(),
93 lifetime_secs: 3600,
94 salt_length: 16,
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct Disclosure {
102 pub encoded: String,
104 pub claim_name: String,
106 pub claim_value: serde_json::Value,
108 pub digest: String,
110}
111
112#[derive(Debug, Clone)]
115pub struct SdJwt {
116 pub jwt: String,
118 pub disclosures: Vec<Disclosure>,
120 pub key_binding_jwt: Option<String>,
122}
123
124impl SdJwt {
125 pub fn serialize(&self) -> String {
127 let mut out = self.jwt.clone();
128 for d in &self.disclosures {
129 out.push('~');
130 out.push_str(&d.encoded);
131 }
132 out.push('~');
133 if let Some(ref kb) = self.key_binding_jwt {
134 out.push_str(kb);
135 }
136 out
137 }
138
139 pub fn present(&self, claims_to_disclose: &[&str]) -> String {
141 let mut out = self.jwt.clone();
142 for d in &self.disclosures {
143 if claims_to_disclose.contains(&d.claim_name.as_str()) {
144 out.push('~');
145 out.push_str(&d.encoded);
146 }
147 }
148 out.push('~');
149 if let Some(ref kb) = self.key_binding_jwt {
150 out.push_str(kb);
151 }
152 out
153 }
154}
155
156pub struct SdJwtIssuer {
158 config: SdJwtConfig,
159}
160
161impl SdJwtIssuer {
162 pub fn new(config: SdJwtConfig) -> Self {
164 Self { config }
165 }
166
167 fn generate_salt(&self) -> Result<String> {
169 let mut salt = vec![0u8; self.config.salt_length];
170 ring::rand::SecureRandom::fill(
171 &ring::rand::SystemRandom::new(),
172 &mut salt,
173 )
174 .map_err(|_| AuthError::crypto("Failed to generate random salt"))?;
175 Ok(URL_SAFE_NO_PAD.encode(&salt))
176 }
177
178 fn create_disclosure(
180 &self,
181 claim_name: &str,
182 claim_value: &serde_json::Value,
183 ) -> Result<Disclosure> {
184 let salt = self.generate_salt()?;
185 let array = serde_json::json!([salt, claim_name, claim_value]);
186 let encoded = URL_SAFE_NO_PAD.encode(array.to_string().as_bytes());
187 let hash = self.config.hash_algorithm.digest(encoded.as_bytes());
188 let digest = URL_SAFE_NO_PAD.encode(&hash);
189
190 Ok(Disclosure {
191 encoded,
192 claim_name: claim_name.to_string(),
193 claim_value: claim_value.clone(),
194 digest,
195 })
196 }
197
198 pub fn issue(
206 &self,
207 claims: &serde_json::Map<String, serde_json::Value>,
208 sd_claims: &[&str],
209 signing_key: &str,
210 ) -> Result<SdJwt> {
211 if claims.is_empty() {
212 return Err(AuthError::validation("Claims map cannot be empty"));
213 }
214
215 let mut payload = serde_json::Map::new();
216 let mut disclosures = Vec::new();
217 let mut sd_digests: Vec<serde_json::Value> = Vec::new();
218
219 for (name, value) in claims {
221 if sd_claims.contains(&name.as_str()) {
222 let disclosure = self.create_disclosure(name, value)?;
223 sd_digests.push(serde_json::Value::String(disclosure.digest.clone()));
224 disclosures.push(disclosure);
225 } else {
226 payload.insert(name.clone(), value.clone());
227 }
228 }
229
230 let now = chrono::Utc::now().timestamp() as u64;
232 payload.insert("iss".to_string(), serde_json::json!(self.config.issuer));
233 payload.insert("iat".to_string(), serde_json::json!(now));
234 payload.insert(
235 "exp".to_string(),
236 serde_json::json!(now + self.config.lifetime_secs),
237 );
238
239 if !sd_digests.is_empty() {
241 payload.insert("_sd".to_string(), serde_json::Value::Array(sd_digests));
242 payload.insert(
243 "_sd_alg".to_string(),
244 serde_json::json!(self.config.hash_algorithm.as_str()),
245 );
246 }
247
248 let header = jsonwebtoken::Header::new(self.config.signing_algorithm);
250 let key = jsonwebtoken::EncodingKey::from_secret(signing_key.as_bytes());
251 let jwt = jsonwebtoken::encode(&header, &payload, &key)
252 .map_err(|e| AuthError::crypto(format!("SD-JWT signing failed: {e}")))?;
253
254 Ok(SdJwt {
255 jwt,
256 disclosures,
257 key_binding_jwt: None,
258 })
259 }
260}
261
262pub struct SdJwtVerifier {
264 config: SdJwtConfig,
265}
266
267impl SdJwtVerifier {
268 pub fn new(config: SdJwtConfig) -> Self {
270 Self { config }
271 }
272
273 pub fn parse(input: &str) -> Result<(String, Vec<String>, Option<String>)> {
275 let parts: Vec<&str> = input.split('~').collect();
276 if parts.len() < 2 {
277 return Err(AuthError::validation(
278 "Invalid SD-JWT format: must contain at least JWT~",
279 ));
280 }
281
282 let jwt = parts[0].to_string();
283 let last = *parts.last().unwrap();
284
285 let (disclosure_parts, kb_jwt) = if last.is_empty() {
288 (&parts[1..parts.len() - 1], None)
289 } else if last.chars().filter(|&c| c == '.').count() == 2 {
290 (
291 &parts[1..parts.len() - 1],
292 Some(last.to_string()),
293 )
294 } else {
295 (&parts[1..], None)
296 };
297
298 let disclosures = disclosure_parts
299 .iter()
300 .filter(|s| !s.is_empty())
301 .map(|s| s.to_string())
302 .collect();
303
304 Ok((jwt, disclosures, kb_jwt))
305 }
306
307 pub fn verify(
312 &self,
313 sd_jwt_str: &str,
314 verification_key: &str,
315 ) -> Result<VerifiedSdJwt> {
316 let (jwt, disclosure_strings, kb_jwt) = Self::parse(sd_jwt_str)?;
317
318 let key = jsonwebtoken::DecodingKey::from_secret(verification_key.as_bytes());
320 let mut validation = jsonwebtoken::Validation::new(self.config.signing_algorithm);
321 validation.set_required_spec_claims::<&str>(&[]);
322 validation.validate_exp = true;
323 validation.set_issuer(&[&self.config.issuer]);
324
325 let token_data = jsonwebtoken::decode::<serde_json::Map<String, serde_json::Value>>(
326 &jwt,
327 &key,
328 &validation,
329 )
330 .map_err(|e| AuthError::token(format!("SD-JWT signature verification failed: {e}")))?;
331
332 let mut payload = token_data.claims;
333
334 let sd_digests: Vec<String> = payload
336 .remove("_sd")
337 .map(|v| {
338 v.as_array()
339 .unwrap_or(&vec![])
340 .iter()
341 .filter_map(|item| item.as_str().map(|s| s.to_string()))
342 .collect()
343 })
344 .unwrap_or_default();
345
346 let _sd_alg = payload.remove("_sd_alg");
347
348 let mut disclosed_claims = HashMap::new();
350 for disclosure_str in &disclosure_strings {
351 let decoded_bytes = URL_SAFE_NO_PAD
352 .decode(disclosure_str.as_bytes())
353 .map_err(|e| {
354 AuthError::validation(format!("Invalid disclosure encoding: {e}"))
355 })?;
356
357 let disclosure_array: serde_json::Value =
358 serde_json::from_slice(&decoded_bytes).map_err(|e| {
359 AuthError::validation(format!("Invalid disclosure JSON: {e}"))
360 })?;
361
362 let arr = disclosure_array.as_array().ok_or_else(|| {
363 AuthError::validation("Disclosure must be a JSON array")
364 })?;
365
366 if arr.len() != 3 {
367 return Err(AuthError::validation(
368 "Disclosure array must have exactly 3 elements [salt, name, value]",
369 ));
370 }
371
372 let claim_name = arr[1].as_str().ok_or_else(|| {
373 AuthError::validation("Disclosure claim name must be a string")
374 })?;
375 let claim_value = &arr[2];
376
377 let hash = self.config.hash_algorithm.digest(disclosure_str.as_bytes());
379 let digest = URL_SAFE_NO_PAD.encode(&hash);
380
381 if !sd_digests.contains(&digest) {
382 return Err(AuthError::validation(format!(
383 "Disclosure for '{}' does not match any _sd digest",
384 claim_name,
385 )));
386 }
387
388 disclosed_claims.insert(claim_name.to_string(), claim_value.clone());
389 }
390
391 Ok(VerifiedSdJwt {
392 plaintext_claims: payload,
393 disclosed_claims,
394 key_binding_jwt: kb_jwt,
395 })
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct VerifiedSdJwt {
402 pub plaintext_claims: serde_json::Map<String, serde_json::Value>,
404 pub disclosed_claims: HashMap<String, serde_json::Value>,
406 pub key_binding_jwt: Option<String>,
408}
409
410impl VerifiedSdJwt {
411 pub fn get_claim(&self, name: &str) -> Option<&serde_json::Value> {
413 self.disclosed_claims
414 .get(name)
415 .or_else(|| self.plaintext_claims.get(name))
416 }
417
418 pub fn all_claims(&self) -> serde_json::Map<String, serde_json::Value> {
420 let mut merged = self.plaintext_claims.clone();
421 for (k, v) in &self.disclosed_claims {
422 merged.insert(k.clone(), v.clone());
423 }
424 merged
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 const TEST_KEY: &str = "test-signing-key-at-least-256-bits-long!!";
433
434 fn test_config() -> SdJwtConfig {
435 SdJwtConfig {
436 lifetime_secs: 3600,
437 ..SdJwtConfig::default()
438 }
439 }
440
441 fn sample_claims() -> serde_json::Map<String, serde_json::Value> {
442 let mut claims = serde_json::Map::new();
443 claims.insert("sub".into(), serde_json::json!("user-42"));
444 claims.insert("email".into(), serde_json::json!("user@example.com"));
445 claims.insert("name".into(), serde_json::json!("Alice"));
446 claims.insert(
447 "address".into(),
448 serde_json::json!({"street": "123 Main St", "city": "Springfield"}),
449 );
450 claims
451 }
452
453 #[test]
454 fn test_issue_and_serialize() {
455 let issuer = SdJwtIssuer::new(test_config());
456 let claims = sample_claims();
457 let sd_jwt = issuer.issue(&claims, &["email", "address"], TEST_KEY).unwrap();
458
459 assert!(!sd_jwt.jwt.is_empty());
460 assert_eq!(sd_jwt.disclosures.len(), 2);
461
462 let serialized = sd_jwt.serialize();
463 assert_eq!(serialized.matches('~').count(), 3);
465 }
466
467 #[test]
468 fn test_issue_no_sd_claims() {
469 let issuer = SdJwtIssuer::new(test_config());
470 let claims = sample_claims();
471 let sd_jwt = issuer.issue(&claims, &[], TEST_KEY).unwrap();
472
473 assert!(sd_jwt.disclosures.is_empty());
474 let serialized = sd_jwt.serialize();
475 assert!(serialized.ends_with('~'));
476 }
477
478 #[test]
479 fn test_full_disclosure_roundtrip() {
480 let config = test_config();
481 let issuer = SdJwtIssuer::new(config.clone());
482 let verifier = SdJwtVerifier::new(config);
483 let claims = sample_claims();
484
485 let sd_jwt = issuer.issue(&claims, &["email", "name"], TEST_KEY).unwrap();
486 let serialized = sd_jwt.serialize();
487
488 let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
489
490 assert_eq!(verified.get_claim("sub").unwrap(), "user-42");
491 assert_eq!(verified.get_claim("email").unwrap(), "user@example.com");
492 assert_eq!(verified.get_claim("name").unwrap(), "Alice");
493 }
494
495 #[test]
496 fn test_selective_disclosure() {
497 let config = test_config();
498 let issuer = SdJwtIssuer::new(config.clone());
499 let verifier = SdJwtVerifier::new(config);
500 let claims = sample_claims();
501
502 let sd_jwt = issuer
503 .issue(&claims, &["email", "name", "address"], TEST_KEY)
504 .unwrap();
505
506 let presentation = sd_jwt.present(&["email"]);
508
509 let verified = verifier.verify(&presentation, TEST_KEY).unwrap();
510
511 assert_eq!(verified.get_claim("sub").unwrap(), "user-42");
513 assert_eq!(verified.get_claim("email").unwrap(), "user@example.com");
515 assert!(verified.get_claim("name").is_none());
517 assert!(verified.get_claim("address").is_none());
518 }
519
520 #[test]
521 fn test_all_claims_merged() {
522 let config = test_config();
523 let issuer = SdJwtIssuer::new(config.clone());
524 let verifier = SdJwtVerifier::new(config);
525 let claims = sample_claims();
526
527 let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
528 let serialized = sd_jwt.serialize();
529
530 let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
531 let merged = verified.all_claims();
532
533 assert!(merged.contains_key("sub"));
534 assert!(merged.contains_key("email"));
535 assert!(merged.contains_key("iss"));
536 assert!(merged.contains_key("iat"));
537 assert!(merged.contains_key("exp"));
538 }
539
540 #[test]
541 fn test_reject_empty_claims() {
542 let issuer = SdJwtIssuer::new(test_config());
543 let claims = serde_json::Map::new();
544 assert!(issuer.issue(&claims, &[], TEST_KEY).is_err());
545 }
546
547 #[test]
548 fn test_reject_wrong_key() {
549 let config = test_config();
550 let issuer = SdJwtIssuer::new(config.clone());
551 let verifier = SdJwtVerifier::new(config);
552 let claims = sample_claims();
553
554 let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
555 let serialized = sd_jwt.serialize();
556
557 assert!(verifier.verify(&serialized, "wrong-key-wrong-key-wrong-key!!!").is_err());
558 }
559
560 #[test]
561 fn test_reject_forged_disclosure() {
562 let config = test_config();
563 let issuer = SdJwtIssuer::new(config.clone());
564 let verifier = SdJwtVerifier::new(config);
565 let claims = sample_claims();
566
567 let sd_jwt = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
568
569 let forged = serde_json::json!(["fakesalt", "role", "admin"]);
571 let forged_encoded = URL_SAFE_NO_PAD.encode(forged.to_string().as_bytes());
572 let forged_sd_jwt = format!("{}~{}~", sd_jwt.jwt, forged_encoded);
573
574 assert!(verifier.verify(&forged_sd_jwt, TEST_KEY).is_err());
575 }
576
577 #[test]
578 fn test_parse_components() {
579 let input = "eyJ0eXAi.payload.sig~disc1~disc2~";
580 let (jwt, disclosures, kb) = SdJwtVerifier::parse(input).unwrap();
581 assert_eq!(jwt, "eyJ0eXAi.payload.sig");
582 assert_eq!(disclosures.len(), 2);
583 assert!(kb.is_none());
584 }
585
586 #[test]
587 fn test_parse_with_kb_jwt() {
588 let input = "eyJ0eXAi.payload.sig~disc1~header.payload.signature";
589 let (jwt, disclosures, kb) = SdJwtVerifier::parse(input).unwrap();
590 assert_eq!(jwt, "eyJ0eXAi.payload.sig");
591 assert_eq!(disclosures.len(), 1);
592 assert_eq!(kb.unwrap(), "header.payload.signature");
593 }
594
595 #[test]
596 fn test_disclosure_uniqueness() {
597 let issuer = SdJwtIssuer::new(test_config());
598 let claims = sample_claims();
599
600 let sd_jwt1 = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
601 let sd_jwt2 = issuer.issue(&claims, &["email"], TEST_KEY).unwrap();
602
603 assert_ne!(sd_jwt1.disclosures[0].encoded, sd_jwt2.disclosures[0].encoded);
605 assert_ne!(sd_jwt1.disclosures[0].digest, sd_jwt2.disclosures[0].digest);
606 }
607
608 #[test]
609 fn test_complex_claim_value() {
610 let config = test_config();
611 let issuer = SdJwtIssuer::new(config.clone());
612 let verifier = SdJwtVerifier::new(config);
613
614 let mut claims = serde_json::Map::new();
615 claims.insert("sub".into(), serde_json::json!("user-1"));
616 claims.insert(
617 "address".into(),
618 serde_json::json!({
619 "street": "123 Main St",
620 "city": "Springfield",
621 "zip": "62701"
622 }),
623 );
624
625 let sd_jwt = issuer.issue(&claims, &["address"], TEST_KEY).unwrap();
626 let serialized = sd_jwt.serialize();
627 let verified = verifier.verify(&serialized, TEST_KEY).unwrap();
628
629 let addr = verified.get_claim("address").unwrap();
630 assert_eq!(addr["city"], "Springfield");
631 assert_eq!(addr["zip"], "62701");
632 }
633}