1use proc_macro::TokenStream;
54use quote::quote;
55use syn::{DeriveInput, LitStr, parse_macro_input};
56
57mod generators;
58mod plugins;
59mod schema_reader;
60mod types;
61
62use generators::{
63 generate_enum_module, generate_model_module, generate_type_module, generate_view_module,
64};
65use schema_reader::read_and_parse_schema;
66
67#[proc_macro]
93pub fn prax_schema(input: TokenStream) -> TokenStream {
94 let input = parse_macro_input!(input as LitStr);
95 let schema_path = input.value();
96
97 match generate_from_schema(&schema_path) {
98 Ok(tokens) => tokens.into(),
99 Err(err) => {
100 let err_msg = err.to_string();
101 quote! {
102 compile_error!(#err_msg);
103 }
104 .into()
105 }
106 }
107}
108
109#[proc_macro_derive(Model, attributes(prax))]
149pub fn derive_model(input: TokenStream) -> TokenStream {
150 let input = parse_macro_input!(input as DeriveInput);
151
152 match generators::derive_model_impl(&input) {
153 Ok(tokens) => tokens.into(),
154 Err(err) => err.to_compile_error().into(),
155 }
156}
157
158fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
160 use plugins::{PluginConfig, PluginContext, PluginRegistry};
161
162 let schema = read_and_parse_schema(schema_path).map_err(|e| {
164 syn::Error::new(
165 proc_macro2::Span::call_site(),
166 format!("Failed to parse schema: {}", e),
167 )
168 })?;
169
170 let plugin_config = PluginConfig::from_env();
172 let plugin_registry = PluginRegistry::with_builtins();
173 let plugin_ctx = PluginContext::new(&schema, &plugin_config);
174
175 let mut output = proc_macro2::TokenStream::new();
176
177 output.extend(generate_prelude());
179
180 let start_output = plugin_registry.run_start(&plugin_ctx);
182 output.extend(start_output.tokens);
183 output.extend(start_output.root_items);
184
185 for (_, enum_def) in &schema.enums {
187 output.extend(generate_enum_module(enum_def)?);
188
189 let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
191 if !plugin_output.is_empty() {
192 output.extend(plugin_output.tokens);
194 }
195 }
196
197 for (_, type_def) in &schema.types {
199 output.extend(generate_type_module(type_def)?);
200
201 let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
203 if !plugin_output.is_empty() {
204 output.extend(plugin_output.tokens);
205 }
206 }
207
208 for (_, view_def) in &schema.views {
210 output.extend(generate_view_module(view_def)?);
211
212 let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
214 if !plugin_output.is_empty() {
215 output.extend(plugin_output.tokens);
216 }
217 }
218
219 for (_, model_def) in &schema.models {
221 output.extend(generate_model_module(model_def, &schema)?);
222
223 let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
225 if !plugin_output.is_empty() {
226 output.extend(plugin_output.tokens);
227 }
228 }
229
230 let finish_output = plugin_registry.run_finish(&plugin_ctx);
232 output.extend(finish_output.tokens);
233 output.extend(finish_output.root_items);
234
235 output.extend(plugins::generate_plugin_docs(&plugin_registry));
237
238 Ok(output)
239}
240
241fn generate_prelude() -> proc_macro2::TokenStream {
243 quote! {
244 pub mod _prax_prelude {
246 pub use std::future::Future;
247 pub use std::pin::Pin;
248 pub use std::sync::Arc;
249
250 pub trait PraxModel {
252 const TABLE_NAME: &'static str;
254
255 const PRIMARY_KEY: &'static [&'static str];
257 }
258
259 pub trait ToSqlParam {
261 fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
263 }
264
265 #[derive(Debug, Clone, Default)]
267 pub struct Unset;
268
269 #[derive(Debug, Clone)]
271 pub enum SetParam<T> {
272 Set(T),
274 Unset,
276 }
277
278 impl<T> Default for SetParam<T> {
279 fn default() -> Self {
280 Self::Unset
281 }
282 }
283
284 impl<T> SetParam<T> {
285 pub fn is_set(&self) -> bool {
287 matches!(self, Self::Set(_))
288 }
289
290 pub fn get(&self) -> Option<&T> {
292 match self {
293 Self::Set(v) => Some(v),
294 Self::Unset => None,
295 }
296 }
297
298 pub fn take(self) -> Option<T> {
300 match self {
301 Self::Set(v) => Some(v),
302 Self::Unset => None,
303 }
304 }
305 }
306
307 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
309 pub enum SortOrder {
310 Asc,
312 Desc,
314 }
315
316 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
318 pub enum NullsOrder {
319 First,
321 Last,
323 }
324
325 #[derive(Debug, Clone)]
327 pub struct Cursor<T> {
328 pub value: T,
330 pub direction: CursorDirection,
332 }
333
334 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
336 pub enum CursorDirection {
337 After,
339 Before,
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_prelude_generation() {
352 let prelude = generate_prelude();
353 let code = prelude.to_string();
354
355 assert!(code.contains("pub mod _prax_prelude"));
356 assert!(code.contains("pub trait PraxModel"));
357 assert!(code.contains("pub enum SortOrder"));
358 assert!(code.contains("pub enum SetParam"));
359 }
360
361 #[test]
362 fn test_prelude_contains_table_name_const() {
363 let prelude = generate_prelude();
364 let code = prelude.to_string();
365
366 assert!(code.contains("TABLE_NAME"));
367 assert!(code.contains("PRIMARY_KEY"));
368 }
369
370 #[test]
371 fn test_prelude_contains_to_sql_param_trait() {
372 let prelude = generate_prelude();
373 let code = prelude.to_string();
374
375 assert!(code.contains("ToSqlParam"));
376 assert!(code.contains("to_sql_param"));
377 }
378
379 #[test]
380 fn test_prelude_contains_unset_type() {
381 let prelude = generate_prelude();
382 let code = prelude.to_string();
383
384 assert!(code.contains("pub struct Unset"));
385 }
386
387 #[test]
388 fn test_prelude_contains_set_param_methods() {
389 let prelude = generate_prelude();
390 let code = prelude.to_string();
391
392 assert!(code.contains("fn is_set"));
393 assert!(code.contains("fn get"));
394 assert!(code.contains("fn take"));
395 }
396
397 #[test]
398 fn test_prelude_contains_sort_order_variants() {
399 let prelude = generate_prelude();
400 let code = prelude.to_string();
401
402 assert!(code.contains("Asc"));
403 assert!(code.contains("Desc"));
404 }
405
406 #[test]
407 fn test_prelude_contains_nulls_order() {
408 let prelude = generate_prelude();
409 let code = prelude.to_string();
410
411 assert!(code.contains("pub enum NullsOrder"));
412 assert!(code.contains("First"));
413 assert!(code.contains("Last"));
414 }
415
416 #[test]
417 fn test_prelude_contains_cursor_types() {
418 let prelude = generate_prelude();
419 let code = prelude.to_string();
420
421 assert!(code.contains("pub struct Cursor"));
422 assert!(code.contains("pub enum CursorDirection"));
423 assert!(code.contains("After"));
424 assert!(code.contains("Before"));
425 }
426
427 #[test]
428 fn test_prelude_contains_std_imports() {
429 let prelude = generate_prelude();
430 let code = prelude.to_string();
431
432 assert!(code.contains("std :: future :: Future"));
433 assert!(code.contains("std :: pin :: Pin"));
434 assert!(code.contains("std :: sync :: Arc"));
435 }
436
437 #[test]
438 fn test_prelude_derive_macros() {
439 let prelude = generate_prelude();
440 let code = prelude.to_string();
441
442 assert!(code.contains("Clone"));
444 assert!(code.contains("Debug"));
445 }
446
447 #[test]
448 fn test_prelude_setparam_default_impl() {
449 let prelude = generate_prelude();
450 let code = prelude.to_string();
451
452 assert!(code.contains("impl < T > Default for SetParam"));
454 assert!(code.contains("Self :: Unset"));
455 }
456}