1use crate::claims::{Claims, RegisteredClaims};
4use crate::error::Error;
5use crate::header::{Algorithm, CborValue, Header, HeaderMap, KeyId};
6use crate::utils::{compute_hmac_sha256, current_timestamp, verify_hmac_sha256};
7use minicbor::{Decoder, Encoder};
8use std::collections::BTreeMap;
9
10#[derive(Debug, Clone)]
12pub struct Token {
13 pub header: Header,
15 pub claims: Claims,
17 pub signature: Vec<u8>,
19}
20
21impl Token {
22 pub fn new(header: Header, claims: Claims, signature: Vec<u8>) -> Self {
24 Self {
25 header,
26 claims,
27 signature,
28 }
29 }
30
31 pub fn to_bytes(&self) -> Result<Vec<u8>, Error> {
33 let mut buf = Vec::new();
41 let mut enc = Encoder::new(&mut buf);
42
43 enc.array(4)?;
45
46 let protected_bytes = encode_map(&self.header.protected)?;
48 enc.bytes(&protected_bytes)?;
49
50 encode_map_direct(&self.header.unprotected, &mut enc)?;
52
53 let claims_map = self.claims.to_map();
55 let claims_bytes = encode_map(&claims_map)?;
56 enc.bytes(&claims_bytes)?;
57
58 enc.bytes(&self.signature)?;
60
61 Ok(buf)
62 }
63
64 pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
70 let mut dec = Decoder::new(bytes);
71
72 if dec.datatype()? == minicbor::data::Type::Tag {
74 let _ = dec.tag()?;
76
77 if dec.datatype()? == minicbor::data::Type::Tag {
79 let _ = dec.tag()?;
80 }
81 }
82
83 let array_len = dec.array()?.unwrap_or(0);
85 if array_len != 4 {
86 return Err(Error::InvalidFormat(format!(
87 "Expected array of length 4, got {}",
88 array_len
89 )));
90 }
91
92 let protected_bytes = dec.bytes()?;
94 let protected = decode_map(protected_bytes)?;
95
96 let unprotected = decode_map_direct(&mut dec)?;
98
99 let header = Header {
101 protected,
102 unprotected,
103 };
104
105 let claims_bytes = dec.bytes()?;
107 let claims_map = decode_map(claims_bytes)?;
108 let claims = Claims::from_map(&claims_map);
109
110 let signature = dec.bytes()?.to_vec();
112
113 Ok(Self {
114 header,
115 claims,
116 signature,
117 })
118 }
119
120 pub fn verify(&self, key: &[u8]) -> Result<(), Error> {
126 let alg = self.header.algorithm().ok_or_else(|| {
127 Error::InvalidFormat("Missing algorithm in protected header".to_string())
128 })?;
129
130 match alg {
131 Algorithm::HmacSha256 => {
132 let sign1_input = self.sign1_input()?;
134 let sign1_result = verify_hmac_sha256(key, &sign1_input, &self.signature);
135
136 if sign1_result.is_ok() {
137 return Ok(());
138 }
139
140 let mac0_input = self.mac0_input()?;
142 verify_hmac_sha256(key, &mac0_input, &self.signature)
143 }
144 }
145 }
146
147 pub fn verify_claims(&self, options: &VerificationOptions) -> Result<(), Error> {
149 let now = current_timestamp();
150
151 if options.verify_exp {
153 if let Some(exp) = self.claims.registered.exp {
154 if now >= exp {
155 return Err(Error::Expired);
156 }
157 } else if options.require_exp {
158 return Err(Error::MissingClaim("exp".to_string()));
159 }
160 }
161
162 if options.verify_nbf {
164 if let Some(nbf) = self.claims.registered.nbf {
165 if now < nbf {
166 return Err(Error::NotYetValid);
167 }
168 }
169 }
170
171 if let Some(expected_iss) = &options.expected_issuer {
173 if let Some(iss) = &self.claims.registered.iss {
174 if iss != expected_iss {
175 return Err(Error::InvalidIssuer);
176 }
177 } else if options.require_iss {
178 return Err(Error::MissingClaim("iss".to_string()));
179 }
180 }
181
182 if let Some(expected_aud) = &options.expected_audience {
184 if let Some(aud) = &self.claims.registered.aud {
185 if aud != expected_aud {
186 return Err(Error::InvalidAudience);
187 }
188 } else if options.require_aud {
189 return Err(Error::MissingClaim("aud".to_string()));
190 }
191 }
192
193 if options.verify_catu {
195 self.verify_catu_claim(options)?;
196 }
197
198 if options.verify_catm {
199 self.verify_catm_claim(options)?;
200 }
201
202 if options.verify_catreplay {
203 self.verify_catreplay_claim(options)?;
204 }
205
206 Ok(())
207 }
208
209 fn verify_catu_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
211 use crate::constants::{cat_keys, uri_components};
212 use url::Url;
213
214 let uri = match &options.uri {
216 Some(uri) => uri,
217 None => {
218 return Err(Error::InvalidClaimValue(
219 "No URI provided for CATU verification".to_string(),
220 ))
221 }
222 };
223
224 let parsed_uri = match Url::parse(uri) {
226 Ok(url) => url,
227 Err(_) => {
228 return Err(Error::InvalidClaimValue(format!(
229 "Invalid URI format: {}",
230 uri
231 )))
232 }
233 };
234
235 let catu_claim = match self.claims.custom.get(&cat_keys::CATU) {
237 Some(claim) => claim,
238 None => return Ok(()), };
240
241 let component_map = match catu_claim {
243 CborValue::Map(map) => map,
244 _ => {
245 return Err(Error::InvalidUriClaim(
246 "CATU claim is not a map".to_string(),
247 ))
248 }
249 };
250
251 for (component_key, component_value) in component_map {
253 match *component_key {
254 uri_components::SCHEME => {
255 self.verify_uri_component(
256 &parsed_uri.scheme().to_string(),
257 component_value,
258 "scheme",
259 )?;
260 }
261 uri_components::HOST => {
262 self.verify_uri_component(
263 &parsed_uri.host_str().unwrap_or("").to_string(),
264 component_value,
265 "host",
266 )?;
267 }
268 uri_components::PORT => {
269 let port = parsed_uri.port().map(|p| p.to_string()).unwrap_or_default();
270 self.verify_uri_component(&port, component_value, "port")?;
271 }
272 uri_components::PATH => {
273 self.verify_uri_component(
274 &parsed_uri.path().to_string(),
275 component_value,
276 "path",
277 )?;
278 }
279 uri_components::QUERY => {
280 let query = parsed_uri.query().unwrap_or("").to_string();
281 self.verify_uri_component(&query, component_value, "query")?;
282 }
283 uri_components::EXTENSION => {
284 let path = parsed_uri.path();
286 let extension = path.split('.').next_back().unwrap_or("").to_string();
287 if !path.contains('.') || path.ends_with('.') {
288 self.verify_uri_component(&"".to_string(), component_value, "extension")?;
290 } else {
291 self.verify_uri_component(
292 &format!(".{}", extension),
293 component_value,
294 "extension",
295 )?;
296 }
297 }
298 _ => {
299 }
301 }
302 }
303
304 Ok(())
305 }
306
307 fn verify_uri_component(
309 &self,
310 component: &String,
311 match_conditions: &CborValue,
312 component_name: &str,
313 ) -> Result<(), Error> {
314 use crate::constants::match_types;
315 use regex::Regex;
316 use sha2::{Digest, Sha256, Sha512};
317
318 let match_map = match match_conditions {
320 CborValue::Map(map) => map,
321 _ => {
322 return Err(Error::InvalidUriClaim(format!(
323 "Match conditions for {} is not a map",
324 component_name
325 )))
326 }
327 };
328
329 for (match_type, match_value) in match_map {
330 match *match_type {
331 match_types::EXACT => {
332 if let CborValue::Text(text) = match_value {
333 if component != text {
334 return Err(Error::InvalidUriClaim(format!(
335 "URI component {} '{}' does not exactly match required value '{}'",
336 component_name, component, text
337 )));
338 }
339 }
340 }
341 match_types::PREFIX => {
342 if let CborValue::Text(prefix) = match_value {
343 if !component.starts_with(prefix) {
344 return Err(Error::InvalidUriClaim(format!(
345 "URI component {} '{}' does not start with required prefix '{}'",
346 component_name, component, prefix
347 )));
348 }
349 }
350 }
351 match_types::SUFFIX => {
352 if let CborValue::Text(suffix) = match_value {
353 if !component.ends_with(suffix) {
354 return Err(Error::InvalidUriClaim(format!(
355 "URI component {} '{}' does not end with required suffix '{}'",
356 component_name, component, suffix
357 )));
358 }
359 }
360 }
361 match_types::CONTAINS => {
362 if let CborValue::Text(contained) = match_value {
363 if !component.contains(contained) {
364 return Err(Error::InvalidUriClaim(format!(
365 "URI component {} '{}' does not contain required text '{}'",
366 component_name, component, contained
367 )));
368 }
369 }
370 }
371 match_types::REGEX => {
372 if let CborValue::Array(array) = match_value {
373 if let Some(CborValue::Text(pattern)) = array.first() {
374 match Regex::new(pattern) {
375 Ok(regex) => {
376 if !regex.is_match(component) {
377 return Err(Error::InvalidUriClaim(format!(
378 "URI component {} '{}' does not match required regex pattern '{}'",
379 component_name, component, pattern
380 )));
381 }
382 }
383 Err(_) => {
384 return Err(Error::InvalidUriClaim(format!(
385 "Invalid regex pattern: {}",
386 pattern
387 )))
388 }
389 }
390 }
391 }
392 }
393 match_types::SHA256 => {
394 if let CborValue::Bytes(expected_hash) = match_value {
395 let mut hasher = Sha256::new();
396 hasher.update(component.as_bytes());
397 let hash = hasher.finalize();
398
399 if hash.as_slice() != expected_hash.as_slice() {
400 return Err(Error::InvalidUriClaim(format!(
401 "URI component {} '{}' SHA-256 hash does not match expected value",
402 component_name, component
403 )));
404 }
405 }
406 }
407 match_types::SHA512_256 => {
408 if let CborValue::Bytes(expected_hash) = match_value {
409 let mut hasher = Sha512::new();
410 hasher.update(component.as_bytes());
411 let hash = hasher.finalize();
412 let truncated_hash = &hash[0..32]; if truncated_hash != expected_hash.as_slice() {
415 return Err(Error::InvalidUriClaim(format!(
416 "URI component {} '{}' SHA-512/256 hash does not match expected value",
417 component_name, component
418 )));
419 }
420 }
421 }
422 _ => {
423 }
425 }
426 }
427
428 Ok(())
429 }
430
431 fn verify_catm_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
433 use crate::constants::cat_keys;
434
435 let method = match &options.http_method {
437 Some(method) => method,
438 None => {
439 return Err(Error::InvalidClaimValue(
440 "No HTTP method provided for CATM verification".to_string(),
441 ))
442 }
443 };
444
445 let catm_claim = match self.claims.custom.get(&cat_keys::CATM) {
447 Some(claim) => claim,
448 None => return Ok(()), };
450
451 let allowed_methods = match catm_claim {
453 CborValue::Array(methods) => methods,
454 _ => {
455 return Err(Error::InvalidMethodClaim(
456 "CATM claim is not an array".to_string(),
457 ))
458 }
459 };
460
461 let method_upper = method.to_uppercase();
463 let method_allowed = allowed_methods.iter().any(|m| {
464 if let CborValue::Text(allowed) = m {
465 allowed.to_uppercase() == method_upper
466 } else {
467 false
468 }
469 });
470
471 if !method_allowed {
472 return Err(Error::InvalidMethodClaim(format!(
473 "HTTP method '{}' is not allowed. Permitted methods: {:?}",
474 method,
475 allowed_methods
476 .iter()
477 .filter_map(|m| if let CborValue::Text(t) = m {
478 Some(t.as_str())
479 } else {
480 None
481 })
482 .collect::<Vec<&str>>()
483 )));
484 }
485
486 Ok(())
487 }
488
489 fn verify_catreplay_claim(&self, options: &VerificationOptions) -> Result<(), Error> {
491 use crate::constants::{cat_keys, replay_values};
492
493 let catreplay_claim = match self.claims.custom.get(&cat_keys::CATREPLAY) {
495 Some(claim) => claim,
496 None => return Ok(()), };
498
499 let replay_value = match catreplay_claim {
501 CborValue::Integer(value) => *value as i32,
502 _ => {
503 return Err(Error::InvalidClaimValue(
504 "CATREPLAY claim is not an integer".to_string(),
505 ))
506 }
507 };
508
509 match replay_value {
510 replay_values::PERMITTED => {
511 Ok(())
513 }
514 replay_values::PROHIBITED => {
515 if options.token_seen_before {
517 Err(Error::ReplayViolation(
518 "Token replay is prohibited".to_string(),
519 ))
520 } else {
521 Ok(())
522 }
523 }
524 replay_values::REUSE_DETECTION => {
525 Ok(())
528 }
529 _ => Err(Error::InvalidClaimValue(format!(
530 "Invalid CATREPLAY value: {}",
531 replay_value
532 ))),
533 }
534 }
535
536 fn signature_input(&self) -> Result<Vec<u8>, Error> {
538 self.sign1_input()
540 }
541
542 fn sign1_input(&self) -> Result<Vec<u8>, Error> {
544 let mut buf = Vec::new();
552 let mut enc = Encoder::new(&mut buf);
553
554 enc.array(4)?;
556
557 enc.str("Signature1")?;
559
560 let protected_bytes = encode_map(&self.header.protected)?;
562 enc.bytes(&protected_bytes)?;
563
564 enc.bytes(&[])?;
566
567 let claims_map = self.claims.to_map();
569 let claims_bytes = encode_map(&claims_map)?;
570 enc.bytes(&claims_bytes)?;
571
572 Ok(buf)
573 }
574
575 fn mac0_input(&self) -> Result<Vec<u8>, Error> {
577 let mut buf = Vec::new();
585 let mut enc = Encoder::new(&mut buf);
586
587 enc.array(4)?;
589
590 enc.str("MAC0")?;
592
593 let protected_bytes = encode_map(&self.header.protected)?;
595 enc.bytes(&protected_bytes)?;
596
597 enc.bytes(&[])?;
599
600 let claims_map = self.claims.to_map();
602 let claims_bytes = encode_map(&claims_map)?;
603 enc.bytes(&claims_bytes)?;
604
605 Ok(buf)
606 }
607}
608
609#[derive(Debug, Clone, Default)]
611pub struct VerificationOptions {
612 pub verify_exp: bool,
614 pub require_exp: bool,
616 pub verify_nbf: bool,
618 pub expected_issuer: Option<String>,
620 pub require_iss: bool,
622 pub expected_audience: Option<String>,
624 pub require_aud: bool,
626 pub verify_catu: bool,
628 pub uri: Option<String>,
630 pub verify_catm: bool,
632 pub http_method: Option<String>,
634 pub verify_catreplay: bool,
636 pub token_seen_before: bool,
638}
639
640impl VerificationOptions {
641 pub fn new() -> Self {
643 Self {
644 verify_exp: true,
645 require_exp: false,
646 verify_nbf: true,
647 expected_issuer: None,
648 require_iss: false,
649 expected_audience: None,
650 require_aud: false,
651 verify_catu: false,
652 uri: None,
653 verify_catm: false,
654 http_method: None,
655 verify_catreplay: false,
656 token_seen_before: false,
657 }
658 }
659
660 pub fn verify_exp(mut self, verify: bool) -> Self {
662 self.verify_exp = verify;
663 self
664 }
665
666 pub fn require_exp(mut self, require: bool) -> Self {
668 self.require_exp = require;
669 self
670 }
671
672 pub fn verify_nbf(mut self, verify: bool) -> Self {
674 self.verify_nbf = verify;
675 self
676 }
677
678 pub fn expected_issuer<S: Into<String>>(mut self, issuer: S) -> Self {
680 self.expected_issuer = Some(issuer.into());
681 self
682 }
683
684 pub fn require_iss(mut self, require: bool) -> Self {
686 self.require_iss = require;
687 self
688 }
689
690 pub fn expected_audience<S: Into<String>>(mut self, audience: S) -> Self {
692 self.expected_audience = Some(audience.into());
693 self
694 }
695
696 pub fn require_aud(mut self, require: bool) -> Self {
698 self.require_aud = require;
699 self
700 }
701
702 pub fn verify_catu(mut self, verify: bool) -> Self {
704 self.verify_catu = verify;
705 self
706 }
707
708 pub fn uri<S: Into<String>>(mut self, uri: S) -> Self {
710 self.uri = Some(uri.into());
711 self
712 }
713
714 pub fn verify_catm(mut self, verify: bool) -> Self {
716 self.verify_catm = verify;
717 self
718 }
719
720 pub fn http_method<S: Into<String>>(mut self, method: S) -> Self {
722 self.http_method = Some(method.into());
723 self
724 }
725
726 pub fn verify_catreplay(mut self, verify: bool) -> Self {
728 self.verify_catreplay = verify;
729 self
730 }
731
732 pub fn token_seen_before(mut self, seen: bool) -> Self {
734 self.token_seen_before = seen;
735 self
736 }
737}
738
739#[derive(Debug, Clone, Default)]
741pub struct TokenBuilder {
742 header: Header,
743 claims: Claims,
744}
745
746impl TokenBuilder {
747 pub fn new() -> Self {
749 Self::default()
750 }
751
752 pub fn algorithm(mut self, alg: Algorithm) -> Self {
754 self.header = self.header.with_algorithm(alg);
755 self
756 }
757
758 pub fn protected_key_id(mut self, kid: KeyId) -> Self {
760 self.header = self.header.with_protected_key_id(kid);
761 self
762 }
763
764 pub fn unprotected_key_id(mut self, kid: KeyId) -> Self {
766 self.header = self.header.with_unprotected_key_id(kid);
767 self
768 }
769
770 pub fn registered_claims(mut self, claims: RegisteredClaims) -> Self {
772 self.claims = self.claims.with_registered_claims(claims);
773 self
774 }
775
776 pub fn custom_string<S: Into<String>>(mut self, key: i32, value: S) -> Self {
778 self.claims = self.claims.with_custom_string(key, value);
779 self
780 }
781
782 pub fn custom_binary<B: Into<Vec<u8>>>(mut self, key: i32, value: B) -> Self {
784 self.claims = self.claims.with_custom_binary(key, value);
785 self
786 }
787
788 pub fn custom_int(mut self, key: i32, value: i64) -> Self {
790 self.claims = self.claims.with_custom_int(key, value);
791 self
792 }
793
794 pub fn custom_map(mut self, key: i32, value: BTreeMap<i32, CborValue>) -> Self {
796 self.claims = self.claims.with_custom_map(key, value);
797 self
798 }
799
800 pub fn custom_cbor(mut self, key: i32, value: CborValue) -> Self {
802 self.claims.custom.insert(key, value);
803 self
804 }
805
806 pub fn custom_array(mut self, key: i32, value: Vec<CborValue>) -> Self {
808 self.claims.custom.insert(key, CborValue::Array(value));
809 self
810 }
811
812 pub fn sign(self, key: &[u8]) -> Result<Token, Error> {
814 let alg = self.header.algorithm().ok_or_else(|| {
816 Error::InvalidFormat("Missing algorithm in protected header".to_string())
817 })?;
818
819 let token = Token {
821 header: self.header,
822 claims: self.claims,
823 signature: Vec::new(),
824 };
825
826 let signature_input = token.signature_input()?;
828
829 let signature = match alg {
831 Algorithm::HmacSha256 => compute_hmac_sha256(key, &signature_input),
832 };
833
834 Ok(Token {
836 header: token.header,
837 claims: token.claims,
838 signature,
839 })
840 }
841}
842
843fn encode_map(map: &HeaderMap) -> Result<Vec<u8>, Error> {
846 let mut buf = Vec::new();
847 let mut enc = Encoder::new(&mut buf);
848
849 encode_map_direct(map, &mut enc)?;
850
851 Ok(buf)
852}
853
854fn encode_cbor_value(value: &CborValue, enc: &mut Encoder<&mut Vec<u8>>) -> Result<(), Error> {
856 match value {
857 CborValue::Integer(i) => {
858 enc.i64(*i)?;
859 }
860 CborValue::Bytes(b) => {
861 enc.bytes(b)?;
862 }
863 CborValue::Text(s) => {
864 enc.str(s)?;
865 }
866 CborValue::Map(nested_map) => {
867 encode_map_direct(nested_map, enc)?;
869 }
870 CborValue::Array(arr) => {
871 enc.array(arr.len() as u64)?;
873 for item in arr {
874 encode_cbor_value(item, enc)?;
875 }
876 }
877 CborValue::Null => {
878 enc.null()?;
879 }
880 }
881 Ok(())
882}
883
884fn encode_map_direct(map: &HeaderMap, enc: &mut Encoder<&mut Vec<u8>>) -> Result<(), Error> {
885 enc.map(map.len() as u64)?;
886
887 for (key, value) in map {
888 enc.i32(*key)?;
889 encode_cbor_value(value, enc)?;
890 }
891
892 Ok(())
893}
894
895fn decode_map(bytes: &[u8]) -> Result<HeaderMap, Error> {
896 let mut dec = Decoder::new(bytes);
897 decode_map_direct(&mut dec)
898}
899
900fn decode_array(dec: &mut Decoder<'_>) -> Result<Vec<CborValue>, Error> {
902 let array_len = dec.array()?.unwrap_or(0);
903 let mut array = Vec::with_capacity(array_len as usize);
904
905 for _ in 0..array_len {
906 let datatype = dec.datatype()?;
908
909 let value = if datatype == minicbor::data::Type::Int {
911 let i = dec.i64()?;
913 CborValue::Integer(i)
914 } else if datatype == minicbor::data::Type::U8
915 || datatype == minicbor::data::Type::U16
916 || datatype == minicbor::data::Type::U32
917 || datatype == minicbor::data::Type::U64
918 {
919 let i = dec.u64()? as i64;
921 CborValue::Integer(i)
922 } else if datatype == minicbor::data::Type::Bytes {
923 let b = dec.bytes()?;
925 CborValue::Bytes(b.to_vec())
926 } else if datatype == minicbor::data::Type::String {
927 let s = dec.str()?;
929 CborValue::Text(s.to_string())
930 } else if datatype == minicbor::data::Type::Map {
931 let nested_map = decode_map_direct(dec)?;
933 CborValue::Map(nested_map)
934 } else if datatype == minicbor::data::Type::Array {
935 let nested_array = decode_array(dec)?;
937 CborValue::Array(nested_array)
938 } else if datatype == minicbor::data::Type::Null {
939 dec.null()?;
941 CborValue::Null
942 } else {
943 return Err(Error::InvalidFormat(format!(
945 "Unsupported CBOR type in array: {:?}",
946 datatype
947 )));
948 };
949
950 array.push(value);
951 }
952
953 Ok(array)
954}
955
956fn decode_map_direct(dec: &mut Decoder<'_>) -> Result<HeaderMap, Error> {
957 let map_len = dec.map()?.unwrap_or(0);
958 let mut map = HeaderMap::new();
959
960 for _ in 0..map_len {
961 let key = dec.i32()?;
962
963 let datatype = dec.datatype()?;
965
966 let value = if datatype == minicbor::data::Type::Int {
968 let i = dec.i64()?;
970 CborValue::Integer(i)
971 } else if datatype == minicbor::data::Type::U8
972 || datatype == minicbor::data::Type::U16
973 || datatype == minicbor::data::Type::U32
974 || datatype == minicbor::data::Type::U64
975 {
976 let i = dec.u64()? as i64;
978 CborValue::Integer(i)
979 } else if datatype == minicbor::data::Type::Bytes {
980 let b = dec.bytes()?;
982 CborValue::Bytes(b.to_vec())
983 } else if datatype == minicbor::data::Type::String {
984 let s = dec.str()?;
986 CborValue::Text(s.to_string())
987 } else if datatype == minicbor::data::Type::Map {
988 let nested_map = decode_map_direct(dec)?;
990 CborValue::Map(nested_map)
991 } else if datatype == minicbor::data::Type::Array {
992 let array = decode_array(dec)?;
994 CborValue::Array(array)
995 } else if datatype == minicbor::data::Type::Null {
996 dec.null()?;
998 CborValue::Null
999 } else {
1000 return Err(Error::InvalidFormat(format!(
1002 "Unsupported CBOR type: {:?}",
1003 datatype
1004 )));
1005 };
1006
1007 map.insert(key, value);
1008 }
1009
1010 Ok(map)
1011}