1#![cfg_attr(not(feature = "std"), no_std)]
3
4#[cfg(feature = "alloc")]
7extern crate alloc;
8
9mod error;
10mod hash_algorithm;
11
12pub use error::{KbsTypesError, Result};
13pub use hash_algorithm::HashAlgorithm;
14
15#[cfg(all(feature = "alloc", not(feature = "std")))]
16use alloc::{string::String, vec::Vec};
17use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
18#[cfg(feature = "std")]
19use ear::{self, RawValue};
20use serde::{Deserialize, Serialize};
21use serde_json::{Map, Value};
22#[cfg(all(feature = "std", not(feature = "alloc")))]
23use std::string::String;
24
25#[derive(Serialize, Clone, Copy, Deserialize, Debug, Eq, Hash, PartialEq)]
26#[serde(rename_all = "lowercase")]
27pub enum Tee {
28 #[serde(rename = "az-snp-vtpm")]
30 AzSnpVtpm,
31 #[serde(rename = "az-tdx-vtpm")]
32 AzTdxVtpm,
33 Nvidia,
34 Sgx,
35 Snp,
36 Tdx,
37 Cca,
39 Csv,
41 Se,
43
44 HygonDcu,
46
47 Tpm,
49
50 Sample,
53 SampleDevice,
54}
55
56#[derive(Clone, Serialize, Deserialize, Debug)]
57pub struct Request {
58 pub version: String,
59 pub tee: Tee,
60 #[serde(rename = "extra-params")]
61 pub extra_params: Value,
62}
63
64#[derive(Clone, Serialize, Deserialize, Debug)]
65pub struct Challenge {
66 pub nonce: String,
67 #[serde(rename = "extra-params")]
68 pub extra_params: Value,
69}
70
71#[derive(Clone, Serialize, Deserialize, Debug)]
72#[serde(tag = "kty")]
73pub enum TeePubKey {
74 RSA {
75 alg: String,
76 #[serde(rename = "n")]
77 k_mod: String,
78 #[serde(rename = "e")]
79 k_exp: String,
80 },
81 EC {
85 crv: String,
86 alg: String,
87 x: String,
88 y: String,
89 },
90}
91
92#[cfg(feature = "std")]
93impl From<&TeePubKey> for ear::RawValue {
94 fn from(tpk: &TeePubKey) -> RawValue {
95 let mut map: Vec<(RawValue, RawValue)> = vec![];
96
97 match tpk {
98 TeePubKey::RSA { alg, k_mod, k_exp } => {
99 map.push((
100 RawValue::String("kty".to_string()),
101 RawValue::String("RSA".to_string()),
102 ));
103 map.push((
104 RawValue::String("alg".to_string()),
105 RawValue::String(alg.clone()),
106 ));
107 map.push((
108 RawValue::String("n".to_string()),
109 RawValue::String(k_mod.clone()),
110 ));
111 map.push((
112 RawValue::String("e".to_string()),
113 RawValue::String(k_exp.clone()),
114 ));
115 }
116 TeePubKey::EC { crv, alg, x, y } => {
117 map.push((
118 RawValue::String("kty".to_string()),
119 RawValue::String("EC".to_string()),
120 ));
121 map.push((
122 RawValue::String("crv".to_string()),
123 RawValue::String(crv.clone()),
124 ));
125 map.push((
126 RawValue::String("alg".to_string()),
127 RawValue::String(alg.clone()),
128 ));
129 map.push((
130 RawValue::String("x".to_string()),
131 RawValue::String(x.clone()),
132 ));
133 map.push((
134 RawValue::String("y".to_string()),
135 RawValue::String(y.clone()),
136 ));
137 }
138 }
139
140 RawValue::Map(map)
141 }
142}
143
144#[derive(Clone, Debug, Deserialize, Serialize)]
147pub struct RuntimeData {
148 pub nonce: String,
150
151 #[serde(rename = "tee-pubkey")]
153 pub tee_pubkey: TeePubKey,
154}
155
156#[derive(Clone, Debug, Deserialize, Serialize)]
158pub struct CompositeEvidence {
159 pub primary_evidence: Value,
161
162 pub additional_evidence: String,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize)]
172pub struct InitData {
173 pub format: String,
175
176 pub body: String,
178}
179
180#[derive(Clone, Serialize, Deserialize, Debug)]
181#[serde(rename_all = "kebab-case")]
182pub struct Attestation {
183 pub init_data: Option<InitData>,
184 pub runtime_data: RuntimeData,
185 pub tee_evidence: CompositeEvidence,
186}
187
188#[derive(Clone, Serialize, Deserialize, Debug)]
189pub struct ProtectedHeader {
190 pub alg: String,
192 pub enc: String,
194
195 #[serde(skip_serializing_if = "Map::is_empty", flatten)]
197 pub other_fields: Map<String, Value>,
198}
199
200impl ProtectedHeader {
201 pub fn generate_aad(&self) -> Result<Vec<u8>> {
203 let protected_utf8 = serde_json::to_string(&self).map_err(|_| KbsTypesError::Serde)?;
204 let aad = BASE64_URL_SAFE_NO_PAD.encode(protected_utf8);
205 Ok(aad.into_bytes())
206 }
207}
208
209fn serialize_base64_protected_header<S>(
210 sub: &ProtectedHeader,
211 serializer: S,
212) -> core::result::Result<S::Ok, S::Error>
213where
214 S: serde::Serializer,
215{
216 let protected_header_json = serde_json::to_string(sub).map_err(serde::ser::Error::custom)?;
217 let encoded = BASE64_URL_SAFE_NO_PAD.encode(protected_header_json);
218 serializer.serialize_str(&encoded)
219}
220
221fn deserialize_base64_protected_header<'de, D>(
222 deserializer: D,
223) -> core::result::Result<ProtectedHeader, D::Error>
224where
225 D: serde::Deserializer<'de>,
226{
227 let encoded = String::deserialize(deserializer)?;
228 let decoded = BASE64_URL_SAFE_NO_PAD
229 .decode(encoded)
230 .map_err(serde::de::Error::custom)?;
231 let protected_header = serde_json::from_slice(&decoded).map_err(serde::de::Error::custom)?;
232
233 Ok(protected_header)
234}
235
236fn serialize_base64<S>(sub: &Vec<u8>, serializer: S) -> core::result::Result<S::Ok, S::Error>
237where
238 S: serde::Serializer,
239{
240 let encoded = BASE64_URL_SAFE_NO_PAD.encode(sub);
241 serializer.serialize_str(&encoded)
242}
243
244fn deserialize_base64<'de, D>(deserializer: D) -> core::result::Result<Vec<u8>, D::Error>
245where
246 D: serde::Deserializer<'de>,
247{
248 let encoded = String::deserialize(deserializer)?;
249 let decoded = BASE64_URL_SAFE_NO_PAD
250 .decode(encoded)
251 .map_err(serde::de::Error::custom)?;
252
253 Ok(decoded)
254}
255
256fn serialize_base64_vec<S>(
257 sub: &Option<Vec<u8>>,
258 serializer: S,
259) -> core::result::Result<S::Ok, S::Error>
260where
261 S: serde::Serializer,
262{
263 match sub {
264 Some(value) => {
265 let encoded = String::from_utf8(value.clone()).map_err(serde::ser::Error::custom)?;
266 serializer.serialize_str(&encoded)
267 }
268 None => serializer.serialize_none(),
269 }
270}
271
272fn deserialize_base64_vec<'de, D>(
273 deserializer: D,
274) -> core::result::Result<Option<Vec<u8>>, D::Error>
275where
276 D: serde::Deserializer<'de>,
277{
278 let string = String::deserialize(deserializer)?;
279 let bytes = string.into_bytes();
280
281 Ok(Some(bytes))
282}
283
284#[derive(Clone, Serialize, Deserialize, Debug)]
285pub struct Response {
286 #[serde(
287 serialize_with = "serialize_base64_protected_header",
288 deserialize_with = "deserialize_base64_protected_header"
289 )]
290 pub protected: ProtectedHeader,
291
292 #[serde(
293 serialize_with = "serialize_base64",
294 deserialize_with = "deserialize_base64"
295 )]
296 pub encrypted_key: Vec<u8>,
297
298 #[serde(
299 skip_serializing_if = "Option::is_none",
300 default = "Option::default",
301 serialize_with = "serialize_base64_vec",
302 deserialize_with = "deserialize_base64_vec"
303 )]
304 pub aad: Option<Vec<u8>>,
305
306 #[serde(
307 serialize_with = "serialize_base64",
308 deserialize_with = "deserialize_base64"
309 )]
310 pub iv: Vec<u8>,
311
312 #[serde(
313 serialize_with = "serialize_base64",
314 deserialize_with = "deserialize_base64"
315 )]
316 pub ciphertext: Vec<u8>,
317
318 #[serde(
319 serialize_with = "serialize_base64",
320 deserialize_with = "deserialize_base64"
321 )]
322 pub tag: Vec<u8>,
323}
324
325#[derive(Clone, Serialize, Deserialize, Debug)]
326pub struct ErrorInformation {
327 #[serde(rename = "type")]
328 pub error_type: String,
329 pub detail: String,
330}
331
332#[cfg(test)]
333mod tests {
334 use serde_json::json;
335
336 use crate::*;
337
338 #[cfg(all(feature = "alloc", not(feature = "std")))]
339 use alloc::string::ToString;
340
341 #[test]
342 fn parse_request() {
343 let data = r#"
344 {
345 "version": "0.0.0",
346 "tee": "tdx",
347 "extra-params": ""
348 }"#;
349
350 let request: Request = serde_json::from_str(data).unwrap();
351
352 assert_eq!(request.version, "0.0.0");
353 assert_eq!(request.tee, Tee::Tdx);
354 assert_eq!(request.extra_params, "");
355 }
356
357 #[test]
358 fn parse_challenge() {
359 let data = r#"
360 {
361 "nonce": "42",
362 "extra-params": ""
363 }"#;
364
365 let challenge: Challenge = serde_json::from_str(data).unwrap();
366
367 assert_eq!(challenge.nonce, "42");
368 assert_eq!(challenge.extra_params, "");
369 }
370
371 #[test]
372 fn protected_header_generate_aad() {
373 let protected_header = ProtectedHeader {
374 alg: "fakealg".to_string(),
375 enc: "fakeenc".to_string(),
376 other_fields: Map::new(),
377 };
378
379 let aad = protected_header.generate_aad().unwrap();
380
381 assert_eq!(
382 aad,
383 "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9".as_bytes()
384 );
385 }
386
387 #[test]
388 fn parse_response() {
389 let data = r#"
390 {
391 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9",
392 "encrypted_key": "ZmFrZWtleQ",
393 "iv": "cmFuZG9tZGF0YQ",
394 "ciphertext": "ZmFrZWVuY291dHB1dA",
395 "tag": "ZmFrZXRhZw"
396 }"#;
397
398 let response: Response = serde_json::from_str(data).unwrap();
399
400 assert_eq!(response.protected.alg, "fakealg");
401 assert_eq!(response.protected.enc, "fakeenc");
402 assert!(response.protected.other_fields.is_empty());
403 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
404 assert_eq!(response.iv, "randomdata".as_bytes());
405 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
406 assert_eq!(response.tag, "faketag".as_bytes());
407 assert_eq!(response.aad, None);
408 }
409
410 #[test]
411 fn parse_response_nested_protected_header() {
412 let data = r#"
413 {
414 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImVwayI6eyJrdHkiOiJPS1AiLCJjcnYiOiJYMjU1MTkiLCJ4IjoiaFNEd0NZa3dwMVIwaTMzY3RENzNXZzJfT2cwbU9CcjA2NlNwanFxYlRtbyJ9fQo",
415 "encrypted_key": "ZmFrZWtleQ",
416 "iv": "cmFuZG9tZGF0YQ",
417 "ciphertext": "ZmFrZWVuY291dHB1dA",
418 "tag": "ZmFrZXRhZw"
419 }"#;
420
421 let response: Response = serde_json::from_str(data).unwrap();
422
423 assert_eq!(response.protected.alg, "fakealg");
424 assert_eq!(response.protected.enc, "fakeenc");
425
426 let expected_other_fields = json!({
427 "epk": {
428 "kty" : "OKP",
429 "crv": "X25519",
430 "x": "hSDwCYkwp1R0i33ctD73Wg2_Og0mOBr066SpjqqbTmo"
431 }
432 })
433 .as_object()
434 .unwrap()
435 .clone();
436
437 assert_eq!(response.protected.other_fields, expected_other_fields);
438 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
439 assert_eq!(response.iv, "randomdata".as_bytes());
440 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
441 assert_eq!(response.tag, "faketag".as_bytes());
442 assert_eq!(response.aad, None);
443 }
444
445 #[test]
446 fn parse_response_with_aad() {
447 let data = r#"
448 {
449 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9Cg",
450 "encrypted_key": "ZmFrZWtleQ",
451 "iv": "cmFuZG9tZGF0YQ",
452 "aad": "fakeaad",
453 "ciphertext": "ZmFrZWVuY291dHB1dA",
454 "tag": "ZmFrZXRhZw"
455 }"#;
456
457 let response: Response = serde_json::from_str(data).unwrap();
458
459 assert_eq!(response.protected.alg, "fakealg");
460 assert_eq!(response.protected.enc, "fakeenc");
461 assert!(response.protected.other_fields.is_empty());
462 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
463 assert_eq!(response.iv, "randomdata".as_bytes());
464 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
465 assert_eq!(response.tag, "faketag".as_bytes());
466 assert_eq!(response.aad, Some("fakeaad".into()));
467 }
468
469 #[test]
470 fn parse_response_with_protectedheader() {
471 let data = r#"
472 {
473 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9",
474 "encrypted_key": "ZmFrZWtleQ",
475 "iv": "cmFuZG9tZGF0YQ",
476 "aad": "fakeaad",
477 "ciphertext": "ZmFrZWVuY291dHB1dA",
478 "tag": "ZmFrZXRhZw"
479 }"#;
480
481 let response: Response = serde_json::from_str(data).unwrap();
482
483 assert_eq!(response.protected.alg, "fakealg");
484 assert_eq!(response.protected.enc, "fakeenc");
485 assert_eq!(response.protected.other_fields["fakefield"], "fakevalue");
486 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
487 assert_eq!(response.iv, "randomdata".as_bytes());
488 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
489 assert_eq!(response.tag, "faketag".as_bytes());
490 assert_eq!(response.aad, Some("fakeaad".into()));
491 }
492
493 #[test]
494 fn serialize_response() {
495 let response = Response {
496 protected: ProtectedHeader {
497 alg: "fakealg".into(),
498 enc: "fakeenc".into(),
499 other_fields: [("fakefield".into(), "fakevalue".into())]
500 .into_iter()
501 .collect(),
502 },
503 encrypted_key: "fakekey".as_bytes().to_vec(),
504 iv: "randomdata".as_bytes().to_vec(),
505 aad: Some("fakeaad".into()),
506 tag: "faketag".as_bytes().to_vec(),
507 ciphertext: "fakeencoutput".as_bytes().to_vec(),
508 };
509
510 let expected = json!({
511 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9",
512 "encrypted_key": "ZmFrZWtleQ",
513 "iv": "cmFuZG9tZGF0YQ",
514 "aad": "fakeaad",
515 "ciphertext": "ZmFrZWVuY291dHB1dA",
516 "tag": "ZmFrZXRhZw"
517 });
518
519 let serialized = serde_json::to_value(&response).unwrap();
520 assert_eq!(serialized, expected);
521 }
522
523 #[test]
524 fn parse_attestation_ec() {
525 let data = r#"
526 {
527 "runtime-data": {
528 "nonce": "test_nonce",
529 "tee-pubkey": {
530 "kty": "EC",
531 "crv": "fakecrv",
532 "alg": "fakealgorithm",
533 "x": "fakex",
534 "y": "fakey"
535 }
536 },
537 "tee-evidence": {
538 "primary_evidence": "test_primary_evidence",
539 "additional_evidence": "test_additional_evidence"
540 }
541 }"#;
542
543 let attestation: Attestation = serde_json::from_str(data).unwrap();
544 let tee_pubkey = attestation.runtime_data.tee_pubkey;
545
546 let TeePubKey::EC { alg, crv, x, y } = tee_pubkey else {
547 panic!("Must be an EC key");
548 };
549
550 assert_eq!(alg, "fakealgorithm");
551 assert_eq!(crv, "fakecrv");
552 assert_eq!(x, "fakex");
553 assert_eq!(y, "fakey");
554 assert_eq!(
555 attestation.tee_evidence.primary_evidence,
556 "test_primary_evidence"
557 );
558 assert_eq!(
559 attestation.tee_evidence.additional_evidence,
560 "test_additional_evidence"
561 );
562 }
563
564 #[test]
565 fn parse_attestation_rsa() {
566 let data = r#"
567 {
568 "runtime-data": {
569 "nonce": "test_nonce",
570 "tee-pubkey": {
571 "kty": "RSA",
572 "alg": "fakealgorithm",
573 "n": "fakemodulus",
574 "e": "fakeexponent"
575 }
576 },
577 "tee-evidence": {
578 "primary_evidence": "test_primary_evidence",
579 "additional_evidence": "test_additional_evidence"
580 }
581 }"#;
582
583 let attestation: Attestation = serde_json::from_str(data).unwrap();
584 let tee_pubkey = attestation.runtime_data.tee_pubkey;
585
586 let TeePubKey::RSA { alg, k_mod, k_exp } = tee_pubkey else {
587 panic!("Must be a RSA key");
588 };
589
590 assert_eq!(attestation.runtime_data.nonce, "test_nonce");
591 assert_eq!(alg, "fakealgorithm");
592 assert_eq!(k_mod, "fakemodulus");
593 assert_eq!(k_exp, "fakeexponent");
594 assert_eq!(
595 attestation.tee_evidence.primary_evidence,
596 "test_primary_evidence"
597 );
598 assert_eq!(
599 attestation.tee_evidence.additional_evidence,
600 "test_additional_evidence"
601 );
602 }
603
604 #[test]
605 fn parse_error_information() {
606 let data = r#"
607 {
608 "type": "problemtype",
609 "detail": "problemdetail"
610 }"#;
611
612 let info: ErrorInformation = serde_json::from_str(data).unwrap();
613
614 assert_eq!(info.error_type, "problemtype");
615 assert_eq!(info.detail, "problemdetail");
616 }
617
618 #[test]
619 #[cfg(feature = "std")]
620 fn tee_pubkey_ear_json_deserialize() {
621 let tpk = TeePubKey::RSA {
623 alg: "test".to_string(),
624 k_mod: "test".to_string(),
625 k_exp: "test".to_string(),
626 };
627 let ear_raw: RawValue = (&tpk).into();
628 let json_str = serde_json::to_string(&ear_raw).unwrap();
629 assert_eq!(json_str, serde_json::to_string(&tpk).unwrap());
630
631 let tpk = TeePubKey::EC {
633 crv: "test".to_string(),
634 alg: "test".to_string(),
635 x: "test".to_string(),
636 y: "test".to_string(),
637 };
638 let ear_raw: RawValue = (&tpk).into();
639 let json_str = serde_json::to_string(&ear_raw).unwrap();
640 assert_eq!(json_str, serde_json::to_string(&tpk).unwrap());
641 }
642}