prax_codegen/
lib.rs

1//! Procedural macros for the Prax ORM.
2//!
3//! This crate provides compile-time code generation for Prax, transforming
4//! schema definitions into type-safe Rust code.
5//!
6//! # Macros
7//!
8//! - [`prax_schema!`] - Generate models from a `.prax` schema file
9//! - [`Model`] - Derive macro for manual model definition
10//!
11//! # Plugins
12//!
13//! Code generation can be extended with plugins enabled via environment variables:
14//!
15//! ```bash
16//! # Enable debug information
17//! PRAX_PLUGIN_DEBUG=1 cargo build
18//!
19//! # Enable JSON Schema generation
20//! PRAX_PLUGIN_JSON_SCHEMA=1 cargo build
21//!
22//! # Enable GraphQL SDL generation
23//! PRAX_PLUGIN_GRAPHQL=1 cargo build
24//!
25//! # Enable custom serialization helpers
26//! PRAX_PLUGIN_SERDE=1 cargo build
27//!
28//! # Enable runtime validation
29//! PRAX_PLUGIN_VALIDATOR=1 cargo build
30//!
31//! # Enable all plugins
32//! PRAX_PLUGINS_ALL=1 cargo build
33//! ```
34//!
35//! # Example
36//!
37//! ```rust,ignore
38//! // Generate models from schema file
39//! prax::prax_schema!("schema.prax");
40//!
41//! // Or manually define with derive macro
42//! #[derive(prax::Model)]
43//! #[prax(table = "users")]
44//! struct User {
45//!     #[prax(id, auto)]
46//!     id: i32,
47//!     #[prax(unique)]
48//!     email: String,
49//!     name: Option<String>,
50//! }
51//! ```
52
53use proc_macro::TokenStream;
54use quote::quote;
55use syn::{DeriveInput, LitStr, parse_macro_input};
56
57mod generators;
58mod plugins;
59mod schema_reader;
60mod types;
61
62use generators::{
63    generate_enum_module, generate_model_module_with_style, generate_type_module,
64    generate_view_module,
65};
66
67/// Generate models from a Prax schema file.
68///
69/// This macro reads a `.prax` schema file at compile time and generates
70/// type-safe Rust code for all models, enums, and types defined in the schema.
71///
72/// # Example
73///
74/// ```rust,ignore
75/// prax::prax_schema!("schema.prax");
76///
77/// // Now you can use the generated types:
78/// let user = client.user().find_unique(user::id::equals(1)).exec().await?;
79/// ```
80///
81/// # Generated Code
82///
83/// For each model in the schema, this macro generates:
84/// - A module with the model name (snake_case)
85/// - A `Data` struct representing a row from the database
86/// - A `CreateInput` struct for creating new records
87/// - A `UpdateInput` struct for updating records
88/// - Field modules with filter operations (`equals`, `contains`, `in_`, etc.)
89/// - A `WhereParam` enum for type-safe filtering
90/// - An `OrderByParam` enum for sorting
91/// - Select and Include builders for partial queries
92#[proc_macro]
93pub fn prax_schema(input: TokenStream) -> TokenStream {
94    let input = parse_macro_input!(input as LitStr);
95    let schema_path = input.value();
96
97    match generate_from_schema(&schema_path) {
98        Ok(tokens) => tokens.into(),
99        Err(err) => {
100            let err_msg = err.to_string();
101            quote! {
102                compile_error!(#err_msg);
103            }
104            .into()
105        }
106    }
107}
108
109/// Derive macro for defining Prax models manually.
110///
111/// This derive macro allows you to define models in Rust code instead of
112/// using a `.prax` schema file. It generates the same query builder methods
113/// and type-safe operations.
114///
115/// # Attributes
116///
117/// ## Struct-level
118/// - `#[prax(table = "table_name")]` - Map to a different table name
119/// - `#[prax(schema = "schema_name")]` - Specify database schema
120///
121/// ## Field-level
122/// - `#[prax(id)]` - Mark as primary key
123/// - `#[prax(auto)]` - Auto-increment field
124/// - `#[prax(unique)]` - Unique constraint
125/// - `#[prax(default = value)]` - Default value
126/// - `#[prax(column = "col_name")]` - Map to different column
127/// - `#[prax(relation(...))]` - Define relation
128///
129/// # Example
130///
131/// ```rust,ignore
132/// #[derive(prax::Model)]
133/// #[prax(table = "users")]
134/// struct User {
135///     #[prax(id, auto)]
136///     id: i32,
137///
138///     #[prax(unique)]
139///     email: String,
140///
141///     #[prax(column = "display_name")]
142///     name: Option<String>,
143///
144///     #[prax(default = "now()")]
145///     created_at: chrono::DateTime<chrono::Utc>,
146/// }
147/// ```
148#[proc_macro_derive(Model, attributes(prax))]
149pub fn derive_model(input: TokenStream) -> TokenStream {
150    let input = parse_macro_input!(input as DeriveInput);
151
152    match generators::derive_model_impl(&input) {
153        Ok(tokens) => tokens.into(),
154        Err(err) => err.to_compile_error().into(),
155    }
156}
157
158/// Internal function to generate code from a schema file.
159fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
160    use plugins::{PluginConfig, PluginContext, PluginRegistry};
161    use schema_reader::read_schema_with_config;
162
163    // Read and parse the schema file along with prax.toml configuration
164    let schema_with_config = read_schema_with_config(schema_path).map_err(|e| {
165        syn::Error::new(
166            proc_macro2::Span::call_site(),
167            format!("Failed to parse schema: {}", e),
168        )
169    })?;
170
171    let schema = schema_with_config.schema;
172    let model_style = schema_with_config.model_style;
173
174    // Initialize plugin system with model_style from prax.toml
175    // This auto-enables graphql plugins when model_style is GraphQL
176    let plugin_config = PluginConfig::with_model_style(model_style);
177    let plugin_registry = PluginRegistry::with_builtins();
178    let plugin_ctx = PluginContext::new(&schema, &plugin_config);
179
180    let mut output = proc_macro2::TokenStream::new();
181
182    // Generate prelude with common imports
183    output.extend(generate_prelude());
184
185    // Run plugin start hooks
186    let start_output = plugin_registry.run_start(&plugin_ctx);
187    output.extend(start_output.tokens);
188    output.extend(start_output.root_items);
189
190    // Generate enums first (models may reference them)
191    for (_, enum_def) in &schema.enums {
192        output.extend(generate_enum_module(enum_def)?);
193
194        // Run plugin enum hooks
195        let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
196        if !plugin_output.is_empty() {
197            // Add plugin output to the enum module
198            output.extend(plugin_output.tokens);
199        }
200    }
201
202    // Generate composite types
203    for (_, type_def) in &schema.types {
204        output.extend(generate_type_module(type_def)?);
205
206        // Run plugin type hooks
207        let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
208        if !plugin_output.is_empty() {
209            output.extend(plugin_output.tokens);
210        }
211    }
212
213    // Generate views
214    for (_, view_def) in &schema.views {
215        output.extend(generate_view_module(view_def)?);
216
217        // Run plugin view hooks
218        let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
219        if !plugin_output.is_empty() {
220            output.extend(plugin_output.tokens);
221        }
222    }
223
224    // Generate models with the configured model style
225    for (_, model_def) in &schema.models {
226        output.extend(generate_model_module_with_style(
227            model_def,
228            &schema,
229            model_style,
230        )?);
231
232        // Run plugin model hooks
233        let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
234        if !plugin_output.is_empty() {
235            output.extend(plugin_output.tokens);
236        }
237    }
238
239    // Run plugin finish hooks
240    let finish_output = plugin_registry.run_finish(&plugin_ctx);
241    output.extend(finish_output.tokens);
242    output.extend(finish_output.root_items);
243
244    // Generate plugin documentation
245    output.extend(plugins::generate_plugin_docs(&plugin_registry));
246
247    Ok(output)
248}
249
250/// Generate the prelude module with common types and imports.
251fn generate_prelude() -> proc_macro2::TokenStream {
252    quote! {
253        /// Common types used by generated Prax models.
254        pub mod _prax_prelude {
255            pub use std::future::Future;
256            pub use std::pin::Pin;
257            pub use std::sync::Arc;
258
259            /// Marker trait for Prax models.
260            pub trait PraxModel {
261                /// The table name in the database.
262                const TABLE_NAME: &'static str;
263
264                /// The primary key column(s).
265                const PRIMARY_KEY: &'static [&'static str];
266            }
267
268            /// Trait for types that can be converted to SQL parameters.
269            pub trait ToSqlParam {
270                /// Convert to a boxed SQL parameter.
271                fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
272            }
273
274            /// Marker for optional fields in queries.
275            #[derive(Debug, Clone, Default)]
276            pub struct Unset;
277
278            /// Set or unset field wrapper for updates.
279            #[derive(Debug, Clone)]
280            pub enum SetParam<T> {
281                /// Set the field to a value.
282                Set(T),
283                /// Leave the field unchanged.
284                Unset,
285            }
286
287            impl<T> Default for SetParam<T> {
288                fn default() -> Self {
289                    Self::Unset
290                }
291            }
292
293            impl<T> SetParam<T> {
294                /// Check if the value is set.
295                pub fn is_set(&self) -> bool {
296                    matches!(self, Self::Set(_))
297                }
298
299                /// Get the inner value if set.
300                pub fn get(&self) -> Option<&T> {
301                    match self {
302                        Self::Set(v) => Some(v),
303                        Self::Unset => None,
304                    }
305                }
306
307                /// Take the inner value if set.
308                pub fn take(self) -> Option<T> {
309                    match self {
310                        Self::Set(v) => Some(v),
311                        Self::Unset => None,
312                    }
313                }
314            }
315
316            /// Sort direction for order by clauses.
317            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
318            pub enum SortOrder {
319                /// Ascending order (A-Z, 0-9).
320                Asc,
321                /// Descending order (Z-A, 9-0).
322                Desc,
323            }
324
325            /// Null handling in sorting.
326            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
327            pub enum NullsOrder {
328                /// Nulls first in the result.
329                First,
330                /// Nulls last in the result.
331                Last,
332            }
333
334            /// Pagination cursor for cursor-based pagination.
335            #[derive(Debug, Clone)]
336            pub struct Cursor<T> {
337                /// The field value to start from.
338                pub value: T,
339                /// The direction of pagination.
340                pub direction: CursorDirection,
341            }
342
343            /// Direction for cursor-based pagination.
344            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
345            pub enum CursorDirection {
346                /// Get records after the cursor.
347                After,
348                /// Get records before the cursor.
349                Before,
350            }
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_prelude_generation() {
361        let prelude = generate_prelude();
362        let code = prelude.to_string();
363
364        assert!(code.contains("pub mod _prax_prelude"));
365        assert!(code.contains("pub trait PraxModel"));
366        assert!(code.contains("pub enum SortOrder"));
367        assert!(code.contains("pub enum SetParam"));
368    }
369
370    #[test]
371    fn test_prelude_contains_table_name_const() {
372        let prelude = generate_prelude();
373        let code = prelude.to_string();
374
375        assert!(code.contains("TABLE_NAME"));
376        assert!(code.contains("PRIMARY_KEY"));
377    }
378
379    #[test]
380    fn test_prelude_contains_to_sql_param_trait() {
381        let prelude = generate_prelude();
382        let code = prelude.to_string();
383
384        assert!(code.contains("ToSqlParam"));
385        assert!(code.contains("to_sql_param"));
386    }
387
388    #[test]
389    fn test_prelude_contains_unset_type() {
390        let prelude = generate_prelude();
391        let code = prelude.to_string();
392
393        assert!(code.contains("pub struct Unset"));
394    }
395
396    #[test]
397    fn test_prelude_contains_set_param_methods() {
398        let prelude = generate_prelude();
399        let code = prelude.to_string();
400
401        assert!(code.contains("fn is_set"));
402        assert!(code.contains("fn get"));
403        assert!(code.contains("fn take"));
404    }
405
406    #[test]
407    fn test_prelude_contains_sort_order_variants() {
408        let prelude = generate_prelude();
409        let code = prelude.to_string();
410
411        assert!(code.contains("Asc"));
412        assert!(code.contains("Desc"));
413    }
414
415    #[test]
416    fn test_prelude_contains_nulls_order() {
417        let prelude = generate_prelude();
418        let code = prelude.to_string();
419
420        assert!(code.contains("pub enum NullsOrder"));
421        assert!(code.contains("First"));
422        assert!(code.contains("Last"));
423    }
424
425    #[test]
426    fn test_prelude_contains_cursor_types() {
427        let prelude = generate_prelude();
428        let code = prelude.to_string();
429
430        assert!(code.contains("pub struct Cursor"));
431        assert!(code.contains("pub enum CursorDirection"));
432        assert!(code.contains("After"));
433        assert!(code.contains("Before"));
434    }
435
436    #[test]
437    fn test_prelude_contains_std_imports() {
438        let prelude = generate_prelude();
439        let code = prelude.to_string();
440
441        assert!(code.contains("std :: future :: Future"));
442        assert!(code.contains("std :: pin :: Pin"));
443        assert!(code.contains("std :: sync :: Arc"));
444    }
445
446    #[test]
447    fn test_prelude_derive_macros() {
448        let prelude = generate_prelude();
449        let code = prelude.to_string();
450
451        // SetParam should derive Clone and Debug
452        assert!(code.contains("Clone"));
453        assert!(code.contains("Debug"));
454    }
455
456    #[test]
457    fn test_prelude_setparam_default_impl() {
458        let prelude = generate_prelude();
459        let code = prelude.to_string();
460
461        // Should have Default implementation
462        assert!(code.contains("impl < T > Default for SetParam"));
463        assert!(code.contains("Self :: Unset"));
464    }
465}