use std::collections::HashSet;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{Error, GenericArgument, PathArguments, Type};
use crate::parse::{FieldDef, ModelDef};
pub fn validate_model(model: &ModelDef) -> Result<(), Error> {
let mut errors = Vec::new();
validate_has_fields(model, &mut errors);
validate_table_name(&model.table_name, model.name.span(), &mut errors);
validate_no_duplicate_columns(model, &mut errors);
for field in &model.fields {
validate_field(field, &mut errors);
}
validate_auto_increment_has_pk(model, &mut errors);
validate_joined_inheritance_parent_field(model, &mut errors);
if errors.is_empty() {
Ok(())
} else {
let mut combined = errors.remove(0);
for err in errors {
combined.combine(err);
}
Err(combined)
}
}
fn validate_joined_inheritance_parent_field(model: &ModelDef, errors: &mut Vec<Error>) {
if !model.config.table {
return;
}
if model.config.inherits.is_none() {
return;
}
if model.config.discriminator_value.is_some() {
return;
}
if model.config.inheritance != crate::parse::InheritanceStrategy::Joined {
return;
}
let parent_fields: Vec<_> = model.fields.iter().filter(|f| f.parent).collect();
if parent_fields.len() != 1 {
errors.push(Error::new(
model.name.span(),
"joined-table inheritance child models must include exactly one `#[sqlmodel(parent)]` field to embed the parent model",
));
}
}
fn validate_has_fields(model: &ModelDef, errors: &mut Vec<Error>) {
if model.fields.is_empty() {
errors.push(Error::new(
model.name.span(),
"Model struct must have at least one field",
));
}
}
fn validate_table_name(table_name: &str, span: Span, errors: &mut Vec<Error>) {
const DANGEROUS_CHARS: &[char] = &[';', '\'', '"', '`', '-', '/', '*', '\\', '\0', '\n', '\r'];
for ch in table_name.chars() {
if DANGEROUS_CHARS.contains(&ch) {
errors.push(Error::new(
span,
format!(
"table name contains invalid character '{ch}'; \
table names should only contain alphanumeric characters and underscores"
),
));
return;
}
}
if table_name.trim().is_empty() {
errors.push(Error::new(span, "table name cannot be empty or whitespace"));
}
if let Some(first) = table_name.chars().next() {
if !first.is_alphabetic() && first != '_' {
errors.push(Error::new(
span,
format!("table name must start with a letter or underscore, got '{first}'"),
));
}
}
}
fn validate_no_duplicate_columns(model: &ModelDef, errors: &mut Vec<Error>) {
let mut seen_columns: HashSet<&str> = HashSet::new();
for field in &model.fields {
if field.skip {
continue;
}
if !seen_columns.insert(&field.column_name) {
errors.push(Error::new(
field.name.span(),
format!(
"duplicate column name '{}'; another field already maps to this column",
field.column_name
),
));
}
}
}
fn validate_auto_increment_has_pk(model: &ModelDef, errors: &mut Vec<Error>) {
for field in &model.fields {
if field.auto_increment && !field.primary_key {
errors.push(Error::new(
field.name.span(),
"auto_increment requires primary_key; add #[sqlmodel(primary_key)] to this field",
));
}
}
}
fn validate_field(field: &FieldDef, errors: &mut Vec<Error>) {
validate_type(&field.ty, field.name.span(), errors);
validate_skip_conflicts(field, errors);
}
fn validate_type(ty: &Type, span: Span, errors: &mut Vec<Error>) {
if is_nested_option(ty) {
errors.push(Error::new(
span,
"nested Option<Option<T>> is ambiguous and not supported; \
use a single Option<T> or a custom type",
));
}
if is_reference_type(ty) {
errors.push(Error::new(
span,
"reference types (&T) are not supported; use owned types instead",
));
}
if is_raw_pointer(ty) {
errors.push(Error::new(
span,
"raw pointer types (*const T, *mut T) are not supported; use owned types instead",
));
}
}
fn validate_skip_conflicts(field: &FieldDef, errors: &mut Vec<Error>) {
if field.skip && field.unique {
errors.push(Error::new(
field.name.span(),
"cannot use both #[sqlmodel(skip)] and #[sqlmodel(unique)] on the same field; \
skipped fields are excluded from database operations",
));
}
if field.skip && field.foreign_key.is_some() {
errors.push(Error::new(
field.name.span(),
"cannot use both #[sqlmodel(skip)] and #[sqlmodel(foreign_key)] on the same field; \
skipped fields are excluded from database operations",
));
}
if field.skip && field.index.is_some() {
errors.push(Error::new(
field.name.span(),
"cannot use both #[sqlmodel(skip)] and #[sqlmodel(index)] on the same field; \
skipped fields are excluded from database operations",
));
}
}
fn is_nested_option(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(Type::Path(inner_path))) = args.args.first() {
if let Some(inner_seg) = inner_path.path.segments.last() {
return inner_seg.ident == "Option";
}
}
}
}
}
}
false
}
fn is_reference_type(ty: &Type) -> bool {
matches!(ty, Type::Reference(_))
}
fn is_raw_pointer(ty: &Type) -> bool {
matches!(ty, Type::Ptr(_))
}
#[allow(dead_code)]
fn type_to_string(ty: &Type) -> String {
ty.to_token_stream().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_is_nested_option() {
let ty: Type = parse_quote!(Option<Option<i32>>);
assert!(is_nested_option(&ty));
let ty: Type = parse_quote!(Option<i32>);
assert!(!is_nested_option(&ty));
let ty: Type = parse_quote!(i32);
assert!(!is_nested_option(&ty));
}
#[test]
fn test_is_reference_type() {
let ty: Type = parse_quote!(&str);
assert!(is_reference_type(&ty));
let ty: Type = parse_quote!(&'static str);
assert!(is_reference_type(&ty));
let ty: Type = parse_quote!(&mut String);
assert!(is_reference_type(&ty));
let ty: Type = parse_quote!(String);
assert!(!is_reference_type(&ty));
}
#[test]
fn test_is_raw_pointer() {
let ty: Type = parse_quote!(*const i32);
assert!(is_raw_pointer(&ty));
let ty: Type = parse_quote!(*mut u8);
assert!(is_raw_pointer(&ty));
let ty: Type = parse_quote!(Box<i32>);
assert!(!is_raw_pointer(&ty));
}
#[test]
fn test_validate_table_name_valid() {
let mut errors = Vec::new();
validate_table_name("users", Span::call_site(), &mut errors);
assert!(errors.is_empty());
validate_table_name("user_accounts", Span::call_site(), &mut errors);
assert!(errors.is_empty());
validate_table_name("_internal", Span::call_site(), &mut errors);
assert!(errors.is_empty());
}
#[test]
fn test_validate_table_name_invalid() {
let mut errors = Vec::new();
validate_table_name("users; DROP TABLE users", Span::call_site(), &mut errors);
assert!(!errors.is_empty());
let mut errors = Vec::new();
validate_table_name("user's", Span::call_site(), &mut errors);
assert!(!errors.is_empty());
let mut errors = Vec::new();
validate_table_name("123users", Span::call_site(), &mut errors);
assert!(!errors.is_empty());
}
}