Skip to main content

by_macros/
lib.rs

1extern crate proc_macro;
2
3mod action;
4mod api_model;
5mod api_model_struct;
6mod dynamo_entity;
7mod dynamo_enum;
8mod enum_prop;
9mod mcp_tool;
10pub(crate) mod parse_queryable_fields;
11mod qdrant_entity;
12mod server_fn;
13#[cfg(feature = "server")]
14mod query_builder_functions;
15mod query_display;
16mod rest_error;
17#[cfg(feature = "server")]
18mod sql_model;
19mod sub_partition;
20mod write_file;
21
22use api_model::api_model_impl;
23use dynamo_entity::dynamo_entity_impl;
24use dynamo_enum::dynamo_enum_impl;
25use enum_prop::enum_prop_impl;
26use proc_macro::TokenStream;
27use query_display::query_display_impl;
28use quote::{quote, ToTokens};
29use rest_error::rest_error_impl;
30use sub_partition::sub_partition_impl;
31use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Fields};
32
33#[proc_macro_derive(QueryDisplay)]
34pub fn query_display_derive(input: TokenStream) -> TokenStream {
35    let _ = tracing_subscriber::fmt()
36        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
37        .with_file(true)
38        .with_line_number(true)
39        .with_thread_ids(true)
40        .with_target(false)
41        .try_init();
42    query_display_impl(input)
43}
44
45#[proc_macro_attribute]
46pub fn api_model(attr: TokenStream, item: TokenStream) -> TokenStream {
47    let _ = tracing_subscriber::fmt()
48        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
49        .with_file(true)
50        .with_line_number(true)
51        .with_thread_ids(true)
52        .with_target(false)
53        .try_init();
54    api_model_impl(attr.into(), item.into()).into()
55}
56
57#[proc_macro_derive(EnumProp)]
58pub fn enum_prop_derive(input: TokenStream) -> TokenStream {
59    let _ = tracing_subscriber::fmt()
60        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
61        .with_file(true)
62        .with_line_number(true)
63        .with_thread_ids(true)
64        .with_target(false)
65        .try_init();
66    enum_prop_impl(input)
67}
68
69#[proc_macro_derive(DynamoEntity, attributes(dynamo))]
70pub fn dynamo_entity_derive(input: TokenStream) -> TokenStream {
71    let _ = tracing_subscriber::fmt()
72        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
73        .with_file(true)
74        .with_line_number(true)
75        .with_thread_ids(true)
76        .with_target(false)
77        .try_init();
78    dynamo_entity_impl(input)
79}
80
81#[proc_macro_derive(QdrantEntity, attributes(qdrant))]
82pub fn qdrant_entity_derive(input: TokenStream) -> TokenStream {
83    let _ = tracing_subscriber::fmt()
84        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
85        .with_file(true)
86        .with_line_number(true)
87        .with_thread_ids(true)
88        .with_target(false)
89        .try_init();
90    qdrant_entity::qdrant_entity_impl(input)
91}
92
93#[proc_macro_derive(DummyDynamoEntity, attributes(dynamo))]
94pub fn dummy_dynamo_entity_derive(_input: TokenStream) -> TokenStream {
95    TokenStream::new()
96}
97
98#[proc_macro_derive(DummyJsonSchema)]
99pub fn dummy_json_schema_derive(_input: TokenStream) -> TokenStream {
100    TokenStream::new()
101}
102
103#[proc_macro_derive(DummyOperationIo)]
104pub fn dummy_operation_io_derive(_input: TokenStream) -> TokenStream {
105    TokenStream::new()
106}
107
108#[proc_macro_derive(SubPartition)]
109pub fn sub_partition_derive(input: TokenStream) -> TokenStream {
110    let _ = tracing_subscriber::fmt()
111        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
112        .with_file(true)
113        .with_line_number(true)
114        .with_thread_ids(true)
115        .with_target(false)
116        .try_init();
117    sub_partition_impl(input)
118}
119
120#[proc_macro_derive(DynamoEnum, attributes(dynamo_enum))]
121pub fn dynamo_enum_derive(input: TokenStream) -> TokenStream {
122    let _ = tracing_subscriber::fmt()
123        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
124        .with_file(true)
125        .with_line_number(true)
126        .with_thread_ids(true)
127        .with_target(false)
128        .try_init();
129    dynamo_enum_impl(input)
130}
131
132/// #[derive(RestError)]
133/// #[rest_error(status = 401, code = 1000)]
134#[proc_macro_derive(RestError, attributes(rest_error))]
135pub fn rest_error_derive(input: TokenStream) -> TokenStream {
136    let _ = tracing_subscriber::fmt()
137        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
138        .with_file(true)
139        .with_line_number(true)
140        .with_thread_ids(true)
141        .with_target(false)
142        .try_init();
143    rest_error_impl(input)
144}
145
146#[proc_macro_derive(ApiModel)]
147pub fn derive_api_model(input: TokenStream) -> TokenStream {
148    let _ = tracing_subscriber::fmt()
149        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
150        .with_file(true)
151        .with_line_number(true)
152        .with_thread_ids(true)
153        .with_target(false)
154        .try_init();
155
156    let input = parse_macro_input!(input as DeriveInput);
157    let name = &input.ident;
158
159    let Data::Enum(DataEnum { variants, .. }) = &input.data else {
160        return syn::Error::new_spanned(input.ident, "ApiModel can only be derived for enums")
161            .to_compile_error()
162            .into();
163    };
164
165    let try_from_arms = variants.iter().map(|v| {
166        let ident = &v.ident;
167        let discriminant = match &v.discriminant {
168            Some((_, expr)) => quote! { #expr },
169            None => quote! { compile_error!("Enum variants must have explicit discriminants"); },
170        };
171        tracing::trace!("discriminant: {}", discriminant.to_string());
172        quote! { val if val == #discriminant => Ok(#name::#ident), }
173    });
174
175    let expanded = quote! {
176        impl std::convert::TryFrom<i32> for #name {
177            type Error = String;
178
179            fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
180                match value {
181                    #(#try_from_arms)*
182                    _ => Err(format!("Invalid {}: {}", stringify!(#name), value)),
183                }
184            }
185        }
186
187        impl std::convert::Into<i32> for #name {
188            fn into(self) -> i32 {
189                self as i32
190            }
191        }
192
193        #[cfg(feature = "server")]
194        impl sqlx::Type<sqlx::Postgres> for #name {
195            fn type_info() -> sqlx::postgres::PgTypeInfo {
196                <i32 as sqlx::Type<sqlx::Postgres>>::type_info()
197            }
198        }
199
200        #[cfg(feature = "server")]
201        impl sqlx::Encode<'_, sqlx::Postgres> for #name {
202            fn encode_by_ref(
203                &self,
204                buf: &mut sqlx::postgres::PgArgumentBuffer,
205            ) -> std::result::Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
206                let value: i32 = (*self).clone().into();
207                <i32 as sqlx::Encode<sqlx::Postgres>>::encode_by_ref(&value, buf)
208            }
209        }
210
211        #[cfg(feature = "server")]
212        impl<'r> sqlx::Decode<'r, sqlx::Postgres> for #name {
213            fn decode(
214                value: sqlx::postgres::PgValueRef<'r>,
215            ) -> std::result::Result<Self, sqlx::error::BoxDynError> {
216                let int_value: i32 = <i32 as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
217                #name::try_from(int_value)
218                    .map_err(|e| Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)).into())
219            }
220        }
221
222        impl serde::Serialize for #name {
223            fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
224            where
225                S: serde::Serializer,
226            {
227                serializer.serialize_i32(self.clone() as i32)
228            }
229        }
230
231        impl<'de> serde::Deserialize<'de> for #name {
232            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
233            where
234                D: serde::Deserializer<'de>,
235            {
236                let value = i32::deserialize(deserializer)?;
237                Self::try_from(value)
238                    .map_err(|v| serde::de::Error::custom(format!("Failed to parse ApiModel: {}", v)))
239            }
240        }
241    };
242
243    tracing::trace!("ApiModel expanded: {}", expanded.to_string());
244
245    TokenStream::from(expanded)
246}
247
248#[proc_macro_derive(DioxusController)]
249pub fn derive_dioxus_controller(input: TokenStream) -> TokenStream {
250    let _ = tracing_subscriber::fmt()
251        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
252        .with_file(true)
253        .with_line_number(true)
254        .with_thread_ids(true)
255        .with_target(false)
256        .try_init();
257
258    tracing::trace!("starting derive_dioxus_controller");
259    let input = parse_macro_input!(input as DeriveInput);
260    let struct_name = &input.ident;
261
262    let mut generated_methods = vec![];
263
264    if let Data::Struct(data_struct) = input.data {
265        if let Fields::Named(fields) = data_struct.fields {
266            tracing::trace!("starting parsing fields");
267            for field in fields.named {
268                let field_name = &field.ident.unwrap();
269                let field_type = field.ty.to_token_stream().to_string();
270                let field_type = field_type.trim().replace(" ", "");
271
272                tracing::trace!(
273                    "field_name: {}, field_type: {}",
274                    field_name.to_string(),
275                    field_type
276                );
277
278                let method: proc_macro2::TokenStream = if field_type.starts_with("Signal") {
279                    let t = field_type.trim_start_matches("Signal<");
280                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
281                    quote! {
282                        pub fn #field_name(&self) -> #t {
283                            (self.#field_name)()
284                        }
285                    }
286                } else if field_type.starts_with("ReadSignal") {
287                    let t = field_type.trim_start_matches("ReadSignal<");
288                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
289                    quote! {
290                        pub fn #field_name(&self) -> #t {
291                            (self.#field_name)()
292                        }
293                    }
294                } else if field_type.starts_with("ReadOnlySignal") {
295                    let t = field_type.trim_start_matches("ReadOnlySignal<");
296                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
297                    quote! {
298                        pub fn #field_name(&self) -> #t {
299                            (self.#field_name)()
300                        }
301                    }
302                } else if field_type.starts_with("Memo") {
303                    let t = field_type.trim_start_matches("Memo<");
304                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
305                    quote! {
306                        pub fn #field_name(&self) -> #t {
307                            (self.#field_name)()
308                        }
309                    }
310                } else if field_type.starts_with("Resource<") {
311                    let t = field_type.trim_start_matches("Resource<");
312                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
313
314                    quote! {
315                        pub fn #field_name(&self) -> std::result::Result<#t, RenderError> {
316                            Ok(self.#field_name.suspend()?())
317                        }
318                    }
319                } else if field_type.starts_with("Loader<") {
320                    let t = field_type.trim_start_matches("Loader<");
321                    let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
322
323                    quote! {
324                        pub fn #field_name(&self) -> #t {
325                            (self.#field_name)()
326                        }
327                    }
328                } else {
329                    continue;
330                };
331
332                tracing::trace!("method: {}", method.to_string());
333
334                generated_methods.push(method);
335            }
336        }
337    }
338
339    let expanded = quote! {
340        impl #struct_name {
341            #(#generated_methods)*
342        }
343    };
344
345    save_file(struct_name.to_string().as_str(), &expanded.to_string());
346
347    expanded.into()
348}
349
350/// Marks a Dioxus server function handler as an MCP tool.
351///
352/// This macro extracts the function body into a `{name}_mcp_impl` function that can be
353/// called directly from the MCP server, while keeping the original handler intact for
354/// the `#[post]`/`#[get]` macro to process.
355///
356/// # Usage
357///
358/// ```rust,ignore
359/// #[mcp_tool(name = "create_post", description = "Create and publish a new post.")]
360/// #[post("/api/posts", user: User)]
361/// pub async fn create_post_handler(team_id: Option<TeamPartition>) -> Result<CreatePostResponse> {
362///     // body is extracted into create_post_handler_mcp_impl(user, team_id)
363/// }
364/// ```
365///
366/// The generated `create_post_handler_mcp_impl` has the extracted params (e.g., `user: User`)
367/// as explicit arguments, so the MCP server can call it with `self.user`.
368#[proc_macro_attribute]
369pub fn mcp_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
370    mcp_tool::mcp_tool_impl(attr.into(), item.into()).into()
371}
372
373// HTTP method attribute macros — shadow `dioxus::fullstack::{get,post,...}`.
374// Under `cfg(not(tauri-web))` they forward unchanged to dioxus's macro so
375// server + browser-client code is generated the same way. Under
376// `cfg(tauri-web)` they emit a reqwest stub that calls
377// `crate::common::fullstack::server_fn::<method>` so the bundle no longer
378// depends on dioxus-fullstack's RPC transport (and therefore not on
379// dioxus-web/hydrate).
380
381#[proc_macro_attribute]
382pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
383    server_fn::server_fn_impl("GET", attr, item)
384}
385
386#[proc_macro_attribute]
387pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
388    server_fn::server_fn_impl("POST", attr, item)
389}
390
391#[proc_macro_attribute]
392pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
393    server_fn::server_fn_impl("PUT", attr, item)
394}
395
396#[proc_macro_attribute]
397pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
398    server_fn::server_fn_impl("PATCH", attr, item)
399}
400
401#[proc_macro_attribute]
402pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
403    server_fn::server_fn_impl("DELETE", attr, item)
404}
405
406pub(crate) fn save_file(st_name: &str, output: &str) {
407    if option_env!("WRITE_OUTPUT").is_none() {
408        return;
409    }
410
411    let dir_path = match option_env!("API_MODEL_ARTIFACT_DIR") {
412        Some(dir) => dir.to_string(),
413        None => {
414            let current_dir = std::env::current_dir().unwrap();
415            format!(
416                "{}",
417                current_dir
418                    .join(".build/generated_api_models")
419                    .to_str()
420                    .unwrap()
421            )
422        }
423    };
424    use convert_case::Casing;
425
426    let file_path = format!(
427        "{}/{}.rs",
428        dir_path,
429        st_name.to_case(convert_case::Case::Snake)
430    );
431
432    let dir = std::path::Path::new(&dir_path);
433
434    use std::fs;
435
436    if !dir.exists() {
437        if let Err(e) = fs::create_dir_all(dir) {
438            tracing::error!("Failed to create directory: {}", e);
439        }
440    }
441
442    if let Err(e) = fs::write(&file_path, output.to_string()) {
443        tracing::error!("Failed to write file: {}", e);
444    } else {
445        tracing::info!("generated code {} into {}", st_name, file_path);
446    }
447}