prax_codegen/
lib.rs

1//! Procedural macros for the Prax ORM.
2//!
3//! This crate provides compile-time code generation for Prax, transforming
4//! schema definitions into type-safe Rust code.
5//!
6//! # Macros
7//!
8//! - [`prax_schema!`] - Generate models from a `.prax` schema file
9//! - [`Model`] - Derive macro for manual model definition
10//!
11//! # Plugins
12//!
13//! Code generation can be extended with plugins enabled via environment variables:
14//!
15//! ```bash
16//! # Enable debug information
17//! PRAX_PLUGIN_DEBUG=1 cargo build
18//!
19//! # Enable JSON Schema generation
20//! PRAX_PLUGIN_JSON_SCHEMA=1 cargo build
21//!
22//! # Enable GraphQL SDL generation
23//! PRAX_PLUGIN_GRAPHQL=1 cargo build
24//!
25//! # Enable custom serialization helpers
26//! PRAX_PLUGIN_SERDE=1 cargo build
27//!
28//! # Enable runtime validation
29//! PRAX_PLUGIN_VALIDATOR=1 cargo build
30//!
31//! # Enable all plugins
32//! PRAX_PLUGINS_ALL=1 cargo build
33//! ```
34//!
35//! # Example
36//!
37//! ```rust,ignore
38//! // Generate models from schema file
39//! prax::prax_schema!("schema.prax");
40//!
41//! // Or manually define with derive macro
42//! #[derive(prax::Model)]
43//! #[prax(table = "users")]
44//! struct User {
45//!     #[prax(id, auto)]
46//!     id: i32,
47//!     #[prax(unique)]
48//!     email: String,
49//!     name: Option<String>,
50//! }
51//! ```
52
53use proc_macro::TokenStream;
54use quote::quote;
55use syn::{DeriveInput, LitStr, parse_macro_input};
56
57mod generators;
58mod plugins;
59mod schema_reader;
60mod types;
61
62use generators::{
63    generate_enum_module, generate_model_module_with_style, generate_type_module,
64    generate_view_module,
65};
66
67/// Generate models from a Prax schema file.
68///
69/// This macro reads a `.prax` schema file at compile time and generates
70/// type-safe Rust code for all models, enums, and types defined in the schema.
71///
72/// # Example
73///
74/// ```rust,ignore
75/// prax::prax_schema!("schema.prax");
76///
77/// // Now you can use the generated types:
78/// let user = client.user().find_unique(user::id::equals(1)).exec().await?;
79/// ```
80///
81/// # Generated Code
82///
83/// For each model in the schema, this macro generates:
84/// - A module with the model name (snake_case)
85/// - A `Data` struct representing a row from the database
86/// - A `CreateInput` struct for creating new records
87/// - A `UpdateInput` struct for updating records
88/// - Field modules with filter operations (`equals`, `contains`, `in_`, etc.)
89/// - A `WhereParam` enum for type-safe filtering
90/// - An `OrderByParam` enum for sorting
91/// - Select and Include builders for partial queries
92#[proc_macro]
93pub fn prax_schema(input: TokenStream) -> TokenStream {
94    let input = parse_macro_input!(input as LitStr);
95    let schema_path = input.value();
96
97    match generate_from_schema(&schema_path) {
98        Ok(tokens) => tokens.into(),
99        Err(err) => {
100            let err_msg = err.to_string();
101            quote! {
102                compile_error!(#err_msg);
103            }
104            .into()
105        }
106    }
107}
108
109/// Derive macro for defining Prax models manually.
110///
111/// This derive macro allows you to define models in Rust code instead of
112/// using a `.prax` schema file. It generates the same query builder methods
113/// and type-safe operations.
114///
115/// # Attributes
116///
117/// ## Struct-level
118/// - `#[prax(table = "table_name")]` - Map to a different table name
119/// - `#[prax(schema = "schema_name")]` - Specify database schema
120///
121/// ## Field-level
122/// - `#[prax(id)]` - Mark as primary key
123/// - `#[prax(auto)]` - Auto-increment field
124/// - `#[prax(unique)]` - Unique constraint
125/// - `#[prax(default = value)]` - Default value
126/// - `#[prax(column = "col_name")]` - Map to different column
127/// - `#[prax(relation(...))]` - Define relation
128///
129/// # Example
130///
131/// ```rust,ignore
132/// #[derive(prax::Model)]
133/// #[prax(table = "users")]
134/// struct User {
135///     #[prax(id, auto)]
136///     id: i32,
137///
138///     #[prax(unique)]
139///     email: String,
140///
141///     #[prax(column = "display_name")]
142///     name: Option<String>,
143///
144///     #[prax(default = "now()")]
145///     created_at: chrono::DateTime<chrono::Utc>,
146/// }
147/// ```
148#[proc_macro_derive(Model, attributes(prax))]
149pub fn derive_model(input: TokenStream) -> TokenStream {
150    let input = parse_macro_input!(input as DeriveInput);
151
152    match generators::derive_model_impl(&input) {
153        Ok(tokens) => tokens.into(),
154        Err(err) => err.to_compile_error().into(),
155    }
156}
157
158/// Internal function to generate code from a schema file.
159fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
160    use plugins::{PluginConfig, PluginContext, PluginRegistry};
161    use schema_reader::read_schema_with_config;
162
163    // Read and parse the schema file along with prax.toml configuration
164    let schema_with_config = read_schema_with_config(schema_path).map_err(|e| {
165        syn::Error::new(
166            proc_macro2::Span::call_site(),
167            format!("Failed to parse schema: {}", e),
168        )
169    })?;
170
171    let schema = schema_with_config.schema;
172    let model_style = schema_with_config.model_style;
173
174    // Initialize plugin system with model_style from prax.toml
175    // This auto-enables graphql plugins when model_style is GraphQL
176    let plugin_config = PluginConfig::with_model_style(model_style);
177    let plugin_registry = PluginRegistry::with_builtins();
178    let plugin_ctx = PluginContext::new(&schema, &plugin_config);
179
180    let mut output = proc_macro2::TokenStream::new();
181
182    // Generate prelude with common imports
183    output.extend(generate_prelude());
184
185    // Run plugin start hooks
186    let start_output = plugin_registry.run_start(&plugin_ctx);
187    output.extend(start_output.tokens);
188    output.extend(start_output.root_items);
189
190    // Generate enums first (models may reference them)
191    for (_, enum_def) in &schema.enums {
192        output.extend(generate_enum_module(enum_def)?);
193
194        // Run plugin enum hooks
195        let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
196        if !plugin_output.is_empty() {
197            // Add plugin output to the enum module
198            output.extend(plugin_output.tokens);
199        }
200    }
201
202    // Generate composite types
203    for (_, type_def) in &schema.types {
204        output.extend(generate_type_module(type_def)?);
205
206        // Run plugin type hooks
207        let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
208        if !plugin_output.is_empty() {
209            output.extend(plugin_output.tokens);
210        }
211    }
212
213    // Generate views
214    for (_, view_def) in &schema.views {
215        output.extend(generate_view_module(view_def)?);
216
217        // Run plugin view hooks
218        let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
219        if !plugin_output.is_empty() {
220            output.extend(plugin_output.tokens);
221        }
222    }
223
224    // Generate models with the configured model style
225    for (_, model_def) in &schema.models {
226        output.extend(generate_model_module_with_style(model_def, &schema, model_style)?);
227
228        // Run plugin model hooks
229        let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
230        if !plugin_output.is_empty() {
231            output.extend(plugin_output.tokens);
232        }
233    }
234
235    // Run plugin finish hooks
236    let finish_output = plugin_registry.run_finish(&plugin_ctx);
237    output.extend(finish_output.tokens);
238    output.extend(finish_output.root_items);
239
240    // Generate plugin documentation
241    output.extend(plugins::generate_plugin_docs(&plugin_registry));
242
243    Ok(output)
244}
245
246/// Generate the prelude module with common types and imports.
247fn generate_prelude() -> proc_macro2::TokenStream {
248    quote! {
249        /// Common types used by generated Prax models.
250        pub mod _prax_prelude {
251            pub use std::future::Future;
252            pub use std::pin::Pin;
253            pub use std::sync::Arc;
254
255            /// Marker trait for Prax models.
256            pub trait PraxModel {
257                /// The table name in the database.
258                const TABLE_NAME: &'static str;
259
260                /// The primary key column(s).
261                const PRIMARY_KEY: &'static [&'static str];
262            }
263
264            /// Trait for types that can be converted to SQL parameters.
265            pub trait ToSqlParam {
266                /// Convert to a boxed SQL parameter.
267                fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
268            }
269
270            /// Marker for optional fields in queries.
271            #[derive(Debug, Clone, Default)]
272            pub struct Unset;
273
274            /// Set or unset field wrapper for updates.
275            #[derive(Debug, Clone)]
276            pub enum SetParam<T> {
277                /// Set the field to a value.
278                Set(T),
279                /// Leave the field unchanged.
280                Unset,
281            }
282
283            impl<T> Default for SetParam<T> {
284                fn default() -> Self {
285                    Self::Unset
286                }
287            }
288
289            impl<T> SetParam<T> {
290                /// Check if the value is set.
291                pub fn is_set(&self) -> bool {
292                    matches!(self, Self::Set(_))
293                }
294
295                /// Get the inner value if set.
296                pub fn get(&self) -> Option<&T> {
297                    match self {
298                        Self::Set(v) => Some(v),
299                        Self::Unset => None,
300                    }
301                }
302
303                /// Take the inner value if set.
304                pub fn take(self) -> Option<T> {
305                    match self {
306                        Self::Set(v) => Some(v),
307                        Self::Unset => None,
308                    }
309                }
310            }
311
312            /// Sort direction for order by clauses.
313            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
314            pub enum SortOrder {
315                /// Ascending order (A-Z, 0-9).
316                Asc,
317                /// Descending order (Z-A, 9-0).
318                Desc,
319            }
320
321            /// Null handling in sorting.
322            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
323            pub enum NullsOrder {
324                /// Nulls first in the result.
325                First,
326                /// Nulls last in the result.
327                Last,
328            }
329
330            /// Pagination cursor for cursor-based pagination.
331            #[derive(Debug, Clone)]
332            pub struct Cursor<T> {
333                /// The field value to start from.
334                pub value: T,
335                /// The direction of pagination.
336                pub direction: CursorDirection,
337            }
338
339            /// Direction for cursor-based pagination.
340            #[derive(Debug, Clone, Copy, PartialEq, Eq)]
341            pub enum CursorDirection {
342                /// Get records after the cursor.
343                After,
344                /// Get records before the cursor.
345                Before,
346            }
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_prelude_generation() {
357        let prelude = generate_prelude();
358        let code = prelude.to_string();
359
360        assert!(code.contains("pub mod _prax_prelude"));
361        assert!(code.contains("pub trait PraxModel"));
362        assert!(code.contains("pub enum SortOrder"));
363        assert!(code.contains("pub enum SetParam"));
364    }
365
366    #[test]
367    fn test_prelude_contains_table_name_const() {
368        let prelude = generate_prelude();
369        let code = prelude.to_string();
370
371        assert!(code.contains("TABLE_NAME"));
372        assert!(code.contains("PRIMARY_KEY"));
373    }
374
375    #[test]
376    fn test_prelude_contains_to_sql_param_trait() {
377        let prelude = generate_prelude();
378        let code = prelude.to_string();
379
380        assert!(code.contains("ToSqlParam"));
381        assert!(code.contains("to_sql_param"));
382    }
383
384    #[test]
385    fn test_prelude_contains_unset_type() {
386        let prelude = generate_prelude();
387        let code = prelude.to_string();
388
389        assert!(code.contains("pub struct Unset"));
390    }
391
392    #[test]
393    fn test_prelude_contains_set_param_methods() {
394        let prelude = generate_prelude();
395        let code = prelude.to_string();
396
397        assert!(code.contains("fn is_set"));
398        assert!(code.contains("fn get"));
399        assert!(code.contains("fn take"));
400    }
401
402    #[test]
403    fn test_prelude_contains_sort_order_variants() {
404        let prelude = generate_prelude();
405        let code = prelude.to_string();
406
407        assert!(code.contains("Asc"));
408        assert!(code.contains("Desc"));
409    }
410
411    #[test]
412    fn test_prelude_contains_nulls_order() {
413        let prelude = generate_prelude();
414        let code = prelude.to_string();
415
416        assert!(code.contains("pub enum NullsOrder"));
417        assert!(code.contains("First"));
418        assert!(code.contains("Last"));
419    }
420
421    #[test]
422    fn test_prelude_contains_cursor_types() {
423        let prelude = generate_prelude();
424        let code = prelude.to_string();
425
426        assert!(code.contains("pub struct Cursor"));
427        assert!(code.contains("pub enum CursorDirection"));
428        assert!(code.contains("After"));
429        assert!(code.contains("Before"));
430    }
431
432    #[test]
433    fn test_prelude_contains_std_imports() {
434        let prelude = generate_prelude();
435        let code = prelude.to_string();
436
437        assert!(code.contains("std :: future :: Future"));
438        assert!(code.contains("std :: pin :: Pin"));
439        assert!(code.contains("std :: sync :: Arc"));
440    }
441
442    #[test]
443    fn test_prelude_derive_macros() {
444        let prelude = generate_prelude();
445        let code = prelude.to_string();
446
447        // SetParam should derive Clone and Debug
448        assert!(code.contains("Clone"));
449        assert!(code.contains("Debug"));
450    }
451
452    #[test]
453    fn test_prelude_setparam_default_impl() {
454        let prelude = generate_prelude();
455        let code = prelude.to_string();
456
457        // Should have Default implementation
458        assert!(code.contains("impl < T > Default for SetParam"));
459        assert!(code.contains("Self :: Unset"));
460    }
461}