use std::collections::HashMap;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, GenericArgument, PathArguments, Result, Type, parse_quote};
use syn::{Ident, LitStr, bracketed, parenthesized};
use crate::crate_paths::{
get_linkme_crate, get_reinhardt_core_crate, get_reinhardt_crate,
get_reinhardt_migrations_crate, get_reinhardt_orm_crate,
};
use crate::rel::RelAttribute;
#[derive(Debug, Clone)]
enum ConstraintSpec {
Unique {
fields: Vec<String>,
name: Option<String>,
condition: Option<String>,
},
}
struct ModelAttributesParsed {
app_label: Option<String>,
table_name: Option<String>,
constraints: Option<Vec<ConstraintSpec>>,
unique_together: Vec<Vec<String>>, manager: Option<syn::Path>,
}
fn validate_sql_expression(sql: &str, attr_name: &str) -> Result<()> {
let upper = sql.to_uppercase();
if sql.contains(';') {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"Semicolons are not allowed in {} expressions: {:?}",
attr_name, sql
),
));
}
const BLOCKED_KEYWORDS: &[&str] = &[
"DROP ",
"DELETE ",
"INSERT ",
"UPDATE ",
"ALTER ",
"TRUNCATE ",
"EXEC ",
"EXECUTE ",
"CREATE ",
"GRANT ",
"REVOKE ",
];
for keyword in BLOCKED_KEYWORDS {
if upper.contains(keyword) {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"Dangerous SQL keyword {:?} detected in {} expression: {:?}",
keyword.trim(),
attr_name,
sql
),
));
}
}
if sql.contains("--") || sql.contains("/*") {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"SQL comments are not allowed in {} expressions: {:?}",
attr_name, sql
),
));
}
Ok(())
}
#[derive(Debug, Clone)]
struct ModelConfig {
app_label: String,
table_name: String,
constraints: Vec<ConstraintSpec>,
manager: Option<syn::Path>,
}
impl ModelConfig {
fn from_attrs(attrs: &[syn::Attribute], struct_name: &syn::Ident) -> Result<Self> {
let mut app_label = None;
let mut table_name = None;
let mut constraints = Vec::new();
let mut manager: Option<syn::Path> = None;
for attr in attrs {
if !attr.path().is_ident("model") && !attr.path().is_ident("model_config") {
continue;
}
let model_attr = attr
.parse_args_with(|input: syn::parse::ParseStream| {
Self::parse_model_attributes(input)
})
.map_err(|e| {
syn::Error::new_spanned(attr, format!("parse_args_with failed: {}", e))
})?;
if let Some(c) = model_attr.constraints {
constraints = c;
}
for fields in model_attr.unique_together {
constraints.push(ConstraintSpec::Unique {
fields,
name: None, condition: None,
});
}
if let Some(al) = model_attr.app_label {
app_label = Some(al);
}
if let Some(tn) = model_attr.table_name {
table_name = Some(tn);
}
if let Some(m) = model_attr.manager {
if manager.is_some() {
return Err(syn::Error::new_spanned(
struct_name,
"#[model(manager = ...)] specified more than once",
));
}
manager = Some(m);
}
}
let table_name = table_name.ok_or_else(|| {
syn::Error::new_spanned(
struct_name,
"table_name attribute is required in #[model(...)]",
)
})?;
Ok(Self {
app_label: app_label.unwrap_or_else(|| "default".to_string()),
table_name,
constraints,
manager,
})
}
fn parse_model_attributes(input: syn::parse::ParseStream) -> Result<ModelAttributesParsed> {
use syn::Token;
let mut app_label = None;
let mut table_name = None;
let mut constraints = None;
let mut unique_together = Vec::new();
let mut manager: Option<syn::Path> = None;
while !input.is_empty() {
let ident: Ident = input.parse()?;
input.parse::<Token![=]>()?;
if ident == "app_label" {
let value: LitStr = input.parse()?;
app_label = Some(value.value());
} else if ident == "table_name" {
let value: LitStr = input.parse()?;
table_name = Some(value.value());
} else if ident == "manager" {
let path: syn::Path = input.parse()?;
manager = Some(path);
} else if ident == "unique_together" {
use syn::punctuated::Punctuated;
let content;
parenthesized!(content in input);
let fields: Punctuated<LitStr, Token![,]> =
content.call(Punctuated::parse_terminated)?;
unique_together.push(fields.iter().map(|lit| lit.value()).collect());
} else if ident == "constraints" {
let array_content;
bracketed!(array_content in input);
let mut specs = Vec::new();
while !array_content.is_empty() {
specs.push(Self::parse_constraint(&array_content)?);
if array_content.peek(Token![,]) {
array_content.parse::<Token![,]>()?;
} else {
break;
}
}
constraints = Some(specs);
} else {
return Err(syn::Error::new_spanned(
&ident,
format!("Unknown model attribute: {}", ident),
));
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
} else {
break;
}
}
Ok(ModelAttributesParsed {
app_label,
table_name,
constraints,
unique_together,
manager,
})
}
fn parse_constraint(input: syn::parse::ParseStream) -> Result<ConstraintSpec> {
use syn::Token;
use syn::punctuated::Punctuated;
mod kw {
syn::custom_keyword!(unique);
}
let _unique_keyword = input.parse::<kw::unique>()?;
let content;
parenthesized!(content in input);
let mut fields = None;
let mut name = None;
let mut condition = None;
loop {
if content.is_empty() {
break;
}
let param_name: Ident = content.parse()?;
content.parse::<Token![=]>()?;
if param_name == "fields" {
let array_content;
bracketed!(array_content in content);
let field_literals: Punctuated<LitStr, Token![,]> =
array_content.call(Punctuated::parse_terminated)?;
fields = Some(field_literals.iter().map(|lit| lit.value()).collect());
} else if param_name == "name" {
let value: LitStr = content.parse()?;
name = Some(value.value());
} else if param_name == "condition" {
let value: LitStr = content.parse()?;
let condition_str = value.value();
validate_sql_expression(&condition_str, "condition")?;
condition = Some(condition_str);
} else {
return Err(syn::Error::new_spanned(
param_name,
"Unknown parameter. Supported: fields, name, condition",
));
}
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
} else {
break;
}
}
let fields = fields.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"unique constraint requires 'fields' parameter",
)
})?;
Ok(ConstraintSpec::Unique {
fields,
name,
condition,
})
}
}
#[derive(Debug, Clone)]
enum ForeignKeySpec {
Type(syn::Type),
AppModel {
app_label: String,
model_name: String,
},
}
#[cfg(feature = "db-postgres")]
#[derive(Debug, Clone)]
enum StorageStrategy {
Plain,
Extended,
External,
Main,
}
#[cfg(feature = "db-postgres")]
#[derive(Debug, Clone)]
enum CompressionMethod {
Pglz,
Lz4,
}
#[derive(Debug, Clone, Default)]
struct FieldConfig {
primary_key: bool,
max_length: Option<u64>,
null: Option<bool>,
blank: Option<bool>,
unique: Option<bool>,
default: Option<syn::Expr>, db_column: Option<String>,
editable: Option<bool>,
index: Option<bool>,
check: Option<String>,
email: Option<bool>,
url: Option<bool>,
min_length: Option<u64>,
min_value: Option<i64>,
max_value: Option<i64>,
auto_now_add: Option<bool>,
auto_now: Option<bool>,
foreign_key: Option<ForeignKeySpec>,
generated: Option<String>,
generated_stored: Option<bool>,
#[cfg(any(feature = "db-mysql", feature = "db-sqlite"))]
generated_virtual: Option<bool>,
#[cfg(feature = "db-postgres")]
identity_always: Option<bool>,
#[cfg(feature = "db-postgres")]
identity_by_default: Option<bool>,
auto_increment: Option<bool>,
#[cfg(feature = "db-sqlite")]
autoincrement: Option<bool>,
collate: Option<String>,
#[cfg(feature = "db-mysql")]
character_set: Option<String>,
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
comment: Option<String>,
#[cfg(feature = "db-postgres")]
storage: Option<StorageStrategy>,
#[cfg(feature = "db-postgres")]
compression: Option<CompressionMethod>,
#[cfg(feature = "db-mysql")]
on_update_current_timestamp: Option<bool>,
#[cfg(feature = "db-mysql")]
invisible: Option<bool>,
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
fulltext: Option<bool>,
#[cfg(feature = "db-mysql")]
unsigned: Option<bool>,
#[cfg(feature = "db-mysql")]
zerofill: Option<bool>,
skip_getter: bool,
skip: bool,
include_in_new: Option<bool>,
#[cfg(feature = "db-postgres")]
field_type: Option<String>,
#[cfg(feature = "db-postgres")]
array_base_type: Option<String>,
}
impl FieldConfig {
fn from_attrs(attrs: &[syn::Attribute]) -> Result<Self> {
let mut config = Self::default();
for attr in attrs {
if !attr.path().is_ident("field") {
continue;
}
if matches!(attr.meta, syn::Meta::Path(_)) {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("primary_key") {
let value: syn::LitBool = meta.value()?.parse()?;
config.primary_key = value.value;
Ok(())
} else if meta.path.is_ident("max_length") {
let value: syn::LitInt = meta.value()?.parse()?;
config.max_length = Some(value.base10_parse()?);
Ok(())
} else if meta.path.is_ident("null") {
let value: syn::LitBool = meta.value()?.parse()?;
config.null = Some(value.value);
Ok(())
} else if meta.path.is_ident("blank") {
let value: syn::LitBool = meta.value()?.parse()?;
config.blank = Some(value.value);
Ok(())
} else if meta.path.is_ident("unique") {
let value: syn::LitBool = meta.value()?.parse()?;
config.unique = Some(value.value);
Ok(())
} else if meta.path.is_ident("default") {
let value: syn::Expr = meta.value()?.parse()?;
config.default = Some(value);
Ok(())
} else if meta.path.is_ident("db_column") {
let value: syn::LitStr = meta.value()?.parse()?;
config.db_column = Some(value.value());
Ok(())
} else if meta.path.is_ident("editable") {
let value: syn::LitBool = meta.value()?.parse()?;
config.editable = Some(value.value);
Ok(())
} else if meta.path.is_ident("index") {
let value: syn::LitBool = meta.value()?.parse()?;
config.index = Some(value.value);
Ok(())
} else if meta.path.is_ident("check") {
let value: syn::LitStr = meta.value()?.parse()?;
let check_str = value.value();
validate_sql_expression(&check_str, "check")?;
config.check = Some(check_str);
Ok(())
} else if meta.path.is_ident("email") {
let value: syn::LitBool = meta.value()?.parse()?;
config.email = Some(value.value);
Ok(())
} else if meta.path.is_ident("url") {
let value: syn::LitBool = meta.value()?.parse()?;
config.url = Some(value.value);
Ok(())
} else if meta.path.is_ident("min_length") {
let value: syn::LitInt = meta.value()?.parse()?;
config.min_length = Some(value.base10_parse()?);
Ok(())
} else if meta.path.is_ident("min_value") {
let value: syn::LitInt = meta.value()?.parse()?;
config.min_value = Some(value.base10_parse()?);
Ok(())
} else if meta.path.is_ident("max_value") {
let value: syn::LitInt = meta.value()?.parse()?;
config.max_value = Some(value.base10_parse()?);
Ok(())
} else if meta.path.is_ident("auto_now_add") {
let value: syn::LitBool = meta.value()?.parse()?;
config.auto_now_add = Some(value.value);
Ok(())
} else if meta.path.is_ident("auto_now") {
let value: syn::LitBool = meta.value()?.parse()?;
config.auto_now = Some(value.value);
Ok(())
} else if meta.path.is_ident("foreign_key") {
if let Ok(ty) = meta.value()?.parse::<syn::Type>() {
config.foreign_key = Some(ForeignKeySpec::Type(ty));
return Ok(());
}
if let Ok(value) = meta.value()?.parse::<syn::LitStr>() {
let spec_str = value.value();
if spec_str.contains('.') {
let parts: Vec<&str> = spec_str.split('.').collect();
if parts.len() == 2 {
config.foreign_key = Some(ForeignKeySpec::AppModel {
app_label: parts[0].to_string(),
model_name: parts[1].to_string(),
});
return Ok(());
} else {
return Err(meta.error(
"foreign_key must be in 'app_label.model_name' format",
));
}
} else {
if let Ok(ty) = syn::parse_str::<syn::Type>(&spec_str) {
config.foreign_key = Some(ForeignKeySpec::Type(ty));
return Ok(());
} else {
return Err(meta.error("Invalid foreign_key specification"));
}
}
}
Err(meta.error("foreign_key must be a type (User) or string (\"users.User\")"))
}
else if meta.path.is_ident("generated") {
let value: syn::LitStr = meta.value()?.parse()?;
let gen_str = value.value();
validate_sql_expression(&gen_str, "generated")?;
config.generated = Some(gen_str);
Ok(())
} else if meta.path.is_ident("generated_stored") {
let value: syn::LitBool = meta.value()?.parse()?;
config.generated_stored = Some(value.value);
Ok(())
} else if meta.path.is_ident("generated_virtual") {
#[cfg(any(feature = "db-mysql", feature = "db-sqlite"))]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.generated_virtual = Some(value.value);
Ok(())
}
#[cfg(not(any(feature = "db-mysql", feature = "db-sqlite")))]
{
Err(meta.error(
"generated_virtual is only available with db-mysql or db-sqlite features",
))
}
}
else if meta.path.is_ident("identity_always") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.identity_always = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta
.error("identity_always is only available with db-postgres feature"))
}
} else if meta.path.is_ident("identity_by_default") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.identity_by_default = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta.error(
"identity_by_default is only available with db-postgres feature",
))
}
} else if meta.path.is_ident("auto_increment") {
let value: syn::LitBool = meta.value()?.parse()?;
config.auto_increment = Some(value.value);
Ok(())
} else if meta.path.is_ident("autoincrement") {
#[cfg(feature = "db-sqlite")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.autoincrement = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-sqlite"))]
{
Err(meta.error("autoincrement is only available with db-sqlite feature"))
}
}
else if meta.path.is_ident("collate") {
let value: syn::LitStr = meta.value()?.parse()?;
config.collate = Some(value.value());
Ok(())
} else if meta.path.is_ident("character_set") {
#[cfg(feature = "db-mysql")]
{
let value: syn::LitStr = meta.value()?.parse()?;
config.character_set = Some(value.value());
Ok(())
}
#[cfg(not(feature = "db-mysql"))]
{
Err(meta.error("character_set is only available with db-mysql feature"))
}
}
else if meta.path.is_ident("comment") {
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
{
let value: syn::LitStr = meta.value()?.parse()?;
config.comment = Some(value.value());
Ok(())
}
#[cfg(not(any(feature = "db-postgres", feature = "db-mysql")))]
{
Err(meta.error(
"comment is only available with db-postgres or db-mysql features",
))
}
}
else if meta.path.is_ident("storage") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitStr = meta.value()?.parse()?;
let storage_str = value.value();
let storage = match storage_str.to_lowercase().as_str() {
"plain" => StorageStrategy::Plain,
"extended" => StorageStrategy::Extended,
"external" => StorageStrategy::External,
"main" => StorageStrategy::Main,
_ => {
return Err(meta.error(
"storage must be one of: plain, extended, external, main",
));
}
};
config.storage = Some(storage);
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta.error("storage is only available with db-postgres feature"))
}
} else if meta.path.is_ident("compression") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitStr = meta.value()?.parse()?;
let compression_str = value.value();
let compression = match compression_str.to_lowercase().as_str() {
"pglz" => CompressionMethod::Pglz,
"lz4" => CompressionMethod::Lz4,
_ => return Err(meta.error("compression must be one of: pglz, lz4")),
};
config.compression = Some(compression);
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta.error("compression is only available with db-postgres feature"))
}
}
else if meta.path.is_ident("on_update_current_timestamp") {
#[cfg(feature = "db-mysql")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.on_update_current_timestamp = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-mysql"))]
{
Err(meta.error(
"on_update_current_timestamp is only available with db-mysql feature",
))
}
}
else if meta.path.is_ident("invisible") {
#[cfg(feature = "db-mysql")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.invisible = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-mysql"))]
{
Err(meta.error("invisible is only available with db-mysql feature"))
}
}
else if meta.path.is_ident("fulltext") {
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.fulltext = Some(value.value);
Ok(())
}
#[cfg(not(any(feature = "db-postgres", feature = "db-mysql")))]
{
Err(meta.error(
"fulltext is only available with db-postgres or db-mysql features",
))
}
}
else if meta.path.is_ident("unsigned") {
#[cfg(feature = "db-mysql")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.unsigned = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-mysql"))]
{
Err(meta.error("unsigned is only available with db-mysql feature"))
}
} else if meta.path.is_ident("zerofill") {
#[cfg(feature = "db-mysql")]
{
let value: syn::LitBool = meta.value()?.parse()?;
config.zerofill = Some(value.value);
Ok(())
}
#[cfg(not(feature = "db-mysql"))]
{
Err(meta.error("zerofill is only available with db-mysql feature"))
}
}
else if meta.path.is_ident("include_in_new") {
let value: syn::LitBool = meta.value()?.parse()?;
config.include_in_new = Some(value.value);
Ok(())
}
else if meta.path.is_ident("field_type") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitStr = meta.value()?.parse()?;
config.field_type = Some(value.value());
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta.error("field_type is only available with db-postgres feature"))
}
} else if meta.path.is_ident("array_base_type") {
#[cfg(feature = "db-postgres")]
{
let value: syn::LitStr = meta.value()?.parse()?;
config.array_base_type = Some(value.value());
Ok(())
}
#[cfg(not(feature = "db-postgres"))]
{
Err(meta
.error("array_base_type is only available with db-postgres feature"))
}
} else if meta.path.is_ident("skip_getter") {
config.skip_getter = meta.value()?.parse::<syn::LitBool>()?.value();
Ok(())
} else if meta.path.is_ident("skip") {
config.skip = meta.value()?.parse::<syn::LitBool>()?.value();
Ok(())
} else {
Err(meta.error("unsupported field attribute"))
}
})?;
}
if config.skip {
config.skip_getter = true;
}
Ok(config)
}
fn validate(&self) -> Result<()> {
#[allow(unused_mut)]
let mut auto_increment_count = 0;
#[cfg(feature = "db-postgres")]
{
if self.identity_always.is_some() {
auto_increment_count += 1;
}
if self.identity_by_default.is_some() {
auto_increment_count += 1;
}
}
#[cfg(feature = "db-mysql")]
{
if self.auto_increment.is_some() {
auto_increment_count += 1;
}
}
#[cfg(feature = "db-sqlite")]
{
if self.autoincrement.is_some() {
auto_increment_count += 1;
}
}
if auto_increment_count > 1 {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Only one auto-increment attribute (identity_always, identity_by_default, auto_increment, autoincrement) can be specified per field",
));
}
if self.generated.is_some() && self.default.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Generated columns cannot have default values",
));
}
if self.generated.is_some() {
let has_stored = self.generated_stored.unwrap_or(false);
#[cfg(any(feature = "db-mysql", feature = "db-sqlite"))]
let has_virtual = self.generated_virtual.unwrap_or(false);
#[cfg(not(any(feature = "db-mysql", feature = "db-sqlite")))]
let has_virtual = false;
if !has_stored && !has_virtual {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Generated columns must specify either generated_stored=true or generated_virtual=true",
));
}
if has_stored && has_virtual {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Generated columns cannot be both STORED and VIRTUAL",
));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct FieldInfo {
name: syn::Ident,
ty: Type,
config: FieldConfig,
#[allow(dead_code)]
rel: Option<RelAttribute>,
is_fk_id_field: bool,
}
#[derive(Debug, Clone)]
#[allow(dead_code)] struct ForeignKeyFieldInfo {
field_name: syn::Ident,
target_type: Type,
id_column_name: String,
related_name: Option<String>,
is_one_to_one: bool,
rel_attr: RelAttribute,
}
fn field_type_to_metadata_string(ty: &Type, _config: &FieldConfig) -> Result<String> {
let (_is_option, inner_ty) = extract_option_type(ty);
match inner_ty {
Type::Path(type_path) => {
let last_segment = type_path
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(ty, "Invalid type path"))?;
let type_name = match last_segment.ident.to_string().as_str() {
"i32" => "IntegerField",
"i64" => "BigIntegerField",
"String" => "CharField",
"bool" => "BooleanField",
"f32" | "f64" => "FloatField",
"DateTime" => "DateTimeField",
"Date" => "DateField",
"Time" => "TimeField",
"Decimal" => "DecimalField",
"Uuid" => "UuidField",
"Vec" => "ArrayField",
"Value" => "JsonField",
"HashMap" => "HStoreField",
other => {
return Err(syn::Error::new_spanned(
ty,
format!("Unsupported field type: {}", other),
));
}
};
Ok(format!("reinhardt.orm.models.{}", type_name))
}
_ => Err(syn::Error::new_spanned(ty, "Unsupported field type")),
}
}
fn serialize_field_default(expr: &syn::Expr) -> Option<String> {
if let syn::Expr::Unary(unary) = expr
&& matches!(unary.op, syn::UnOp::Neg(_))
&& let Some(inner) = serialize_field_default(&unary.expr)
{
return Some(format!("-{}", inner));
}
let lit = match expr {
syn::Expr::Lit(l) => &l.lit,
_ => return None,
};
match lit {
syn::Lit::Bool(b) => Some(if b.value {
"true".into()
} else {
"false".into()
}),
syn::Lit::Int(i) => Some(i.base10_digits().to_string()),
syn::Lit::Float(f) => Some(f.base10_digits().to_string()),
syn::Lit::Str(s) => Some(format!("'{}'", s.value().replace('\'', "''"))),
_ => None,
}
}
fn map_type_to_field_type(ty: &Type, config: &FieldConfig) -> Result<TokenStream> {
let migrations_crate = get_reinhardt_migrations_crate();
#[cfg(feature = "db-postgres")]
if let Some(explicit_type) = &config.field_type {
return map_explicit_field_type(explicit_type, &migrations_crate);
}
let (_is_option, inner_ty) = extract_option_type(ty);
let field_type = match inner_ty {
Type::Path(type_path) => {
let last_segment = type_path
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(ty, "Invalid type path"))?;
match last_segment.ident.to_string().as_str() {
"i32" => {
quote! { #migrations_crate::FieldType::Integer }
}
"i64" => {
quote! { #migrations_crate::FieldType::BigInteger }
}
"String" => {
let max_length = config.max_length.ok_or_else(|| {
syn::Error::new_spanned(ty, "String fields require max_length attribute")
})? as u32;
quote! { #migrations_crate::FieldType::VarChar(#max_length) }
}
"bool" => {
quote! { #migrations_crate::FieldType::Boolean }
}
"DateTime" => {
quote! { #migrations_crate::FieldType::TimestampTz }
}
"Date" => {
quote! { #migrations_crate::FieldType::Date }
}
"Time" => {
quote! { #migrations_crate::FieldType::Time }
}
"f32" => {
quote! { #migrations_crate::FieldType::Float }
}
"f64" => {
quote! { #migrations_crate::FieldType::Double }
}
"Uuid" => {
quote! { #migrations_crate::FieldType::Uuid }
}
#[cfg(feature = "db-postgres")]
"Vec" => {
return map_vec_to_array_type(ty, last_segment, config, &migrations_crate);
}
#[cfg(feature = "db-postgres")]
"Value" => {
quote! { #migrations_crate::FieldType::Jsonb }
}
#[cfg(feature = "db-postgres")]
"HashMap" => {
quote! { #migrations_crate::FieldType::HStore }
}
_ => {
return Err(syn::Error::new_spanned(
ty,
format!("Unsupported field type: {}", last_segment.ident),
));
}
}
}
_ => {
return Err(syn::Error::new_spanned(ty, "Unsupported field type"));
}
};
Ok(field_type)
}
#[cfg(feature = "db-postgres")]
fn map_explicit_field_type(
field_type_str: &str,
migrations_crate: &proc_macro2::TokenStream,
) -> Result<TokenStream> {
let field_type = match field_type_str.to_lowercase().as_str() {
"jsonb" => quote! { #migrations_crate::FieldType::Jsonb },
"json" => quote! { #migrations_crate::FieldType::Json },
"hstore" => quote! { #migrations_crate::FieldType::HStore },
"citext" => quote! { #migrations_crate::FieldType::CIText },
"int4range" | "integer_range" => quote! { #migrations_crate::FieldType::Int4Range },
"int8range" | "bigint_range" => quote! { #migrations_crate::FieldType::Int8Range },
"numrange" | "decimal_range" => quote! { #migrations_crate::FieldType::NumRange },
"daterange" | "date_range" => quote! { #migrations_crate::FieldType::DateRange },
"tsrange" | "timestamp_range" => quote! { #migrations_crate::FieldType::TsRange },
"tstzrange" | "timestamptz_range" => quote! { #migrations_crate::FieldType::TsTzRange },
"tsvector" => quote! { #migrations_crate::FieldType::TsVector },
"tsquery" => quote! { #migrations_crate::FieldType::TsQuery },
"uuid" => quote! { #migrations_crate::FieldType::Uuid },
"text" => quote! { #migrations_crate::FieldType::Text },
other => {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"Unknown PostgreSQL field type: '{}'. Supported types: jsonb, json, hstore, \
citext, int4range, int8range, numrange, daterange, tsrange, tstzrange, \
tsvector, tsquery, uuid, text",
other
),
));
}
};
Ok(field_type)
}
#[cfg(feature = "db-postgres")]
fn map_vec_to_array_type(
ty: &Type,
segment: &syn::PathSegment,
config: &FieldConfig,
migrations_crate: &proc_macro2::TokenStream,
) -> Result<TokenStream> {
if let Some(base_type) = &config.array_base_type {
let inner_field_type = parse_base_type_string(base_type, migrations_crate)?;
return Ok(quote! {
#migrations_crate::FieldType::Array(Box::new(#inner_field_type))
});
}
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(syn::GenericArgument::Type(Type::Path(inner_path))) = args.args.first()
&& let Some(inner_segment) = inner_path.path.segments.last()
{
let inner_type_name = inner_segment.ident.to_string();
let inner_field_type = match inner_type_name.as_str() {
"String" => {
if let Some(max_length) = config.max_length {
let ml = max_length as u32;
quote! { #migrations_crate::FieldType::VarChar(#ml) }
} else {
quote! { #migrations_crate::FieldType::Text }
}
}
"i32" => quote! { #migrations_crate::FieldType::Integer },
"i64" => quote! { #migrations_crate::FieldType::BigInteger },
"f32" => quote! { #migrations_crate::FieldType::Float },
"f64" => quote! { #migrations_crate::FieldType::Double },
"bool" => quote! { #migrations_crate::FieldType::Boolean },
"Uuid" => quote! { #migrations_crate::FieldType::Uuid },
_ => {
return Err(syn::Error::new_spanned(
ty,
format!(
"Cannot infer array element type for Vec<{}>. \
Use #[field(array_base_type = \"...\")] to specify explicitly.",
inner_type_name
),
));
}
};
return Ok(quote! {
#migrations_crate::FieldType::Array(Box::new(#inner_field_type))
});
}
Err(syn::Error::new_spanned(
ty,
"Cannot infer Vec element type. Use #[field(array_base_type = \"...\")] to specify explicitly.",
))
}
#[cfg(feature = "db-postgres")]
fn parse_base_type_string(
base_type: &str,
migrations_crate: &proc_macro2::TokenStream,
) -> Result<TokenStream> {
let upper = base_type.to_uppercase();
if upper.starts_with("VARCHAR(") && upper.ends_with(')') {
let len_str = &upper[8..upper.len() - 1];
if let Ok(length) = len_str.parse::<u32>() {
return Ok(quote! { #migrations_crate::FieldType::VarChar(#length) });
}
}
if upper.starts_with("CHAR(") && upper.ends_with(')') {
let len_str = &upper[5..upper.len() - 1];
if let Ok(length) = len_str.parse::<u32>() {
return Ok(quote! { #migrations_crate::FieldType::Char(#length) });
}
}
let field_type = match upper.as_str() {
"INTEGER" | "INT" | "INT4" => quote! { #migrations_crate::FieldType::Integer },
"BIGINT" | "INT8" => quote! { #migrations_crate::FieldType::BigInteger },
"SMALLINT" | "INT2" => quote! { #migrations_crate::FieldType::SmallInteger },
"TEXT" => quote! { #migrations_crate::FieldType::Text },
"BOOLEAN" | "BOOL" => quote! { #migrations_crate::FieldType::Boolean },
"REAL" | "FLOAT4" => quote! { #migrations_crate::FieldType::Float },
"DOUBLE PRECISION" | "FLOAT8" => quote! { #migrations_crate::FieldType::Double },
"UUID" => quote! { #migrations_crate::FieldType::Uuid },
"DATE" => quote! { #migrations_crate::FieldType::Date },
"TIME" => quote! { #migrations_crate::FieldType::Time },
"TIMESTAMP" => quote! { #migrations_crate::FieldType::DateTime },
"JSONB" => quote! { #migrations_crate::FieldType::Jsonb },
"JSON" => quote! { #migrations_crate::FieldType::Json },
_ => {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"Unknown base type for array: '{}'. Use standard SQL types like \
INTEGER, BIGINT, VARCHAR(n), TEXT, BOOLEAN, etc.",
base_type
),
));
}
};
Ok(field_type)
}
fn extract_option_type(ty: &Type) -> (bool, &Type) {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
&& last_segment.ident == "Option"
&& let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
&& let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
{
return (true, inner_ty);
}
(false, ty)
}
fn generate_field_accessors(struct_name: &syn::Ident, field_infos: &[FieldInfo]) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let accessor_methods: Vec<_> = field_infos
.iter()
.filter(|field| !field.config.skip)
.map(|field| {
let field_name = &field.name;
let field_type = &field.ty;
let method_name = syn::Ident::new(&format!("field_{}", field_name), field_name.span());
let field_name_str = field_name.to_string();
quote! {
pub const fn #method_name() -> #orm_crate::expressions::FieldRef<#struct_name, #field_type> {
#orm_crate::expressions::FieldRef::new(#field_name_str)
}
}
})
.collect();
quote! {
impl #struct_name {
#(#accessor_methods)*
}
}
}
fn generate_m2m_accessor_methods(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let accessor_methods: Vec<_> = field_infos
.iter()
.filter(|field| is_many_to_many_field_type(&field.ty))
.filter_map(|field| {
let field_name = &field.name;
let field_name_str = field_name.to_string();
let method_name = syn::Ident::new(
&format!("{}_accessor", field_name),
field_name.span()
);
let target_ty = extract_m2m_target_type(&field.ty)?;
let doc_comment = format!(
"Create a ManyToManyAccessor for the '{}' relationship",
field_name_str
);
Some(quote! {
#[doc = #doc_comment]
pub fn #method_name(
&self,
db: #orm_crate::connection::DatabaseConnection
) -> #orm_crate::ManyToManyAccessor<#struct_name, #target_ty> {
#orm_crate::ManyToManyAccessor::new(
self,
#field_name_str,
db
)
}
})
})
.collect();
if accessor_methods.is_empty() {
quote! {}
} else {
quote! {
impl #struct_name {
#(#accessor_methods)*
}
}
}
}
fn generate_fk_accessor_methods(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let core_crate = get_reinhardt_core_crate();
let accessor_methods: Vec<_> = field_infos
.iter()
.filter(|field| {
is_foreign_key_field_type(&field.ty) || is_one_to_one_field_type(&field.ty)
})
.map(|field| {
let field_name = &field.name;
let field_name_str = field_name.to_string();
let fk_id_field_name = syn::Ident::new(
&format!("{}_id", field_name),
field_name.span()
);
let method_name = field_name;
let target_ty = extract_foreign_key_target_type(&field.ty);
let doc_comment = format!(
"Load the related '{}' instance from the database",
field_name_str
);
quote! {
#[doc = #doc_comment]
pub async fn #method_name(
&self,
db: &#orm_crate::connection::DatabaseConnection
) -> #core_crate::exception::Result<Option<#target_ty>> {
use #orm_crate::Model;
use #orm_crate::{FilterOperator, FilterValue};
let fk_id = self.#fk_id_field_name();
#target_ty::objects()
.filter(
#target_ty::field_id(),
FilterOperator::Eq,
FilterValue::String(fk_id.to_string())
)
.first_with_db(db)
.await
}
}
})
.collect();
if accessor_methods.is_empty() {
quote! {}
} else {
quote! {
impl #struct_name {
#(#accessor_methods)*
}
}
}
}
fn generate_fk_static_accessor_methods(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let accessor_methods: Vec<_> = field_infos
.iter()
.filter(|field| {
is_foreign_key_field_type(&field.ty) || is_one_to_one_field_type(&field.ty)
})
.map(|field| {
let field_name = &field.name;
let field_name_str = field_name.to_string();
let db_column = format!("{}_id", field_name_str);
let method_name =
syn::Ident::new(&format!("{}_accessor", field_name), field_name.span());
let target_ty = extract_foreign_key_target_type(&field.ty);
let doc_comment = format!(
"Get the ForeignKey accessor for the '{}' relationship",
field_name_str
);
quote! {
#[doc = #doc_comment]
pub fn #method_name() -> #orm_crate::ForeignKeyAccessor<#struct_name, #target_ty> {
#orm_crate::ForeignKeyAccessor::new(#db_column)
}
}
})
.collect();
if accessor_methods.is_empty() {
quote! {}
} else {
quote! {
impl #struct_name {
#(#accessor_methods)*
}
}
}
}
fn make_fields_private(input: &mut DeriveInput) {
if let Data::Struct(data) = &mut input.data
&& let Fields::Named(fields) = &mut data.fields
{
for field in fields.named.iter_mut() {
field.vis = syn::Visibility::Inherited;
}
}
}
fn is_copy_type(ty: &Type) -> bool {
matches!(
ty,
Type::Path(path) if matches!(
path.path.segments.last().map(|s| s.ident.to_string()).as_deref(),
Some("i8" | "i16" | "i32" | "i64" | "i128" |
"u8" | "u16" | "u32" | "u64" | "u128" |
"f32" | "f64" | "bool" | "char" | "Uuid")
)
) || matches!(
ty,
Type::Path(path) if path.path.segments.iter().any(|seg|
seg.ident == "DateTime"
)
)
}
fn generate_getter_methods(struct_name: &syn::Ident, field_infos: &[FieldInfo]) -> TokenStream {
let getter_methods: Vec<_> = field_infos
.iter()
.filter(|field| {
!is_foreign_key_field_type(&field.ty)
&& !is_one_to_one_field_type(&field.ty)
&& !field.config.skip_getter
})
.map(|field| {
let field_name = &field.name;
let field_type = &field.ty;
let method_name = field_name;
if is_copy_type(field_type) {
quote! {
#[doc = concat!("Get ", stringify!(#field_name))]
pub fn #method_name(&self) -> #field_type {
self.#field_name
}
}
} else {
quote! {
#[doc = concat!("Get reference to ", stringify!(#field_name))]
pub fn #method_name(&self) -> &#field_type {
&self.#field_name
}
}
}
})
.collect();
quote! {
impl #struct_name {
#(#getter_methods)*
}
}
}
fn generate_setter_methods(struct_name: &syn::Ident, field_infos: &[FieldInfo]) -> TokenStream {
let setter_methods: Vec<_> = field_infos
.iter()
.filter(|f| !is_auto_generated_field(f) && !f.config.skip_getter)
.map(|field| {
let field_name = &field.name;
let field_type = &field.ty;
let setter_name = syn::Ident::new(&format!("set_{}", field_name), field_name.span());
quote! {
#[doc = concat!("Set ", stringify!(#field_name))]
pub fn #setter_name(&mut self, value: #field_type) {
self.#field_name = value;
}
}
})
.collect();
quote! {
impl #struct_name {
#(#setter_methods)*
}
}
}
pub(crate) fn model_derive_impl(mut input: DeriveInput) -> Result<TokenStream> {
let _reinhardt = get_reinhardt_crate();
let orm_crate = get_reinhardt_orm_crate();
make_fields_private(&mut input);
let struct_name = &input.ident;
let generics = &input.generics;
let where_clause = &generics.where_clause;
let model_config = ModelConfig::from_attrs(&input.attrs, struct_name)?;
let app_label = &model_config.app_label;
let table_name = &model_config.table_name;
let fields = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
struct_name,
"Model can only be derived for structs with named fields",
));
}
},
_ => {
return Err(syn::Error::new_spanned(
struct_name,
"Model can only be derived for structs",
));
}
};
let mut field_infos = Vec::new();
let mut rel_fields = Vec::new();
let mut fk_id_field_names: Vec<syn::Ident> = Vec::new();
for field in fields {
let is_fk_id_field = if let Some(field_name) = &field.ident {
let name_str = field_name.to_string();
let field_ty = &field.ty;
let type_str = quote!(#field_ty).to_string();
name_str.ends_with("_id")
&& type_str.contains("Model")
&& type_str.contains("PrimaryKey")
} else {
false
};
if is_fk_id_field {
if let Some(field_name) = &field.ident {
fk_id_field_names.push(field_name.clone());
}
}
let name = field
.ident
.clone()
.ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?;
let ty = field.ty.clone();
let config = FieldConfig::from_attrs(&field.attrs)?;
config.validate()?;
let rel = field
.attrs
.iter()
.find(|attr| attr.path().is_ident("rel"))
.map(RelAttribute::from_attribute)
.transpose()?;
if let Some(ref rel_attr) = rel {
rel_fields.push((name.clone(), rel_attr.clone()));
}
field_infos.push(FieldInfo {
name,
ty,
config,
rel,
is_fk_id_field,
});
}
let mut fk_field_infos: Vec<ForeignKeyFieldInfo> = Vec::new();
for field_info in &field_infos {
if let Some(ref rel_attr) = field_info.rel {
if let Some(target_type) = extract_fk_target_type(&field_info.ty) {
let is_one_to_one = is_one_to_one_field_type(&field_info.ty);
if is_one_to_one && rel_attr.rel_type != crate::rel::RelationType::OneToOne {
return Err(syn::Error::new(
rel_attr.span,
"OneToOneField must use #[rel(one_to_one, ...)]",
));
}
if is_foreign_key_field_type(&field_info.ty)
&& rel_attr.rel_type != crate::rel::RelationType::ForeignKey
{
return Err(syn::Error::new(
rel_attr.span,
"ForeignKeyField must use #[rel(foreign_key, ...)]",
));
}
let id_column_name = rel_attr
.db_column
.clone()
.unwrap_or_else(|| format!("{}_id", field_info.name));
fk_field_infos.push(ForeignKeyFieldInfo {
field_name: field_info.name.clone(),
target_type: target_type.clone(),
id_column_name,
related_name: rel_attr.related_name.clone(),
is_one_to_one,
rel_attr: rel_attr.clone(),
});
}
}
}
let pk_fields: Vec<_> = field_infos
.iter()
.filter(|f| f.config.primary_key)
.collect();
if pk_fields.is_empty() {
return Err(syn::Error::new_spanned(
struct_name,
"Model must have at least one primary key field",
));
}
let is_composite_pk = pk_fields.len() > 1;
let indexed_fields: Vec<_> = field_infos
.iter()
.filter(|f| f.config.index.unwrap_or(false))
.map(|f| f.name.to_string())
.collect();
let check_constraints: Vec<(String, String)> = field_infos
.iter()
.filter_map(|f| {
f.config
.check
.as_ref()
.map(|expr| (f.name.to_string(), expr.clone()))
})
.collect();
let check_constraint_names: Vec<String> = check_constraints
.iter()
.map(|(field_name, _)| format!("{}_check", field_name))
.collect();
let check_constraint_expressions: Vec<String> = check_constraints
.iter()
.map(|(_, expr)| expr.clone())
.collect();
let unique_constraints: Vec<(Vec<String>, Option<String>, Option<String>)> = model_config
.constraints
.iter()
.map(|c| match c {
ConstraintSpec::Unique {
fields,
name,
condition,
} => (fields.clone(), name.clone(), condition.clone()),
})
.collect();
let unique_constraint_names: Vec<String> = unique_constraints
.iter()
.map(|(fields, name, _)| {
if let Some(n) = name {
n.clone()
} else {
format!("{}_{}_uniq", table_name, fields.join("_"))
}
})
.collect();
let unique_constraint_definitions: Vec<String> = unique_constraints
.iter()
.map(|(fields, _, condition)| {
let fields_str = fields.join(", ");
if let Some(cond) = condition {
format!("UNIQUE ({}) WHERE {}", fields_str, cond)
} else {
format!("UNIQUE ({})", fields_str)
}
})
.collect();
let unique_constraint_field_lists: Vec<Vec<String>> = unique_constraints
.iter()
.map(|(fields, _, _)| fields.clone())
.collect();
let composite_pk_type_def: Option<TokenStream>;
#[allow(unused_assignments)]
let mut composite_pk_type_holder: Option<Type> = None;
let (pk_name, _pk_ty, pk_is_option, pk_type) = if !is_composite_pk {
composite_pk_type_def = None;
let pk_field = pk_fields[0];
let pk_name = &pk_field.name;
let pk_ty = &pk_field.ty;
let (pk_is_option, pk_inner_ty) = extract_option_type(pk_ty);
let pk_type = if pk_is_option { pk_inner_ty } else { pk_ty };
(pk_name, pk_ty, pk_is_option, pk_type)
} else {
let composite_pk_name =
syn::Ident::new(&format!("{}CompositePk", struct_name), struct_name.span());
composite_pk_type_def = Some(generate_composite_pk_type(struct_name, &pk_fields));
composite_pk_type_holder = Some(parse_quote! { #composite_pk_name });
let composite_pk_type_ref = composite_pk_type_holder.as_ref().unwrap();
let first_pk_name = &pk_fields[0].name;
(
first_pk_name,
composite_pk_type_ref,
false,
composite_pk_type_ref,
)
};
let field_metadata_items = generate_field_metadata(&field_infos, &fk_field_infos)?;
let registration_code = generate_registration_code(
struct_name,
app_label,
table_name,
&field_infos,
&fk_field_infos,
&unique_constraint_names,
&unique_constraint_field_lists,
)?;
let relationship_registrations =
generate_relationship_registrations(struct_name, app_label, &field_infos, &fk_field_infos);
let (pk_impl, set_pk_impl, composite_pk_impl) = if is_composite_pk {
let composite_impl = generate_composite_pk_impl(&pk_fields);
let pk_field_names: Vec<_> = pk_fields.iter().map(|f| &f.name).collect();
let has_option_fields = pk_fields.iter().any(|f| {
let (is_option, _) = extract_option_type(&f.ty);
is_option
});
let pk_getter = if has_option_fields {
quote! {
fn primary_key(&self) -> Option<Self::PrimaryKey> {
if #(self.#pk_field_names.is_some())&&* {
Some(Self::PrimaryKey::new(
#(self.#pk_field_names.clone().unwrap()),*
))
} else {
None
}
}
}
} else {
quote! {
fn primary_key(&self) -> Option<Self::PrimaryKey> {
Some(Self::PrimaryKey::new(
#(self.#pk_field_names.clone()),*
))
}
}
};
let pk_setter = if has_option_fields {
quote! {
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
#(
self.#pk_field_names = Some(value.#pk_field_names);
)*
}
}
} else {
quote! {
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
#(
self.#pk_field_names = value.#pk_field_names;
)*
}
}
};
(pk_getter, pk_setter, composite_impl)
} else {
let (pk_getter, pk_setter) = if pk_is_option {
(
quote! {
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.#pk_name.clone()
}
},
quote! {
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.#pk_name = Some(value);
}
},
)
} else {
(
quote! {
fn primary_key(&self) -> Option<Self::PrimaryKey> {
Some(self.#pk_name.clone())
}
},
quote! {
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.#pk_name = value;
}
},
)
};
(pk_getter, pk_setter, quote! {})
};
let field_accessors = generate_field_accessors(struct_name, &field_infos);
let m2m_accessor_methods = generate_m2m_accessor_methods(struct_name, &field_infos);
let fk_accessor_methods = generate_fk_accessor_methods(struct_name, &field_infos);
let relationship_metadata = generate_relationship_metadata(&rel_fields, app_label, struct_name);
let new_fn_impl = generate_new_function(struct_name, &field_infos, &fk_id_field_names);
let build_fn_impl = generate_build_function(struct_name, &field_infos, &fk_id_field_names);
let getters = generate_getter_methods(struct_name, &field_infos);
let setters = generate_setter_methods(struct_name, &field_infos);
let fk_static_accessor_methods = generate_fk_static_accessor_methods(struct_name, &field_infos);
let field_selector_name =
syn::Ident::new(&format!("{}Fields", struct_name), struct_name.span());
let field_selector_struct = generate_field_selector_struct(struct_name, &field_infos);
let custom_manager_impl = match &model_config.manager {
Some(path) => quote! {
impl #generics #orm_crate::HasCustomManager for #struct_name #generics #where_clause {
type Manager = #path;
}
},
None => quote! {},
};
let expanded = quote! {
#composite_pk_type_def
#new_fn_impl
#build_fn_impl
#getters
#setters
#field_accessors
#m2m_accessor_methods
#fk_accessor_methods
#fk_static_accessor_methods
impl #generics #orm_crate::Model for #struct_name #generics #where_clause {
type PrimaryKey = #pk_type;
type Fields = #field_selector_name;
fn table_name() -> &'static str {
#table_name
}
fn new_fields() -> Self::Fields {
#field_selector_name::new()
}
fn app_label() -> &'static str {
#app_label
}
fn primary_key_field() -> &'static str {
stringify!(#pk_name)
}
#pk_impl
#set_pk_impl
#composite_pk_impl
fn field_metadata() -> Vec<#orm_crate::inspection::FieldInfo> {
vec![
#(#field_metadata_items),*
]
}
fn index_metadata() -> Vec<#orm_crate::inspection::IndexInfo> {
vec![
#(
#orm_crate::inspection::IndexInfo {
name: format!("{}_{}_idx", <Self as #orm_crate::Model>::table_name(), #indexed_fields),
fields: vec![#indexed_fields.to_string()],
unique: false,
condition: None,
}
),*
]
}
fn constraint_metadata() -> Vec<#orm_crate::inspection::ConstraintInfo> {
let mut constraints = Vec::new();
#(
constraints.push(#orm_crate::inspection::ConstraintInfo {
name: #check_constraint_names.to_string(),
constraint_type: #orm_crate::inspection::ConstraintType::Check,
definition: #check_constraint_expressions.to_string(),
});
)*
#(
constraints.push(#orm_crate::inspection::ConstraintInfo {
name: #unique_constraint_names.to_string(),
constraint_type: #orm_crate::inspection::ConstraintType::Unique,
definition: #unique_constraint_definitions.to_string(),
});
)*
constraints
}
#relationship_metadata
}
#custom_manager_impl
#registration_code
#relationship_registrations
#field_selector_struct
};
Ok(expanded)
}
fn generate_field_metadata(
field_infos: &[FieldInfo],
fk_field_infos: &[ForeignKeyFieldInfo],
) -> Result<Vec<TokenStream>> {
let mut items = Vec::new();
let regular_fields: Vec<_> = field_infos
.iter()
.filter(|f| {
if f.config.skip {
return false;
}
if f.is_fk_id_field {
return false;
}
if f.rel
.as_ref()
.map(|r| matches!(r.rel_type, crate::rel::RelationType::ManyToMany))
.unwrap_or(false)
{
return false;
}
if is_relationship_field_type(&f.ty) {
return false;
}
true
})
.collect();
let orm_crate = get_reinhardt_orm_crate();
if regular_fields.is_empty() {
let _ = &orm_crate; }
for field_info in regular_fields {
let name = field_info.name.to_string();
let field_type_path = field_type_to_metadata_string(&field_info.ty, &field_info.config)?;
let _field_type = map_type_to_field_type(&field_info.ty, &field_info.config)?;
let config = &field_info.config;
let (is_option, _) = extract_option_type(&field_info.ty);
let nullable = config.null.unwrap_or(is_option);
let primary_key = config.primary_key;
let unique = config.unique.unwrap_or(false);
let blank = config.blank.unwrap_or(false);
let editable = config.editable.unwrap_or(true);
let mut attrs = Vec::new();
if let Some(max_length) = config.max_length {
attrs.push(quote! {
attributes.insert(
"max_length".to_string(),
#orm_crate::fields::FieldKwarg::Uint(#max_length)
);
});
}
if let Some(email) = config.email
&& email
{
attrs.push(quote! {
attributes.insert(
"email".to_string(),
#orm_crate::fields::FieldKwarg::Bool(true)
);
});
}
if let Some(url) = config.url
&& url
{
attrs.push(quote! {
attributes.insert(
"url".to_string(),
#orm_crate::fields::FieldKwarg::Bool(true)
);
});
}
if let Some(min_length) = config.min_length {
attrs.push(quote! {
attributes.insert(
"min_length".to_string(),
#orm_crate::fields::FieldKwarg::Uint(#min_length)
);
});
}
if let Some(min_value) = config.min_value {
attrs.push(quote! {
attributes.insert(
"min_value".to_string(),
#orm_crate::fields::FieldKwarg::Int(#min_value)
);
});
}
if let Some(max_value) = config.max_value {
attrs.push(quote! {
attributes.insert(
"max_value".to_string(),
#orm_crate::fields::FieldKwarg::Int(#max_value)
);
});
}
if let Some(ref generated_expr) = config.generated {
attrs.push(quote! {
attributes.insert(
"generated".to_string(),
#orm_crate::fields::FieldKwarg::String(#generated_expr.to_string())
);
});
}
if let Some(generated_stored) = config.generated_stored {
attrs.push(quote! {
attributes.insert(
"generated_stored".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#generated_stored)
);
});
}
#[cfg(any(feature = "db-mysql", feature = "db-sqlite"))]
if let Some(generated_virtual) = config.generated_virtual {
attrs.push(quote! {
attributes.insert(
"generated_virtual".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#generated_virtual)
);
});
}
#[cfg(feature = "db-postgres")]
if let Some(identity_always) = config.identity_always {
attrs.push(quote! {
attributes.insert(
"identity_always".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#identity_always)
);
});
}
#[cfg(feature = "db-postgres")]
if let Some(identity_by_default) = config.identity_by_default {
attrs.push(quote! {
attributes.insert(
"identity_by_default".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#identity_by_default)
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(auto_increment) = config.auto_increment {
attrs.push(quote! {
attributes.insert(
"auto_increment".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#auto_increment)
);
});
}
#[cfg(feature = "db-sqlite")]
if let Some(autoincrement) = config.autoincrement {
attrs.push(quote! {
attributes.insert(
"autoincrement".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#autoincrement)
);
});
}
if let Some(ref collate) = config.collate {
attrs.push(quote! {
attributes.insert(
"collate".to_string(),
#orm_crate::fields::FieldKwarg::String(#collate.to_string())
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(ref character_set) = config.character_set {
attrs.push(quote! {
attributes.insert(
"character_set".to_string(),
#orm_crate::fields::FieldKwarg::String(#character_set.to_string())
);
});
}
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
if let Some(ref comment) = config.comment {
attrs.push(quote! {
attributes.insert(
"comment".to_string(),
#orm_crate::fields::FieldKwarg::String(#comment.to_string())
);
});
}
#[cfg(feature = "db-postgres")]
if let Some(ref storage) = config.storage {
let storage_str = match storage {
StorageStrategy::Plain => "plain",
StorageStrategy::Extended => "extended",
StorageStrategy::External => "external",
StorageStrategy::Main => "main",
};
attrs.push(quote! {
attributes.insert(
"storage".to_string(),
#orm_crate::fields::FieldKwarg::String(#storage_str.to_string())
);
});
}
#[cfg(feature = "db-postgres")]
if let Some(ref compression) = config.compression {
let compression_str = match compression {
CompressionMethod::Pglz => "pglz",
CompressionMethod::Lz4 => "lz4",
};
attrs.push(quote! {
attributes.insert(
"compression".to_string(),
#orm_crate::fields::FieldKwarg::String(#compression_str.to_string())
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(on_update_current_timestamp) = config.on_update_current_timestamp {
attrs.push(quote! {
attributes.insert(
"on_update_current_timestamp".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#on_update_current_timestamp)
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(invisible) = config.invisible {
attrs.push(quote! {
attributes.insert(
"invisible".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#invisible)
);
});
}
#[cfg(any(feature = "db-postgres", feature = "db-mysql"))]
if let Some(fulltext) = config.fulltext {
attrs.push(quote! {
attributes.insert(
"fulltext".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#fulltext)
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(unsigned) = config.unsigned {
attrs.push(quote! {
attributes.insert(
"unsigned".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#unsigned)
);
});
}
#[cfg(feature = "db-mysql")]
if let Some(zerofill) = config.zerofill {
attrs.push(quote! {
attributes.insert(
"zerofill".to_string(),
#orm_crate::fields::FieldKwarg::Bool(#zerofill)
);
});
}
let db_column_value = match &config.db_column {
Some(col) => quote! { Some(#col.to_string()) },
None => quote! { None },
};
let item = quote! {
{
let mut attributes = ::std::collections::HashMap::new();
#(#attrs)*
#orm_crate::inspection::FieldInfo {
name: #name.to_string(),
field_type: #field_type_path.to_string(),
nullable: #nullable,
primary_key: #primary_key,
unique: #unique,
blank: #blank,
editable: #editable,
default: None,
db_default: None,
db_column: #db_column_value,
choices: None,
attributes,
}
}
};
items.push(item);
}
for fk_info in fk_field_infos {
let name = &fk_info.id_column_name;
let nullable = fk_info.rel_attr.null.unwrap_or(false);
let unique = fk_info.is_one_to_one; let db_index = fk_info.rel_attr.db_index.unwrap_or(true);
let field_type_path = "IntegerField";
let item = quote! {
{
let mut attributes = ::std::collections::HashMap::new();
if #db_index {
attributes.insert(
"db_index".to_string(),
#orm_crate::fields::FieldKwarg::Bool(true)
);
}
#orm_crate::inspection::FieldInfo {
name: #name.to_string(),
field_type: #field_type_path.to_string(),
nullable: #nullable,
primary_key: false,
unique: #unique,
blank: false,
editable: true,
default: None,
db_default: None,
db_column: None,
choices: None,
attributes,
}
}
};
items.push(item);
}
Ok(items)
}
fn generate_registration_code(
struct_name: &syn::Ident,
app_label: &str,
table_name: &str,
field_infos: &[FieldInfo],
fk_field_infos: &[ForeignKeyFieldInfo],
unique_constraint_names: &[String],
unique_constraint_field_lists: &[Vec<String>],
) -> Result<TokenStream> {
let migrations_crate = get_reinhardt_migrations_crate();
let orm_crate = get_reinhardt_orm_crate();
let model_name = struct_name.to_string();
let register_fn_name = syn::Ident::new(
&format!(
"__register_{}_model",
struct_name.to_string().to_lowercase()
),
struct_name.span(),
);
let (m2m_fields, regular_fields_with_fk_id): (Vec<_>, Vec<_>) =
field_infos.iter().partition(|f| {
if f.rel
.as_ref()
.map(|r| matches!(r.rel_type, crate::rel::RelationType::ManyToMany))
.unwrap_or(false)
{
return true;
}
if is_relationship_field_type(&f.ty) {
return true;
}
false
});
let regular_fields: Vec<_> = regular_fields_with_fk_id
.into_iter()
.filter(|f| !f.is_fk_id_field && !f.config.skip)
.collect();
let mut field_registrations = Vec::new();
for field_info in ®ular_fields {
let field_name = field_info.name.to_string();
let field_type = map_type_to_field_type(&field_info.ty, &field_info.config)?;
let config = &field_info.config;
let mut params = Vec::new();
if config.primary_key {
params.push(quote! { .with_param("primary_key", "true") });
}
let (is_option, _) = extract_option_type(&field_info.ty);
let is_not_null = if let Some(null) = config.null {
!null
} else if config.primary_key {
true
} else {
!is_option
};
if is_not_null {
params.push(quote! { .with_param("not_null", "true") });
}
if let Some(max_length) = config.max_length {
let ml_str = max_length.to_string();
params.push(quote! { .with_param("max_length", #ml_str) });
}
if let Some(null) = config.null {
let null_str = null.to_string();
params.push(quote! { .with_param("null", #null_str) });
}
if let Some(unique) = config.unique
&& unique
{
params.push(quote! { .with_param("unique", "true") });
}
if config.null.is_none() {
let (is_option, _) = extract_option_type(&field_info.ty);
let nullable = !config.primary_key && is_option;
let null_str = nullable.to_string();
params.push(quote! { .with_param("null", #null_str) });
}
if config.primary_key && is_integer_primary_key_type(&field_info.ty) {
let auto_inc = config.auto_increment.unwrap_or(true);
let auto_inc_str = auto_inc.to_string();
params.push(quote! { .with_param("auto_increment", #auto_inc_str) });
} else if let Some(auto_increment) = config.auto_increment {
let auto_inc_str = auto_increment.to_string();
params.push(quote! { .with_param("auto_increment", #auto_inc_str) });
}
if config.auto_now == Some(true) {
params.push(quote! { .with_param("auto_now", "true") });
}
if config.auto_now_add == Some(true) {
params.push(quote! { .with_param("auto_now_add", "true") });
}
if let Some(ref default_expr) = config.default
&& let Some(serialized) = serialize_field_default(default_expr)
{
params.push(quote! { .with_param("default", #serialized) });
}
let fk_registration = if let Some(fk_spec) = &config.foreign_key {
match fk_spec {
ForeignKeySpec::Type(ty) => {
let type_name_str = quote! { #ty }.to_string();
quote! {
.with_foreign_key({
let type_name = #type_name_str;
let last_segment = type_name.split("::").last().unwrap_or(&type_name);
let referenced_table = #migrations_crate::to_snake_case(last_segment);
#migrations_crate::ForeignKeyInfo {
referenced_table,
referenced_column: "id".to_string(),
on_delete: #migrations_crate::ForeignKeyAction::Cascade,
on_update: #migrations_crate::ForeignKeyAction::Cascade,
}
})
}
}
ForeignKeySpec::AppModel {
app_label,
model_name,
} => {
let table_name_str = format!("{}_{}", app_label, model_name.to_lowercase());
quote! {
.with_foreign_key(#migrations_crate::ForeignKeyInfo {
referenced_table: #table_name_str.to_string(),
referenced_column: "id".to_string(),
on_delete: #migrations_crate::ForeignKeyAction::Cascade,
on_update: #migrations_crate::ForeignKeyAction::Cascade,
})
}
}
}
} else {
quote! {}
};
field_registrations.push(quote! {
metadata.add_field(
#field_name.to_string(),
#migrations_crate::model_registry::FieldMetadata::new(#field_type)
#(#params)*
#fk_registration
);
});
}
let mut m2m_registrations = Vec::new();
for field_info in &m2m_fields {
let field_name = field_info.name.to_string();
let to_model = if let Some(rel) = &field_info.rel
&& let Some(to_type) = &rel.to
{
quote! { #to_type }.to_string()
} else if let Some(target_ty) = extract_m2m_target_type(&field_info.ty) {
if let Type::Path(type_path) = target_ty
&& let Some(last_segment) = type_path.path.segments.last()
{
last_segment.ident.to_string()
} else {
continue; }
} else {
continue; };
let related_name = field_info
.rel
.as_ref()
.and_then(|r| r.related_name.as_ref())
.map(|r| quote! { Some(#r.to_string()) })
.unwrap_or(quote! { None });
let through = field_info
.rel
.as_ref()
.and_then(|r| r.through.as_ref())
.map(|t| quote! { Some(#t.to_string()) })
.unwrap_or(quote! { None });
let source_field = field_info
.rel
.as_ref()
.and_then(|r| r.source_field.as_ref())
.map(|s| quote! { Some(#s.to_string()) })
.unwrap_or(quote! { None });
let target_field = field_info
.rel
.as_ref()
.and_then(|r| r.target_field.as_ref())
.map(|t| quote! { Some(#t.to_string()) })
.unwrap_or(quote! { None });
m2m_registrations.push(quote! {
metadata.add_many_to_many(
#migrations_crate::model_registry::ManyToManyMetadata {
field_name: #field_name.to_string(),
to_model: #to_model.to_string(),
related_name: #related_name,
through: #through,
source_field: #source_field,
target_field: #target_field,
db_constraint_prefix: None,
}
);
});
}
let mut fk_id_registrations = Vec::new();
for fk_info in fk_field_infos {
let id_column_name = &fk_info.id_column_name;
let nullable = fk_info.rel_attr.null.unwrap_or(false);
let unique = fk_info.is_one_to_one; let db_index = fk_info.rel_attr.db_index.unwrap_or(true); let not_null_str = (!nullable).to_string();
let unique_str = unique.to_string();
let db_index_str = db_index.to_string();
let target_model_name = if let Type::Path(type_path) = &fk_info.target_type {
type_path
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_else(|| "Unknown".to_string())
} else {
"Unknown".to_string()
};
let fk_target_app_chain = if let Type::Path(_) = &fk_info.target_type {
let target_ty = &fk_info.target_type;
quote! {
.with_param(
"fk_target_app",
<#target_ty as #orm_crate::Model>::app_label(),
)
}
} else {
quote! {}
};
fk_id_registrations.push(quote! {
metadata.add_field(
#id_column_name.to_string(),
#migrations_crate::model_registry::FieldMetadata::new(
#migrations_crate::FieldType::Uuid
)
.with_nullable(#nullable)
.with_param("not_null", #not_null_str)
.with_param("unique", #unique_str)
.with_param("db_index", #db_index_str)
.with_param("fk_target", #target_model_name)
#fk_target_app_chain
);
});
}
let type_path = quote! { #struct_name }.to_string();
let constraint_registrations: Vec<TokenStream> = unique_constraint_names
.iter()
.zip(unique_constraint_field_lists.iter())
.map(|(name, fields)| {
let field_lits = fields.iter().map(|f| quote! { #f.to_string() });
quote! {
metadata.add_constraint(
#migrations_crate::ConstraintDefinition {
name: #name.to_string(),
constraint_type: "unique".to_string(),
fields: vec![ #(#field_lits),* ],
expression: None,
foreign_key_info: None,
}
);
}
})
.collect();
let code = quote! {
#[::ctor::ctor]
fn #register_fn_name() {
use #migrations_crate::model_registry::ModelMetadata;
let mut metadata = ModelMetadata::new(
#app_label,
#model_name,
#table_name,
);
#(#field_registrations)*
#(#fk_id_registrations)*
#(#m2m_registrations)*
#(#constraint_registrations)*
#migrations_crate::model_registry::global_registry().register_model(metadata);
#orm_crate::registry::global_model_registry().register(
#orm_crate::registry::ModelInfo {
app_label: #app_label.to_string(),
model_name: #model_name.to_string(),
type_path: #type_path.to_string(),
table_name: #table_name.to_string(),
}
);
}
};
Ok(code)
}
fn generate_relationship_registrations(
struct_name: &syn::Ident,
app_label: &str,
field_infos: &[FieldInfo],
fk_field_infos: &[ForeignKeyFieldInfo],
) -> TokenStream {
let reinhardt = get_reinhardt_crate();
let _orm_crate = get_reinhardt_orm_crate();
let linkme = get_linkme_crate();
let mut registrations = Vec::new();
let model_name = struct_name.to_string();
for fk_info in fk_field_infos {
let field_name = &fk_info.field_name;
let field_name_str = field_name.to_string();
let is_one_to_one = fk_info.is_one_to_one;
let target_model_name = if let Type::Path(type_path) = &fk_info.target_type {
type_path
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_else(|| "Unknown".to_string())
} else {
"Unknown".to_string()
};
let related_name_opt = fk_info.rel_attr.related_name.as_ref();
let related_name = related_name_opt
.map(|r| quote! { Some(#r) })
.unwrap_or(quote! { None });
let db_column = fk_info
.rel_attr
.db_column
.as_ref()
.map(|c| quote! { Some(#c) })
.unwrap_or_else(|| {
let default_db_column = format!("{}_id", field_name_str);
quote! { Some(#default_db_column) }
});
let relationship_type = if is_one_to_one {
quote! { #reinhardt::apps::registry::RelationshipType::OneToOne }
} else {
quote! { #reinhardt::apps::registry::RelationshipType::ForeignKey }
};
let static_var_name = syn::Ident::new(
&format!(
"__REL_{}_{}_TO_{}",
model_name.to_uppercase(),
field_name_str.to_uppercase(),
target_model_name.to_uppercase()
),
struct_name.span(),
);
registrations.push(quote! {
#[#linkme::distributed_slice(#reinhardt::apps::registry::RELATIONSHIPS)]
static #static_var_name: #reinhardt::apps::registry::RelationshipMetadata =
#reinhardt::apps::registry::RelationshipMetadata {
from_model: concat!(#app_label, ".", #model_name),
to_model: #target_model_name,
relationship_type: #relationship_type,
field_name: #field_name_str,
related_name: #related_name,
db_column: #db_column,
through_table: None,
};
});
if let Some(related_name_str) = related_name_opt {
let reverse_relationship_type = if is_one_to_one {
quote! { #reinhardt::apps::registry::RelationshipType::OneToOne }
} else {
quote! { #reinhardt::apps::registry::RelationshipType::ForeignKey }
};
let reverse_static_var_name = syn::Ident::new(
&format!(
"__REL_REVERSE_{}_{}_TO_{}",
target_model_name.to_uppercase(),
related_name_str.to_uppercase(),
model_name.to_uppercase()
),
struct_name.span(),
);
registrations.push(quote! {
#[#linkme::distributed_slice(#reinhardt::apps::registry::RELATIONSHIPS)]
static #reverse_static_var_name: #reinhardt::apps::registry::RelationshipMetadata =
#reinhardt::apps::registry::RelationshipMetadata {
from_model: #target_model_name,
to_model: concat!(#app_label, ".", #model_name),
relationship_type: #reverse_relationship_type,
field_name: #related_name_str,
related_name: Some(#field_name_str),
db_column: None,
through_table: None,
};
});
}
}
for field_info in field_infos {
if !is_many_to_many_field_type(&field_info.ty) {
continue;
}
let field_name = &field_info.name;
let field_name_str = field_name.to_string();
let target_model_name = if let Some(target_ty) = extract_m2m_target_type(&field_info.ty) {
if let Type::Path(type_path) = target_ty {
type_path
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_else(|| "Unknown".to_string())
} else {
continue; }
} else {
continue; };
let (related_name, through_table, related_name_opt) = if let Some(rel) = &field_info.rel {
let related_name_str = rel.related_name.as_ref();
let related_name = related_name_str
.map(|r| quote! { Some(#r) })
.unwrap_or(quote! { None });
let through_table = rel
.through
.as_ref()
.map(|t| {
let through_str = quote! { #t }.to_string();
quote! { Some(#through_str) }
})
.unwrap_or(quote! { None });
(related_name, through_table, related_name_str)
} else {
(quote! { None }, quote! { None }, None)
};
let static_var_name = syn::Ident::new(
&format!(
"__REL_M2M_{}_{}_TO_{}",
model_name.to_uppercase(),
field_name_str.to_uppercase(),
target_model_name.to_uppercase()
),
struct_name.span(),
);
registrations.push(quote! {
#[#linkme::distributed_slice(#reinhardt::apps::registry::RELATIONSHIPS)]
static #static_var_name: #reinhardt::apps::registry::RelationshipMetadata =
#reinhardt::apps::registry::RelationshipMetadata {
from_model: concat!(#app_label, ".", #model_name),
to_model: #target_model_name,
relationship_type: #reinhardt::apps::registry::RelationshipType::ManyToMany,
field_name: #field_name_str,
related_name: #related_name,
db_column: None,
through_table: #through_table,
};
});
if let Some(related_name_str) = related_name_opt {
let reverse_static_var_name = syn::Ident::new(
&format!(
"__REL_M2M_REVERSE_{}_{}_TO_{}",
target_model_name.to_uppercase(),
related_name_str.to_uppercase(),
model_name.to_uppercase()
),
struct_name.span(),
);
registrations.push(quote! {
#[#linkme::distributed_slice(#reinhardt::apps::registry::RELATIONSHIPS)]
static #reverse_static_var_name: #reinhardt::apps::registry::RelationshipMetadata =
#reinhardt::apps::registry::RelationshipMetadata {
from_model: #target_model_name,
to_model: concat!(#app_label, ".", #model_name),
relationship_type: #reinhardt::apps::registry::RelationshipType::ManyToMany,
field_name: #related_name_str,
related_name: Some(#field_name_str),
db_column: None,
through_table: #through_table,
};
});
}
}
quote! {
#(#registrations)*
}
}
fn generate_composite_pk_impl(pk_fields: &[&FieldInfo]) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let field_name_strings: Vec<String> = pk_fields.iter().map(|f| f.name.to_string()).collect();
quote! {
fn composite_primary_key() -> Option<#orm_crate::composite_pk::CompositePrimaryKey> {
Some(
#orm_crate::composite_pk::CompositePrimaryKey::new(
vec![#(#field_name_strings.to_string()),*]
)
.expect("Invalid composite primary key")
)
}
fn get_composite_pk_values(&self) -> ::std::collections::HashMap<String, #orm_crate::composite_pk::PkValue> {
if let Some(pk) = self.primary_key() {
pk.to_pk_values()
} else {
::std::collections::HashMap::new()
}
}
}
}
fn generate_composite_pk_type(struct_name: &syn::Ident, pk_fields: &[&FieldInfo]) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let composite_pk_name =
syn::Ident::new(&format!("{}CompositePk", struct_name), struct_name.span());
let field_names: Vec<_> = pk_fields.iter().map(|f| &f.name).collect();
let field_types: Vec<_> = pk_fields
.iter()
.map(|f| {
let ty = &f.ty;
let (is_option, inner_ty) = extract_option_type(ty);
if is_option { inner_ty } else { ty }
})
.collect();
let tuple_type = if field_types.len() == 1 {
quote! { #(#field_types),* }
} else {
quote! { (#(#field_types),*) }
};
let pk_value_conversions: Vec<_> = field_names
.iter()
.map(|name| {
quote! {
values.insert(
stringify!(#name).to_string(),
#orm_crate::composite_pk::PkValue::from(&self.#name)
);
}
})
.collect();
quote! {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct #composite_pk_name {
#(pub #field_names: #field_types),*
}
impl #composite_pk_name {
pub fn new(#(#field_names: #field_types),*) -> Self {
Self {
#(#field_names),*
}
}
pub fn to_pk_values(&self) -> ::std::collections::HashMap<String, #orm_crate::composite_pk::PkValue> {
let mut values = ::std::collections::HashMap::new();
#(#pk_value_conversions)*
values
}
}
impl ::std::convert::From<#tuple_type> for #composite_pk_name {
fn from(tuple: #tuple_type) -> Self {
let (#(#field_names),*) = tuple;
Self {
#(#field_names),*
}
}
}
impl ::std::convert::From<#composite_pk_name> for #tuple_type {
fn from(pk: #composite_pk_name) -> Self {
(#(pk.#field_names),*)
}
}
impl ::std::fmt::Display for #composite_pk_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
write!(f, "(")?;
let mut first = true;
#(
if !first {
write!(f, ", ")?;
}
write!(f, "{}={}", stringify!(#field_names), self.#field_names)?;
first = false;
)*
write!(f, ")")
}
}
}
}
fn generate_relationship_metadata(
rel_fields: &[(Ident, RelAttribute)],
_app_label: &str,
_struct_name: &Ident,
) -> TokenStream {
use crate::rel::RelationType;
let orm_crate = get_reinhardt_orm_crate();
if rel_fields.is_empty() {
return quote! {
fn relationship_metadata() -> Vec<#orm_crate::inspection::RelationInfo> {
Vec::new()
}
};
}
let relation_info_items: Vec<TokenStream> = rel_fields
.iter()
.map(|(field_name, rel)| {
let field_name_str = field_name.to_string();
let relationship_type = match rel.rel_type {
RelationType::ForeignKey => {
quote! { #orm_crate::relationship::RelationshipType::ManyToOne }
}
RelationType::OneToOne => {
quote! { #orm_crate::relationship::RelationshipType::OneToOne }
}
RelationType::OneToMany => {
quote! { #orm_crate::relationship::RelationshipType::OneToMany }
}
RelationType::ManyToMany | RelationType::PolymorphicManyToMany => {
quote! { #orm_crate::relationship::RelationshipType::ManyToMany }
}
RelationType::Polymorphic | RelationType::GenericForeignKey => {
quote! { #orm_crate::relationship::RelationshipType::ManyToOne }
}
RelationType::GenericRelation => {
quote! { #orm_crate::relationship::RelationshipType::OneToMany }
}
};
let related_model = rel.to.as_ref().map_or_else(
|| quote! { "" },
|path| {
let path_str = quote! { #path }.to_string();
quote! { #path_str }
},
);
let back_populates = rel.related_name.as_ref().map_or_else(
|| quote! { None },
|name| quote! { Some(#name.to_string()) },
);
let foreign_key = match rel.rel_type {
RelationType::ForeignKey | RelationType::OneToOne => {
quote! { Some(#field_name_str.to_string()) }
}
RelationType::OneToMany => rel
.foreign_key
.as_ref()
.map_or_else(|| quote! { None }, |fk| quote! { Some(#fk.to_string()) }),
_ => quote! { None },
};
let through_table = rel
.through
.as_ref()
.map_or_else(|| quote! { None }, |t| quote! { Some(#t.to_string()) });
let source_field = rel
.source_field
.as_ref()
.map_or_else(|| quote! { None }, |s| quote! { Some(#s.to_string()) });
let target_field = rel
.target_field
.as_ref()
.map_or_else(|| quote! { None }, |t| quote! { Some(#t.to_string()) });
quote! {
#orm_crate::inspection::RelationInfo {
name: #field_name_str.to_string(),
relationship_type: #relationship_type,
foreign_key: #foreign_key,
related_model: #related_model.to_string(),
back_populates: #back_populates,
through_table: #through_table,
source_field: #source_field,
target_field: #target_field,
}
}
})
.collect();
quote! {
fn relationship_metadata() -> Vec<#orm_crate::inspection::RelationInfo> {
vec![
#(#relation_info_items),*
]
}
}
}
fn is_uuid_type(ty: &Type) -> bool {
crate::pk_shape::pk_uuid_shape(ty).0
}
fn is_string_type(ty: &Type) -> bool {
let (_, inner_ty) = extract_option_type(ty);
if let Type::Path(type_path) = inner_ty
&& let Some(last_segment) = type_path.path.segments.last()
{
return last_segment.ident == "String";
}
false
}
fn is_integer_primary_key_type(ty: &Type) -> bool {
let (_, inner_ty) = extract_option_type(ty);
if let Type::Path(type_path) = inner_ty
&& let Some(last_segment) = type_path.path.segments.last()
{
let ident_str = last_segment.ident.to_string();
return matches!(
ident_str.as_str(),
"i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize"
);
}
false
}
fn is_datetime_utc_type(ty: &Type) -> bool {
let (_, inner_ty) = extract_option_type(ty);
if let Type::Path(type_path) = inner_ty
&& let Some(last_segment) = type_path.path.segments.last()
{
if last_segment.ident != "DateTime" {
return false;
}
if let PathArguments::AngleBracketed(args) = &last_segment.arguments
&& let Some(GenericArgument::Type(Type::Path(arg_path))) = args.args.first()
&& let Some(arg_segment) = arg_path.path.segments.last()
{
return arg_segment.ident == "Utc";
}
return true;
}
false
}
fn is_many_to_many_field_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
{
return last_segment.ident == "ManyToManyField";
}
false
}
fn is_foreign_key_field_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
{
return last_segment.ident == "ForeignKeyField";
}
false
}
fn is_one_to_one_field_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
{
return last_segment.ident == "OneToOneField";
}
false
}
fn extract_fk_target_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
&& (last_segment.ident == "ForeignKeyField" || last_segment.ident == "OneToOneField")
&& let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
&& let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
{
return Some(inner_ty);
}
None
}
fn extract_m2m_target_type(ty: &Type) -> Option<&Type> {
if let Type::Path(type_path) = ty
&& let Some(last_segment) = type_path.path.segments.last()
&& last_segment.ident == "ManyToManyField"
&& let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments
&& args.args.len() >= 2
&& let Some(syn::GenericArgument::Type(target_ty)) = args.args.iter().nth(1)
{
return Some(target_ty);
}
None
}
fn is_relationship_field_type(ty: &Type) -> bool {
is_foreign_key_field_type(ty) || is_one_to_one_field_type(ty)
}
fn is_timestamp_field(field: &FieldInfo) -> bool {
let config = &field.config;
let auto_timestamp = config.auto_now_add == Some(true) || config.auto_now == Some(true);
#[cfg(feature = "db-mysql")]
let mysql_timestamp = config.on_update_current_timestamp == Some(true);
#[cfg(not(feature = "db-mysql"))]
let mysql_timestamp = false;
auto_timestamp || mysql_timestamp
}
fn extract_foreign_key_target_type(ty: &Type) -> Type {
if let Type::Path(type_path) = ty
&& let Some(segment) = type_path.path.segments.last()
&& let PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(GenericArgument::Type(inner_ty)) = args.args.first()
{
return inner_ty.clone();
}
ty.clone()
}
fn is_option_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(type_path) = ty
&& let Some(segment) = type_path.path.segments.last()
{
return segment.ident == "Option";
}
false
}
fn is_auto_generated_field(field: &FieldInfo) -> bool {
if field.config.skip {
return true;
}
if field.is_fk_id_field {
return true;
}
let config = &field.config;
if config.include_in_new == Some(false) {
return true;
}
if config.include_in_new == Some(true) {
return false;
}
if is_timestamp_field(field) {
return true;
}
if config.generated.is_some() {
return true;
}
#[cfg(feature = "db-postgres")]
{
if config.identity_always == Some(true) || config.identity_by_default == Some(true) {
return true;
}
}
#[cfg(feature = "db-mysql")]
{
if config.auto_increment == Some(true) {
return true;
}
}
#[cfg(feature = "db-sqlite")]
{
if config.autoincrement == Some(true) {
return true;
}
}
if is_many_to_many_field_type(&field.ty) {
return true;
}
if is_relationship_field_type(&field.ty) {
return true;
}
if let Some(rel) = &field.rel
&& matches!(rel.rel_type, crate::rel::RelationType::ManyToMany)
{
return true;
}
if config.primary_key && is_uuid_type(&field.ty) {
return true;
}
if config.primary_key && is_integer_primary_key_type(&field.ty) {
if config.auto_increment == Some(false) {
return false;
}
return true;
}
false
}
fn get_auto_field_default_value(field: &FieldInfo) -> TokenStream {
let config = &field.config;
if config.skip {
return quote! { ::std::default::Default::default() };
}
if is_many_to_many_field_type(&field.ty) {
return quote! { ::std::default::Default::default() };
}
if let Some(rel) = &field.rel
&& matches!(rel.rel_type, crate::rel::RelationType::ManyToMany)
{
return quote! { ::std::default::Default::default() };
}
if is_relationship_field_type(&field.ty) {
return quote! { ::std::default::Default::default() };
}
if is_timestamp_field(field) && is_datetime_utc_type(&field.ty) {
if is_option_type(&field.ty) {
return quote! { ::std::option::Option::Some(::chrono::Utc::now()) };
}
return quote! { ::chrono::Utc::now() };
}
if config.primary_key && is_uuid_type(&field.ty) {
let (is_option, _) = extract_option_type(&field.ty);
if is_option {
return quote! { Some(::uuid::Uuid::now_v7()) };
} else {
return quote! { ::uuid::Uuid::now_v7() };
}
}
if config.primary_key && is_integer_primary_key_type(&field.ty) {
let (is_option, inner_ty) = extract_option_type(&field.ty);
if is_option {
return quote! { ::std::option::Option::None };
} else {
return quote! { 0 as #inner_ty };
}
}
quote! { ::std::default::Default::default() }
}
fn generate_new_function(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
fk_id_field_names: &[syn::Ident],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let user_fields: Vec<_> = field_infos
.iter()
.filter(|f| !is_auto_generated_field(f))
.collect();
let auto_fields: Vec<_> = field_infos
.iter()
.filter(|f| is_auto_generated_field(f))
.collect();
let fk_id_to_fk_field: HashMap<String, String> = fk_id_field_names
.iter()
.filter_map(|id_name| {
let id_str = id_name.to_string();
if id_str.ends_with("_id") {
let fk_name = id_str.trim_end_matches("_id").to_string();
Some((id_str, fk_name))
} else {
None
}
})
.collect();
let mut params = Vec::new();
let mut where_clauses = Vec::new();
let mut generic_params = Vec::new();
let mut fk_field_assignments = Vec::new();
let mut fk_id_assignments = Vec::new();
let mut generic_counter = 0;
let mut string_fields: HashMap<String, bool> = HashMap::new();
for f in user_fields.iter() {
let field_name = &f.name;
let field_name_str = field_name.to_string();
if let Some(fk_field_name) = fk_id_to_fk_field.get(&field_name_str) {
let generic_param =
syn::Ident::new(&format!("F{}", generic_counter), field_name.span());
generic_counter += 1;
let fk_field_info = field_infos.iter().find(|fi| fi.name == fk_field_name);
if let Some(fk_info) = fk_field_info {
let related_model_type = extract_foreign_key_target_type(&fk_info.ty);
let fk_field_ident = syn::Ident::new(fk_field_name, field_name.span());
params.push(quote! { #fk_field_ident: #generic_param });
where_clauses.push(quote! {
#generic_param: #orm_crate::IntoPrimaryKey<#related_model_type>
});
generic_params.push(quote! { #generic_param });
fk_id_assignments.push(quote! {
#field_name: #fk_field_ident.into_primary_key()
});
}
} else {
let ty = &f.ty;
let (is_option, _) = extract_option_type(ty);
if is_string_type(ty) && !is_option {
let generic_param =
syn::Ident::new(&format!("S{}", generic_counter), field_name.span());
generic_counter += 1;
params.push(quote! { #field_name: #generic_param });
where_clauses
.push(quote! { #generic_param: ::std::convert::Into<::std::string::String> });
generic_params.push(quote! { #generic_param });
string_fields.insert(field_name_str.clone(), false);
} else {
params.push(quote! { #field_name: #ty });
}
}
}
for (_fk_id_str, fk_name_str) in fk_id_to_fk_field.iter() {
let fk_name = syn::Ident::new(fk_name_str, proc_macro2::Span::call_site());
fk_field_assignments.push(quote! {
#fk_name: ::std::default::Default::default()
});
}
for fk_id_name in fk_id_field_names.iter() {
let fk_id_str = fk_id_name.to_string();
if let Some(fk_field_name) = fk_id_to_fk_field.get(&fk_id_str) {
let fk_field_info = field_infos.iter().find(|fi| fi.name == fk_field_name);
if let Some(fk_info) = fk_field_info {
let related_model_type = extract_foreign_key_target_type(&fk_info.ty);
let generic_param =
syn::Ident::new(&format!("F{}", generic_counter), fk_id_name.span());
generic_counter += 1;
let fk_field_ident = syn::Ident::new(fk_field_name, fk_id_name.span());
params.push(quote! { #fk_field_ident: #generic_param });
where_clauses.push(quote! {
#generic_param: #orm_crate::IntoPrimaryKey<#related_model_type>
});
generic_params.push(quote! { #generic_param });
fk_id_assignments.push(quote! {
#fk_id_name: #fk_field_ident.into_primary_key()
});
} else {
fk_id_assignments.push(quote! {
#fk_id_name: ::std::default::Default::default()
});
}
} else {
fk_id_assignments.push(quote! {
#fk_id_name: ::std::default::Default::default()
});
}
}
let fk_field_names: std::collections::HashSet<String> =
fk_id_to_fk_field.values().cloned().collect();
let fk_id_field_names_set: std::collections::HashSet<String> =
fk_id_to_fk_field.keys().cloned().collect();
let user_field_assignments: Vec<_> = user_fields
.iter()
.filter(|f| {
!fk_field_names.contains(&f.name.to_string())
&& !fk_id_field_names_set.contains(&f.name.to_string())
})
.map(|f| {
let name = &f.name;
let name_str = name.to_string();
if string_fields.contains_key(&name_str) {
quote! { #name: #name.into() }
} else {
quote! { #name }
}
})
.collect();
let auto_field_assignments: Vec<_> = auto_fields
.iter()
.filter(|f| {
!fk_field_names.contains(&f.name.to_string())
&& !fk_id_field_names_set.contains(&f.name.to_string())
})
.map(|f| {
let name = &f.name;
let default_value = get_auto_field_default_value(f);
quote! { #name: #default_value }
})
.collect();
let generic_signature = if generic_params.is_empty() {
quote! {}
} else {
quote! { <#(#generic_params),*> }
};
let where_clause = if where_clauses.is_empty() {
quote! {}
} else {
quote! { where #(#where_clauses),* }
};
quote! {
impl #struct_name {
#[allow(clippy::too_many_arguments)]
pub fn new #generic_signature(#(#params),*) -> Self
#where_clause
{
Self {
#(#user_field_assignments,)*
#(#fk_id_assignments,)*
#(#fk_field_assignments,)*
#(#auto_field_assignments,)*
}
}
}
}
}
fn generate_build_function(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
fk_id_field_names: &[syn::Ident],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let user_fields: Vec<_> = field_infos
.iter()
.filter(|f| !is_auto_generated_field(f))
.collect();
let auto_fields: Vec<_> = field_infos
.iter()
.filter(|f| is_auto_generated_field(f))
.collect();
let fk_id_to_fk_field: HashMap<String, String> = fk_id_field_names
.iter()
.filter_map(|id_name| {
let id_str = id_name.to_string();
if id_str.ends_with("_id") {
let fk_name = id_str.trim_end_matches("_id").to_string();
Some((id_str, fk_name))
} else {
None
}
})
.collect();
enum SetterKind {
ForeignKey {
related_type: Box<Type>,
setter_name: syn::Ident,
},
String,
Plain,
}
struct Required<'a> {
storage_name: syn::Ident,
storage_ty: &'a Type,
kind: SetterKind,
}
let mut required: Vec<Required> =
Vec::with_capacity(user_fields.len() + fk_id_field_names.len());
for f in user_fields.iter() {
let name_str = f.name.to_string();
if let Some(fk_field_name) = fk_id_to_fk_field.get(&name_str) {
let fk_field_info = field_infos.iter().find(|fi| fi.name == *fk_field_name);
let related_type = match fk_field_info {
Some(info) => extract_foreign_key_target_type(&info.ty),
None => f.ty.clone(),
};
let setter_name = syn::Ident::new(fk_field_name, f.name.span());
required.push(Required {
storage_name: f.name.clone(),
storage_ty: &f.ty,
kind: SetterKind::ForeignKey {
related_type: Box::new(related_type),
setter_name,
},
});
} else if is_string_type(&f.ty) && !extract_option_type(&f.ty).0 {
required.push(Required {
storage_name: f.name.clone(),
storage_ty: &f.ty,
kind: SetterKind::String,
});
} else {
required.push(Required {
storage_name: f.name.clone(),
storage_ty: &f.ty,
kind: SetterKind::Plain,
});
}
}
for fk_id_name in fk_id_field_names.iter() {
let fk_id_str = fk_id_name.to_string();
let Some(fk_field_name) = fk_id_to_fk_field.get(&fk_id_str) else {
continue;
};
let id_field_info = field_infos
.iter()
.find(|fi| fi.name == *fk_id_name)
.unwrap_or_else(|| {
panic!(
"internal macro invariant: `{}` is in fk_id_field_names but missing from field_infos",
fk_id_str
)
});
let fk_field_info = field_infos.iter().find(|fi| fi.name == *fk_field_name);
let related_type = match fk_field_info {
Some(info) => extract_foreign_key_target_type(&info.ty),
None => id_field_info.ty.clone(),
};
let setter_name = match fk_field_info {
Some(info) => info.name.clone(),
None => {
let bare = fk_field_name
.strip_prefix("r#")
.unwrap_or(fk_field_name.as_str());
if matches!(bare, "self" | "Self" | "super" | "crate") {
return syn::Error::new(
fk_id_name.span(),
format!(
"cannot derive builder setter for FK field `{fk_id_str}`: \
the implied setter name `{bare}` is a reserved identifier; \
rename the related model-typed field or the `*_id` field"
),
)
.to_compile_error();
}
syn::Ident::new_raw(bare, fk_id_name.span())
}
};
required.push(Required {
storage_name: id_field_info.name.clone(),
storage_ty: &id_field_info.ty,
kind: SetterKind::ForeignKey {
related_type: Box::new(related_type),
setter_name,
},
});
}
let builder_name = syn::Ident::new(&format!("{}Builder", struct_name), struct_name.span());
let set_marker = syn::Ident::new(&format!("{}BuilderSet", struct_name), struct_name.span());
let unset_marker = syn::Ident::new(&format!("{}BuilderUnset", struct_name), struct_name.span());
let state_params: Vec<syn::Ident> = (0..required.len())
.map(|i| syn::Ident::new(&format!("B{}", i), struct_name.span()))
.collect();
let builder_struct_fields: Vec<TokenStream> = required
.iter()
.map(|r| {
let name = &r.storage_name;
let ty = r.storage_ty;
quote! { #name: ::std::option::Option<#ty> }
})
.collect();
let init_struct_field_assignments: Vec<TokenStream> = required
.iter()
.map(|r| {
let name = &r.storage_name;
quote! { #name: ::std::option::Option::None }
})
.collect();
let mut setter_impls: Vec<TokenStream> = Vec::with_capacity(required.len());
for (idx, r) in required.iter().enumerate() {
let other_params: Vec<&syn::Ident> = state_params
.iter()
.enumerate()
.filter_map(|(i, p)| if i == idx { None } else { Some(p) })
.collect();
let input_states: Vec<TokenStream> = state_params
.iter()
.enumerate()
.map(|(i, p)| {
if i == idx {
quote! { #unset_marker }
} else {
quote! { #p }
}
})
.collect();
let output_states: Vec<TokenStream> = state_params
.iter()
.enumerate()
.map(|(i, p)| {
if i == idx {
quote! { #set_marker }
} else {
quote! { #p }
}
})
.collect();
let copy_fields: Vec<TokenStream> = required
.iter()
.enumerate()
.map(|(i, other)| {
let n = &other.storage_name;
if i == idx {
quote! {}
} else {
quote! { #n: self.#n, }
}
})
.collect();
let storage_name = &r.storage_name;
let storage_ty = r.storage_ty;
let (setter_sig, value_expr): (TokenStream, TokenStream) = match &r.kind {
SetterKind::ForeignKey {
related_type,
setter_name,
} => {
let sig = quote! {
pub fn #setter_name<__FkArg>(self, value: __FkArg)
-> #builder_name<#(#output_states),*>
where
__FkArg: #orm_crate::IntoPrimaryKey<#related_type>,
};
let expr = quote! { value.into_primary_key() };
(sig, expr)
}
SetterKind::String => {
let sig = quote! {
pub fn #storage_name<__StrArg>(self, value: __StrArg)
-> #builder_name<#(#output_states),*>
where
__StrArg: ::std::convert::Into<::std::string::String>,
};
let expr = quote! { value.into() };
(sig, expr)
}
SetterKind::Plain => {
let sig = quote! {
pub fn #storage_name(self, value: #storage_ty)
-> #builder_name<#(#output_states),*>
};
let expr = quote! { value };
(sig, expr)
}
};
let other_param_list = if other_params.is_empty() {
quote! {}
} else {
quote! { <#(#other_params),*> }
};
setter_impls.push(quote! {
impl #other_param_list #builder_name<#(#input_states),*> {
#setter_sig
{
#builder_name {
#(#copy_fields)*
#storage_name: ::std::option::Option::Some(#value_expr),
__state: ::std::marker::PhantomData,
}
}
}
});
}
let fk_id_field_names_set: std::collections::HashSet<String> =
fk_id_to_fk_field.keys().cloned().collect();
let fk_field_names: std::collections::HashSet<String> =
fk_id_to_fk_field.values().cloned().collect();
let user_field_assignments: Vec<TokenStream> = user_fields
.iter()
.filter(|f| {
!fk_field_names.contains(&f.name.to_string())
&& !fk_id_field_names_set.contains(&f.name.to_string())
})
.map(|f| {
let name = &f.name;
quote! {
#name: self
.#name
.expect(concat!(
"build() typestate guarantees ",
stringify!(#name),
" is set before finish() is callable"
))
}
})
.collect();
let fk_id_assignments: Vec<TokenStream> = fk_id_field_names
.iter()
.map(|fk_id_name| {
let name = fk_id_name.clone();
quote! {
#name: self
.#name
.expect(concat!(
"build() typestate guarantees ",
stringify!(#name),
" is set before finish() is callable"
))
}
})
.collect();
let fk_field_assignments: Vec<TokenStream> = fk_id_to_fk_field
.values()
.map(|fk_name_str| {
let fk_name = syn::Ident::new(fk_name_str, proc_macro2::Span::call_site());
quote! { #fk_name: ::std::default::Default::default() }
})
.collect();
let auto_field_assignments: Vec<TokenStream> = auto_fields
.iter()
.filter(|f| {
!fk_field_names.contains(&f.name.to_string())
&& !fk_id_field_names_set.contains(&f.name.to_string())
})
.map(|f| {
let name = &f.name;
let default_value = get_auto_field_default_value(f);
quote! { #name: #default_value }
})
.collect();
let all_set_states: Vec<TokenStream> = state_params
.iter()
.map(|_| quote! { #set_marker })
.collect();
let state_param_list = if state_params.is_empty() {
quote! {}
} else {
quote! { <#(#state_params),*> }
};
let initial_unset_states: Vec<TokenStream> = state_params
.iter()
.map(|_| quote! { #unset_marker })
.collect();
let phantom_tuple_ty = if state_params.is_empty() {
quote! { () }
} else {
quote! { ( #(#state_params,)* ) }
};
let allow_dead = quote! { #[allow(dead_code)] };
quote! {
#allow_dead
pub struct #set_marker;
#allow_dead
pub struct #unset_marker;
#allow_dead
pub struct #builder_name #state_param_list {
#(#builder_struct_fields,)*
__state: ::std::marker::PhantomData<#phantom_tuple_ty>,
}
impl #struct_name {
pub fn build() -> #builder_name<#(#initial_unset_states),*> {
#builder_name {
#(#init_struct_field_assignments,)*
__state: ::std::marker::PhantomData,
}
}
}
#(#setter_impls)*
impl #builder_name<#(#all_set_states),*> {
pub fn finish(self) -> #struct_name {
#struct_name {
#(#user_field_assignments,)*
#(#fk_id_assignments,)*
#(#fk_field_assignments,)*
#(#auto_field_assignments,)*
}
}
}
}
}
fn generate_field_selector_struct(
struct_name: &syn::Ident,
field_infos: &[FieldInfo],
) -> TokenStream {
let orm_crate = get_reinhardt_orm_crate();
let regular_fields: Vec<_> = field_infos
.iter()
.filter(|f| {
if f.config.skip {
return false;
}
!is_foreign_key_field_type(&f.ty)
&& !is_one_to_one_field_type(&f.ty)
&& !is_many_to_many_field_type(&f.ty)
})
.collect();
let field_selector_name =
syn::Ident::new(&format!("{}Fields", struct_name), struct_name.span());
let field_declarations: Vec<_> = regular_fields
.iter()
.map(|field| {
let field_name = &field.name;
let field_type = &field.ty;
quote! {
#field_name: #orm_crate::query_fields::Field<#struct_name, #field_type>
}
})
.collect();
let field_initializers: Vec<_> = regular_fields
.iter()
.map(|field| {
let field_name = &field.name;
let field_name_str = field_name.to_string();
quote! {
#field_name: #orm_crate::query_fields::Field::new(vec![#field_name_str])
}
})
.collect();
let regular_field_names: Vec<_> = regular_fields.iter().map(|field| &field.name).collect();
quote! {
#[derive(Debug, Clone)]
pub struct #field_selector_name {
#(#field_declarations),*
}
impl #field_selector_name {
pub fn new() -> Self {
Self {
#(#field_initializers),*
}
}
}
impl #orm_crate::FieldSelector for #field_selector_name {
fn with_alias(mut self, alias: &str) -> Self {
#(self.#regular_field_names = self.#regular_field_names.with_alias(alias);)*
self
}
}
impl ::std::default::Default for #field_selector_name {
fn default() -> Self {
Self::new()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fields_are_private() {
let input = quote! {
#[model(app_label = "test", table_name = "test")]
pub struct TestModel {
#[field(primary_key = true)]
pub id: i64,
#[field(max_length = 255)]
pub name: String,
}
};
let output = model_derive_impl(syn::parse2(input).unwrap()).unwrap();
let output_str = output.to_string();
assert!(!output_str.contains("pub id"));
assert!(!output_str.contains("pub name"));
}
#[test]
fn test_getter_methods_generated() {
let input = quote! {
#[model(app_label = "test", table_name = "test")]
pub struct TestModel {
#[field(primary_key = true)]
pub id: i64,
#[field(max_length = 255)]
pub name: String,
}
};
let output = model_derive_impl(syn::parse2(input).unwrap()).unwrap();
let output_str = output.to_string();
assert!(output_str.contains("pub fn id"));
assert!(output_str.contains("pub fn name"));
}
#[test]
fn test_setter_methods_exclude_auto_fields() {
let input = quote! {
#[model(app_label = "test", table_name = "test")]
pub struct TestModel {
#[field(primary_key = true)]
pub id: i64,
#[field(max_length = 255)]
pub name: String,
#[field(auto_now_add = true)]
pub created_at: DateTime<Utc>,
}
};
let output = model_derive_impl(syn::parse2(input).unwrap()).unwrap();
let output_str = output.to_string();
assert!(output_str.contains("pub fn set_name"));
assert!(!output_str.contains("pub fn set_id"));
assert!(!output_str.contains("pub fn set_created_at"));
}
}