1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::{parse_macro_input, DeriveInput, Ident, ItemFn, Token};
5
6#[proc_macro_attribute]
18pub fn attest(_attr: TokenStream, item: TokenStream) -> TokenStream {
19 let input_fn = parse_macro_input!(item as ItemFn);
20 let fn_name = &input_fn.sig.ident;
21 let fn_vis = &input_fn.vis;
22 let fn_inputs = &input_fn.sig.inputs;
23 let fn_body = &input_fn.block;
24 let fn_attrs = &input_fn.attrs;
25
26 let expanded = quote! {
27 #(#fn_attrs)*
28 #fn_vis async fn #fn_name(
29 ::axum::extract::Extension(tee_state): ::axum::extract::Extension<::std::sync::Arc<TeeState>>,
30 #fn_inputs
31 ) -> impl ::axum::response::IntoResponse {
32 use ::axum::response::IntoResponse;
33 use ::axum::http::header::HeaderValue;
34
35 let request_id = ::uuid::Uuid::new_v4().to_string();
37
38 let inner_response = {
40 #fn_body
41 };
42
43 let response = inner_response.into_response();
45 let (mut parts, body) = response.into_parts();
46
47 let body_bytes = match ::axum::body::to_bytes(body, usize::MAX).await {
49 Ok(bytes) => bytes,
50 Err(_) => {
51 let error_response = ::axum::response::Response::builder()
52 .status(::axum::http::StatusCode::INTERNAL_SERVER_ERROR)
53 .header("content-type", "application/json")
54 .body(::axum::body::Body::from(
55 r#"{"error":{"code":"body_read_failed","message":"Failed to read response body for attestation"}}"#
56 ))
57 .expect("failed to build error response");
58 return error_response.into_response();
59 }
60 };
61
62 let header = tee_state.sign_response(&body_bytes, &request_id);
64
65 if let Ok(val) = HeaderValue::from_str(&header.to_header_value()) {
67 parts.headers.insert("X-TEE-Attestation", val);
68 }
69 if let Ok(val) = HeaderValue::from_str("true") {
70 parts.headers.insert("X-TEE-Verified", val);
71 }
72 if let Ok(val) = HeaderValue::from_str(&request_id) {
73 parts.headers.insert("X-TEE-Request-Id", val);
74 }
75
76 ::axum::response::Response::from_parts(parts, ::axum::body::Body::from(body_bytes))
77 }
78 };
79
80 TokenStream::from(expanded)
81}
82
83fn to_snake_case(name: &str) -> String {
87 let mut result = String::new();
88 for (i, ch) in name.chars().enumerate() {
89 if ch.is_uppercase() {
90 if i > 0 {
91 result.push('_');
92 }
93 for lower in ch.to_lowercase() {
94 result.push(lower);
95 }
96 } else {
97 result.push(ch);
98 }
99 }
100 result
101}
102
103enum SealSection {
104 MrEnclave,
105 MrSigner,
106 External,
107}
108
109struct StateInput {
110 mrenclave_types: Vec<Ident>,
111 mrsigner_types: Vec<Ident>,
112 external_types: Vec<Ident>,
113}
114
115impl Parse for StateInput {
116 fn parse(input: ParseStream) -> syn::Result<Self> {
117 let mut mrenclave_types = Vec::new();
118 let mut mrsigner_types = Vec::new();
119 let mut external_types = Vec::new();
120 let mut current_section: Option<SealSection> = None;
121
122 while !input.is_empty() {
123 if input.peek(Token![#]) {
124 input.parse::<Token![#]>()?;
125 let content;
126 syn::bracketed!(content in input);
127 let attr_name: Ident = content.parse()?;
128 if attr_name == "mrenclave" {
129 current_section = Some(SealSection::MrEnclave);
130 } else if attr_name == "mrsigner" {
131 current_section = Some(SealSection::MrSigner);
132 } else if attr_name == "external" {
133 current_section = Some(SealSection::External);
134 } else {
135 return Err(syn::Error::new(
136 attr_name.span(),
137 "expected `mrenclave`, `mrsigner`, or `external`",
138 ));
139 }
140 } else {
141 let type_name: Ident = input.parse()?;
142 if input.peek(Token![,]) {
143 input.parse::<Token![,]>()?;
144 }
145 match ¤t_section {
146 Some(SealSection::MrEnclave) => mrenclave_types.push(type_name),
147 Some(SealSection::MrSigner) => mrsigner_types.push(type_name),
148 Some(SealSection::External) => external_types.push(type_name),
149 None => {
150 return Err(syn::Error::new(
151 type_name.span(),
152 "type must be under #[mrenclave], #[mrsigner], or #[external]",
153 ))
154 }
155 }
156 }
157 }
158
159 Ok(StateInput {
160 mrenclave_types,
161 mrsigner_types,
162 external_types,
163 })
164 }
165}
166
167#[proc_macro]
196pub fn state(input: TokenStream) -> TokenStream {
197 let parsed = parse_macro_input!(input as StateInput);
198
199 let has_enclave = !parsed.mrenclave_types.is_empty();
200 let has_signer = !parsed.mrsigner_types.is_empty();
201
202 let enclave_fields: Vec<Ident> = parsed
204 .mrenclave_types
205 .iter()
206 .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
207 .collect();
208 let enclave_types = &parsed.mrenclave_types;
209
210 let signer_fields: Vec<Ident> = parsed
211 .mrsigner_types
212 .iter()
213 .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
214 .collect();
215 let signer_types = &parsed.mrsigner_types;
216
217 let enclave_state = if has_enclave {
219 quote! {
220 #[derive(::serde::Serialize, ::serde::Deserialize)]
223 pub struct EnclaveState {
224 #[serde(with = "::guarantee::seal::signing_key_serde")]
225 signing_key: ::ed25519_dalek::SigningKey,
226 #(pub #enclave_fields: #enclave_types,)*
227 }
228
229 impl EnclaveState {
230 #(
231 pub fn #enclave_fields(&self) -> &#enclave_types {
233 &self.#enclave_fields
234 }
235 )*
236 }
237 }
238 } else {
239 quote! {}
240 };
241
242 let signer_state = if has_signer {
244 quote! {
245 #[derive(::serde::Serialize, ::serde::Deserialize)]
248 pub struct SignerState {
249 master_key: [u8; 32],
250 #(pub #signer_fields: #signer_types,)*
251 }
252
253 impl SignerState {
254 #(
255 pub fn #signer_fields(&self) -> &#signer_types {
257 &self.#signer_fields
258 }
259 )*
260 }
261 }
262 } else {
263 quote! {}
264 };
265
266 let enclave_field_def = if has_enclave {
268 quote! { enclave: EnclaveState, }
269 } else {
270 quote! {}
271 };
272 let signer_field_def = if has_signer {
273 quote! { signer: SignerState, }
274 } else {
275 quote! {}
276 };
277
278 let enclave_init = if has_enclave {
280 quote! {
281 let enclave: EnclaveState = match ::guarantee::seal::unseal_from_file(
282 &enclave_path,
283 ::guarantee::seal::SealMode::MrEnclave,
284 ) {
285 Ok(data) => {
286 ::tracing::info!("Unsealed MRENCLAVE state");
287 ::serde_json::from_slice(&data).map_err(|e| {
288 ::guarantee::SdkError::SealError(format!("Deserialize enclave state: {e}"))
289 })?
290 }
291 Err(_) => {
292 ::tracing::info!("No existing MRENCLAVE state -- generating fresh signing key");
293 let signing_key =
294 ::ed25519_dalek::SigningKey::generate(&mut ::rand::rngs::OsRng);
295 let state = EnclaveState {
296 signing_key,
297 #(#enclave_fields: Default::default(),)*
298 };
299 let data = ::serde_json::to_vec(&state).map_err(|e| {
300 ::guarantee::SdkError::SealError(format!("Serialize enclave state: {e}"))
301 })?;
302 ::guarantee::seal::seal_to_file(
303 &data,
304 &enclave_path,
305 ::guarantee::seal::SealMode::MrEnclave,
306 )?;
307 state
308 }
309 };
310 }
311 } else {
312 quote! {}
313 };
314
315 let signer_init = if has_signer {
317 quote! {
318 let signer: SignerState = match ::guarantee::seal::unseal_from_file(
319 &signer_path,
320 ::guarantee::seal::SealMode::MrSigner,
321 ) {
322 Ok(data) => {
323 ::tracing::info!("Unsealed MRSIGNER state");
324 ::serde_json::from_slice(&data).map_err(|e| {
325 ::guarantee::SdkError::SealError(format!("Deserialize signer state: {e}"))
326 })?
327 }
328 Err(_) => {
329 ::tracing::info!("No existing MRSIGNER state -- generating fresh master key");
330 let mut master_key = [0u8; 32];
331 ::rand::RngCore::fill_bytes(&mut ::rand::rngs::OsRng, &mut master_key);
332 let state = SignerState {
333 master_key,
334 #(#signer_fields: Default::default(),)*
335 };
336 let data = ::serde_json::to_vec(&state).map_err(|e| {
337 ::guarantee::SdkError::SealError(format!("Serialize signer state: {e}"))
338 })?;
339 ::guarantee::seal::seal_to_file(
340 &data,
341 &signer_path,
342 ::guarantee::seal::SealMode::MrSigner,
343 )?;
344 state
345 }
346 };
347 }
348 } else {
349 quote! {}
350 };
351
352 let tee_state_construct = match (has_enclave, has_signer) {
354 (true, true) => quote! { TeeState { enclave, signer } },
355 (true, false) => quote! { TeeState { enclave } },
356 (false, true) => quote! { TeeState { signer } },
357 (false, false) => quote! { TeeState {} },
358 };
359
360 let enclave_accessor = if has_enclave {
362 quote! {
363 pub fn enclave(&self) -> &EnclaveState {
365 &self.enclave
366 }
367 pub fn enclave_mut(&mut self) -> &mut EnclaveState {
369 &mut self.enclave
370 }
371 }
372 } else {
373 quote! {}
374 };
375
376 let attestation_methods = if has_enclave {
378 quote! {
379 pub fn sign_response(&self, body: &[u8], request_id: &str) -> ::guarantee::AttestationHeader {
382 ::guarantee::seal::sign_with_enclave_key(&self.enclave.signing_key, body, request_id)
383 }
384
385 pub fn public_key(&self) -> ::ed25519_dalek::VerifyingKey {
387 self.enclave.signing_key.verifying_key()
388 }
389
390 pub fn attestation_json(&self) -> ::serde_json::Value {
392 let pub_key = self.enclave.signing_key.verifying_key();
393 ::serde_json::json!({
394 "public_key": ::guarantee::response::hex_encode(pub_key.as_bytes()),
395 "tee_type": if ::std::env::var("GUARANTEE_ENCLAVE").map(|v| v == "1").unwrap_or(false) {
396 "intel-sgx"
397 } else {
398 "dev-mode"
399 },
400 })
401 }
402 }
403 } else {
404 quote! {}
405 };
406
407 let signer_accessor = if has_signer {
408 quote! {
409 pub fn signer(&self) -> &SignerState {
411 &self.signer
412 }
413 pub fn signer_mut(&mut self) -> &mut SignerState {
415 &mut self.signer
416 }
417 }
418 } else {
419 quote! {}
420 };
421
422 let seal_enclave = if has_enclave {
424 quote! {
425 let enclave_data = ::serde_json::to_vec(&self.enclave).map_err(|e| {
426 ::guarantee::SdkError::SealError(format!("Serialize enclave state: {e}"))
427 })?;
428 ::guarantee::seal::seal_to_file(
429 &enclave_data,
430 &enclave_path,
431 ::guarantee::seal::SealMode::MrEnclave,
432 )?;
433 }
434 } else {
435 quote! {}
436 };
437
438 let seal_signer = if has_signer {
439 quote! {
440 let signer_data = ::serde_json::to_vec(&self.signer).map_err(|e| {
441 ::guarantee::SdkError::SealError(format!("Serialize signer state: {e}"))
442 })?;
443 ::guarantee::seal::seal_to_file(
444 &signer_data,
445 &signer_path,
446 ::guarantee::seal::SealMode::MrSigner,
447 )?;
448 }
449 } else {
450 quote! {}
451 };
452
453 let external_snake_names: Vec<Ident> = parsed
456 .external_types
457 .iter()
458 .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
459 .collect();
460 let external_types_ref = &parsed.external_types;
461 let external_encrypted_names: Vec<Ident> = parsed
462 .external_types
463 .iter()
464 .map(|t| format_ident!("Encrypted{}", t))
465 .collect();
466 let external_purpose_strings: Vec<String> = parsed
467 .external_types
468 .iter()
469 .map(|t| format!("external:{}", to_snake_case(&t.to_string())))
470 .collect();
471
472 let encrypt_method_names: Vec<Ident> = external_snake_names
473 .iter()
474 .map(|s| format_ident!("encrypt_{}", s))
475 .collect();
476 let decrypt_method_names: Vec<Ident> = external_snake_names
477 .iter()
478 .map(|s| format_ident!("decrypt_{}", s))
479 .collect();
480
481 let encryption_methods = if has_signer && !parsed.external_types.is_empty() {
482 quote! {
483 #(
484 pub fn #encrypt_method_names(&self, value: &#external_types_ref) -> Result<#external_encrypted_names, ::guarantee::SdkError> {
487 let key = ::guarantee::crypto::derive_key(&self.signer.master_key, #external_purpose_strings.as_bytes());
488 value.encrypt(&key)
489 }
490
491 pub fn #decrypt_method_names(&self, encrypted: &#external_encrypted_names) -> Result<#external_types_ref, ::guarantee::SdkError> {
493 let key = ::guarantee::crypto::derive_key(&self.signer.master_key, #external_purpose_strings.as_bytes());
494 #external_types_ref::decrypt_from(encrypted, &key)
495 }
496 )*
497 }
498 } else {
499 quote! {}
500 };
501
502 let output = quote! {
503 #enclave_state
504 #signer_state
505
506 pub struct TeeState {
511 #enclave_field_def
512 #signer_field_def
513 }
514
515 impl TeeState {
516 pub fn initialize(
519 seal_dir: &::std::path::Path,
520 ) -> Result<Self, ::guarantee::SdkError> {
521 let enclave_path = seal_dir.join("enclave.sealed");
522 let signer_path = seal_dir.join("signer.sealed");
523
524 #enclave_init
525 #signer_init
526
527 Ok(#tee_state_construct)
528 }
529
530 pub fn seal(
532 &self,
533 seal_dir: &::std::path::Path,
534 ) -> Result<(), ::guarantee::SdkError> {
535 let enclave_path = seal_dir.join("enclave.sealed");
536 let signer_path = seal_dir.join("signer.sealed");
537
538 #seal_enclave
539 #seal_signer
540
541 Ok(())
542 }
543
544 #enclave_accessor
545 #signer_accessor
546 #attestation_methods
547 #encryption_methods
548 }
549 };
550
551 TokenStream::from(output)
552}
553
554#[proc_macro_derive(Encrypted, attributes(encrypt))]
579pub fn derive_encrypted(input: TokenStream) -> TokenStream {
580 let input = parse_macro_input!(input as DeriveInput);
581 match impl_encrypted(&input) {
582 Ok(tokens) => tokens,
583 Err(err) => err.to_compile_error().into(),
584 }
585}
586
587fn impl_encrypted(input: &DeriveInput) -> syn::Result<TokenStream> {
588 let name = &input.ident;
589 let encrypted_name = format_ident!("Encrypted{}", name);
590
591 let fields = match &input.data {
592 syn::Data::Struct(data) => match &data.fields {
593 syn::Fields::Named(fields) => &fields.named,
594 _ => {
595 return Err(syn::Error::new_spanned(
596 input,
597 "Encrypted can only be derived for structs with named fields",
598 ))
599 }
600 },
601 _ => {
602 return Err(syn::Error::new_spanned(
603 input,
604 "Encrypted can only be derived for structs",
605 ))
606 }
607 };
608
609 let mut encrypted_field_defs = Vec::new();
610 let mut encrypt_exprs = Vec::new();
611 let mut decrypt_exprs = Vec::new();
612
613 for field in fields {
614 let field_name = field.ident.as_ref().ok_or_else(|| {
615 syn::Error::new_spanned(field, "expected named field")
616 })?;
617 let field_ty = &field.ty;
618 let has_encrypt = field.attrs.iter().any(|a| a.path().is_ident("encrypt"));
619
620 if has_encrypt {
621 encrypted_field_defs.push(quote! {
623 pub #field_name: String
624 });
625 encrypt_exprs.push(quote! {
626 #field_name: ::guarantee::crypto::encrypt_field(&self.#field_name, key)?
627 });
628 decrypt_exprs.push(quote! {
629 #field_name: ::guarantee::crypto::decrypt_field(&encrypted.#field_name, key)?
630 });
631 } else {
632 encrypted_field_defs.push(quote! {
634 pub #field_name: #field_ty
635 });
636 encrypt_exprs.push(quote! {
637 #field_name: self.#field_name.clone()
638 });
639 decrypt_exprs.push(quote! {
640 #field_name: encrypted.#field_name.clone()
641 });
642 }
643 }
644
645 let output = quote! {
646 #[derive(::serde::Serialize, ::serde::Deserialize, Debug, Clone)]
649 pub struct #encrypted_name {
650 #(#encrypted_field_defs,)*
651 }
652
653 impl ::guarantee::crypto::Encryptable for #name {
654 type Encrypted = #encrypted_name;
655
656 fn encrypt(&self, key: &[u8; 32]) -> Result<#encrypted_name, ::guarantee::SdkError> {
657 Ok(#encrypted_name {
658 #(#encrypt_exprs,)*
659 })
660 }
661
662 fn decrypt_from(encrypted: &#encrypted_name, key: &[u8; 32]) -> Result<Self, ::guarantee::SdkError> {
663 Ok(#name {
664 #(#decrypt_exprs,)*
665 })
666 }
667 }
668 };
669
670 Ok(output.into())
671}