1use proc_macro::TokenStream;
54use quote::quote;
55use syn::{DeriveInput, LitStr, parse_macro_input};
56
57mod generators;
58mod plugins;
59mod schema_reader;
60mod types;
61
62use generators::{
63 generate_enum_module, generate_model_module_with_style, generate_type_module,
64 generate_view_module,
65};
66
67#[proc_macro]
93pub fn prax_schema(input: TokenStream) -> TokenStream {
94 let input = parse_macro_input!(input as LitStr);
95 let schema_path = input.value();
96
97 match generate_from_schema(&schema_path) {
98 Ok(tokens) => tokens.into(),
99 Err(err) => {
100 let err_msg = err.to_string();
101 quote! {
102 compile_error!(#err_msg);
103 }
104 .into()
105 }
106 }
107}
108
109#[proc_macro_derive(Model, attributes(prax))]
149pub fn derive_model(input: TokenStream) -> TokenStream {
150 let input = parse_macro_input!(input as DeriveInput);
151
152 match generators::derive_model_impl(&input) {
153 Ok(tokens) => tokens.into(),
154 Err(err) => err.to_compile_error().into(),
155 }
156}
157
158fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
160 use plugins::{PluginConfig, PluginContext, PluginRegistry};
161 use schema_reader::read_schema_with_config;
162
163 let schema_with_config = read_schema_with_config(schema_path).map_err(|e| {
165 syn::Error::new(
166 proc_macro2::Span::call_site(),
167 format!("Failed to parse schema: {}", e),
168 )
169 })?;
170
171 let schema = schema_with_config.schema;
172 let model_style = schema_with_config.model_style;
173
174 let plugin_config = PluginConfig::with_model_style(model_style);
177 let plugin_registry = PluginRegistry::with_builtins();
178 let plugin_ctx = PluginContext::new(&schema, &plugin_config);
179
180 let mut output = proc_macro2::TokenStream::new();
181
182 output.extend(generate_prelude());
184
185 let start_output = plugin_registry.run_start(&plugin_ctx);
187 output.extend(start_output.tokens);
188 output.extend(start_output.root_items);
189
190 for (_, enum_def) in &schema.enums {
192 output.extend(generate_enum_module(enum_def)?);
193
194 let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
196 if !plugin_output.is_empty() {
197 output.extend(plugin_output.tokens);
199 }
200 }
201
202 for (_, type_def) in &schema.types {
204 output.extend(generate_type_module(type_def)?);
205
206 let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
208 if !plugin_output.is_empty() {
209 output.extend(plugin_output.tokens);
210 }
211 }
212
213 for (_, view_def) in &schema.views {
215 output.extend(generate_view_module(view_def)?);
216
217 let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
219 if !plugin_output.is_empty() {
220 output.extend(plugin_output.tokens);
221 }
222 }
223
224 for (_, model_def) in &schema.models {
226 output.extend(generate_model_module_with_style(model_def, &schema, model_style)?);
227
228 let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
230 if !plugin_output.is_empty() {
231 output.extend(plugin_output.tokens);
232 }
233 }
234
235 let finish_output = plugin_registry.run_finish(&plugin_ctx);
237 output.extend(finish_output.tokens);
238 output.extend(finish_output.root_items);
239
240 output.extend(plugins::generate_plugin_docs(&plugin_registry));
242
243 Ok(output)
244}
245
246fn generate_prelude() -> proc_macro2::TokenStream {
248 quote! {
249 pub mod _prax_prelude {
251 pub use std::future::Future;
252 pub use std::pin::Pin;
253 pub use std::sync::Arc;
254
255 pub trait PraxModel {
257 const TABLE_NAME: &'static str;
259
260 const PRIMARY_KEY: &'static [&'static str];
262 }
263
264 pub trait ToSqlParam {
266 fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
268 }
269
270 #[derive(Debug, Clone, Default)]
272 pub struct Unset;
273
274 #[derive(Debug, Clone)]
276 pub enum SetParam<T> {
277 Set(T),
279 Unset,
281 }
282
283 impl<T> Default for SetParam<T> {
284 fn default() -> Self {
285 Self::Unset
286 }
287 }
288
289 impl<T> SetParam<T> {
290 pub fn is_set(&self) -> bool {
292 matches!(self, Self::Set(_))
293 }
294
295 pub fn get(&self) -> Option<&T> {
297 match self {
298 Self::Set(v) => Some(v),
299 Self::Unset => None,
300 }
301 }
302
303 pub fn take(self) -> Option<T> {
305 match self {
306 Self::Set(v) => Some(v),
307 Self::Unset => None,
308 }
309 }
310 }
311
312 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
314 pub enum SortOrder {
315 Asc,
317 Desc,
319 }
320
321 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
323 pub enum NullsOrder {
324 First,
326 Last,
328 }
329
330 #[derive(Debug, Clone)]
332 pub struct Cursor<T> {
333 pub value: T,
335 pub direction: CursorDirection,
337 }
338
339 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
341 pub enum CursorDirection {
342 After,
344 Before,
346 }
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_prelude_generation() {
357 let prelude = generate_prelude();
358 let code = prelude.to_string();
359
360 assert!(code.contains("pub mod _prax_prelude"));
361 assert!(code.contains("pub trait PraxModel"));
362 assert!(code.contains("pub enum SortOrder"));
363 assert!(code.contains("pub enum SetParam"));
364 }
365
366 #[test]
367 fn test_prelude_contains_table_name_const() {
368 let prelude = generate_prelude();
369 let code = prelude.to_string();
370
371 assert!(code.contains("TABLE_NAME"));
372 assert!(code.contains("PRIMARY_KEY"));
373 }
374
375 #[test]
376 fn test_prelude_contains_to_sql_param_trait() {
377 let prelude = generate_prelude();
378 let code = prelude.to_string();
379
380 assert!(code.contains("ToSqlParam"));
381 assert!(code.contains("to_sql_param"));
382 }
383
384 #[test]
385 fn test_prelude_contains_unset_type() {
386 let prelude = generate_prelude();
387 let code = prelude.to_string();
388
389 assert!(code.contains("pub struct Unset"));
390 }
391
392 #[test]
393 fn test_prelude_contains_set_param_methods() {
394 let prelude = generate_prelude();
395 let code = prelude.to_string();
396
397 assert!(code.contains("fn is_set"));
398 assert!(code.contains("fn get"));
399 assert!(code.contains("fn take"));
400 }
401
402 #[test]
403 fn test_prelude_contains_sort_order_variants() {
404 let prelude = generate_prelude();
405 let code = prelude.to_string();
406
407 assert!(code.contains("Asc"));
408 assert!(code.contains("Desc"));
409 }
410
411 #[test]
412 fn test_prelude_contains_nulls_order() {
413 let prelude = generate_prelude();
414 let code = prelude.to_string();
415
416 assert!(code.contains("pub enum NullsOrder"));
417 assert!(code.contains("First"));
418 assert!(code.contains("Last"));
419 }
420
421 #[test]
422 fn test_prelude_contains_cursor_types() {
423 let prelude = generate_prelude();
424 let code = prelude.to_string();
425
426 assert!(code.contains("pub struct Cursor"));
427 assert!(code.contains("pub enum CursorDirection"));
428 assert!(code.contains("After"));
429 assert!(code.contains("Before"));
430 }
431
432 #[test]
433 fn test_prelude_contains_std_imports() {
434 let prelude = generate_prelude();
435 let code = prelude.to_string();
436
437 assert!(code.contains("std :: future :: Future"));
438 assert!(code.contains("std :: pin :: Pin"));
439 assert!(code.contains("std :: sync :: Arc"));
440 }
441
442 #[test]
443 fn test_prelude_derive_macros() {
444 let prelude = generate_prelude();
445 let code = prelude.to_string();
446
447 assert!(code.contains("Clone"));
449 assert!(code.contains("Debug"));
450 }
451
452 #[test]
453 fn test_prelude_setparam_default_impl() {
454 let prelude = generate_prelude();
455 let code = prelude.to_string();
456
457 assert!(code.contains("impl < T > Default for SetParam"));
459 assert!(code.contains("Self :: Unset"));
460 }
461}