1extern crate proc_macro;
2
3mod action;
4mod api_model;
5mod api_model_struct;
6mod dynamo_entity;
7mod dynamo_enum;
8mod enum_prop;
9mod mcp_tool;
10pub(crate) mod parse_queryable_fields;
11mod qdrant_entity;
12mod server_fn;
13#[cfg(feature = "server")]
14mod query_builder_functions;
15mod query_display;
16mod rest_error;
17#[cfg(feature = "server")]
18mod sql_model;
19mod sub_partition;
20mod write_file;
21
22use api_model::api_model_impl;
23use dynamo_entity::dynamo_entity_impl;
24use dynamo_enum::dynamo_enum_impl;
25use enum_prop::enum_prop_impl;
26use proc_macro::TokenStream;
27use query_display::query_display_impl;
28use quote::{quote, ToTokens};
29use rest_error::rest_error_impl;
30use sub_partition::sub_partition_impl;
31use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Fields};
32
33#[proc_macro_derive(QueryDisplay)]
34pub fn query_display_derive(input: TokenStream) -> TokenStream {
35 let _ = tracing_subscriber::fmt()
36 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
37 .with_file(true)
38 .with_line_number(true)
39 .with_thread_ids(true)
40 .with_target(false)
41 .try_init();
42 query_display_impl(input)
43}
44
45#[proc_macro_attribute]
46pub fn api_model(attr: TokenStream, item: TokenStream) -> TokenStream {
47 let _ = tracing_subscriber::fmt()
48 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
49 .with_file(true)
50 .with_line_number(true)
51 .with_thread_ids(true)
52 .with_target(false)
53 .try_init();
54 api_model_impl(attr.into(), item.into()).into()
55}
56
57#[proc_macro_derive(EnumProp)]
58pub fn enum_prop_derive(input: TokenStream) -> TokenStream {
59 let _ = tracing_subscriber::fmt()
60 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
61 .with_file(true)
62 .with_line_number(true)
63 .with_thread_ids(true)
64 .with_target(false)
65 .try_init();
66 enum_prop_impl(input)
67}
68
69#[proc_macro_derive(DynamoEntity, attributes(dynamo))]
70pub fn dynamo_entity_derive(input: TokenStream) -> TokenStream {
71 let _ = tracing_subscriber::fmt()
72 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
73 .with_file(true)
74 .with_line_number(true)
75 .with_thread_ids(true)
76 .with_target(false)
77 .try_init();
78 dynamo_entity_impl(input)
79}
80
81#[proc_macro_derive(QdrantEntity, attributes(qdrant))]
82pub fn qdrant_entity_derive(input: TokenStream) -> TokenStream {
83 let _ = tracing_subscriber::fmt()
84 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
85 .with_file(true)
86 .with_line_number(true)
87 .with_thread_ids(true)
88 .with_target(false)
89 .try_init();
90 qdrant_entity::qdrant_entity_impl(input)
91}
92
93#[proc_macro_derive(DummyDynamoEntity, attributes(dynamo))]
94pub fn dummy_dynamo_entity_derive(_input: TokenStream) -> TokenStream {
95 TokenStream::new()
96}
97
98#[proc_macro_derive(DummyJsonSchema)]
99pub fn dummy_json_schema_derive(_input: TokenStream) -> TokenStream {
100 TokenStream::new()
101}
102
103#[proc_macro_derive(DummyOperationIo)]
104pub fn dummy_operation_io_derive(_input: TokenStream) -> TokenStream {
105 TokenStream::new()
106}
107
108#[proc_macro_derive(SubPartition)]
109pub fn sub_partition_derive(input: TokenStream) -> TokenStream {
110 let _ = tracing_subscriber::fmt()
111 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
112 .with_file(true)
113 .with_line_number(true)
114 .with_thread_ids(true)
115 .with_target(false)
116 .try_init();
117 sub_partition_impl(input)
118}
119
120#[proc_macro_derive(DynamoEnum, attributes(dynamo_enum))]
121pub fn dynamo_enum_derive(input: TokenStream) -> TokenStream {
122 let _ = tracing_subscriber::fmt()
123 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
124 .with_file(true)
125 .with_line_number(true)
126 .with_thread_ids(true)
127 .with_target(false)
128 .try_init();
129 dynamo_enum_impl(input)
130}
131
132#[proc_macro_derive(RestError, attributes(rest_error))]
135pub fn rest_error_derive(input: TokenStream) -> TokenStream {
136 let _ = tracing_subscriber::fmt()
137 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
138 .with_file(true)
139 .with_line_number(true)
140 .with_thread_ids(true)
141 .with_target(false)
142 .try_init();
143 rest_error_impl(input)
144}
145
146#[proc_macro_derive(ApiModel)]
147pub fn derive_api_model(input: TokenStream) -> TokenStream {
148 let _ = tracing_subscriber::fmt()
149 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
150 .with_file(true)
151 .with_line_number(true)
152 .with_thread_ids(true)
153 .with_target(false)
154 .try_init();
155
156 let input = parse_macro_input!(input as DeriveInput);
157 let name = &input.ident;
158
159 let Data::Enum(DataEnum { variants, .. }) = &input.data else {
160 return syn::Error::new_spanned(input.ident, "ApiModel can only be derived for enums")
161 .to_compile_error()
162 .into();
163 };
164
165 let try_from_arms = variants.iter().map(|v| {
166 let ident = &v.ident;
167 let discriminant = match &v.discriminant {
168 Some((_, expr)) => quote! { #expr },
169 None => quote! { compile_error!("Enum variants must have explicit discriminants"); },
170 };
171 tracing::trace!("discriminant: {}", discriminant.to_string());
172 quote! { val if val == #discriminant => Ok(#name::#ident), }
173 });
174
175 let expanded = quote! {
176 impl std::convert::TryFrom<i32> for #name {
177 type Error = String;
178
179 fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
180 match value {
181 #(#try_from_arms)*
182 _ => Err(format!("Invalid {}: {}", stringify!(#name), value)),
183 }
184 }
185 }
186
187 impl std::convert::Into<i32> for #name {
188 fn into(self) -> i32 {
189 self as i32
190 }
191 }
192
193 #[cfg(feature = "server")]
194 impl sqlx::Type<sqlx::Postgres> for #name {
195 fn type_info() -> sqlx::postgres::PgTypeInfo {
196 <i32 as sqlx::Type<sqlx::Postgres>>::type_info()
197 }
198 }
199
200 #[cfg(feature = "server")]
201 impl sqlx::Encode<'_, sqlx::Postgres> for #name {
202 fn encode_by_ref(
203 &self,
204 buf: &mut sqlx::postgres::PgArgumentBuffer,
205 ) -> std::result::Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
206 let value: i32 = (*self).clone().into();
207 <i32 as sqlx::Encode<sqlx::Postgres>>::encode_by_ref(&value, buf)
208 }
209 }
210
211 #[cfg(feature = "server")]
212 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for #name {
213 fn decode(
214 value: sqlx::postgres::PgValueRef<'r>,
215 ) -> std::result::Result<Self, sqlx::error::BoxDynError> {
216 let int_value: i32 = <i32 as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
217 #name::try_from(int_value)
218 .map_err(|e| Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)).into())
219 }
220 }
221
222 impl serde::Serialize for #name {
223 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
224 where
225 S: serde::Serializer,
226 {
227 serializer.serialize_i32(self.clone() as i32)
228 }
229 }
230
231 impl<'de> serde::Deserialize<'de> for #name {
232 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
233 where
234 D: serde::Deserializer<'de>,
235 {
236 let value = i32::deserialize(deserializer)?;
237 Self::try_from(value)
238 .map_err(|v| serde::de::Error::custom(format!("Failed to parse ApiModel: {}", v)))
239 }
240 }
241 };
242
243 tracing::trace!("ApiModel expanded: {}", expanded.to_string());
244
245 TokenStream::from(expanded)
246}
247
248#[proc_macro_derive(DioxusController)]
249pub fn derive_dioxus_controller(input: TokenStream) -> TokenStream {
250 let _ = tracing_subscriber::fmt()
251 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
252 .with_file(true)
253 .with_line_number(true)
254 .with_thread_ids(true)
255 .with_target(false)
256 .try_init();
257
258 tracing::trace!("starting derive_dioxus_controller");
259 let input = parse_macro_input!(input as DeriveInput);
260 let struct_name = &input.ident;
261
262 let mut generated_methods = vec![];
263
264 if let Data::Struct(data_struct) = input.data {
265 if let Fields::Named(fields) = data_struct.fields {
266 tracing::trace!("starting parsing fields");
267 for field in fields.named {
268 let field_name = &field.ident.unwrap();
269 let field_type = field.ty.to_token_stream().to_string();
270 let field_type = field_type.trim().replace(" ", "");
271
272 tracing::trace!(
273 "field_name: {}, field_type: {}",
274 field_name.to_string(),
275 field_type
276 );
277
278 let method: proc_macro2::TokenStream = if field_type.starts_with("Signal") {
279 let t = field_type.trim_start_matches("Signal<");
280 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
281 quote! {
282 pub fn #field_name(&self) -> #t {
283 (self.#field_name)()
284 }
285 }
286 } else if field_type.starts_with("ReadSignal") {
287 let t = field_type.trim_start_matches("ReadSignal<");
288 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
289 quote! {
290 pub fn #field_name(&self) -> #t {
291 (self.#field_name)()
292 }
293 }
294 } else if field_type.starts_with("ReadOnlySignal") {
295 let t = field_type.trim_start_matches("ReadOnlySignal<");
296 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
297 quote! {
298 pub fn #field_name(&self) -> #t {
299 (self.#field_name)()
300 }
301 }
302 } else if field_type.starts_with("Memo") {
303 let t = field_type.trim_start_matches("Memo<");
304 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
305 quote! {
306 pub fn #field_name(&self) -> #t {
307 (self.#field_name)()
308 }
309 }
310 } else if field_type.starts_with("Resource<") {
311 let t = field_type.trim_start_matches("Resource<");
312 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
313
314 quote! {
315 pub fn #field_name(&self) -> std::result::Result<#t, RenderError> {
316 Ok(self.#field_name.suspend()?())
317 }
318 }
319 } else if field_type.starts_with("Loader<") {
320 let t = field_type.trim_start_matches("Loader<");
321 let t: proc_macro2::TokenStream = t[..t.len() - 1].parse().unwrap();
322
323 quote! {
324 pub fn #field_name(&self) -> #t {
325 (self.#field_name)()
326 }
327 }
328 } else {
329 continue;
330 };
331
332 tracing::trace!("method: {}", method.to_string());
333
334 generated_methods.push(method);
335 }
336 }
337 }
338
339 let expanded = quote! {
340 impl #struct_name {
341 #(#generated_methods)*
342 }
343 };
344
345 save_file(struct_name.to_string().as_str(), &expanded.to_string());
346
347 expanded.into()
348}
349
350#[proc_macro_attribute]
369pub fn mcp_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
370 mcp_tool::mcp_tool_impl(attr.into(), item.into()).into()
371}
372
373#[proc_macro_attribute]
382pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
383 server_fn::server_fn_impl("GET", attr, item)
384}
385
386#[proc_macro_attribute]
387pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
388 server_fn::server_fn_impl("POST", attr, item)
389}
390
391#[proc_macro_attribute]
392pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
393 server_fn::server_fn_impl("PUT", attr, item)
394}
395
396#[proc_macro_attribute]
397pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
398 server_fn::server_fn_impl("PATCH", attr, item)
399}
400
401#[proc_macro_attribute]
402pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
403 server_fn::server_fn_impl("DELETE", attr, item)
404}
405
406pub(crate) fn save_file(st_name: &str, output: &str) {
407 if option_env!("WRITE_OUTPUT").is_none() {
408 return;
409 }
410
411 let dir_path = match option_env!("API_MODEL_ARTIFACT_DIR") {
412 Some(dir) => dir.to_string(),
413 None => {
414 let current_dir = std::env::current_dir().unwrap();
415 format!(
416 "{}",
417 current_dir
418 .join(".build/generated_api_models")
419 .to_str()
420 .unwrap()
421 )
422 }
423 };
424 use convert_case::Casing;
425
426 let file_path = format!(
427 "{}/{}.rs",
428 dir_path,
429 st_name.to_case(convert_case::Case::Snake)
430 );
431
432 let dir = std::path::Path::new(&dir_path);
433
434 use std::fs;
435
436 if !dir.exists() {
437 if let Err(e) = fs::create_dir_all(dir) {
438 tracing::error!("Failed to create directory: {}", e);
439 }
440 }
441
442 if let Err(e) = fs::write(&file_path, output.to_string()) {
443 tracing::error!("Failed to write file: {}", e);
444 } else {
445 tracing::info!("generated code {} into {}", st_name, file_path);
446 }
447}