1use proc_macro::TokenStream;
54use quote::quote;
55use syn::{DeriveInput, LitStr, parse_macro_input};
56
57mod generators;
58mod plugins;
59mod schema_reader;
60mod types;
61
62use generators::{
63 generate_enum_module, generate_model_module_with_style, generate_type_module,
64 generate_view_module,
65};
66
67#[proc_macro]
93pub fn prax_schema(input: TokenStream) -> TokenStream {
94 let input = parse_macro_input!(input as LitStr);
95 let schema_path = input.value();
96
97 match generate_from_schema(&schema_path) {
98 Ok(tokens) => tokens.into(),
99 Err(err) => {
100 let err_msg = err.to_string();
101 quote! {
102 compile_error!(#err_msg);
103 }
104 .into()
105 }
106 }
107}
108
109#[proc_macro_derive(Model, attributes(prax))]
149pub fn derive_model(input: TokenStream) -> TokenStream {
150 let input = parse_macro_input!(input as DeriveInput);
151
152 match generators::derive_model_impl(&input) {
153 Ok(tokens) => tokens.into(),
154 Err(err) => err.to_compile_error().into(),
155 }
156}
157
158fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
160 use plugins::{PluginConfig, PluginContext, PluginRegistry};
161 use schema_reader::read_schema_with_config;
162
163 let schema_with_config = read_schema_with_config(schema_path).map_err(|e| {
165 syn::Error::new(
166 proc_macro2::Span::call_site(),
167 format!("Failed to parse schema: {}", e),
168 )
169 })?;
170
171 let schema = schema_with_config.schema;
172 let model_style = schema_with_config.model_style;
173
174 let plugin_config = PluginConfig::with_model_style(model_style);
177 let plugin_registry = PluginRegistry::with_builtins();
178 let plugin_ctx = PluginContext::new(&schema, &plugin_config);
179
180 let mut output = proc_macro2::TokenStream::new();
181
182 output.extend(generate_prelude());
184
185 let start_output = plugin_registry.run_start(&plugin_ctx);
187 output.extend(start_output.tokens);
188 output.extend(start_output.root_items);
189
190 for (_, enum_def) in &schema.enums {
192 output.extend(generate_enum_module(enum_def)?);
193
194 let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
196 if !plugin_output.is_empty() {
197 output.extend(plugin_output.tokens);
199 }
200 }
201
202 for (_, type_def) in &schema.types {
204 output.extend(generate_type_module(type_def)?);
205
206 let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
208 if !plugin_output.is_empty() {
209 output.extend(plugin_output.tokens);
210 }
211 }
212
213 for (_, view_def) in &schema.views {
215 output.extend(generate_view_module(view_def)?);
216
217 let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
219 if !plugin_output.is_empty() {
220 output.extend(plugin_output.tokens);
221 }
222 }
223
224 for (_, model_def) in &schema.models {
226 output.extend(generate_model_module_with_style(
227 model_def,
228 &schema,
229 model_style,
230 )?);
231
232 let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
234 if !plugin_output.is_empty() {
235 output.extend(plugin_output.tokens);
236 }
237 }
238
239 let finish_output = plugin_registry.run_finish(&plugin_ctx);
241 output.extend(finish_output.tokens);
242 output.extend(finish_output.root_items);
243
244 output.extend(plugins::generate_plugin_docs(&plugin_registry));
246
247 Ok(output)
248}
249
250fn generate_prelude() -> proc_macro2::TokenStream {
252 quote! {
253 pub mod _prax_prelude {
255 pub use std::future::Future;
256 pub use std::pin::Pin;
257 pub use std::sync::Arc;
258
259 pub trait PraxModel {
261 const TABLE_NAME: &'static str;
263
264 const PRIMARY_KEY: &'static [&'static str];
266 }
267
268 pub trait ToSqlParam {
270 fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
272 }
273
274 #[derive(Debug, Clone, Default)]
276 pub struct Unset;
277
278 #[derive(Debug, Clone)]
280 pub enum SetParam<T> {
281 Set(T),
283 Unset,
285 }
286
287 impl<T> Default for SetParam<T> {
288 fn default() -> Self {
289 Self::Unset
290 }
291 }
292
293 impl<T> SetParam<T> {
294 pub fn is_set(&self) -> bool {
296 matches!(self, Self::Set(_))
297 }
298
299 pub fn get(&self) -> Option<&T> {
301 match self {
302 Self::Set(v) => Some(v),
303 Self::Unset => None,
304 }
305 }
306
307 pub fn take(self) -> Option<T> {
309 match self {
310 Self::Set(v) => Some(v),
311 Self::Unset => None,
312 }
313 }
314 }
315
316 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
318 pub enum SortOrder {
319 Asc,
321 Desc,
323 }
324
325 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
327 pub enum NullsOrder {
328 First,
330 Last,
332 }
333
334 #[derive(Debug, Clone)]
336 pub struct Cursor<T> {
337 pub value: T,
339 pub direction: CursorDirection,
341 }
342
343 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
345 pub enum CursorDirection {
346 After,
348 Before,
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_prelude_generation() {
361 let prelude = generate_prelude();
362 let code = prelude.to_string();
363
364 assert!(code.contains("pub mod _prax_prelude"));
365 assert!(code.contains("pub trait PraxModel"));
366 assert!(code.contains("pub enum SortOrder"));
367 assert!(code.contains("pub enum SetParam"));
368 }
369
370 #[test]
371 fn test_prelude_contains_table_name_const() {
372 let prelude = generate_prelude();
373 let code = prelude.to_string();
374
375 assert!(code.contains("TABLE_NAME"));
376 assert!(code.contains("PRIMARY_KEY"));
377 }
378
379 #[test]
380 fn test_prelude_contains_to_sql_param_trait() {
381 let prelude = generate_prelude();
382 let code = prelude.to_string();
383
384 assert!(code.contains("ToSqlParam"));
385 assert!(code.contains("to_sql_param"));
386 }
387
388 #[test]
389 fn test_prelude_contains_unset_type() {
390 let prelude = generate_prelude();
391 let code = prelude.to_string();
392
393 assert!(code.contains("pub struct Unset"));
394 }
395
396 #[test]
397 fn test_prelude_contains_set_param_methods() {
398 let prelude = generate_prelude();
399 let code = prelude.to_string();
400
401 assert!(code.contains("fn is_set"));
402 assert!(code.contains("fn get"));
403 assert!(code.contains("fn take"));
404 }
405
406 #[test]
407 fn test_prelude_contains_sort_order_variants() {
408 let prelude = generate_prelude();
409 let code = prelude.to_string();
410
411 assert!(code.contains("Asc"));
412 assert!(code.contains("Desc"));
413 }
414
415 #[test]
416 fn test_prelude_contains_nulls_order() {
417 let prelude = generate_prelude();
418 let code = prelude.to_string();
419
420 assert!(code.contains("pub enum NullsOrder"));
421 assert!(code.contains("First"));
422 assert!(code.contains("Last"));
423 }
424
425 #[test]
426 fn test_prelude_contains_cursor_types() {
427 let prelude = generate_prelude();
428 let code = prelude.to_string();
429
430 assert!(code.contains("pub struct Cursor"));
431 assert!(code.contains("pub enum CursorDirection"));
432 assert!(code.contains("After"));
433 assert!(code.contains("Before"));
434 }
435
436 #[test]
437 fn test_prelude_contains_std_imports() {
438 let prelude = generate_prelude();
439 let code = prelude.to_string();
440
441 assert!(code.contains("std :: future :: Future"));
442 assert!(code.contains("std :: pin :: Pin"));
443 assert!(code.contains("std :: sync :: Arc"));
444 }
445
446 #[test]
447 fn test_prelude_derive_macros() {
448 let prelude = generate_prelude();
449 let code = prelude.to_string();
450
451 assert!(code.contains("Clone"));
453 assert!(code.contains("Debug"));
454 }
455
456 #[test]
457 fn test_prelude_setparam_default_impl() {
458 let prelude = generate_prelude();
459 let code = prelude.to_string();
460
461 assert!(code.contains("impl < T > Default for SetParam"));
463 assert!(code.contains("Self :: Unset"));
464 }
465}