1#![cfg_attr(not(feature = "std"), no_std)]
3
4#[cfg(feature = "alloc")]
7extern crate alloc;
8
9mod error;
10mod hash_algorithm;
11
12pub use error::{KbsTypesError, Result};
13pub use hash_algorithm::HashAlgorithm;
14
15#[cfg(all(feature = "alloc", not(feature = "std")))]
16use alloc::{string::String, vec::Vec};
17use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
18#[cfg(feature = "std")]
19use ear::{self, RawValue};
20use serde::{Deserialize, Serialize};
21use serde_json::{Map, Value};
22#[cfg(all(feature = "std", not(feature = "alloc")))]
23use std::string::String;
24
25#[derive(Serialize, Clone, Copy, Deserialize, Debug, Eq, Hash, PartialEq)]
26#[serde(rename_all = "lowercase")]
27pub enum Tee {
28 #[serde(rename = "az-snp-vtpm")]
30 AzSnpVtpm,
31 #[serde(rename = "az-tdx-vtpm")]
32 AzTdxVtpm,
33 Nvidia,
34 Sev,
35 Sgx,
36 Snp,
37 Tdx,
38 Cca,
40 Csv,
42 Se,
44
45 HygonDcu,
47
48 Tpm,
50
51 Sample,
54 SampleDevice,
55}
56
57#[derive(Clone, Serialize, Deserialize, Debug)]
58pub struct Request {
59 pub version: String,
60 pub tee: Tee,
61 #[serde(rename = "extra-params")]
62 pub extra_params: Value,
63}
64
65#[derive(Clone, Serialize, Deserialize, Debug)]
66pub struct Challenge {
67 pub nonce: String,
68 #[serde(rename = "extra-params")]
69 pub extra_params: Value,
70}
71
72#[derive(Clone, Serialize, Deserialize, Debug)]
73#[serde(tag = "kty")]
74pub enum TeePubKey {
75 RSA {
76 alg: String,
77 #[serde(rename = "n")]
78 k_mod: String,
79 #[serde(rename = "e")]
80 k_exp: String,
81 },
82 EC {
86 crv: String,
87 alg: String,
88 x: String,
89 y: String,
90 },
91}
92
93#[cfg(feature = "std")]
94impl From<&TeePubKey> for ear::RawValue {
95 fn from(tpk: &TeePubKey) -> RawValue {
96 let mut map: Vec<(RawValue, RawValue)> = vec![];
97
98 match tpk {
99 TeePubKey::RSA { alg, k_mod, k_exp } => {
100 map.push((
101 RawValue::String("kty".to_string()),
102 RawValue::String("RSA".to_string()),
103 ));
104 map.push((
105 RawValue::String("alg".to_string()),
106 RawValue::String(alg.clone()),
107 ));
108 map.push((
109 RawValue::String("n".to_string()),
110 RawValue::String(k_mod.clone()),
111 ));
112 map.push((
113 RawValue::String("e".to_string()),
114 RawValue::String(k_exp.clone()),
115 ));
116 }
117 TeePubKey::EC { crv, alg, x, y } => {
118 map.push((
119 RawValue::String("kty".to_string()),
120 RawValue::String("EC".to_string()),
121 ));
122 map.push((
123 RawValue::String("crv".to_string()),
124 RawValue::String(crv.clone()),
125 ));
126 map.push((
127 RawValue::String("alg".to_string()),
128 RawValue::String(alg.clone()),
129 ));
130 map.push((
131 RawValue::String("x".to_string()),
132 RawValue::String(x.clone()),
133 ));
134 map.push((
135 RawValue::String("y".to_string()),
136 RawValue::String(y.clone()),
137 ));
138 }
139 }
140
141 RawValue::Map(map)
142 }
143}
144
145#[derive(Clone, Debug, Deserialize, Serialize)]
148pub struct RuntimeData {
149 pub nonce: String,
151
152 #[serde(rename = "tee-pubkey")]
154 pub tee_pubkey: TeePubKey,
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize)]
159pub struct CompositeEvidence {
160 pub primary_evidence: Value,
162
163 pub additional_evidence: String,
169}
170
171#[derive(Clone, Debug, Deserialize, Serialize)]
173pub struct InitData {
174 pub format: String,
176
177 pub body: String,
179}
180
181#[derive(Clone, Serialize, Deserialize, Debug)]
182#[serde(rename_all = "kebab-case")]
183pub struct Attestation {
184 pub init_data: Option<InitData>,
185 pub runtime_data: RuntimeData,
186 pub tee_evidence: CompositeEvidence,
187}
188
189#[derive(Clone, Serialize, Deserialize, Debug)]
190pub struct ProtectedHeader {
191 pub alg: String,
193 pub enc: String,
195
196 #[serde(skip_serializing_if = "Map::is_empty", flatten)]
198 pub other_fields: Map<String, Value>,
199}
200
201impl ProtectedHeader {
202 pub fn generate_aad(&self) -> Result<Vec<u8>> {
204 let protected_utf8 = serde_json::to_string(&self).map_err(|_| KbsTypesError::Serde)?;
205 let aad = BASE64_URL_SAFE_NO_PAD.encode(protected_utf8);
206 Ok(aad.into_bytes())
207 }
208}
209
210fn serialize_base64_protected_header<S>(
211 sub: &ProtectedHeader,
212 serializer: S,
213) -> core::result::Result<S::Ok, S::Error>
214where
215 S: serde::Serializer,
216{
217 let protected_header_json = serde_json::to_string(sub).map_err(serde::ser::Error::custom)?;
218 let encoded = BASE64_URL_SAFE_NO_PAD.encode(protected_header_json);
219 serializer.serialize_str(&encoded)
220}
221
222fn deserialize_base64_protected_header<'de, D>(
223 deserializer: D,
224) -> core::result::Result<ProtectedHeader, D::Error>
225where
226 D: serde::Deserializer<'de>,
227{
228 let encoded = String::deserialize(deserializer)?;
229 let decoded = BASE64_URL_SAFE_NO_PAD
230 .decode(encoded)
231 .map_err(serde::de::Error::custom)?;
232 let protected_header = serde_json::from_slice(&decoded).map_err(serde::de::Error::custom)?;
233
234 Ok(protected_header)
235}
236
237fn serialize_base64<S>(sub: &Vec<u8>, serializer: S) -> core::result::Result<S::Ok, S::Error>
238where
239 S: serde::Serializer,
240{
241 let encoded = BASE64_URL_SAFE_NO_PAD.encode(sub);
242 serializer.serialize_str(&encoded)
243}
244
245fn deserialize_base64<'de, D>(deserializer: D) -> core::result::Result<Vec<u8>, D::Error>
246where
247 D: serde::Deserializer<'de>,
248{
249 let encoded = String::deserialize(deserializer)?;
250 let decoded = BASE64_URL_SAFE_NO_PAD
251 .decode(encoded)
252 .map_err(serde::de::Error::custom)?;
253
254 Ok(decoded)
255}
256
257fn serialize_base64_vec<S>(
258 sub: &Option<Vec<u8>>,
259 serializer: S,
260) -> core::result::Result<S::Ok, S::Error>
261where
262 S: serde::Serializer,
263{
264 match sub {
265 Some(value) => {
266 let encoded = String::from_utf8(value.clone()).map_err(serde::ser::Error::custom)?;
267 serializer.serialize_str(&encoded)
268 }
269 None => serializer.serialize_none(),
270 }
271}
272
273fn deserialize_base64_vec<'de, D>(
274 deserializer: D,
275) -> core::result::Result<Option<Vec<u8>>, D::Error>
276where
277 D: serde::Deserializer<'de>,
278{
279 let string = String::deserialize(deserializer)?;
280 let bytes = string.into_bytes();
281
282 Ok(Some(bytes))
283}
284
285#[derive(Clone, Serialize, Deserialize, Debug)]
286pub struct Response {
287 #[serde(
288 serialize_with = "serialize_base64_protected_header",
289 deserialize_with = "deserialize_base64_protected_header"
290 )]
291 pub protected: ProtectedHeader,
292
293 #[serde(
294 serialize_with = "serialize_base64",
295 deserialize_with = "deserialize_base64"
296 )]
297 pub encrypted_key: Vec<u8>,
298
299 #[serde(
300 skip_serializing_if = "Option::is_none",
301 default = "Option::default",
302 serialize_with = "serialize_base64_vec",
303 deserialize_with = "deserialize_base64_vec"
304 )]
305 pub aad: Option<Vec<u8>>,
306
307 #[serde(
308 serialize_with = "serialize_base64",
309 deserialize_with = "deserialize_base64"
310 )]
311 pub iv: Vec<u8>,
312
313 #[serde(
314 serialize_with = "serialize_base64",
315 deserialize_with = "deserialize_base64"
316 )]
317 pub ciphertext: Vec<u8>,
318
319 #[serde(
320 serialize_with = "serialize_base64",
321 deserialize_with = "deserialize_base64"
322 )]
323 pub tag: Vec<u8>,
324}
325
326#[derive(Clone, Serialize, Deserialize, Debug)]
327pub struct ErrorInformation {
328 #[serde(rename = "type")]
329 pub error_type: String,
330 pub detail: String,
331}
332
333#[cfg(test)]
334mod tests {
335 use serde_json::json;
336
337 use crate::*;
338
339 #[cfg(all(feature = "alloc", not(feature = "std")))]
340 use alloc::string::ToString;
341
342 #[test]
343 fn parse_request() {
344 let data = r#"
345 {
346 "version": "0.0.0",
347 "tee": "sev",
348 "extra-params": ""
349 }"#;
350
351 let request: Request = serde_json::from_str(data).unwrap();
352
353 assert_eq!(request.version, "0.0.0");
354 assert_eq!(request.tee, Tee::Sev);
355 assert_eq!(request.extra_params, "");
356 }
357
358 #[test]
359 fn parse_challenge() {
360 let data = r#"
361 {
362 "nonce": "42",
363 "extra-params": ""
364 }"#;
365
366 let challenge: Challenge = serde_json::from_str(data).unwrap();
367
368 assert_eq!(challenge.nonce, "42");
369 assert_eq!(challenge.extra_params, "");
370 }
371
372 #[test]
373 fn protected_header_generate_aad() {
374 let protected_header = ProtectedHeader {
375 alg: "fakealg".to_string(),
376 enc: "fakeenc".to_string(),
377 other_fields: Map::new(),
378 };
379
380 let aad = protected_header.generate_aad().unwrap();
381
382 assert_eq!(
383 aad,
384 "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9".as_bytes()
385 );
386 }
387
388 #[test]
389 fn parse_response() {
390 let data = r#"
391 {
392 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9",
393 "encrypted_key": "ZmFrZWtleQ",
394 "iv": "cmFuZG9tZGF0YQ",
395 "ciphertext": "ZmFrZWVuY291dHB1dA",
396 "tag": "ZmFrZXRhZw"
397 }"#;
398
399 let response: Response = serde_json::from_str(data).unwrap();
400
401 assert_eq!(response.protected.alg, "fakealg");
402 assert_eq!(response.protected.enc, "fakeenc");
403 assert!(response.protected.other_fields.is_empty());
404 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
405 assert_eq!(response.iv, "randomdata".as_bytes());
406 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
407 assert_eq!(response.tag, "faketag".as_bytes());
408 assert_eq!(response.aad, None);
409 }
410
411 #[test]
412 fn parse_response_nested_protected_header() {
413 let data = r#"
414 {
415 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImVwayI6eyJrdHkiOiJPS1AiLCJjcnYiOiJYMjU1MTkiLCJ4IjoiaFNEd0NZa3dwMVIwaTMzY3RENzNXZzJfT2cwbU9CcjA2NlNwanFxYlRtbyJ9fQo",
416 "encrypted_key": "ZmFrZWtleQ",
417 "iv": "cmFuZG9tZGF0YQ",
418 "ciphertext": "ZmFrZWVuY291dHB1dA",
419 "tag": "ZmFrZXRhZw"
420 }"#;
421
422 let response: Response = serde_json::from_str(data).unwrap();
423
424 assert_eq!(response.protected.alg, "fakealg");
425 assert_eq!(response.protected.enc, "fakeenc");
426
427 let expected_other_fields = json!({
428 "epk": {
429 "kty" : "OKP",
430 "crv": "X25519",
431 "x": "hSDwCYkwp1R0i33ctD73Wg2_Og0mOBr066SpjqqbTmo"
432 }
433 })
434 .as_object()
435 .unwrap()
436 .clone();
437
438 assert_eq!(response.protected.other_fields, expected_other_fields);
439 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
440 assert_eq!(response.iv, "randomdata".as_bytes());
441 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
442 assert_eq!(response.tag, "faketag".as_bytes());
443 assert_eq!(response.aad, None);
444 }
445
446 #[test]
447 fn parse_response_with_aad() {
448 let data = r#"
449 {
450 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyJ9Cg",
451 "encrypted_key": "ZmFrZWtleQ",
452 "iv": "cmFuZG9tZGF0YQ",
453 "aad": "fakeaad",
454 "ciphertext": "ZmFrZWVuY291dHB1dA",
455 "tag": "ZmFrZXRhZw"
456 }"#;
457
458 let response: Response = serde_json::from_str(data).unwrap();
459
460 assert_eq!(response.protected.alg, "fakealg");
461 assert_eq!(response.protected.enc, "fakeenc");
462 assert!(response.protected.other_fields.is_empty());
463 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
464 assert_eq!(response.iv, "randomdata".as_bytes());
465 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
466 assert_eq!(response.tag, "faketag".as_bytes());
467 assert_eq!(response.aad, Some("fakeaad".into()));
468 }
469
470 #[test]
471 fn parse_response_with_protectedheader() {
472 let data = r#"
473 {
474 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9",
475 "encrypted_key": "ZmFrZWtleQ",
476 "iv": "cmFuZG9tZGF0YQ",
477 "aad": "fakeaad",
478 "ciphertext": "ZmFrZWVuY291dHB1dA",
479 "tag": "ZmFrZXRhZw"
480 }"#;
481
482 let response: Response = serde_json::from_str(data).unwrap();
483
484 assert_eq!(response.protected.alg, "fakealg");
485 assert_eq!(response.protected.enc, "fakeenc");
486 assert_eq!(response.protected.other_fields["fakefield"], "fakevalue");
487 assert_eq!(response.encrypted_key, "fakekey".as_bytes());
488 assert_eq!(response.iv, "randomdata".as_bytes());
489 assert_eq!(response.ciphertext, "fakeencoutput".as_bytes());
490 assert_eq!(response.tag, "faketag".as_bytes());
491 assert_eq!(response.aad, Some("fakeaad".into()));
492 }
493
494 #[test]
495 fn serialize_response() {
496 let response = Response {
497 protected: ProtectedHeader {
498 alg: "fakealg".into(),
499 enc: "fakeenc".into(),
500 other_fields: [("fakefield".into(), "fakevalue".into())]
501 .into_iter()
502 .collect(),
503 },
504 encrypted_key: "fakekey".as_bytes().to_vec(),
505 iv: "randomdata".as_bytes().to_vec(),
506 aad: Some("fakeaad".into()),
507 tag: "faketag".as_bytes().to_vec(),
508 ciphertext: "fakeencoutput".as_bytes().to_vec(),
509 };
510
511 let expected = json!({
512 "protected": "eyJhbGciOiJmYWtlYWxnIiwiZW5jIjoiZmFrZWVuYyIsImZha2VmaWVsZCI6ImZha2V2YWx1ZSJ9",
513 "encrypted_key": "ZmFrZWtleQ",
514 "iv": "cmFuZG9tZGF0YQ",
515 "aad": "fakeaad",
516 "ciphertext": "ZmFrZWVuY291dHB1dA",
517 "tag": "ZmFrZXRhZw"
518 });
519
520 let serialized = serde_json::to_value(&response).unwrap();
521 assert_eq!(serialized, expected);
522 }
523
524 #[test]
525 fn parse_attestation_ec() {
526 let data = r#"
527 {
528 "runtime-data": {
529 "nonce": "test_nonce",
530 "tee-pubkey": {
531 "kty": "EC",
532 "crv": "fakecrv",
533 "alg": "fakealgorithm",
534 "x": "fakex",
535 "y": "fakey"
536 }
537 },
538 "tee-evidence": {
539 "primary_evidence": "test_primary_evidence",
540 "additional_evidence": "test_additional_evidence"
541 }
542 }"#;
543
544 let attestation: Attestation = serde_json::from_str(data).unwrap();
545 let tee_pubkey = attestation.runtime_data.tee_pubkey;
546
547 let TeePubKey::EC { alg, crv, x, y } = tee_pubkey else {
548 panic!("Must be an EC key");
549 };
550
551 assert_eq!(alg, "fakealgorithm");
552 assert_eq!(crv, "fakecrv");
553 assert_eq!(x, "fakex");
554 assert_eq!(y, "fakey");
555 assert_eq!(
556 attestation.tee_evidence.primary_evidence,
557 "test_primary_evidence"
558 );
559 assert_eq!(
560 attestation.tee_evidence.additional_evidence,
561 "test_additional_evidence"
562 );
563 }
564
565 #[test]
566 fn parse_attestation_rsa() {
567 let data = r#"
568 {
569 "runtime-data": {
570 "nonce": "test_nonce",
571 "tee-pubkey": {
572 "kty": "RSA",
573 "alg": "fakealgorithm",
574 "n": "fakemodulus",
575 "e": "fakeexponent"
576 }
577 },
578 "tee-evidence": {
579 "primary_evidence": "test_primary_evidence",
580 "additional_evidence": "test_additional_evidence"
581 }
582 }"#;
583
584 let attestation: Attestation = serde_json::from_str(data).unwrap();
585 let tee_pubkey = attestation.runtime_data.tee_pubkey;
586
587 let TeePubKey::RSA { alg, k_mod, k_exp } = tee_pubkey else {
588 panic!("Must be a RSA key");
589 };
590
591 assert_eq!(attestation.runtime_data.nonce, "test_nonce");
592 assert_eq!(alg, "fakealgorithm");
593 assert_eq!(k_mod, "fakemodulus");
594 assert_eq!(k_exp, "fakeexponent");
595 assert_eq!(
596 attestation.tee_evidence.primary_evidence,
597 "test_primary_evidence"
598 );
599 assert_eq!(
600 attestation.tee_evidence.additional_evidence,
601 "test_additional_evidence"
602 );
603 }
604
605 #[test]
606 fn parse_error_information() {
607 let data = r#"
608 {
609 "type": "problemtype",
610 "detail": "problemdetail"
611 }"#;
612
613 let info: ErrorInformation = serde_json::from_str(data).unwrap();
614
615 assert_eq!(info.error_type, "problemtype");
616 assert_eq!(info.detail, "problemdetail");
617 }
618
619 #[test]
620 #[cfg(feature = "std")]
621 fn tee_pubkey_ear_json_deserialize() {
622 let tpk = TeePubKey::RSA {
624 alg: "test".to_string(),
625 k_mod: "test".to_string(),
626 k_exp: "test".to_string(),
627 };
628 let ear_raw: RawValue = (&tpk).into();
629 let json_str = serde_json::to_string(&ear_raw).unwrap();
630 assert_eq!(json_str, serde_json::to_string(&tpk).unwrap());
631
632 let tpk = TeePubKey::EC {
634 crv: "test".to_string(),
635 alg: "test".to_string(),
636 x: "test".to_string(),
637 y: "test".to_string(),
638 };
639 let ear_raw: RawValue = (&tpk).into();
640 let json_str = serde_json::to_string(&ear_raw).unwrap();
641 assert_eq!(json_str, serde_json::to_string(&tpk).unwrap());
642 }
643}