use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, LitStr, parse_macro_input};
mod generators;
mod plugins;
mod schema_reader;
mod types;
use generators::{
generate_enum_module, generate_model_module_with_style, generate_type_module,
generate_view_module,
};
#[proc_macro]
pub fn prax_schema(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as LitStr);
let schema_path = input.value();
match generate_from_schema(&schema_path) {
Ok(tokens) => tokens.into(),
Err(err) => {
let err_msg = err.to_string();
quote! {
compile_error!(#err_msg);
}
.into()
}
}
}
#[proc_macro_derive(Model, attributes(prax))]
pub fn derive_model(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match generators::derive_model_impl(&input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn generate_from_schema(schema_path: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
use plugins::{PluginConfig, PluginContext, PluginRegistry};
use schema_reader::read_schema_with_config;
let schema_with_config = read_schema_with_config(schema_path).map_err(|e| {
syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse schema: {}", e),
)
})?;
let schema = schema_with_config.schema;
let model_style = schema_with_config.model_style;
let plugin_config = PluginConfig::with_model_style(model_style);
let plugin_registry = PluginRegistry::with_builtins();
let plugin_ctx = PluginContext::new(&schema, &plugin_config);
let mut output = proc_macro2::TokenStream::new();
output.extend(generate_prelude());
let start_output = plugin_registry.run_start(&plugin_ctx);
output.extend(start_output.tokens);
output.extend(start_output.root_items);
for (_, enum_def) in &schema.enums {
output.extend(generate_enum_module(enum_def)?);
let plugin_output = plugin_registry.run_enum(&plugin_ctx, enum_def);
if !plugin_output.is_empty() {
output.extend(plugin_output.tokens);
}
}
for (_, type_def) in &schema.types {
output.extend(generate_type_module(type_def)?);
let plugin_output = plugin_registry.run_type(&plugin_ctx, type_def);
if !plugin_output.is_empty() {
output.extend(plugin_output.tokens);
}
}
for (_, view_def) in &schema.views {
output.extend(generate_view_module(view_def)?);
let plugin_output = plugin_registry.run_view(&plugin_ctx, view_def);
if !plugin_output.is_empty() {
output.extend(plugin_output.tokens);
}
}
for (_, model_def) in &schema.models {
output.extend(generate_model_module_with_style(
model_def,
&schema,
model_style,
)?);
let plugin_output = plugin_registry.run_model(&plugin_ctx, model_def);
if !plugin_output.is_empty() {
output.extend(plugin_output.tokens);
}
}
let finish_output = plugin_registry.run_finish(&plugin_ctx);
output.extend(finish_output.tokens);
output.extend(finish_output.root_items);
output.extend(plugins::generate_plugin_docs(&plugin_registry));
Ok(output)
}
fn generate_prelude() -> proc_macro2::TokenStream {
quote! {
pub mod _prax_prelude {
pub use std::future::Future;
pub use std::pin::Pin;
pub use std::sync::Arc;
pub trait PraxModel {
const TABLE_NAME: &'static str;
const PRIMARY_KEY: &'static [&'static str];
}
pub trait ToSqlParam {
fn to_sql_param(&self) -> Box<dyn std::any::Any + Send + Sync>;
}
#[derive(Debug, Clone, Default)]
pub struct Unset;
#[derive(Debug, Clone)]
pub enum SetParam<T> {
Set(T),
Unset,
}
impl<T> Default for SetParam<T> {
fn default() -> Self {
Self::Unset
}
}
impl<T> SetParam<T> {
pub fn is_set(&self) -> bool {
matches!(self, Self::Set(_))
}
pub fn get(&self) -> Option<&T> {
match self {
Self::Set(v) => Some(v),
Self::Unset => None,
}
}
pub fn take(self) -> Option<T> {
match self {
Self::Set(v) => Some(v),
Self::Unset => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortOrder {
Asc,
Desc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NullsOrder {
First,
Last,
}
#[derive(Debug, Clone)]
pub struct Cursor<T> {
pub value: T,
pub direction: CursorDirection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CursorDirection {
After,
Before,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prelude_generation() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("pub mod _prax_prelude"));
assert!(code.contains("pub trait PraxModel"));
assert!(code.contains("pub enum SortOrder"));
assert!(code.contains("pub enum SetParam"));
}
#[test]
fn test_prelude_contains_table_name_const() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("TABLE_NAME"));
assert!(code.contains("PRIMARY_KEY"));
}
#[test]
fn test_prelude_contains_to_sql_param_trait() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("ToSqlParam"));
assert!(code.contains("to_sql_param"));
}
#[test]
fn test_prelude_contains_unset_type() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("pub struct Unset"));
}
#[test]
fn test_prelude_contains_set_param_methods() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("fn is_set"));
assert!(code.contains("fn get"));
assert!(code.contains("fn take"));
}
#[test]
fn test_prelude_contains_sort_order_variants() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("Asc"));
assert!(code.contains("Desc"));
}
#[test]
fn test_prelude_contains_nulls_order() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("pub enum NullsOrder"));
assert!(code.contains("First"));
assert!(code.contains("Last"));
}
#[test]
fn test_prelude_contains_cursor_types() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("pub struct Cursor"));
assert!(code.contains("pub enum CursorDirection"));
assert!(code.contains("After"));
assert!(code.contains("Before"));
}
#[test]
fn test_prelude_contains_std_imports() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("std :: future :: Future"));
assert!(code.contains("std :: pin :: Pin"));
assert!(code.contains("std :: sync :: Arc"));
}
#[test]
fn test_prelude_derive_macros() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("Clone"));
assert!(code.contains("Debug"));
}
#[test]
fn test_prelude_setparam_default_impl() {
let prelude = generate_prelude();
let code = prelude.to_string();
assert!(code.contains("impl < T > Default for SetParam"));
assert!(code.contains("Self :: Unset"));
}
}