1extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{FnArg, Item, ItemFn, ItemMod, parse_macro_input, spanned::Spanned};
10
11#[proc_macro_attribute]
12pub fn ta_create(_args: TokenStream, input: TokenStream) -> TokenStream {
13 input
14}
15
16#[proc_macro_attribute]
17pub fn ta_open_session(_args: TokenStream, input: TokenStream) -> TokenStream {
18 input
19}
20
21#[proc_macro_attribute]
22pub fn ta_close_session(_args: TokenStream, input: TokenStream) -> TokenStream {
23 input
24}
25
26#[proc_macro_attribute]
27pub fn ta_destroy(_args: TokenStream, input: TokenStream) -> TokenStream {
28 input
29}
30
31#[proc_macro_attribute]
32pub fn ta_invoke_command(_args: TokenStream, input: TokenStream) -> TokenStream {
33 input
34}
35
36#[proc_macro_attribute]
37pub fn ta_acl_check(_args: TokenStream, input: TokenStream) -> TokenStream {
38 input
39}
40
41#[proc_macro]
43pub fn xtee_ta(input: TokenStream) -> TokenStream {
44 let file = parse_macro_input!(input as syn::File);
45 match expand_xtee_ta_items(file.items) {
46 Ok(ts) => ts.into(),
47 Err(e) => e.to_compile_error().into(),
48 }
49}
50
51#[proc_macro_attribute]
53pub fn xtee_ta_module(_args: TokenStream, input: TokenStream) -> TokenStream {
54 let mut item_mod = parse_macro_input!(input as ItemMod);
55 let ta_struct_ident = format_ident!("Ta");
56
57 let Some((_, items)) = &mut item_mod.content else {
58 return syn::Error::new(
59 item_mod.span(),
60 "#[xtee_ta_module] only supports inline modules",
61 )
62 .to_compile_error()
63 .into();
64 };
65
66 match expand_xtee_ta_items_mut(items, ta_struct_ident) {
67 Ok(()) => quote!(#item_mod).into(),
68 Err(e) => e.to_compile_error().into(),
69 }
70}
71
72fn expand_xtee_ta_items(items: Vec<Item>) -> Result<TokenStream2, syn::Error> {
73 let ta_struct_ident = format_ident!("Ta");
74 let mut items = items;
75 expand_xtee_ta_items_mut(&mut items, ta_struct_ident)?;
76 Ok(quote! { #(#items)* })
77}
78
79fn expand_xtee_ta_items_mut(
80 items: &mut Vec<Item>,
81 ta_struct_ident: syn::Ident,
82) -> Result<(), syn::Error> {
83 let mut create_fn: Option<ItemFn> = None;
84 let mut open_fn: Option<ItemFn> = None;
85 let mut close_fn: Option<ItemFn> = None;
86 let mut destroy_fn: Option<ItemFn> = None;
87 let mut invoke_fn: Option<ItemFn> = None;
88 let mut acl_fn: Option<ItemFn> = None;
89
90 for item in items.iter_mut() {
91 let Item::Fn(func) = item else {
92 return Err(syn::Error::new(
93 item.span(),
94 "xtee_ta! only supports `fn` items (put `use` at crate root)",
95 ));
96 };
97 let marker = extract_marker_and_strip(func);
98 let Some(marker) = marker else {
99 return Err(syn::Error::new(
100 func.span(),
101 "xtee_ta!: each `fn` must carry one of #[ta_create], #[ta_open_session], #[ta_close_session], #[ta_destroy], #[ta_invoke_command], #[ta_acl_check]",
102 ));
103 };
104 match marker.as_str() {
105 "ta_create" => {
106 if create_fn.replace(func.clone()).is_some() {
107 return Err(syn::Error::new(
108 func.span(),
109 "duplicate #[ta_create] function",
110 ));
111 }
112 }
113 "ta_open_session" => {
114 if open_fn.replace(func.clone()).is_some() {
115 return Err(syn::Error::new(
116 func.span(),
117 "duplicate #[ta_open_session] function",
118 ));
119 }
120 }
121 "ta_close_session" => {
122 if close_fn.replace(func.clone()).is_some() {
123 return Err(syn::Error::new(
124 func.span(),
125 "duplicate #[ta_close_session] function",
126 ));
127 }
128 }
129 "ta_destroy" => {
130 if destroy_fn.replace(func.clone()).is_some() {
131 return Err(syn::Error::new(
132 func.span(),
133 "duplicate #[ta_destroy] function",
134 ));
135 }
136 }
137 "ta_invoke_command" => {
138 if invoke_fn.replace(func.clone()).is_some() {
139 return Err(syn::Error::new(
140 func.span(),
141 "duplicate #[ta_invoke_command] function",
142 ));
143 }
144 }
145 "ta_acl_check" => {
146 if acl_fn.replace(func.clone()).is_some() {
147 return Err(syn::Error::new(
148 func.span(),
149 "duplicate #[ta_acl_check] function",
150 ));
151 }
152 }
153 _ => {
154 return Err(syn::Error::new(
155 func.span(),
156 "unknown #[ta_*] attribute for xtee_ta!",
157 ));
158 }
159 }
160 }
161
162 let Some(create_fn) = create_fn else {
163 return Err(syn::Error::new(
164 proc_macro2::Span::call_site(),
165 "missing #[ta_create] function",
166 ));
167 };
168 let Some(open_fn) = open_fn else {
169 return Err(syn::Error::new(
170 proc_macro2::Span::call_site(),
171 "missing #[ta_open_session] function",
172 ));
173 };
174 let Some(close_fn) = close_fn else {
175 return Err(syn::Error::new(
176 proc_macro2::Span::call_site(),
177 "missing #[ta_close_session] function",
178 ));
179 };
180 let Some(destroy_fn) = destroy_fn else {
181 return Err(syn::Error::new(
182 proc_macro2::Span::call_site(),
183 "missing #[ta_destroy] function",
184 ));
185 };
186 let Some(invoke_fn) = invoke_fn else {
187 return Err(syn::Error::new(
188 proc_macro2::Span::call_site(),
189 "missing #[ta_invoke_command] function",
190 ));
191 };
192
193 let create_ident = create_fn.sig.ident.clone();
194 let destroy_ident = destroy_fn.sig.ident.clone();
195
196 let (session_ctx_ty, open_call, close_call, invoke_call) =
197 build_context_and_calls(&open_fn, &close_fn, &invoke_fn)?;
198
199 let acl_check_impl = if let Some(acl_fn) = &acl_fn {
200 if acl_fn.sig.inputs.len() != 1 {
201 return Err(syn::Error::new(
202 acl_fn.sig.span(),
203 "#[ta_acl_check] expects fn(ca_auth_info: Option<&CaAuthInfo>)",
204 ));
205 }
206 let acl_ident = &acl_fn.sig.ident;
207 quote! {
208 fn acl_check(
209 &self,
210 ca_auth_info: Option<&teec_protocol::CaAuthInfo>,
211 ) -> xtee_utee::error::Result<()> {
212 __XteeIntoTaResult::into_ta_result(#acl_ident(ca_auth_info))
213 }
214 }
215 } else {
216 quote! {}
217 };
218
219 let impl_block = quote! {
220 use teec_protocol::Parameters;
221
222 pub struct #ta_struct_ident;
223
224 trait __XteeIntoTaResult {
225 fn into_ta_result(self) -> xtee_utee::error::Result<()>;
226 }
227
228 impl __XteeIntoTaResult for () {
229 fn into_ta_result(self) -> xtee_utee::error::Result<()> {
230 Ok(())
231 }
232 }
233
234 impl __XteeIntoTaResult for xtee_utee::error::Result<()> {
235 fn into_ta_result(self) -> xtee_utee::error::Result<()> {
236 self
237 }
238 }
239
240 impl xtee_utee::ta_manager::TrustedApplication for #ta_struct_ident {
241 type SessionContext = #session_ctx_ty;
242
243 fn create(&self) -> xtee_utee::error::Result<()> {
244 __XteeIntoTaResult::into_ta_result(#create_ident())
245 }
246
247 #acl_check_impl
248
249 fn open_session(
250 &self,
251 params: &mut Parameters,
252 ) -> xtee_utee::error::Result<Self::SessionContext> {
253 #open_call
254 }
255
256 fn close_session(
257 &self,
258 ctx: &mut Self::SessionContext,
259 ) -> xtee_utee::error::Result<()> {
260 __XteeIntoTaResult::into_ta_result(#close_call)
261 }
262
263 fn destroy(&self) -> xtee_utee::error::Result<()> {
264 __XteeIntoTaResult::into_ta_result(#destroy_ident())
265 }
266
267 fn invoke_command(
268 &self,
269 cmd_id: u32,
270 params: &mut Parameters,
271 ctx: &mut Self::SessionContext,
272 ) -> xtee_utee::error::Result<()> {
273 __XteeIntoTaResult::into_ta_result(#invoke_call)
274 }
275 }
276 };
277
278 let parsed: syn::File = syn::parse2(impl_block).map_err(|e| {
279 syn::Error::new(
280 proc_macro2::Span::call_site(),
281 format!("xtee_ta: failed to parse generated items: {e}"),
282 )
283 })?;
284 for item in parsed.items {
285 items.push(item);
286 }
287
288 Ok(())
289}
290
291fn extract_marker_and_strip(func: &mut ItemFn) -> Option<String> {
292 let mut marker: Option<String> = None;
293 func.attrs.retain(|attr| {
294 let Some(last) = attr.path().segments.last() else {
295 return true;
296 };
297 let name = last.ident.to_string();
298 let is_marker = matches!(
299 name.as_str(),
300 "ta_create"
301 | "ta_open_session"
302 | "ta_close_session"
303 | "ta_destroy"
304 | "ta_invoke_command"
305 | "ta_acl_check"
306 );
307 if is_marker {
308 marker = Some(name);
309 false
310 } else {
311 true
312 }
313 });
314 marker
315}
316
317fn build_context_and_calls(
318 open_fn: &ItemFn,
319 close_fn: &ItemFn,
320 invoke_fn: &ItemFn,
321) -> Result<(TokenStream2, TokenStream2, TokenStream2, TokenStream2), syn::Error> {
322 let open_arg_count = open_fn.sig.inputs.len();
323 let close_arg_count = close_fn.sig.inputs.len();
324 let invoke_arg_count = invoke_fn.sig.inputs.len();
325
326 if !(open_arg_count == 1 || open_arg_count == 2) {
327 return Err(syn::Error::new(
328 open_fn.sig.span(),
329 "#[ta_open_session] expects fn(&mut Parameters) or fn(&mut Parameters, &mut T)",
330 ));
331 }
332
333 if !(close_arg_count == 0 || close_arg_count == 1) {
334 return Err(syn::Error::new(
335 close_fn.sig.span(),
336 "#[ta_close_session] expects fn() or fn(&mut T)",
337 ));
338 }
339
340 if !(invoke_arg_count == 2 || invoke_arg_count == 3) {
341 return Err(syn::Error::new(
342 invoke_fn.sig.span(),
343 "#[ta_invoke_command] expects fn(cmd_id, &mut Parameters) or fn(&mut T, cmd_id, &mut Parameters)",
344 ));
345 }
346
347 let open_ident = &open_fn.sig.ident;
348 let close_ident = &close_fn.sig.ident;
349 let invoke_ident = &invoke_fn.sig.ident;
350
351 if open_arg_count == 1 {
352 if close_arg_count != 0 || invoke_arg_count != 2 {
353 return Err(syn::Error::new(
354 open_fn.sig.span(),
355 "no-session-context mode requires close_session() and invoke_command(cmd_id, params)",
356 ));
357 }
358 let session_ctx_ty = quote! { () };
359 let open_call = quote! {
360 __XteeIntoTaResult::into_ta_result(#open_ident(params))?;
361 Ok(())
362 };
363 let close_call = quote! { #close_ident() };
364 let invoke_call = quote! { #invoke_ident(cmd_id, params) };
365 return Ok((session_ctx_ty, open_call, close_call, invoke_call));
366 }
367
368 let ctx_ty = extract_mut_ref_type(
369 open_fn
370 .sig
371 .inputs
372 .iter()
373 .nth(1)
374 .expect("checked arg count above"),
375 )?;
376 if close_arg_count != 1 || invoke_arg_count != 3 {
377 return Err(syn::Error::new(
378 open_fn.sig.span(),
379 "session-context mode requires close_session(&mut T) and invoke_command(&mut T, cmd_id, params)",
380 ));
381 }
382 let close_ctx_ty = extract_mut_ref_type(
383 close_fn
384 .sig
385 .inputs
386 .iter()
387 .next()
388 .expect("checked arg count above"),
389 )?;
390 let invoke_ctx_ty = extract_mut_ref_type(
391 invoke_fn
392 .sig
393 .inputs
394 .iter()
395 .next()
396 .expect("checked arg count above"),
397 )?;
398 if quote!(#ctx_ty).to_string() != quote!(#close_ctx_ty).to_string()
399 || quote!(#ctx_ty).to_string() != quote!(#invoke_ctx_ty).to_string()
400 {
401 return Err(syn::Error::new(
402 open_fn.sig.span(),
403 "session context type T must be consistent across ta_open_session/ta_close_session/ta_invoke_command",
404 ));
405 }
406
407 let session_ctx_ty = quote! { #ctx_ty };
408 let open_call = quote! {
409 let mut ctx: #ctx_ty = Default::default();
410 __XteeIntoTaResult::into_ta_result(#open_ident(params, &mut ctx))?;
411 Ok(ctx)
412 };
413 let close_call = quote! { #close_ident(ctx) };
414 let invoke_call = quote! { #invoke_ident(ctx, cmd_id, params) };
415 Ok((session_ctx_ty, open_call, close_call, invoke_call))
416}
417
418fn extract_mut_ref_type(arg: &FnArg) -> Result<&syn::Type, syn::Error> {
419 if let FnArg::Typed(pat_ty) = arg {
420 if let syn::Type::Reference(type_ref) = pat_ty.ty.as_ref() {
421 if type_ref.mutability.is_some() {
422 return Ok(type_ref.elem.as_ref());
423 }
424 }
425 }
426 Err(syn::Error::new(arg.span(), "argument must be &mut T"))
427}