common_access_token/
cat.rs

1use ciborium::value::Value;
2use serde_json::json;
3use std::collections::HashMap;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[cfg(test)]
7mod test {
8    include!("cat.test.rs");
9}
10
11use crate::claims::{ClaimLabel, ClaimValue, Claims};
12use crate::cose::{CoseMac, COSE_MAC0_TAG, CWT_TAG};
13use crate::error::Error;
14use crate::util::{from_base64_url, generate_random_hex, to_base64_no_padding};
15
16/// Validation types for Common Access Tokens.
17///
18/// Specifies the cryptographic mechanism used to secure the token.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CatValidationTypes {
21    /// Message Authentication Code (HMAC)
22    Mac,
23
24    /// Digital Signature
25    Sign,
26
27    /// No cryptographic protection
28    None,
29}
30
31impl Default for CatValidationTypes {
32    fn default() -> Self {
33        Self::Mac
34    }
35}
36
37/// Configuration options for the CAT validator/generator.
38#[derive(Debug, Clone)]
39pub struct CatOptions {
40    /// Mapping of key IDs to their key material.
41    ///
42    /// Each key is identified by a string ID and consists of raw bytes.
43    /// For HMAC-SHA256, keys should be at least 32 bytes.
44    pub keys: HashMap<String, Vec<u8>>,
45
46    /// Whether to expect tokens to have a CWT tag.
47    ///
48    /// CWT (CBOR Web Token) tags wrap the COSE_Mac0 structure with tag 61.
49    /// Set this to true for compatibility with most CAT implementations.
50    pub expect_cwt_tag: bool,
51}
52
53impl Default for CatOptions {
54    fn default() -> Self {
55        Self {
56            keys: HashMap::new(),
57            expect_cwt_tag: true,
58        }
59    }
60}
61
62/// Options for token generation.
63#[derive(Debug, Clone)]
64pub struct CatGenerateOptions {
65    /// Type of cryptographic protection to apply to the token.
66    pub token_type: CatValidationTypes,
67
68    /// Algorithm identifier (e.g., "HS256" for HMAC-SHA256).
69    pub alg: String,
70
71    /// ID of the key to use from the key store.
72    pub kid: String,
73
74    /// Whether to automatically generate a random CWT ID claim.
75    ///
76    /// When true, a random identifier will be added as the "cti" claim
77    /// if one does not already exist in the claims.
78    pub generate_cwt_id: bool,
79}
80
81impl Default for CatGenerateOptions {
82    fn default() -> Self {
83        Self {
84            token_type: CatValidationTypes::Mac,
85            alg: "HS256".to_string(),
86            kid: "default".to_string(),
87            generate_cwt_id: true,
88        }
89    }
90}
91
92/// Options for token validation.
93#[derive(Debug, Clone)]
94pub struct CatValidationOptions {
95    /// Expected token issuer.
96    ///
97    /// The token's "iss" claim must match this value.
98    pub issuer: String,
99
100    /// List of allowed audiences for the token.
101    ///
102    /// If provided, the token's "aud" claim must match at least one
103    /// of these values. If None, audience validation is skipped.
104    pub audience: Option<Vec<String>>,
105}
106
107/// Result of token validation.
108#[derive(Debug, Clone)]
109pub struct CatValidationResult {
110    /// The validated token, if successful.
111    ///
112    /// This will be Some if the token was parsed successfully, even if
113    /// validation failed due to expired token, invalid issuer, etc.
114    pub cat: Option<CommonAccessToken>,
115
116    /// Error that occurred during validation, if any.
117    ///
118    /// This will be None if validation succeeded completely.
119    pub error: Option<Error>,
120}
121
122impl CatValidationResult {
123    /// Returns true if the token is valid (no validation errors).
124    pub fn is_valid(&self) -> bool {
125        self.error.is_none() && self.cat.is_some()
126    }
127
128    /// Returns the claims from the token if validation succeeded.
129    pub fn claims(&self) -> Option<HashMap<String, serde_json::Value>> {
130        self.cat.as_ref().map(|cat| cat.get_claims())
131    }
132}
133
134/// Common Access Token
135#[derive(Debug, Clone)]
136pub struct CommonAccessToken {
137    /// The claims in the token
138    claims: Claims,
139
140    /// The raw token data
141    data: Option<Vec<u8>>,
142
143    /// The key ID used to sign/MAC the token
144    kid: Option<String>,
145}
146
147impl CommonAccessToken {
148    /// Create a new CAT with the given claims
149    pub fn new(claims_map: HashMap<String, serde_json::Value>) -> Self {
150        let claims = Claims::from_map(claims_map);
151
152        CommonAccessToken {
153            claims,
154            data: None,
155            kid: None,
156        }
157    }
158
159    /// Create a MAC for the token
160    pub fn mac(
161        &mut self,
162        key: &[u8],
163        kid: &str,
164        _alg: &str,
165        no_cwt_tag: bool,
166    ) -> Result<(), Error> {
167        // Serialize the claims to CBOR
168        let mut payload_bytes = Vec::new();
169        ciborium::ser::into_writer(&self.claims, &mut payload_bytes)
170            .map_err(|_| Error::CborEncoding)?;
171
172        // Create the COSE MAC
173        let cose_mac_bytes = CoseMac::create_hmac_sha256(kid, &payload_bytes, key)?;
174
175        if !no_cwt_tag {
176            // Add the COSE_Mac0 tag
177            let cose_mac_tagged = Value::Tag(COSE_MAC0_TAG, Box::new(Value::Bytes(cose_mac_bytes)));
178
179            // Add the CWT tag
180            let cwt_tagged = Value::Tag(CWT_TAG, Box::new(cose_mac_tagged));
181
182            // Serialize the tagged value
183            let mut cwt_bytes = Vec::new();
184            ciborium::ser::into_writer(&cwt_tagged, &mut cwt_bytes)
185                .map_err(|_| Error::CborEncoding)?;
186
187            self.data = Some(cwt_bytes);
188        } else {
189            self.data = Some(cose_mac_bytes);
190        }
191
192        self.kid = Some(kid.to_string());
193        Ok(())
194    }
195
196    /// Parse a token
197    pub fn parse(
198        &mut self,
199        token: &[u8],
200        key: &[u8],
201        kid: &str,
202        expect_cwt_tag: bool,
203    ) -> Result<(), Error> {
204        // Decode the token
205        let value: Value = ciborium::de::from_reader(token).map_err(|_| Error::CborDecoding)?;
206
207        // Check if it has a CWT tag
208        if expect_cwt_tag {
209            if let Value::Tag(tag, _) = &value {
210                if *tag != CWT_TAG {
211                    return Err(Error::ExpectedCwtTag);
212                }
213            } else {
214                return Err(Error::ExpectedCwtTag);
215            }
216        }
217
218        // Extract the COSE_Mac0 structure
219        let cose_mac_bytes = if let Value::Tag(tag, content) = value {
220            if tag == CWT_TAG {
221                // Extract the COSE_Mac0 structure from the CWT
222                if let Value::Tag(inner_tag, inner_content) = *content {
223                    if inner_tag == COSE_MAC0_TAG {
224                        if let Value::Bytes(bytes) = *inner_content {
225                            bytes
226                        } else {
227                            return Err(Error::UnableToParseToken);
228                        }
229                    } else {
230                        return Err(Error::UnableToParseToken);
231                    }
232                } else {
233                    return Err(Error::UnableToParseToken);
234                }
235            } else {
236                // Direct COSE_Mac0 structure
237                if let Value::Bytes(bytes) = *content {
238                    bytes
239                } else {
240                    return Err(Error::UnableToParseToken);
241                }
242            }
243        } else {
244            // Raw COSE_Mac0 structure
245            let mut bytes = Vec::new();
246            ciborium::ser::into_writer(&value, &mut bytes).map_err(|_| Error::CborEncoding)?;
247            bytes
248        };
249
250        // Verify the MAC and get the payload
251        let payload = CoseMac::verify_hmac_sha256(&cose_mac_bytes, key)?;
252
253        // Decode the payload
254        let claims: Claims =
255            ciborium::de::from_reader(&payload[..]).map_err(|_| Error::CborDecoding)?;
256
257        self.claims = claims;
258        self.kid = Some(kid.to_string());
259        self.data = Some(token.to_vec());
260
261        Ok(())
262    }
263
264    /// Check if the token is acceptable according to the validation options
265    pub fn is_acceptable(&self, opts: &CatValidationOptions) -> Result<bool, Error> {
266        // Check issuer
267        if let Some(ClaimValue::String(iss)) = self.claims.get(ClaimLabel::Iss) {
268            if iss != &opts.issuer {
269                return Err(Error::InvalidIssuer {
270                    expected: opts.issuer.clone(),
271                    actual: iss.clone(),
272                });
273            }
274        }
275
276        // Check expiration
277        if let Some(ClaimValue::Integer(exp)) = self.claims.get(ClaimLabel::Exp) {
278            let now = SystemTime::now()
279                .duration_since(UNIX_EPOCH)
280                .unwrap()
281                .as_secs() as i64;
282
283            if *exp < now {
284                return Err(Error::TokenExpired);
285            }
286        }
287
288        // Check audience
289        if let Some(audiences) = &opts.audience {
290            if let Some(claim_value) = self.claims.get(ClaimLabel::Aud) {
291                match claim_value {
292                    ClaimValue::String(aud) => {
293                        if !audiences.contains(aud) {
294                            return Err(Error::InvalidAudience);
295                        }
296                    }
297                    ClaimValue::Array(auds) => {
298                        let mut found = false;
299                        for aud in auds {
300                            if let ClaimValue::String(aud_str) = aud {
301                                if audiences.contains(aud_str) {
302                                    found = true;
303                                    break;
304                                }
305                            }
306                        }
307                        if !found {
308                            return Err(Error::InvalidAudience);
309                        }
310                    }
311                    _ => return Err(Error::InvalidClaimType),
312                }
313            }
314        }
315
316        // Check not before
317        if let Some(ClaimValue::Integer(nbf)) = self.claims.get(ClaimLabel::Nbf) {
318            let now = SystemTime::now()
319                .duration_since(UNIX_EPOCH)
320                .unwrap()
321                .as_secs() as i64;
322
323            if *nbf > now {
324                return Err(Error::TokenNotActive);
325            }
326        }
327
328        Ok(true)
329    }
330
331    /// Get the raw token data
332    pub fn raw(&self) -> Option<&[u8]> {
333        self.data.as_deref()
334    }
335
336    /// Get the claims as a JSON-like map
337    pub fn get_claims(&self) -> HashMap<String, serde_json::Value> {
338        self.claims.to_map()
339    }
340}
341
342/// Common Access Token (CAT) validator and generator
343#[derive(Debug, Clone)]
344pub struct Cat {
345    /// Key ID to key mapping
346    keys: HashMap<String, Vec<u8>>,
347
348    /// Whether there should be a CWT tag in the token
349    expect_cwt_tag: bool,
350}
351
352impl Cat {
353    /// Creates a new CAT token validator and generator.
354    ///
355    /// # Arguments
356    ///
357    /// * `opts` - Configuration options including keys and token format settings
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use common_access_token::{Cat, CatOptions};
363    /// use std::collections::HashMap;
364    ///
365    /// // Create a key
366    /// let key = hex::decode("403697de87af64611c1d32a05dab0fe1fcb715a86ab435f1ec99192d79569388").unwrap();
367    ///
368    /// // Set up the key store
369    /// let mut keys = HashMap::new();
370    /// keys.insert("my-key".to_string(), key);
371    ///
372    /// // Create the CAT handler
373    /// let cat = Cat::new(CatOptions {
374    ///     keys,
375    ///     expect_cwt_tag: true,
376    /// });
377    /// ```
378    pub fn new(opts: CatOptions) -> Self {
379        Cat {
380            keys: opts.keys,
381            expect_cwt_tag: opts.expect_cwt_tag,
382        }
383    }
384
385    /// Adds a key to the key store.
386    ///
387    /// # Arguments
388    ///
389    /// * `kid` - Key ID to associate with the key
390    /// * `key` - The key material as raw bytes
391    ///
392    /// # Examples
393    ///
394    /// ```
395    /// use common_access_token::{Cat, CatOptions};
396    /// use std::collections::HashMap;
397    ///
398    /// let mut cat = Cat::new(CatOptions::default());
399    /// let key = hex::decode("403697de87af64611c1d32a05dab0fe1fcb715a86ab435f1ec99192d79569388").unwrap();
400    /// cat.add_key("my-key", key);
401    /// ```
402    pub fn add_key(&mut self, kid: impl Into<String>, key: Vec<u8>) {
403        self.keys.insert(kid.into(), key);
404    }
405
406    /// Validates a CAT token.
407    ///
408    /// This function attempts to parse and validate the token using the provided options.
409    /// It will try each available key in the key store until one succeeds or all fail.
410    ///
411    /// # Arguments
412    ///
413    /// * `token` - The base64-encoded token to validate
414    /// * `validation_type` - The type of validation to perform (Mac, Sign, None)
415    /// * `opts` - Validation options including issuer and audience requirements
416    ///
417    /// # Returns
418    ///
419    /// A `Result` containing a `CatValidationResult` with the parsed token and any validation
420    /// error that occurred, or an `Error` if token parsing completely failed.
421    ///
422    /// # Examples
423    ///
424    /// ```
425    /// use common_access_token::{Cat, CatOptions, CatValidationOptions, CatValidationTypes};
426    /// use std::collections::HashMap;
427    ///
428    /// // Create a CAT validator with keys
429    /// let mut keys = HashMap::new();
430    /// keys.insert("Symmetric256".to_string(),
431    ///     hex::decode("403697de87af64611c1d32a05dab0fe1fcb715a86ab435f1ec99192d79569388").unwrap());
432    ///
433    /// let cat = Cat::new(CatOptions {
434    ///     keys,
435    ///     expect_cwt_tag: true,
436    /// });
437    ///
438    /// // Validate a token
439    /// // In a real scenario, you would use an actual token, not a placeholder
440    /// # // We skip the actual validation in the doctest to avoid errors
441    /// # let token_example = || {
442    /// let token = "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1c2VyLTEyMyJ9.SIGNATURE";
443    /// let result = cat.validate(
444    ///     token,
445    ///     CatValidationTypes::Mac,
446    ///     CatValidationOptions {
447    ///         issuer: "example-issuer".to_string(),
448    ///         audience: Some(vec!["api".to_string()]),
449    ///     },
450    /// );
451    /// #     // Example of how to check the result (not actually run in doctest)
452    /// #     match result {
453    /// #         Ok(validation_result) => {
454    /// #             if validation_result.is_valid() {
455    /// #                 println!("Token is valid!");
456    /// #                 if let Some(claims) = validation_result.claims() {
457    /// #                     println!("Subject: {}", claims.get("sub").unwrap_or(&serde_json::json!("none")));
458    /// #                 }
459    /// #             } else {
460    /// #                 println!("Validation failed: {:?}", validation_result.error);
461    /// #             }
462    /// #         },
463    /// #         Err(err) => println!("Error parsing token: {:?}", err),
464    /// #     }
465    /// # };
466    /// ```
467    pub fn validate(
468        &self,
469        token: &str,
470        validation_type: CatValidationTypes,
471        opts: CatValidationOptions,
472    ) -> Result<CatValidationResult, Error> {
473        let token_bytes = from_base64_url(token)?;
474
475        match validation_type {
476            CatValidationTypes::Mac => {
477                let mut cat = None;
478                let mut error = None;
479
480                // Try each key
481                for (kid, key) in &self.keys {
482                    let mut token = CommonAccessToken::new(HashMap::new());
483                    match token.parse(&token_bytes, key, kid, self.expect_cwt_tag) {
484                        Ok(_) => {
485                            cat = Some(token);
486                            break;
487                        }
488                        Err(err) => {
489                            error = Some(err);
490                        }
491                    }
492                }
493
494                // If we found a valid token, check if it's acceptable
495                if let Some(ref token) = cat {
496                    match token.is_acceptable(&opts) {
497                        Ok(_) => {
498                            return Ok(CatValidationResult { cat, error: None });
499                        }
500                        Err(err) => {
501                            return Ok(CatValidationResult {
502                                cat,
503                                error: Some(err),
504                            });
505                        }
506                    }
507                }
508
509                // If we didn't find a valid token, return the last error
510                Ok(CatValidationResult { cat: None, error })
511            }
512            _ => Err(Error::UnsupportedValidationType),
513        }
514    }
515
516    /// Generates a CAT token with the provided claims.
517    ///
518    /// # Arguments
519    ///
520    /// * `claims` - A map of claim names to values
521    /// * `opts` - Options for token generation
522    ///
523    /// # Returns
524    ///
525    /// A base64-encoded token string, or an error if generation fails.
526    ///
527    /// # Examples
528    ///
529    /// ```
530    /// use common_access_token::{Cat, CatOptions, CatGenerateOptions, CatValidationTypes};
531    /// use std::collections::HashMap;
532    /// use std::time::{SystemTime, UNIX_EPOCH};
533    ///
534    /// // Create a CAT generator with keys
535    /// let mut keys = HashMap::new();
536    /// keys.insert("Symmetric256".to_string(),
537    ///     hex::decode("403697de87af64611c1d32a05dab0fe1fcb715a86ab435f1ec99192d79569388").unwrap());
538    ///
539    /// let cat = Cat::new(CatOptions {
540    ///     keys,
541    ///     expect_cwt_tag: true,
542    /// });
543    ///
544    /// // Create claims
545    /// let now = SystemTime::now()
546    ///     .duration_since(UNIX_EPOCH)
547    ///     .unwrap()
548    ///     .as_secs() as i64;
549    ///
550    /// let mut claims = HashMap::new();
551    /// claims.insert("iss".to_string(), serde_json::json!("example-issuer"));
552    /// claims.insert("sub".to_string(), serde_json::json!("user-123"));
553    /// claims.insert("exp".to_string(), serde_json::json!(now + 3600));
554    ///
555    /// // Generate the token
556    /// let token = cat.generate(
557    ///     claims,
558    ///     CatGenerateOptions {
559    ///         token_type: CatValidationTypes::Mac,
560    ///         alg: "HS256".to_string(),
561    ///         kid: "Symmetric256".to_string(),
562    ///         generate_cwt_id: true,
563    ///     },
564    /// ).unwrap();
565    ///
566    /// println!("Generated token: {}", token);
567    /// ```
568    pub fn generate(
569        &self,
570        claims: HashMap<String, serde_json::Value>,
571        opts: CatGenerateOptions,
572    ) -> Result<String, Error> {
573        let mut claims = claims.clone();
574
575        // Generate a random CWT ID if requested
576        if opts.generate_cwt_id && !claims.contains_key("cti") {
577            claims.insert("cti".to_string(), json!(generate_random_hex(16)));
578        }
579
580        // Create the token
581        let mut token = CommonAccessToken::new(claims);
582
583        match opts.token_type {
584            CatValidationTypes::Mac => {
585                let key = self
586                    .keys
587                    .get(&opts.kid)
588                    .ok_or_else(|| Error::KeyNotFound(opts.kid.clone()))?;
589
590                token.mac(key, &opts.kid, &opts.alg, !self.expect_cwt_tag)?;
591
592                if let Some(raw) = token.raw() {
593                    Ok(to_base64_no_padding(raw))
594                } else {
595                    Err(Error::MacFailed)
596                }
597            }
598            _ => Err(Error::UnsupportedValidationType),
599        }
600    }
601
602    /// Creates a builder for constructing claims conveniently
603    ///
604    /// # Returns
605    ///
606    /// A new `ClaimsBuilder` instance that can be used to build token claims
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// use common_access_token::{Cat, CatOptions, CatGenerateOptions, CatValidationTypes};
612    /// use std::collections::HashMap;
613    /// use std::time::{SystemTime, UNIX_EPOCH};
614    ///
615    /// let cat = Cat::new(CatOptions::default());
616    ///
617    /// let now = SystemTime::now()
618    ///     .duration_since(UNIX_EPOCH)
619    ///     .unwrap()
620    ///     .as_secs() as i64;
621    ///
622    /// // Build claims using the builder pattern
623    /// let claims = cat.claims_builder()
624    ///     .issuer("example-issuer")
625    ///     .subject("user-123")
626    ///     .audience("api-service")
627    ///     .expiration(now + 3600)
628    ///     .issued_at(now)
629    ///     .build();
630    /// ```
631    pub fn claims_builder(&self) -> ClaimsBuilder {
632        ClaimsBuilder::new()
633    }
634}
635
636/// Builder for constructing Common Access Token claims
637pub struct ClaimsBuilder {
638    claims: HashMap<String, serde_json::Value>,
639}
640
641impl ClaimsBuilder {
642    /// Creates a new ClaimsBuilder
643    pub fn new() -> Self {
644        Self {
645            claims: HashMap::new(),
646        }
647    }
648
649    /// Sets the token issuer (iss claim)
650    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
651        self.claims.insert("iss".to_string(), json!(issuer.into()));
652        self
653    }
654
655    /// Sets the token subject (sub claim)
656    pub fn subject(mut self, subject: impl Into<String>) -> Self {
657        self.claims.insert("sub".to_string(), json!(subject.into()));
658        self
659    }
660
661    /// Sets the token audience (aud claim)
662    pub fn audience(mut self, audience: impl Into<String>) -> Self {
663        self.claims
664            .insert("aud".to_string(), json!(audience.into()));
665        self
666    }
667
668    /// Sets multiple audiences for the token (aud claim as array)
669    pub fn audiences(mut self, audiences: impl IntoIterator<Item = impl Into<String>>) -> Self {
670        let audiences: Vec<String> = audiences.into_iter().map(Into::into).collect();
671        self.claims.insert("aud".to_string(), json!(audiences));
672        self
673    }
674
675    /// Sets the token expiration time (exp claim)
676    pub fn expiration(mut self, exp: i64) -> Self {
677        self.claims.insert("exp".to_string(), json!(exp));
678        self
679    }
680
681    /// Sets the token not-before time (nbf claim)
682    pub fn not_before(mut self, nbf: i64) -> Self {
683        self.claims.insert("nbf".to_string(), json!(nbf));
684        self
685    }
686
687    /// Sets the token issued-at time (iat claim)
688    pub fn issued_at(mut self, iat: i64) -> Self {
689        self.claims.insert("iat".to_string(), json!(iat));
690        self
691    }
692
693    /// Sets the token ID (cti claim)
694    pub fn token_id(mut self, cti: impl Into<String>) -> Self {
695        self.claims.insert("cti".to_string(), json!(cti.into()));
696        self
697    }
698
699    /// Sets the CAT version (catv claim)
700    pub fn cat_version(mut self, version: u8) -> Self {
701        self.claims.insert("catv".to_string(), json!(version));
702        self
703    }
704
705    /// Sets a custom claim with any JSON-serializable value
706    pub fn claim(mut self, name: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
707        self.claims.insert(name.into(), value.into());
708        self
709    }
710
711    /// Builds the final claims map
712    pub fn build(self) -> HashMap<String, serde_json::Value> {
713        self.claims
714    }
715}
716
717impl Default for ClaimsBuilder {
718    fn default() -> Self {
719        Self::new()
720    }
721}