Skip to main content

aizu_macros/
lib.rs

1//! Proc macros for Aizu SDK.
2//!
3//! This crate provides the `#[query]`, `#[mutation]`, `#[action]`, `#[http]`, and `#[ws]` macros.
4
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{format_ident, quote};
8use syn::{parse_macro_input, FnArg, ItemFn, Pat, PatType, ReturnType, LitStr, Token, Ident};
9use syn::parse::{Parse, ParseStream};
10
11/// **Query** - Read-only database access.
12///
13/// Queries are deterministic and cacheable. Use for fetching data without side effects.
14///
15/// | Capability | Allowed |
16/// |------------|---------|
17/// | Read DB    | Yes     |
18/// | Write DB   | No      |
19/// | HTTP       | No      |
20/// | Auth       | Yes     |
21///
22/// # Example
23///
24/// ```ignore
25/// #[query]
26/// pub fn get_user(ctx: &Ctx, id: Id<User>) -> Option<User> {
27///     ctx.db.get(id)
28/// }
29/// ```
30#[proc_macro_attribute]
31pub fn query(_attr: TokenStream, item: TokenStream) -> TokenStream {
32    let input = parse_macro_input!(item as ItemFn);
33    generate_export(input, "query").into()
34}
35
36/// **Mutation** - Read and write database access.
37///
38/// Mutations are transactional and automatically retried on conflict.
39/// Use for creating, updating, or deleting data.
40///
41/// | Capability | Allowed |
42/// |------------|---------|
43/// | Read DB    | Yes     |
44/// | Write DB   | Yes     |
45/// | HTTP       | No      |
46/// | Auth       | Yes     |
47///
48/// # Example
49///
50/// ```ignore
51/// #[mutation]
52/// pub fn create_todo(ctx: &Ctx, title: String) -> Id<Todo> {
53///     ctx.db.insert(&Todo { title, completed: false }).unwrap()
54/// }
55/// ```
56#[proc_macro_attribute]
57pub fn mutation(_attr: TokenStream, item: TokenStream) -> TokenStream {
58    let input = parse_macro_input!(item as ItemFn);
59    generate_export(input, "mutation").into()
60}
61
62/// **Action** - Full access: database + HTTP.
63///
64/// Actions can perform external I/O like HTTP requests. Not automatically retried.
65/// Use for workflows that need to call external APIs.
66///
67/// | Capability | Allowed |
68/// |------------|---------|
69/// | Read DB    | Yes     |
70/// | Write DB   | Yes     |
71/// | HTTP       | Yes     |
72/// | Auth       | Yes     |
73///
74/// # Example
75///
76/// ```ignore
77/// #[action]
78/// pub fn process_payment(ctx: &Ctx, order_id: Id<Order>) -> Result<bool, String> {
79///     let order = ctx.db.get(order_id).ok_or("Order not found")?;
80///     ctx.http.post("https://api.stripe.com/charge", order.amount)?;
81///     ctx.db.insert(&PaymentRecord { order_id, status: "paid" })?;
82///     Ok(true)
83/// }
84/// ```
85#[proc_macro_attribute]
86pub fn action(_attr: TokenStream, item: TokenStream) -> TokenStream {
87    let input = parse_macro_input!(item as ItemFn);
88    generate_export(input, "action").into()
89}
90
91/// **HTTP** - HTTP route handler.
92///
93/// Registers a function as an HTTP endpoint. The function receives an `HttpRequest`
94/// and should return a `RouteResponse`.
95///
96/// | Capability | Allowed |
97/// |------------|---------|
98/// | Read DB    | Yes     |
99/// | Write DB   | Yes     |
100/// | HTTP       | Yes     |
101/// | Auth       | Yes     |
102///
103/// # Example
104///
105/// ```ignore
106/// #[http(GET, "/api/todos")]
107/// pub fn list_todos(ctx: &Ctx, req: HttpRequest) -> RouteResponse {
108///     let todos = ctx.db.query::<Todo>().collect::<Vec<_>>();
109///     RouteResponse::json(&todos)
110/// }
111///
112/// #[http(GET, "/api/todos/:id")]
113/// pub fn get_todo(ctx: &Ctx, req: HttpRequest) -> RouteResponse {
114///     let id = req.param("id").unwrap();
115///     match ctx.db.get::<Todo>(Id::from_str(id)) {
116///         Some(todo) => RouteResponse::json(&todo),
117///         None => RouteResponse::not_found("Todo not found"),
118///     }
119/// }
120///
121/// #[http(POST, "/api/todos")]
122/// pub fn create_todo(ctx: &Ctx, req: HttpRequest) -> RouteResponse {
123///     let body: CreateTodo = req.json().unwrap();
124///     let id = ctx.db.insert(&Todo { title: body.title, completed: false }).unwrap();
125///     RouteResponse::json(&id).with_status(201)
126/// }
127/// ```
128#[proc_macro_attribute]
129pub fn http(attr: TokenStream, item: TokenStream) -> TokenStream {
130    let args = parse_macro_input!(attr as HttpArgs);
131    let input = parse_macro_input!(item as ItemFn);
132    generate_http_export(input, args).into()
133}
134
135/// Arguments for the `#[http]` macro: `#[http(METHOD, "/path")]`
136struct HttpArgs {
137    method: Ident,
138    path: LitStr,
139}
140
141impl Parse for HttpArgs {
142    fn parse(input: ParseStream) -> syn::Result<Self> {
143        let method: Ident = input.parse()?;
144        input.parse::<Token![,]>()?;
145        let path: LitStr = input.parse()?;
146        Ok(HttpArgs { method, path })
147    }
148}
149
150/// Normalize a URL path to a valid identifier suffix.
151///
152/// - `/api/todos/:id` → `api_todos__id`
153/// - `/` → `root`
154fn normalize_path(path: &str) -> String {
155    if path == "/" {
156        return "root".to_string();
157    }
158
159    path.trim_start_matches('/')
160        .replace('/', "_")
161        .replace(':', "_")
162        .replace('-', "_")
163}
164
165/// Generate the HTTP handler export.
166fn generate_http_export(func: ItemFn, args: HttpArgs) -> TokenStream2 {
167    let fn_name = &func.sig.ident;
168    let fn_vis = &func.vis;
169    let fn_block = &func.block;
170    let fn_output = &func.sig.output;
171    let fn_attrs = &func.attrs;
172
173    let method = args.method.to_string().to_uppercase();
174    let path = args.path.value();
175    let normalized_path = normalize_path(&path);
176
177    // Export name: __aizu_http_{METHOD}_{path_normalized}
178    let export_name = format_ident!("__aizu_http_{}_{}", method.to_lowercase(), normalized_path);
179
180    // The function signature should be: fn(ctx: &Ctx, req: HttpRequest) -> RouteResponse
181    // We'll pass the serialized HttpRequest as args
182
183    // Determine if return type is Result-wrapped
184    let is_result = match fn_output {
185        ReturnType::Default => false,
186        ReturnType::Type(_, ty) => {
187            let ty_str = quote!(#ty).to_string();
188            ty_str.starts_with("Result") || ty_str.contains(":: Result")
189        }
190    };
191
192    // Generate result handling that serializes RouteResponse as JSON
193    let result_handling = if is_result {
194        quote! {
195            match result {
196                Ok(response) => {
197                    match aizu::serde_json::to_vec(&response) {
198                        Ok(data) => {
199                            aizu::set_result(&data);
200                            0
201                        }
202                        Err(_) => -3,
203                    }
204                }
205                Err(e) => {
206                    let err_response = aizu::RouteResponse::error(500, &aizu::format!("{:?}", e));
207                    match aizu::serde_json::to_vec(&err_response) {
208                        Ok(data) => {
209                            aizu::set_result(&data);
210                            -100
211                        }
212                        Err(_) => -3,
213                    }
214                }
215            }
216        }
217    } else {
218        quote! {
219            match aizu::serde_json::to_vec(&result) {
220                Ok(data) => {
221                    aizu::set_result(&data);
222                    0
223                }
224                Err(_) => -3,
225            }
226        }
227    };
228
229    let expanded = quote! {
230        // Original function (kept for direct calls in tests)
231        #(#fn_attrs)*
232        #fn_vis fn #fn_name(ctx: &aizu::Ctx, req: aizu::HttpRequest) #fn_output
233        #fn_block
234
235        // WASM export for HTTP handler
236        #[unsafe(no_mangle)]
237        pub extern "C" fn #export_name(
238            args_ptr: i32,
239            args_len: i32,
240        ) -> i32 {
241            aizu::with_ctx(|ctx| {
242                // Deserialize HttpRequest from JSON
243                let args_slice = if args_ptr == 0 || args_len <= 0 {
244                    &[]
245                } else {
246                    unsafe {
247                        core::slice::from_raw_parts(args_ptr as *const u8, args_len as usize)
248                    }
249                };
250
251                let req: aizu::HttpRequest = match aizu::serde_json::from_slice(args_slice) {
252                    Ok(r) => r,
253                    Err(_) => return -1,
254                };
255
256                // Call the user function
257                let result = #fn_name(ctx, req);
258
259                // Serialize RouteResponse and return
260                #result_handling
261            })
262        }
263    };
264
265    expanded
266}
267
268/// Generate the WASM export wrapper for a function.
269///
270/// Generates exports compatible with Tatara's calling convention:
271/// - Function: `__aizu_{kind}_{name}(args_ptr, args_len) -> i32`
272/// - Returns 0 on success, negative error code on failure
273/// - Results written to static buffer, read via get_result_ptr/get_result_len
274fn generate_export(func: ItemFn, kind: &str) -> TokenStream2 {
275    let fn_name = &func.sig.ident;
276    let fn_vis = &func.vis;
277    let fn_block = &func.block;
278    let fn_output = &func.sig.output;
279    let fn_attrs = &func.attrs;
280
281    // Generate export name: __aizu_{kind}_{name}
282    let export_name = format_ident!("__aizu_{}_{}", kind, fn_name);
283
284    // Extract function arguments (skip first `ctx: &Ctx`)
285    let args: Vec<_> = func.sig.inputs.iter().skip(1).collect();
286
287    // Extract argument names and types
288    let mut arg_names = Vec::new();
289    let mut arg_types = Vec::new();
290
291    for arg in &args {
292        if let FnArg::Typed(PatType { pat, ty, .. }) = arg
293            && let Pat::Ident(pat_ident) = &**pat
294        {
295            arg_names.push(&pat_ident.ident);
296            arg_types.push(&**ty);
297        }
298    }
299
300    // Generate args struct name
301    let args_struct_name = format_ident!("__{}_Args", fn_name);
302
303    // Determine return type for serialization
304    let (_return_type, is_result) = match fn_output {
305        ReturnType::Default => (quote! { () }, false),
306        ReturnType::Type(_, ty) => {
307            // Check if it's a Result type
308            let ty_str = quote!(#ty).to_string();
309            if ty_str.starts_with("Result") || ty_str.contains(":: Result") {
310                (quote! { #ty }, true)
311            } else {
312                (quote! { #ty }, false)
313            }
314        }
315    };
316
317    // Generate the result handling code
318    // Use to_vec_named to serialize structs as maps with named keys (not arrays)
319    let result_handling = if is_result {
320        quote! {
321            match result {
322                Ok(value) => {
323                    match aizu::rmp_serde::to_vec_named(&value) {
324                        Ok(data) => {
325                            aizu::set_result(&data);
326                            0
327                        }
328                        Err(_) => -3, // Serialization error
329                    }
330                }
331                Err(e) => {
332                    // Serialize error as string (use Display, not Debug, to avoid quotes)
333                    let err_msg = aizu::format!("{}", e);
334                    aizu::set_result(err_msg.as_bytes());
335                    -100 // Error indicator
336                }
337            }
338        }
339    } else {
340        quote! {
341            match aizu::rmp_serde::to_vec_named(&result) {
342                Ok(data) => {
343                    aizu::set_result(&data);
344                    0
345                }
346                Err(_) => -3, // Serialization error
347            }
348        }
349    };
350
351    // Handle the case of zero arguments (no args struct needed)
352    let args_handling = if arg_names.is_empty() {
353        quote! {
354            // No arguments to deserialize
355            let _ = args_slice;
356        }
357    } else {
358        quote! {
359            let args: #args_struct_name = match aizu::rmp_serde::from_slice(args_slice) {
360                Ok(a) => a,
361                Err(_) => return -1, // Deserialization error
362            };
363        }
364    };
365
366    // Generate the function call with or without args
367    let fn_call = if arg_names.is_empty() {
368        quote! { #fn_name(ctx) }
369    } else {
370        quote! { #fn_name(ctx, #(args.#arg_names),*) }
371    };
372
373    // Generate the args struct only if there are arguments
374    let args_struct = if arg_names.is_empty() {
375        quote! {}
376    } else {
377        quote! {
378            #[derive(aizu::serde::Deserialize)]
379            #[allow(non_camel_case_types)]
380            struct #args_struct_name {
381                #(#arg_names: #arg_types),*
382            }
383        }
384    };
385
386    // Generate the wrapper
387    let expanded = quote! {
388        // Original function (kept for direct calls in tests)
389        #(#fn_attrs)*
390        #fn_vis fn #fn_name(ctx: &aizu::Ctx, #(#arg_names: #arg_types),*) #fn_output
391        #fn_block
392
393        // Args struct for deserialization
394        #args_struct
395
396        // WASM export (Tatara calling convention: args_ptr, args_len -> status)
397        #[unsafe(no_mangle)]
398        pub extern "C" fn #export_name(
399            args_ptr: i32,
400            args_len: i32,
401        ) -> i32 {
402            aizu::with_ctx(|ctx| {
403                // Get args slice
404                let args_slice = if args_ptr == 0 || args_len <= 0 {
405                    &[]
406                } else {
407                    unsafe {
408                        core::slice::from_raw_parts(args_ptr as *const u8, args_len as usize)
409                    }
410                };
411
412                #args_handling
413
414                // Call the user function
415                let result = #fn_call;
416
417                // Serialize and return
418                #result_handling
419            })
420        }
421    };
422
423    expanded
424}
425
426/// **WebSocket** - WebSocket route handler.
427///
428/// Registers a function as a WebSocket endpoint. The function receives a `WsMessage`
429/// and should return a `WsResponse`.
430///
431/// | Capability | Allowed |
432/// |------------|---------|
433/// | Read DB    | Yes     |
434/// | Write DB   | Yes     |
435/// | HTTP       | Yes     |
436/// | Auth       | Yes     |
437///
438/// # Example
439///
440/// ```ignore
441/// #[ws("/ws/chat")]
442/// pub fn chat_handler(ctx: &Ctx, msg: WsMessage) -> WsResponse {
443///     match msg {
444///         WsMessage::Connect { params, .. } => {
445///             // Client connected
446///             WsResponse::text("Welcome!")
447///         }
448///         WsMessage::Text(text) => {
449///             // Echo the message back
450///             WsResponse::text(format!("You said: {}", text))
451///         }
452///         WsMessage::Close { .. } => {
453///             // Client disconnecting
454///             WsResponse::ok()
455///         }
456///         _ => WsResponse::ok(),
457///     }
458/// }
459/// ```
460#[proc_macro_attribute]
461pub fn ws(attr: TokenStream, item: TokenStream) -> TokenStream {
462    let args = parse_macro_input!(attr as WsArgs);
463    let input = parse_macro_input!(item as ItemFn);
464    generate_ws_export(input, args).into()
465}
466
467/// Arguments for the `#[ws]` macro: `#[ws("/path")]`
468struct WsArgs {
469    path: LitStr,
470}
471
472impl Parse for WsArgs {
473    fn parse(input: ParseStream) -> syn::Result<Self> {
474        let path: LitStr = input.parse()?;
475        Ok(WsArgs { path })
476    }
477}
478
479/// Generate the WebSocket handler export.
480fn generate_ws_export(func: ItemFn, args: WsArgs) -> TokenStream2 {
481    let fn_name = &func.sig.ident;
482    let fn_vis = &func.vis;
483    let fn_block = &func.block;
484    let fn_output = &func.sig.output;
485    let fn_attrs = &func.attrs;
486
487    let path = args.path.value();
488    let normalized_path = normalize_path(&path);
489
490    // Export name: __aizu_ws_{path_normalized}
491    let export_name = format_ident!("__aizu_ws_{}", normalized_path);
492
493    // The function signature should be: fn(ctx: &Ctx, msg: WsMessage) -> WsResponse
494
495    // Determine if return type is Result-wrapped
496    let is_result = match fn_output {
497        ReturnType::Default => false,
498        ReturnType::Type(_, ty) => {
499            let ty_str = quote!(#ty).to_string();
500            ty_str.starts_with("Result") || ty_str.contains(":: Result")
501        }
502    };
503
504    // Generate result handling that serializes WsResponse as JSON
505    let result_handling = if is_result {
506        quote! {
507            match result {
508                Ok(response) => {
509                    match aizu::serde_json::to_vec(&response) {
510                        Ok(data) => {
511                            aizu::set_result(&data);
512                            0
513                        }
514                        Err(_) => -3,
515                    }
516                }
517                Err(e) => {
518                    let err_response = aizu::WsResponse::close_with_error(&aizu::format!("{:?}", e));
519                    match aizu::serde_json::to_vec(&err_response) {
520                        Ok(data) => {
521                            aizu::set_result(&data);
522                            -100
523                        }
524                        Err(_) => -3,
525                    }
526                }
527            }
528        }
529    } else {
530        quote! {
531            match aizu::serde_json::to_vec(&result) {
532                Ok(data) => {
533                    aizu::set_result(&data);
534                    0
535                }
536                Err(_) => -3,
537            }
538        }
539    };
540
541    let expanded = quote! {
542        // Original function (kept for direct calls in tests)
543        #(#fn_attrs)*
544        #fn_vis fn #fn_name(ctx: &aizu::Ctx, msg: aizu::WsMessage) #fn_output
545        #fn_block
546
547        // WASM export for WebSocket handler
548        #[unsafe(no_mangle)]
549        pub extern "C" fn #export_name(
550            args_ptr: i32,
551            args_len: i32,
552        ) -> i32 {
553            aizu::with_ctx(|ctx| {
554                // Deserialize WsMessage from JSON
555                let args_slice = if args_ptr == 0 || args_len <= 0 {
556                    &[]
557                } else {
558                    unsafe {
559                        core::slice::from_raw_parts(args_ptr as *const u8, args_len as usize)
560                    }
561                };
562
563                let msg: aizu::WsMessage = match aizu::serde_json::from_slice(args_slice) {
564                    Ok(m) => m,
565                    Err(_) => return -1,
566                };
567
568                // Call the user function
569                let result = #fn_name(ctx, msg);
570
571                // Serialize WsResponse and return
572                #result_handling
573            })
574        }
575    };
576
577    expanded
578}