Skip to main content

guarantee_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::{parse_macro_input, DeriveInput, Ident, ItemFn, Token};
5
6/// Transforms an axum handler to automatically sign responses with TEE attestation.
7///
8/// The macro:
9/// 1. Extracts `Arc<TeeState>` from axum Extension
10/// 2. Generates a request ID
11/// 3. Runs the original handler
12/// 4. Signs the response body via `TeeState::sign_response`
13/// 5. Attaches X-TEE-Attestation and X-TEE-Verified headers
14///
15/// The signing key is never exposed -- `TeeState::sign_response` is the only way
16/// to produce attestation signatures.
17#[proc_macro_attribute]
18pub fn attest(_attr: TokenStream, item: TokenStream) -> TokenStream {
19    let input_fn = parse_macro_input!(item as ItemFn);
20    let fn_name = &input_fn.sig.ident;
21    let fn_vis = &input_fn.vis;
22    let fn_inputs = &input_fn.sig.inputs;
23    let fn_body = &input_fn.block;
24    let fn_attrs = &input_fn.attrs;
25
26    let expanded = quote! {
27        #(#fn_attrs)*
28        #fn_vis async fn #fn_name(
29            ::axum::extract::Extension(tee_state): ::axum::extract::Extension<::std::sync::Arc<TeeState>>,
30            #fn_inputs
31        ) -> impl ::axum::response::IntoResponse {
32            use ::axum::response::IntoResponse;
33            use ::axum::http::header::HeaderValue;
34
35            // Generate request ID
36            let request_id = ::uuid::Uuid::new_v4().to_string();
37
38            // Execute original handler
39            let inner_response = {
40                #fn_body
41            };
42
43            // Convert to axum response
44            let response = inner_response.into_response();
45            let (mut parts, body) = response.into_parts();
46
47            // Read body bytes -- return 500 if body cannot be read
48            let body_bytes = match ::axum::body::to_bytes(body, usize::MAX).await {
49                Ok(bytes) => bytes,
50                Err(_) => {
51                    let error_response = ::axum::response::Response::builder()
52                        .status(::axum::http::StatusCode::INTERNAL_SERVER_ERROR)
53                        .header("content-type", "application/json")
54                        .body(::axum::body::Body::from(
55                            r#"{"error":{"code":"body_read_failed","message":"Failed to read response body for attestation"}}"#
56                        ))
57                        .expect("failed to build error response");
58                    return error_response.into_response();
59                }
60            };
61
62            // Sign using TeeState -- signing key is never exposed
63            let header = tee_state.sign_response(&body_bytes, &request_id);
64
65            // Insert attestation headers
66            if let Ok(val) = HeaderValue::from_str(&header.to_header_value()) {
67                parts.headers.insert("X-TEE-Attestation", val);
68            }
69            if let Ok(val) = HeaderValue::from_str("true") {
70                parts.headers.insert("X-TEE-Verified", val);
71            }
72            if let Ok(val) = HeaderValue::from_str(&request_id) {
73                parts.headers.insert("X-TEE-Request-Id", val);
74            }
75
76            ::axum::response::Response::from_parts(parts, ::axum::body::Body::from(body_bytes))
77        }
78    };
79
80    TokenStream::from(expanded)
81}
82
83// --- state! proc macro ---
84
85/// Convert a PascalCase identifier to snake_case.
86fn to_snake_case(name: &str) -> String {
87    let mut result = String::new();
88    for (i, ch) in name.chars().enumerate() {
89        if ch.is_uppercase() {
90            if i > 0 {
91                result.push('_');
92            }
93            for lower in ch.to_lowercase() {
94                result.push(lower);
95            }
96        } else {
97            result.push(ch);
98        }
99    }
100    result
101}
102
103enum SealSection {
104    MrEnclave,
105    MrSigner,
106    External,
107}
108
109struct StateInput {
110    mrenclave_types: Vec<Ident>,
111    mrsigner_types: Vec<Ident>,
112    external_types: Vec<Ident>,
113}
114
115impl Parse for StateInput {
116    fn parse(input: ParseStream) -> syn::Result<Self> {
117        let mut mrenclave_types = Vec::new();
118        let mut mrsigner_types = Vec::new();
119        let mut external_types = Vec::new();
120        let mut current_section: Option<SealSection> = None;
121
122        while !input.is_empty() {
123            if input.peek(Token![#]) {
124                input.parse::<Token![#]>()?;
125                let content;
126                syn::bracketed!(content in input);
127                let attr_name: Ident = content.parse()?;
128                if attr_name == "mrenclave" {
129                    current_section = Some(SealSection::MrEnclave);
130                } else if attr_name == "mrsigner" {
131                    current_section = Some(SealSection::MrSigner);
132                } else if attr_name == "external" {
133                    current_section = Some(SealSection::External);
134                } else {
135                    return Err(syn::Error::new(
136                        attr_name.span(),
137                        "expected `mrenclave`, `mrsigner`, or `external`",
138                    ));
139                }
140            } else {
141                let type_name: Ident = input.parse()?;
142                if input.peek(Token![,]) {
143                    input.parse::<Token![,]>()?;
144                }
145                match &current_section {
146                    Some(SealSection::MrEnclave) => mrenclave_types.push(type_name),
147                    Some(SealSection::MrSigner) => mrsigner_types.push(type_name),
148                    Some(SealSection::External) => external_types.push(type_name),
149                    None => {
150                        return Err(syn::Error::new(
151                            type_name.span(),
152                            "type must be under #[mrenclave], #[mrsigner], or #[external]",
153                        ))
154                    }
155                }
156            }
157        }
158
159        Ok(StateInput {
160            mrenclave_types,
161            mrsigner_types,
162            external_types,
163        })
164    }
165}
166
167/// Declare TEE state with automatic key management and sealing.
168///
169/// Types listed under `#[mrenclave]` are sealed with MRENCLAVE (reset on redeploy).
170/// An Ed25519 `signing_key` is auto-generated and included.
171///
172/// Types listed under `#[mrsigner]` are sealed with MRSIGNER (persist across redeploys).
173/// A 256-bit `master_key` is auto-generated and included.
174///
175/// # Example
176///
177/// ```rust,ignore
178/// use guarantee::state;
179/// use serde::{Serialize, Deserialize};
180///
181/// #[derive(Serialize, Deserialize, Default, Clone, Debug)]
182/// struct SessionState { user_id: String }
183///
184/// #[derive(Serialize, Deserialize, Default, Clone, Debug)]
185/// struct UserSecrets { api_key: String }
186///
187/// state! {
188///     #[mrenclave]
189///     SessionState,
190///
191///     #[mrsigner]
192///     UserSecrets,
193/// }
194/// ```
195#[proc_macro]
196pub fn state(input: TokenStream) -> TokenStream {
197    let parsed = parse_macro_input!(input as StateInput);
198
199    let has_enclave = !parsed.mrenclave_types.is_empty();
200    let has_signer = !parsed.mrsigner_types.is_empty();
201
202    // Generate field names (snake_case) from type names
203    let enclave_fields: Vec<Ident> = parsed
204        .mrenclave_types
205        .iter()
206        .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
207        .collect();
208    let enclave_types = &parsed.mrenclave_types;
209
210    let signer_fields: Vec<Ident> = parsed
211        .mrsigner_types
212        .iter()
213        .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
214        .collect();
215    let signer_types = &parsed.mrsigner_types;
216
217    // Generate EnclaveState struct + impl
218    let enclave_state = if has_enclave {
219        quote! {
220            /// MRENCLAVE-sealed state. Reset on redeploy (new binary = new measurement).
221            /// Contains an auto-generated Ed25519 signing key for per-response attestation.
222            #[derive(::serde::Serialize, ::serde::Deserialize)]
223            pub struct EnclaveState {
224                #[serde(with = "::guarantee::seal::signing_key_serde")]
225                signing_key: ::ed25519_dalek::SigningKey,
226                #(pub #enclave_fields: #enclave_types,)*
227            }
228
229            impl EnclaveState {
230                #(
231                    /// Read-only accessor for the `#enclave_fields` component.
232                    pub fn #enclave_fields(&self) -> &#enclave_types {
233                        &self.#enclave_fields
234                    }
235                )*
236            }
237        }
238    } else {
239        quote! {}
240    };
241
242    // Generate SignerState struct + impl
243    let signer_state = if has_signer {
244        quote! {
245            /// MRSIGNER-sealed state. Persists across redeploys (same signing key = same MRSIGNER).
246            /// Contains an auto-generated 256-bit master key for encrypting user data at rest.
247            #[derive(::serde::Serialize, ::serde::Deserialize)]
248            pub struct SignerState {
249                master_key: [u8; 32],
250                #(pub #signer_fields: #signer_types,)*
251            }
252
253            impl SignerState {
254                #(
255                    /// Read-only accessor for the `#signer_fields` component.
256                    pub fn #signer_fields(&self) -> &#signer_types {
257                        &self.#signer_fields
258                    }
259                )*
260            }
261        }
262    } else {
263        quote! {}
264    };
265
266    // TeeState struct fields
267    let enclave_field_def = if has_enclave {
268        quote! { enclave: EnclaveState, }
269    } else {
270        quote! {}
271    };
272    let signer_field_def = if has_signer {
273        quote! { signer: SignerState, }
274    } else {
275        quote! {}
276    };
277
278    // Initialization code for enclave state
279    let enclave_init = if has_enclave {
280        quote! {
281            let enclave: EnclaveState = match ::guarantee::seal::unseal_from_file(
282                &enclave_path,
283                ::guarantee::seal::SealMode::MrEnclave,
284            ) {
285                Ok(data) => {
286                    ::tracing::info!("Unsealed MRENCLAVE state");
287                    ::serde_json::from_slice(&data).map_err(|e| {
288                        ::guarantee::SdkError::SealError(format!("Deserialize enclave state: {e}"))
289                    })?
290                }
291                Err(_) => {
292                    ::tracing::info!("No existing MRENCLAVE state -- generating fresh signing key");
293                    let signing_key =
294                        ::ed25519_dalek::SigningKey::generate(&mut ::rand::rngs::OsRng);
295                    let state = EnclaveState {
296                        signing_key,
297                        #(#enclave_fields: Default::default(),)*
298                    };
299                    let data = ::serde_json::to_vec(&state).map_err(|e| {
300                        ::guarantee::SdkError::SealError(format!("Serialize enclave state: {e}"))
301                    })?;
302                    ::guarantee::seal::seal_to_file(
303                        &data,
304                        &enclave_path,
305                        ::guarantee::seal::SealMode::MrEnclave,
306                    )?;
307                    state
308                }
309            };
310        }
311    } else {
312        quote! {}
313    };
314
315    // Initialization code for signer state
316    let signer_init = if has_signer {
317        quote! {
318            let signer: SignerState = match ::guarantee::seal::unseal_from_file(
319                &signer_path,
320                ::guarantee::seal::SealMode::MrSigner,
321            ) {
322                Ok(data) => {
323                    ::tracing::info!("Unsealed MRSIGNER state");
324                    ::serde_json::from_slice(&data).map_err(|e| {
325                        ::guarantee::SdkError::SealError(format!("Deserialize signer state: {e}"))
326                    })?
327                }
328                Err(_) => {
329                    ::tracing::info!("No existing MRSIGNER state -- generating fresh master key");
330                    let mut master_key = [0u8; 32];
331                    ::rand::RngCore::fill_bytes(&mut ::rand::rngs::OsRng, &mut master_key);
332                    let state = SignerState {
333                        master_key,
334                        #(#signer_fields: Default::default(),)*
335                    };
336                    let data = ::serde_json::to_vec(&state).map_err(|e| {
337                        ::guarantee::SdkError::SealError(format!("Serialize signer state: {e}"))
338                    })?;
339                    ::guarantee::seal::seal_to_file(
340                        &data,
341                        &signer_path,
342                        ::guarantee::seal::SealMode::MrSigner,
343                    )?;
344                    state
345                }
346            };
347        }
348    } else {
349        quote! {}
350    };
351
352    // TeeState constructor expression
353    let tee_state_construct = match (has_enclave, has_signer) {
354        (true, true) => quote! { TeeState { enclave, signer } },
355        (true, false) => quote! { TeeState { enclave } },
356        (false, true) => quote! { TeeState { signer } },
357        (false, false) => quote! { TeeState {} },
358    };
359
360    // Accessors on TeeState
361    let enclave_accessor = if has_enclave {
362        quote! {
363            /// Access the MRENCLAVE-sealed state (read-only).
364            pub fn enclave(&self) -> &EnclaveState {
365                &self.enclave
366            }
367            /// Access the MRENCLAVE-sealed state (mutable).
368            pub fn enclave_mut(&mut self) -> &mut EnclaveState {
369                &mut self.enclave
370            }
371        }
372    } else {
373        quote! {}
374    };
375
376    // Attestation methods on TeeState (only when mrenclave section exists)
377    let attestation_methods = if has_enclave {
378        quote! {
379            /// Sign a response body for attestation. Used by the `#[attest]` macro.
380            /// The signing key is never exposed -- this is the only way to produce signatures.
381            pub fn sign_response(&self, body: &[u8], request_id: &str) -> ::guarantee::AttestationHeader {
382                ::guarantee::seal::sign_with_enclave_key(&self.enclave.signing_key, body, request_id)
383            }
384
385            /// Get the public verifying key for the attestation endpoint.
386            pub fn public_key(&self) -> ::ed25519_dalek::VerifyingKey {
387                self.enclave.signing_key.verifying_key()
388            }
389
390            /// Get startup attestation JSON for `/.well-known/tee-attestation`.
391            pub fn attestation_json(&self) -> ::serde_json::Value {
392                let pub_key = self.enclave.signing_key.verifying_key();
393                ::serde_json::json!({
394                    "public_key": ::guarantee::response::hex_encode(pub_key.as_bytes()),
395                    "tee_type": if ::std::env::var("GUARANTEE_ENCLAVE").map(|v| v == "1").unwrap_or(false) {
396                        "intel-sgx"
397                    } else {
398                        "dev-mode"
399                    },
400                })
401            }
402        }
403    } else {
404        quote! {}
405    };
406
407    let signer_accessor = if has_signer {
408        quote! {
409            /// Access the MRSIGNER-sealed state (read-only).
410            pub fn signer(&self) -> &SignerState {
411                &self.signer
412            }
413            /// Access the MRSIGNER-sealed state (mutable).
414            pub fn signer_mut(&mut self) -> &mut SignerState {
415                &mut self.signer
416            }
417        }
418    } else {
419        quote! {}
420    };
421
422    // Seal logic
423    let seal_enclave = if has_enclave {
424        quote! {
425            let enclave_data = ::serde_json::to_vec(&self.enclave).map_err(|e| {
426                ::guarantee::SdkError::SealError(format!("Serialize enclave state: {e}"))
427            })?;
428            ::guarantee::seal::seal_to_file(
429                &enclave_data,
430                &enclave_path,
431                ::guarantee::seal::SealMode::MrEnclave,
432            )?;
433        }
434    } else {
435        quote! {}
436    };
437
438    let seal_signer = if has_signer {
439        quote! {
440            let signer_data = ::serde_json::to_vec(&self.signer).map_err(|e| {
441                ::guarantee::SdkError::SealError(format!("Serialize signer state: {e}"))
442            })?;
443            ::guarantee::seal::seal_to_file(
444                &signer_data,
445                &signer_path,
446                ::guarantee::seal::SealMode::MrSigner,
447            )?;
448        }
449    } else {
450        quote! {}
451    };
452
453    // Per-type encryption methods on TeeState for each #[external] type.
454    // Each type gets its own derived key from master_key + "external:<snake_case_type>".
455    let external_snake_names: Vec<Ident> = parsed
456        .external_types
457        .iter()
458        .map(|t| format_ident!("{}", to_snake_case(&t.to_string())))
459        .collect();
460    let external_types_ref = &parsed.external_types;
461    let external_encrypted_names: Vec<Ident> = parsed
462        .external_types
463        .iter()
464        .map(|t| format_ident!("Encrypted{}", t))
465        .collect();
466    let external_purpose_strings: Vec<String> = parsed
467        .external_types
468        .iter()
469        .map(|t| format!("external:{}", to_snake_case(&t.to_string())))
470        .collect();
471
472    let encrypt_method_names: Vec<Ident> = external_snake_names
473        .iter()
474        .map(|s| format_ident!("encrypt_{}", s))
475        .collect();
476    let decrypt_method_names: Vec<Ident> = external_snake_names
477        .iter()
478        .map(|s| format_ident!("decrypt_{}", s))
479        .collect();
480
481    let encryption_methods = if has_signer && !parsed.external_types.is_empty() {
482        quote! {
483            #(
484                /// Encrypt a value using a per-type derived key from the MRSIGNER-bound master key.
485                /// The key is derived at runtime via HKDF-SHA256 with purpose `"external:<type>"`.
486                pub fn #encrypt_method_names(&self, value: &#external_types_ref) -> Result<#external_encrypted_names, ::guarantee::SdkError> {
487                    let key = ::guarantee::crypto::derive_key(&self.signer.master_key, #external_purpose_strings.as_bytes());
488                    value.encrypt(&key)
489                }
490
491                /// Decrypt a value using a per-type derived key from the MRSIGNER-bound master key.
492                pub fn #decrypt_method_names(&self, encrypted: &#external_encrypted_names) -> Result<#external_types_ref, ::guarantee::SdkError> {
493                    let key = ::guarantee::crypto::derive_key(&self.signer.master_key, #external_purpose_strings.as_bytes());
494                    #external_types_ref::decrypt_from(encrypted, &key)
495                }
496            )*
497        }
498    } else {
499        quote! {}
500    };
501
502    let output = quote! {
503        #enclave_state
504        #signer_state
505
506        /// Unified TEE state container. Holds both MRENCLAVE-sealed and MRSIGNER-sealed state.
507        ///
508        /// - `enclave()` -- state that resets on redeploy (bound to binary measurement)
509        /// - `signer()` -- state that persists across redeploys (bound to signing key)
510        pub struct TeeState {
511            #enclave_field_def
512            #signer_field_def
513        }
514
515        impl TeeState {
516            /// Initialize TEE state. Attempts to unseal existing state from `seal_dir`.
517            /// If no sealed state exists (first boot or MRENCLAVE changed), generates fresh keys.
518            pub fn initialize(
519                seal_dir: &::std::path::Path,
520            ) -> Result<Self, ::guarantee::SdkError> {
521                let enclave_path = seal_dir.join("enclave.sealed");
522                let signer_path = seal_dir.join("signer.sealed");
523
524                #enclave_init
525                #signer_init
526
527                Ok(#tee_state_construct)
528            }
529
530            /// Seal all state to disk. Call after mutating state to persist changes.
531            pub fn seal(
532                &self,
533                seal_dir: &::std::path::Path,
534            ) -> Result<(), ::guarantee::SdkError> {
535                let enclave_path = seal_dir.join("enclave.sealed");
536                let signer_path = seal_dir.join("signer.sealed");
537
538                #seal_enclave
539                #seal_signer
540
541                Ok(())
542            }
543
544            #enclave_accessor
545            #signer_accessor
546            #attestation_methods
547            #encryption_methods
548        }
549    };
550
551    TokenStream::from(output)
552}
553
554// --- Encrypted derive macro ---
555
556/// Derive macro that generates an encrypted version of a struct and `Encryptable` trait impl.
557///
558/// Fields annotated with `#[encrypt]` will be encrypted using AES-256-GCM when
559/// `encrypt()` is called. Non-annotated fields are copied as-is.
560///
561/// `#[encrypt]` only works on `String` fields.
562///
563/// # Example
564///
565/// ```rust,ignore
566/// #[derive(Encrypted, Serialize, Deserialize, Clone, Debug, PartialEq)]
567/// struct UserRecord {
568///     user_id: String,
569///     #[encrypt]
570///     ssn: String,
571///     #[encrypt]
572///     bank_account: String,
573///     email: String,
574/// }
575/// ```
576///
577/// This generates `EncryptedUserRecord` and an `Encryptable` impl on `UserRecord`.
578#[proc_macro_derive(Encrypted, attributes(encrypt))]
579pub fn derive_encrypted(input: TokenStream) -> TokenStream {
580    let input = parse_macro_input!(input as DeriveInput);
581    match impl_encrypted(&input) {
582        Ok(tokens) => tokens,
583        Err(err) => err.to_compile_error().into(),
584    }
585}
586
587fn impl_encrypted(input: &DeriveInput) -> syn::Result<TokenStream> {
588    let name = &input.ident;
589    let encrypted_name = format_ident!("Encrypted{}", name);
590
591    let fields = match &input.data {
592        syn::Data::Struct(data) => match &data.fields {
593            syn::Fields::Named(fields) => &fields.named,
594            _ => {
595                return Err(syn::Error::new_spanned(
596                    input,
597                    "Encrypted can only be derived for structs with named fields",
598                ))
599            }
600        },
601        _ => {
602            return Err(syn::Error::new_spanned(
603                input,
604                "Encrypted can only be derived for structs",
605            ))
606        }
607    };
608
609    let mut encrypted_field_defs = Vec::new();
610    let mut encrypt_exprs = Vec::new();
611    let mut decrypt_exprs = Vec::new();
612
613    for field in fields {
614        let field_name = field.ident.as_ref().ok_or_else(|| {
615            syn::Error::new_spanned(field, "expected named field")
616        })?;
617        let field_ty = &field.ty;
618        let has_encrypt = field.attrs.iter().any(|a| a.path().is_ident("encrypt"));
619
620        if has_encrypt {
621            // Encrypted field: type becomes String in the encrypted struct
622            encrypted_field_defs.push(quote! {
623                pub #field_name: String
624            });
625            encrypt_exprs.push(quote! {
626                #field_name: ::guarantee::crypto::encrypt_field(&self.#field_name, key)?
627            });
628            decrypt_exprs.push(quote! {
629                #field_name: ::guarantee::crypto::decrypt_field(&encrypted.#field_name, key)?
630            });
631        } else {
632            // Non-encrypted field: keep the same type, clone the value
633            encrypted_field_defs.push(quote! {
634                pub #field_name: #field_ty
635            });
636            encrypt_exprs.push(quote! {
637                #field_name: self.#field_name.clone()
638            });
639            decrypt_exprs.push(quote! {
640                #field_name: encrypted.#field_name.clone()
641            });
642        }
643    }
644
645    let output = quote! {
646        /// Encrypted version of [`#name`]. Fields marked `#[encrypt]` contain
647        /// AES-256-GCM ciphertext in the format `enc:v1:<nonce_hex>:<ciphertext_hex>`.
648        #[derive(::serde::Serialize, ::serde::Deserialize, Debug, Clone)]
649        pub struct #encrypted_name {
650            #(#encrypted_field_defs,)*
651        }
652
653        impl ::guarantee::crypto::Encryptable for #name {
654            type Encrypted = #encrypted_name;
655
656            fn encrypt(&self, key: &[u8; 32]) -> Result<#encrypted_name, ::guarantee::SdkError> {
657                Ok(#encrypted_name {
658                    #(#encrypt_exprs,)*
659                })
660            }
661
662            fn decrypt_from(encrypted: &#encrypted_name, key: &[u8; 32]) -> Result<Self, ::guarantee::SdkError> {
663                Ok(#name {
664                    #(#decrypt_exprs,)*
665                })
666            }
667        }
668    };
669
670    Ok(output.into())
671}