common_access_token/
token.rs

1//! Token implementation for Common Access Token
2
3use crate::claims::{Claims, RegisteredClaims};
4use crate::error::Error;
5use crate::header::{Algorithm, CborValue, Header, HeaderMap, KeyId};
6use crate::utils::{compute_hmac_sha256, current_timestamp, verify_hmac_sha256};
7use minicbor::{Decoder, Encoder};
8use std::collections::BTreeMap;
9
10/// Common Access Token structure
11#[derive(Debug, Clone)]
12pub struct Token {
13    /// Token header
14    pub header: Header,
15    /// Token claims
16    pub claims: Claims,
17    /// Token signature
18    pub signature: Vec<u8>,
19    /// Original payload bytes (for verification)
20    original_payload_bytes: Option<Vec<u8>>,
21}
22
23impl Token {
24    /// Create a new token with the given header, claims, and signature
25    pub fn new(header: Header, claims: Claims, signature: Vec<u8>) -> Self {
26        Self {
27            header,
28            claims,
29            signature,
30            original_payload_bytes: None,
31        }
32    }
33
34    /// Encode the token to CBOR bytes
35    pub fn to_bytes(&self) -> Result<Vec<u8>, Error> {
36        let mut buf = Vec::new();
37        let mut enc = Encoder::new(&mut buf);
38
39        // For HMAC algorithms, use COSE_Mac0 format with CWT tag
40        if let Some(Algorithm::HmacSha256) = self.header.algorithm() {
41            // Apply CWT tag (61)
42            enc.tag(minicbor::data::Tag::new(61))?;
43            // Apply COSE_Mac0 tag (17)
44            enc.tag(minicbor::data::Tag::new(17))?;
45        }
46
47        // COSE structure array with 4 items
48        enc.array(4)?;
49
50        // 1. Protected header (encoded as CBOR and then as bstr)
51        let protected_bytes = encode_map(&self.header.protected)?;
52        enc.bytes(&protected_bytes)?;
53
54        // 2. Unprotected header
55        encode_map_direct(&self.header.unprotected, &mut enc)?;
56
57        // 3. Payload (encoded as CBOR and then as bstr)
58        let claims_map = self.claims.to_map();
59        let claims_bytes = encode_map(&claims_map)?;
60        enc.bytes(&claims_bytes)?;
61
62        // 4. Signature/MAC
63        enc.bytes(&self.signature)?;
64
65        Ok(buf)
66    }
67
68    /// Decode a token from CBOR bytes
69    ///
70    /// This function supports both COSE_Sign1 (tag 18) and COSE_Mac0 (tag 17) structures,
71    /// as well as custom tags. It will automatically skip any tags and process the underlying
72    /// CBOR array.
73    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
74        let mut dec = Decoder::new(bytes);
75
76        // Check if the token starts with a tag (COSE_Sign1 tag = 18, COSE_Mac0 tag = 17, or custom tag = 61)
77        if dec.datatype()? == minicbor::data::Type::Tag {
78            // Skip the tag
79            let _ = dec.tag()?;
80
81            // Check for a second tag
82            if dec.datatype()? == minicbor::data::Type::Tag {
83                let _ = dec.tag()?;
84            }
85        }
86
87        // Expect array with 4 items
88        let array_len = dec.array()?.unwrap_or(0);
89        if array_len != 4 {
90            return Err(Error::InvalidFormat(format!(
91                "Expected array of length 4, got {array_len}"
92            )));
93        }
94
95        // 1. Protected header
96        let protected_bytes = dec.bytes()?;
97        let protected = decode_map(protected_bytes)?;
98
99        // 2. Unprotected header
100        let unprotected = decode_map_direct(&mut dec)?;
101
102        // Create header
103        let header = Header {
104            protected,
105            unprotected,
106        };
107
108        // 3. Payload
109        let claims_bytes = dec.bytes()?;
110        let claims_map = decode_map(claims_bytes)?;
111        let claims = Claims::from_map(&claims_map);
112
113        // 4. Signature
114        let signature = dec.bytes()?.to_vec();
115
116        Ok(Self {
117            header,
118            claims,
119            signature,
120            original_payload_bytes: Some(claims_bytes.to_vec()),
121        })
122    }
123
124    /// Verify the token signature
125    ///
126    /// This function supports both COSE_Sign1 and COSE_Mac0 structures.
127    /// It will first try to verify the signature using the COSE_Sign1 structure,
128    /// and if that fails, it will try the COSE_Mac0 structure.
129    pub fn verify(&self, key: &[u8]) -> Result<(), Error> {
130        let alg = self.header.algorithm().ok_or_else(|| {
131            Error::InvalidFormat("Missing algorithm in protected header".to_string())
132        })?;
133
134        match alg {
135            Algorithm::HmacSha256 => {
136                // Try with COSE_Sign1 structure first
137                let sign1_input = self.sign1_input()?;
138                let sign1_result = verify_hmac_sha256(key, &sign1_input, &self.signature);
139
140                if sign1_result.is_ok() {
141                    return Ok(());
142                }
143
144                // If COSE_Sign1 verification fails, try COSE_Mac0 structure
145                let mac0_input = self.mac0_input()?;
146                verify_hmac_sha256(key, &mac0_input, &self.signature)
147            }
148        }
149    }
150
151    /// Verify the token claims
152    pub fn verify_claims(&self, options: &VerificationOptions) -> Result<(), Error> {
153        let now = current_timestamp();
154
155        // Check expiration
156        if options.verify_exp {
157            if let Some(exp) = self.claims.registered.exp {
158                if now >= exp {
159                    return Err(Error::Expired);
160                }
161            } else if options.require_exp {
162                return Err(Error::MissingClaim("exp".to_string()));
163            }
164        }
165
166        // Check not before
167        if options.verify_nbf {
168            if let Some(nbf) = self.claims.registered.nbf {
169                if now < nbf {
170                    return Err(Error::NotYetValid);
171                }
172            }
173        }
174
175        // Check issuer
176        if let Some(expected_iss) = &options.expected_issuer {
177            if let Some(iss) = &self.claims.registered.iss {
178                if iss != expected_iss {
179                    return Err(Error::InvalidIssuer);
180                }
181            } else if options.require_iss {
182                return Err(Error::MissingClaim("iss".to_string()));
183            }
184        }
185
186        // Check audience
187        if let Some(expected_aud) = &options.expected_audience {
188            if let Some(aud) = &self.claims.registered.aud {
189                if aud != expected_aud {
190                    return Err(Error::InvalidAudience);
191                }
192            } else if options.require_aud {
193                return Err(Error::MissingClaim("aud".to_string()));
194            }
195        }
196
197        // Check CAT-specific claims
198        if options.verify_catu {
199            self.verify_catu_claim(options)?;
200        }
201
202        if options.verify_catm {
203            self.verify_catm_claim(options)?;
204        }
205
206        if options.verify_catreplay {
207            self.verify_catreplay_claim(options)?;
208        }
209
210        Ok(())
211    }
212
213    /// Verify the CATU (URI) claim against the provided URI
214    fn verify_catu_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
215        use crate::constants::{cat_keys, uri_components};
216        use url::Url;
217
218        // Get the URI to verify against
219        let uri = match &options.uri {
220            Some(uri) => uri,
221            None => {
222                return Err(Error::InvalidClaimValue(
223                    "No URI provided for CATU verification".to_string(),
224                ))
225            }
226        };
227
228        // Parse the URI
229        let parsed_uri = match Url::parse(uri) {
230            Ok(url) => url,
231            Err(_) => {
232                return Err(Error::InvalidClaimValue(format!(
233                    "Invalid URI format: {uri}"
234                )))
235            }
236        };
237
238        // Check if token has CATU claim
239        let catu_claim = match self.claims.custom.get(&cat_keys::CATU) {
240            Some(claim) => claim,
241            None => return Ok(()), // No CATU claim, so nothing to verify
242        };
243
244        // CATU claim should be a map
245        let component_map = match catu_claim {
246            CborValue::Map(map) => map,
247            _ => {
248                return Err(Error::InvalidUriClaim(
249                    "CATU claim is not a map".to_string(),
250                ))
251            }
252        };
253
254        // Verify each component in the CATU claim
255        for (component_key, component_value) in component_map {
256            match *component_key {
257                uri_components::SCHEME => {
258                    self.verify_uri_component(
259                        &parsed_uri.scheme().to_string(),
260                        component_value,
261                        "scheme",
262                    )?;
263                }
264                uri_components::HOST => {
265                    self.verify_uri_component(
266                        &parsed_uri.host_str().unwrap_or("").to_string(),
267                        component_value,
268                        "host",
269                    )?;
270                }
271                uri_components::PORT => {
272                    let port = parsed_uri.port().map(|p| p.to_string()).unwrap_or_default();
273                    self.verify_uri_component(&port, component_value, "port")?;
274                }
275                uri_components::PATH => {
276                    self.verify_uri_component(
277                        &parsed_uri.path().to_string(),
278                        component_value,
279                        "path",
280                    )?;
281                }
282                uri_components::QUERY => {
283                    let query = parsed_uri.query().unwrap_or("").to_string();
284                    self.verify_uri_component(&query, component_value, "query")?;
285                }
286                uri_components::EXTENSION => {
287                    // Extract file extension from path
288                    let path = parsed_uri.path();
289                    let extension = path.split('.').next_back().unwrap_or("").to_string();
290                    if !path.contains('.') || path.ends_with('.') {
291                        // No extension or ends with dot
292                        self.verify_uri_component(&"".to_string(), component_value, "extension")?;
293                    } else {
294                        self.verify_uri_component(
295                            &format!(".{extension}"),
296                            component_value,
297                            "extension",
298                        )?;
299                    }
300                }
301                _ => {
302                    // Ignore unsupported components
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    /// Verify a URI component against match conditions
311    fn verify_uri_component(
312        &self,
313        component: &String,
314        match_conditions: &CborValue,
315        component_name: &str,
316    ) -> Result<(), Error> {
317        use crate::constants::match_types;
318        use regex::Regex;
319        use sha2::{Digest, Sha256, Sha512};
320
321        // Match conditions should be a map
322        let match_map = match match_conditions {
323            CborValue::Map(map) => map,
324            _ => {
325                return Err(Error::InvalidUriClaim(format!(
326                    "Match conditions for {component_name} is not a map"
327                )))
328            }
329        };
330
331        for (match_type, match_value) in match_map {
332            match *match_type {
333                match_types::EXACT => {
334                    if let CborValue::Text(text) = match_value {
335                        if component != text {
336                            return Err(Error::InvalidUriClaim(format!(
337                                "URI component {component_name} '{component}' does not exactly match required value '{text}'"
338                            )));
339                        }
340                    }
341                }
342                match_types::PREFIX => {
343                    if let CborValue::Text(prefix) = match_value {
344                        if !component.starts_with(prefix) {
345                            return Err(Error::InvalidUriClaim(format!(
346                                "URI component {component_name} '{component}' does not start with required prefix '{prefix}'"
347                            )));
348                        }
349                    }
350                }
351                match_types::SUFFIX => {
352                    if let CborValue::Text(suffix) = match_value {
353                        if !component.ends_with(suffix) {
354                            return Err(Error::InvalidUriClaim(format!(
355                                "URI component {component_name} '{component}' does not end with required suffix '{suffix}'"
356                            )));
357                        }
358                    }
359                }
360                match_types::CONTAINS => {
361                    if let CborValue::Text(contained) = match_value {
362                        if !component.contains(contained) {
363                            return Err(Error::InvalidUriClaim(format!(
364                                "URI component {component_name} '{component}' does not contain required text '{contained}'"
365                            )));
366                        }
367                    }
368                }
369                match_types::REGEX => {
370                    if let CborValue::Array(array) = match_value {
371                        if let Some(CborValue::Text(pattern)) = array.first() {
372                            match Regex::new(pattern) {
373                                Ok(regex) => {
374                                    if !regex.is_match(component) {
375                                        return Err(Error::InvalidUriClaim(format!(
376                                            "URI component {component_name} '{component}' does not match required regex pattern '{pattern}'"
377                                        )));
378                                    }
379                                }
380                                Err(_) => {
381                                    return Err(Error::InvalidUriClaim(format!(
382                                        "Invalid regex pattern: {pattern}"
383                                    )))
384                                }
385                            }
386                        }
387                    }
388                }
389                match_types::SHA256 => {
390                    if let CborValue::Bytes(expected_hash) = match_value {
391                        let mut hasher = Sha256::new();
392                        hasher.update(component.as_bytes());
393                        let hash = hasher.finalize();
394
395                        if hash.as_slice() != expected_hash.as_slice() {
396                            return Err(Error::InvalidUriClaim(format!(
397                                "URI component {component_name} '{component}' SHA-256 hash does not match expected value"
398                            )));
399                        }
400                    }
401                }
402                match_types::SHA512_256 => {
403                    if let CborValue::Bytes(expected_hash) = match_value {
404                        let mut hasher = Sha512::new();
405                        hasher.update(component.as_bytes());
406                        let hash = hasher.finalize();
407                        let truncated_hash = &hash[0..32]; // Take first 256 bits (32 bytes)
408
409                        if truncated_hash != expected_hash.as_slice() {
410                            return Err(Error::InvalidUriClaim(format!(
411                                "URI component {component_name} '{component}' SHA-512/256 hash does not match expected value"
412                            )));
413                        }
414                    }
415                }
416                _ => {
417                    // Ignore unsupported match types
418                }
419            }
420        }
421
422        Ok(())
423    }
424
425    /// Verify the CATM (HTTP method) claim against the provided method
426    fn verify_catm_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
427        use crate::constants::cat_keys;
428
429        // Get the HTTP method to verify against
430        let method = match &options.http_method {
431            Some(method) => method,
432            None => {
433                return Err(Error::InvalidClaimValue(
434                    "No HTTP method provided for CATM verification".to_string(),
435                ))
436            }
437        };
438
439        // Check if token has CATM claim
440        let catm_claim = match self.claims.custom.get(&cat_keys::CATM) {
441            Some(claim) => claim,
442            None => return Ok(()), // No CATM claim, so nothing to verify
443        };
444
445        // CATM claim should be an array of allowed methods
446        let allowed_methods = match catm_claim {
447            CborValue::Array(methods) => methods,
448            _ => {
449                return Err(Error::InvalidMethodClaim(
450                    "CATM claim is not an array".to_string(),
451                ))
452            }
453        };
454
455        // Check if the provided method is in the allowed methods list
456        let method_upper = method.to_uppercase();
457        let method_allowed = allowed_methods.iter().any(|m| {
458            if let CborValue::Text(allowed) = m {
459                allowed.to_uppercase() == method_upper
460            } else {
461                false
462            }
463        });
464
465        if !method_allowed {
466            return Err(Error::InvalidMethodClaim(format!(
467                "HTTP method '{}' is not allowed. Permitted methods: {:?}",
468                method,
469                allowed_methods
470                    .iter()
471                    .filter_map(|m| if let CborValue::Text(t) = m {
472                        Some(t.as_str())
473                    } else {
474                        None
475                    })
476                    .collect::<Vec<&str>>()
477            )));
478        }
479
480        Ok(())
481    }
482
483    /// Verify the CATREPLAY claim for token replay protection
484    fn verify_catreplay_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
485        use crate::constants::{cat_keys, replay_values};
486
487        // Check if token has CATREPLAY claim
488        let catreplay_claim = match self.claims.custom.get(&cat_keys::CATREPLAY) {
489            Some(claim) => claim,
490            None => return Ok(()), // No CATREPLAY claim, so nothing to verify
491        };
492
493        // Get the replay protection value
494        let replay_value = match catreplay_claim {
495            CborValue::Integer(value) => *value as i32,
496            _ => {
497                return Err(Error::InvalidClaimValue(
498                    "CATREPLAY claim is not an integer".to_string(),
499                ))
500            }
501        };
502
503        match replay_value {
504            replay_values::PERMITTED => {
505                // Replay is permitted, no verification needed
506                Ok(())
507            }
508            replay_values::PROHIBITED => {
509                // Replay is prohibited, check if token has been seen before
510                if options.token_seen_before {
511                    Err(Error::ReplayViolation(
512                        "Token replay is prohibited".to_string(),
513                    ))
514                } else {
515                    Ok(())
516                }
517            }
518            replay_values::REUSE_DETECTION => {
519                // Reuse is detected but allowed, no error returned
520                // Implementations should log or notify about reuse
521                Ok(())
522            }
523            _ => Err(Error::InvalidClaimValue(format!(
524                "Invalid CATREPLAY value: {replay_value}"
525            ))),
526        }
527    }
528
529    // Note: signature_input method removed as we now use mac0_input for HMAC algorithms
530
531    /// Get the encoded payload bytes, using original bytes if available
532    fn get_payload_bytes(&self) -> Result<Vec<u8>, Error> {
533        if let Some(ref original) = self.original_payload_bytes {
534            // Use original bytes for verification
535            Ok(original.clone())
536        } else {
537            // Encode claims for newly created tokens
538            let claims_map = self.claims.to_map();
539            encode_map(&claims_map)
540        }
541    }
542
543    /// Get the COSE_Sign1 signature input
544    fn sign1_input(&self) -> Result<Vec<u8>, Error> {
545        // Sig_structure = [
546        //   context : "Signature1",
547        //   protected : bstr .cbor header_map,
548        //   external_aad : bstr,
549        //   payload : bstr .cbor claims
550        // ]
551
552        let mut buf = Vec::new();
553        let mut enc = Encoder::new(&mut buf);
554
555        // Start array with 4 items
556        enc.array(4)?;
557
558        // 1. Context
559        enc.str("Signature1")?;
560
561        // 2. Protected header
562        let protected_bytes = encode_map(&self.header.protected)?;
563        enc.bytes(&protected_bytes)?;
564
565        // 3. External AAD (empty in our case)
566        enc.bytes(&[])?;
567
568        // 4. Payload
569        let claims_bytes = self.get_payload_bytes()?;
570        enc.bytes(&claims_bytes)?;
571
572        Ok(buf)
573    }
574
575    /// Get the COSE_Mac0 signature input
576    fn mac0_input(&self) -> Result<Vec<u8>, Error> {
577        // Mac_structure = [
578        //   context : "MAC0",
579        //   protected : bstr .cbor header_map,
580        //   external_aad : bstr,
581        //   payload : bstr .cbor claims
582        // ]
583
584        let mut buf = Vec::new();
585        let mut enc = Encoder::new(&mut buf);
586
587        // Start array with 4 items
588        enc.array(4)?;
589
590        // 1. Context
591        enc.str("MAC0")?;
592
593        // 2. Protected header
594        let protected_bytes = encode_map(&self.header.protected)?;
595        enc.bytes(&protected_bytes)?;
596
597        // 3. External AAD (empty in our case)
598        enc.bytes(&[])?;
599
600        // 4. Payload
601        let claims_bytes = self.get_payload_bytes()?;
602        enc.bytes(&claims_bytes)?;
603
604        Ok(buf)
605    }
606}
607
608/// Options for token verification
609#[derive(Debug, Clone, Default)]
610pub struct VerificationOptions {
611    /// Verify expiration claim
612    pub verify_exp: bool,
613    /// Require expiration claim
614    pub require_exp: bool,
615    /// Verify not before claim
616    pub verify_nbf: bool,
617    /// Expected issuer
618    pub expected_issuer: Option<String>,
619    /// Require issuer claim
620    pub require_iss: bool,
621    /// Expected audience
622    pub expected_audience: Option<String>,
623    /// Require audience claim
624    pub require_aud: bool,
625    /// Verify CAT-specific URI claim (CATU) against provided URI
626    pub verify_catu: bool,
627    /// URI to verify against CATU claim
628    pub uri: Option<String>,
629    /// Verify CAT-specific HTTP methods claim (CATM) against provided method
630    pub verify_catm: bool,
631    /// HTTP method to verify against CATM claim
632    pub http_method: Option<String>,
633    /// Verify CAT-specific replay protection (CATREPLAY)
634    pub verify_catreplay: bool,
635    /// Whether the token has been seen before (for replay protection)
636    pub token_seen_before: bool,
637}
638
639impl VerificationOptions {
640    /// Create new default verification options
641    pub fn new() -> Self {
642        Self {
643            verify_exp: true,
644            require_exp: false,
645            verify_nbf: true,
646            expected_issuer: None,
647            require_iss: false,
648            expected_audience: None,
649            require_aud: false,
650            verify_catu: false,
651            uri: None,
652            verify_catm: false,
653            http_method: None,
654            verify_catreplay: false,
655            token_seen_before: false,
656        }
657    }
658
659    /// Set whether to verify expiration
660    pub fn verify_exp(mut self, verify: bool) -> Self {
661        self.verify_exp = verify;
662        self
663    }
664
665    /// Set whether to require expiration
666    pub fn require_exp(mut self, require: bool) -> Self {
667        self.require_exp = require;
668        self
669    }
670
671    /// Set whether to verify not before
672    pub fn verify_nbf(mut self, verify: bool) -> Self {
673        self.verify_nbf = verify;
674        self
675    }
676
677    /// Set expected issuer
678    pub fn expected_issuer<S: Into<String>>(mut self, issuer: S) -> Self {
679        self.expected_issuer = Some(issuer.into());
680        self
681    }
682
683    /// Set whether to require issuer
684    pub fn require_iss(mut self, require: bool) -> Self {
685        self.require_iss = require;
686        self
687    }
688
689    /// Set expected audience
690    pub fn expected_audience<S: Into<String>>(mut self, audience: S) -> Self {
691        self.expected_audience = Some(audience.into());
692        self
693    }
694
695    /// Set whether to require audience
696    pub fn require_aud(mut self, require: bool) -> Self {
697        self.require_aud = require;
698        self
699    }
700
701    /// Set whether to verify CAT-specific URI claim (CATU)
702    pub fn verify_catu(mut self, verify: bool) -> Self {
703        self.verify_catu = verify;
704        self
705    }
706
707    /// Set URI to verify against CATU claim
708    pub fn uri<S: Into<String>>(mut self, uri: S) -> Self {
709        self.uri = Some(uri.into());
710        self
711    }
712
713    /// Set whether to verify CAT-specific HTTP methods claim (CATM)
714    pub fn verify_catm(mut self, verify: bool) -> Self {
715        self.verify_catm = verify;
716        self
717    }
718
719    /// Set HTTP method to verify against CATM claim
720    pub fn http_method<S: Into<String>>(mut self, method: S) -> Self {
721        self.http_method = Some(method.into());
722        self
723    }
724
725    /// Set whether to verify CAT-specific replay protection (CATREPLAY)
726    pub fn verify_catreplay(mut self, verify: bool) -> Self {
727        self.verify_catreplay = verify;
728        self
729    }
730
731    /// Set whether the token has been seen before (for replay protection)
732    pub fn token_seen_before(mut self, seen: bool) -> Self {
733        self.token_seen_before = seen;
734        self
735    }
736}
737
738/// Builder for creating tokens
739#[derive(Debug, Clone, Default)]
740pub struct TokenBuilder {
741    header: Header,
742    claims: Claims,
743}
744
745impl TokenBuilder {
746    /// Create a new token builder
747    pub fn new() -> Self {
748        Self::default()
749    }
750
751    /// Set the algorithm
752    pub fn algorithm(mut self, alg: Algorithm) -> Self {
753        self.header = self.header.with_algorithm(alg);
754        self
755    }
756
757    /// Set the key identifier in the protected header
758    pub fn protected_key_id(mut self, kid: KeyId) -> Self {
759        self.header = self.header.with_protected_key_id(kid);
760        self
761    }
762
763    /// Set the key identifier in the unprotected header
764    pub fn unprotected_key_id(mut self, kid: KeyId) -> Self {
765        self.header = self.header.with_unprotected_key_id(kid);
766        self
767    }
768
769    /// Set the registered claims
770    pub fn registered_claims(mut self, claims: RegisteredClaims) -> Self {
771        self.claims = self.claims.with_registered_claims(claims);
772        self
773    }
774
775    /// Add a custom claim with a string value
776    pub fn custom_string<S: Into<String>>(mut self, key: i32, value: S) -> Self {
777        self.claims = self.claims.with_custom_string(key, value);
778        self
779    }
780
781    /// Add a custom claim with a binary value
782    pub fn custom_binary<B: Into<Vec<u8>>>(mut self, key: i32, value: B) -> Self {
783        self.claims = self.claims.with_custom_binary(key, value);
784        self
785    }
786
787    /// Add a custom claim with an integer value
788    pub fn custom_int(mut self, key: i32, value: i64) -> Self {
789        self.claims = self.claims.with_custom_int(key, value);
790        self
791    }
792
793    /// Add a custom claim with a nested map value
794    pub fn custom_map(mut self, key: i32, value: BTreeMap<i32, CborValue>) -> Self {
795        self.claims = self.claims.with_custom_map(key, value);
796        self
797    }
798
799    /// Add a custom claim with a CborValue directly
800    pub fn custom_cbor(mut self, key: i32, value: CborValue) -> Self {
801        self.claims.custom.insert(key, value);
802        self
803    }
804
805    /// Add a custom claim with an array value
806    pub fn custom_array(mut self, key: i32, value: Vec<CborValue>) -> Self {
807        self.claims.custom.insert(key, CborValue::Array(value));
808        self
809    }
810
811    /// Build and sign the token
812    pub fn sign(self, key: &[u8]) -> Result<Token, Error> {
813        // Ensure we have an algorithm
814        let alg = self.header.algorithm().ok_or_else(|| {
815            Error::InvalidFormat("Missing algorithm in protected header".to_string())
816        })?;
817
818        // Create token without signature
819        let token = Token {
820            header: self.header,
821            claims: self.claims,
822            signature: Vec::new(),
823            original_payload_bytes: None,
824        };
825
826        // Compute signature input based on algorithm
827        // HMAC algorithms use COSE_Mac0 structure, others use COSE_Sign1
828        let (_signature_input, signature) = match alg {
829            Algorithm::HmacSha256 => {
830                let mac_input = token.mac0_input()?;
831                let mac = compute_hmac_sha256(key, &mac_input);
832                (mac_input, mac)
833            }
834        };
835
836        // Create final token with signature
837        Ok(Token {
838            header: token.header,
839            claims: token.claims,
840            signature,
841            original_payload_bytes: None,
842        })
843    }
844}
845
846// Helper functions for CBOR encoding/decoding
847
848fn encode_map(map: &HeaderMap) -> Result<Vec<u8>, Error> {
849    let mut buf = Vec::new();
850    let mut enc = Encoder::new(&mut buf);
851
852    encode_map_direct(map, &mut enc)?;
853
854    Ok(buf)
855}
856
857/// Encode a CBOR value directly to the encoder
858fn encode_cbor_value(value: &CborValue, enc: &mut Encoder<&mut Vec<u8>>) -> Result<(), Error> {
859    match value {
860        CborValue::Integer(i) => {
861            enc.i64(*i)?;
862        }
863        CborValue::Bytes(b) => {
864            enc.bytes(b)?;
865        }
866        CborValue::Text(s) => {
867            enc.str(s)?;
868        }
869        CborValue::Map(nested_map) => {
870            // Create a nested encoder for the map
871            encode_map_direct(nested_map, enc)?;
872        }
873        CborValue::Array(arr) => {
874            // Create a nested encoder for the array
875            enc.array(arr.len() as u64)?;
876            for item in arr {
877                encode_cbor_value(item, enc)?;
878            }
879        }
880        CborValue::Null => {
881            enc.null()?;
882        }
883    }
884    Ok(())
885}
886
887fn encode_map_direct(map: &HeaderMap, enc: &mut Encoder<&mut Vec<u8>>) -> Result<(), Error> {
888    enc.map(map.len() as u64)?;
889
890    for (key, value) in map {
891        enc.i32(*key)?;
892        encode_cbor_value(value, enc)?;
893    }
894
895    Ok(())
896}
897
898fn decode_map(bytes: &[u8]) -> Result<HeaderMap, Error> {
899    let mut dec = Decoder::new(bytes);
900    decode_map_direct(&mut dec)
901}
902
903/// Decode a CBOR array
904fn decode_array(dec: &mut Decoder<'_>) -> Result<Vec<CborValue>, Error> {
905    let array_len = dec.array()?.unwrap_or(0);
906    let mut array = Vec::with_capacity(array_len as usize);
907
908    for _ in 0..array_len {
909        // Try to decode based on the datatype
910        let datatype = dec.datatype()?;
911
912        // Handle each type separately
913        let value = if datatype == minicbor::data::Type::Int {
914            // Integer value
915            let i = dec.i64()?;
916            CborValue::Integer(i)
917        } else if datatype == minicbor::data::Type::U8
918            || datatype == minicbor::data::Type::U16
919            || datatype == minicbor::data::Type::U32
920            || datatype == minicbor::data::Type::U64
921        {
922            // Unsigned integer value
923            let i = dec.u64()? as i64;
924            CborValue::Integer(i)
925        } else if datatype == minicbor::data::Type::Bytes {
926            // Byte string
927            let b = dec.bytes()?;
928            CborValue::Bytes(b.to_vec())
929        } else if datatype == minicbor::data::Type::String {
930            // Text string
931            let s = dec.str()?;
932            CborValue::Text(s.to_string())
933        } else if datatype == minicbor::data::Type::Map {
934            // Nested map
935            let nested_map = decode_map_direct(dec)?;
936            CborValue::Map(nested_map)
937        } else if datatype == minicbor::data::Type::Array {
938            // Nested array
939            let nested_array = decode_array(dec)?;
940            CborValue::Array(nested_array)
941        } else if datatype == minicbor::data::Type::Null {
942            // Null value
943            dec.null()?;
944            CborValue::Null
945        } else {
946            // Unsupported type
947            return Err(Error::InvalidFormat(format!(
948                "Unsupported CBOR type in array: {datatype:?}"
949            )));
950        };
951
952        array.push(value);
953    }
954
955    Ok(array)
956}
957
958fn decode_map_direct(dec: &mut Decoder<'_>) -> Result<HeaderMap, Error> {
959    let map_len = dec.map()?.unwrap_or(0);
960    let mut map = HeaderMap::new();
961
962    for _ in 0..map_len {
963        let key = dec.i32()?;
964
965        // Try to decode based on the datatype
966        let datatype = dec.datatype()?;
967
968        // Handle each type separately
969        let value = if datatype == minicbor::data::Type::Int {
970            // Integer value
971            let i = dec.i64()?;
972            CborValue::Integer(i)
973        } else if datatype == minicbor::data::Type::U8
974            || datatype == minicbor::data::Type::U16
975            || datatype == minicbor::data::Type::U32
976            || datatype == minicbor::data::Type::U64
977        {
978            // Unsigned integer value
979            let i = dec.u64()? as i64;
980            CborValue::Integer(i)
981        } else if datatype == minicbor::data::Type::Bytes {
982            // Byte string
983            let b = dec.bytes()?;
984            CborValue::Bytes(b.to_vec())
985        } else if datatype == minicbor::data::Type::String {
986            // Text string
987            let s = dec.str()?;
988            CborValue::Text(s.to_string())
989        } else if datatype == minicbor::data::Type::Map {
990            // Nested map
991            let nested_map = decode_map_direct(dec)?;
992            CborValue::Map(nested_map)
993        } else if datatype == minicbor::data::Type::Array {
994            // Array
995            let array = decode_array(dec)?;
996            CborValue::Array(array)
997        } else if datatype == minicbor::data::Type::Null {
998            // Null value
999            dec.null()?;
1000            CborValue::Null
1001        } else {
1002            // Unsupported type
1003            return Err(Error::InvalidFormat(format!(
1004                "Unsupported CBOR type: {datatype:?}"
1005            )));
1006        };
1007
1008        map.insert(key, value);
1009    }
1010
1011    Ok(map)
1012}