Skip to main content

ferriorm_parser/
validator.rs

1//! Semantic validation of the raw AST into the resolved Schema IR.
2//!
3//! The validator walks the [`ferriorm_core::ast::SchemaFile`] produced by the
4//! parser and performs the following:
5//!
6//! - Resolves field type names to scalars, enums, or model references.
7//! - Infers database table and column names from `@@map`/`@map` or `snake_case`
8//!   conventions.
9//! - Checks that every model has a primary key (`@id` or `@@id`).
10//! - Detects duplicate model/enum names and unknown type references.
11//! - Resolves relation cardinality and referential actions.
12//!
13//! The output is an [`ferriorm_core::schema::Schema`], the canonical IR consumed
14//! by codegen and the migration engine.
15
16use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21    DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22    RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27/// Validate a parsed AST and produce a resolved Schema IR.
28///
29/// # Errors
30///
31/// Returns a [`CoreError`] if the AST has validation problems such as
32/// missing primary keys, unknown types, or duplicate names.
33pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34    let datasource = validate_datasource(ast)?;
35    let generators = validate_generators(ast)?;
36    let enums = validate_enums(ast)?;
37    let models = validate_models(ast, &enums)?;
38
39    validate_unique_db_names(&models)?;
40    validate_relation_disambiguation(&models)?;
41
42    Ok(Schema {
43        datasource,
44        generators,
45        enums,
46        models,
47    })
48}
49
50/// Two models cannot map to the same database table (`@@map("..."` /
51/// implicit snake_case-plural). Catching this here prevents conflicting
52/// CREATE TABLE statements at migration time.
53fn validate_unique_db_names(models: &[Model]) -> Result<(), CoreError> {
54    use std::collections::HashMap;
55    let mut seen: HashMap<&str, &str> = HashMap::new();
56    for m in models {
57        if let Some(existing) = seen.get(m.db_name.as_str()) {
58            return Err(CoreError::Validation {
59                message: format!(
60                    "Duplicate table name `{}` (used by models `{}` and `{}`). \
61                     Each model must map to a distinct table; use `@@map(\"...\")` to disambiguate.",
62                    m.db_name, existing, m.name,
63                ),
64            });
65        }
66        seen.insert(&m.db_name, &m.name);
67    }
68    Ok(())
69}
70
71/// Is `s` a Rust keyword (reserved or strict)? Used to reject schema
72/// field names that would cause `format_ident!` to panic in codegen.
73fn is_rust_keyword(s: &str) -> bool {
74    matches!(
75        s,
76        // Strict keywords
77        "as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
78        | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match"
79        | "mod" | "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self"
80        | "static" | "struct" | "super" | "trait" | "true" | "type" | "unsafe"
81        | "use" | "where" | "while"
82        // 2018+ keywords
83        | "async" | "await" | "dyn"
84        // Reserved (might become keywords)
85        | "abstract" | "become" | "box" | "do" | "final" | "macro" | "override"
86        | "priv" | "typeof" | "unsized" | "virtual" | "yield" | "try"
87    )
88}
89
90/// When two or more fields on the same model are *forward* FKs to the
91/// same target, OR two or more are *back-references* (implicit lists
92/// or `@relation` without `fields:`) from the same target, each must
93/// use `@relation("Name", ...)` to disambiguate. The forward and
94/// back-reference sides are tracked separately so a `parent` + `children`
95/// self-reference (one forward, one back) is unambiguous.
96fn validate_relation_disambiguation(models: &[Model]) -> Result<(), CoreError> {
97    use std::collections::{HashMap, HashSet};
98
99    for model in models {
100        // (target_model_name, is_fk_owner) -> fields in that group.
101        let mut groups: HashMap<(&str, bool), Vec<&Field>> = HashMap::new();
102
103        for field in &model.fields {
104            let target = match &field.field_type {
105                FieldKind::Model(name) => name.as_str(),
106                _ => continue,
107            };
108            let is_fk_owner = field
109                .relation
110                .as_ref()
111                .is_some_and(|r| !r.fields.is_empty());
112            groups.entry((target, is_fk_owner)).or_default().push(field);
113        }
114
115        for ((target, _), group) in &groups {
116            if group.len() < 2 {
117                continue;
118            }
119
120            let mut seen_names: HashSet<&str> = HashSet::new();
121            for field in group {
122                let name = field.relation.as_ref().and_then(|r| r.name.as_deref());
123                let Some(n) = name else {
124                    return Err(CoreError::Validation {
125                        message: format!(
126                            "Multiple relations from `{}` to `{}` require disambiguation. \
127                             Add `@relation(\"<Name>\", ...)` to each related field on both sides.",
128                            model.name, target,
129                        ),
130                    });
131                };
132                if !seen_names.insert(n) {
133                    return Err(CoreError::Validation {
134                        message: format!(
135                            "Duplicate relation name `{}` between `{}` and `{}`. \
136                             Each relation between the same pair of models must have a unique name.",
137                            n, model.name, target,
138                        ),
139                    });
140                }
141            }
142        }
143    }
144    Ok(())
145}
146
147fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
148    let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
149        message: "Missing datasource block".into(),
150    })?;
151
152    let provider =
153        ds.provider
154            .parse::<DatabaseProvider>()
155            .map_err(|_| CoreError::UnknownProvider {
156                provider: ds.provider.clone(),
157            })?;
158
159    let url = match &ds.url {
160        ast::StringOrEnv::Literal(s) => s.clone(),
161        ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
162    };
163
164    Ok(DatasourceConfig {
165        name: ds.name.clone(),
166        provider,
167        url,
168    })
169}
170
171fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
172    ast.generators
173        .iter()
174        .map(|g| {
175            Ok(GeneratorConfig {
176                name: g.name.clone(),
177                output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
178            })
179        })
180        .collect()
181}
182
183fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
184    let mut names = HashSet::new();
185    let mut result = Vec::new();
186
187    for e in &ast.enums {
188        if !names.insert(&e.name) {
189            return Err(CoreError::DuplicateName {
190                name: e.name.clone(),
191                kind: "enum",
192            });
193        }
194
195        result.push(Enum {
196            name: e.name.clone(),
197            db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
198            variants: e.variants.clone(),
199        });
200    }
201
202    Ok(result)
203}
204
205fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
206    let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
207    let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
208    let mut seen_names = HashSet::new();
209
210    let mut result = Vec::new();
211
212    for model_def in &ast.models {
213        if !seen_names.insert(&model_def.name) {
214            return Err(CoreError::DuplicateName {
215                name: model_def.name.clone(),
216                kind: "model",
217            });
218        }
219
220        // Check for name collision with enums
221        if enum_names.contains(model_def.name.as_str()) {
222            return Err(CoreError::DuplicateName {
223                name: model_def.name.clone(),
224                kind: "model/enum",
225            });
226        }
227
228        let model = validate_model(model_def, &enum_names, &model_names)?;
229        result.push(model);
230    }
231
232    Ok(result)
233}
234
235#[allow(clippy::too_many_lines)] // sequential validation passes; splitting hides the order
236fn validate_model(
237    model_def: &ast::ModelDef,
238    enum_names: &HashSet<&str>,
239    model_names: &HashSet<&str>,
240) -> Result<Model, CoreError> {
241    // Resolve @@map
242    let db_name = model_def
243        .attributes
244        .iter()
245        .find_map(|a| match a {
246            ast::BlockAttribute::Map(name) => Some(name.clone()),
247            _ => None,
248        })
249        .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
250
251    let mut fields = Vec::new();
252    let mut has_id_field = false;
253
254    for field_def in &model_def.fields {
255        let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
256        if field.is_id {
257            has_id_field = true;
258        }
259        fields.push(field);
260    }
261
262    // Check @@id for composite primary key
263    let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
264        ast::BlockAttribute::Id(fields) => Some(fields.clone()),
265        _ => None,
266    });
267
268    if !has_id_field && composite_id.is_none() {
269        return Err(CoreError::MissingPrimaryKey {
270            model_name: model_def.name.clone(),
271        });
272    }
273
274    // Field-name set for B4 (block-attribute field-existence checks).
275    // We accept either the schema field name or the snake_case form.
276    let field_name_set: HashSet<&str> = fields.iter().map(|f| f.name.as_str()).collect();
277    let field_db_set: HashSet<&str> = fields.iter().map(|f| f.db_name.as_str()).collect();
278    let field_resolver = |needle: &str| -> Option<&Field> {
279        fields
280            .iter()
281            .find(|f| f.name == needle || f.db_name == needle || to_snake_case(&f.name) == needle)
282    };
283
284    let primary_key = if let Some(composite_fields) = composite_id {
285        // B4 (PK): all named fields must exist on the model.
286        // B7: PK fields cannot be Json (uncomparable / unhashable in DBs).
287        for f in &composite_fields {
288            let Some(resolved) = field_resolver(f) else {
289                return Err(CoreError::Validation {
290                    message: format!(
291                        "`@@id` on model `{}` references unknown field `{}`",
292                        model_def.name, f,
293                    ),
294                });
295            };
296            if matches!(resolved.field_type, FieldKind::Scalar(ScalarType::Json)) {
297                return Err(CoreError::Validation {
298                    message: format!(
299                        "Field `{}.{}` of type `Json` cannot be part of a composite primary key.",
300                        model_def.name, resolved.name,
301                    ),
302                });
303            }
304        }
305        PrimaryKey {
306            fields: composite_fields,
307        }
308    } else {
309        let id_fields: Vec<String> = fields
310            .iter()
311            .filter(|f| f.is_id)
312            .map(|f| f.name.clone())
313            .collect();
314        PrimaryKey { fields: id_fields }
315    };
316
317    // B4: @@index / @@unique field-existence checks. Each named field
318    // must exist on the model; otherwise the migration would emit a
319    // CREATE INDEX referencing a non-existent column.
320    for attr in &model_def.attributes {
321        let (kind, fs) = match attr {
322            ast::BlockAttribute::Index(idx) => ("@@index", &idx.fields),
323            ast::BlockAttribute::Unique(idx) => ("@@unique", &idx.fields),
324            _ => continue,
325        };
326        for f in fs {
327            if !field_name_set.contains(f.as_str())
328                && !field_db_set.contains(f.as_str())
329                && field_resolver(f).is_none()
330            {
331                return Err(CoreError::Validation {
332                    message: format!(
333                        "`{}` on model `{}` references unknown field `{}`",
334                        kind, model_def.name, f,
335                    ),
336                });
337            }
338        }
339    }
340
341    // Indexes
342    let indexes = model_def
343        .attributes
344        .iter()
345        .filter_map(|a| match a {
346            ast::BlockAttribute::Index(idx) => Some(Index {
347                fields: idx.fields.clone(),
348                name: idx.name.clone(),
349            }),
350            _ => None,
351        })
352        .collect();
353
354    // Unique constraints (from @@unique)
355    let unique_constraints = model_def
356        .attributes
357        .iter()
358        .filter_map(|a| match a {
359            ast::BlockAttribute::Unique(idx) => Some(UniqueConstraint {
360                fields: idx.fields.clone(),
361                name: idx.name.clone(),
362            }),
363            _ => None,
364        })
365        .collect();
366
367    Ok(Model {
368        name: model_def.name.clone(),
369        db_name,
370        fields,
371        primary_key,
372        indexes,
373        unique_constraints,
374    })
375}
376
377#[allow(clippy::too_many_lines)] // sequential per-field checks; splitting hides the order
378fn validate_field(
379    field_def: &ast::FieldDef,
380    model_name: &str,
381    enum_names: &HashSet<&str>,
382    model_names: &HashSet<&str>,
383) -> Result<Field, CoreError> {
384    let type_name = &field_def.field_type.name;
385
386    // B1: reject Rust keywords as field names. Codegen would otherwise
387    // panic in `format_ident!`. Suggest `@map` for users who need the
388    // database column to keep that name.
389    if is_rust_keyword(&field_def.name) {
390        return Err(CoreError::Validation {
391            message: format!(
392                "Field name `{}.{}` is a Rust keyword and cannot be used as a struct field. \
393                 Rename the field and use `@map(\"{}\")` if you need that database column name.",
394                model_name, field_def.name, field_def.name,
395            ),
396        });
397    }
398
399    let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
400        FieldKind::Scalar(scalar)
401    } else if enum_names.contains(type_name.as_str()) {
402        FieldKind::Enum(type_name.clone())
403    } else if model_names.contains(type_name.as_str()) {
404        FieldKind::Model(type_name.clone())
405    } else {
406        return Err(CoreError::UnknownType {
407            model_name: model_name.to_string(),
408            field_name: field_def.name.clone(),
409            type_name: type_name.clone(),
410        });
411    };
412
413    let is_id = field_def
414        .attributes
415        .iter()
416        .any(|a| matches!(a, ast::FieldAttribute::Id));
417    let is_unique = field_def
418        .attributes
419        .iter()
420        .any(|a| matches!(a, ast::FieldAttribute::Unique));
421    let is_updated_at = field_def
422        .attributes
423        .iter()
424        .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
425    let default = field_def.attributes.iter().find_map(|a| match a {
426        ast::FieldAttribute::Default(d) => Some(d.clone()),
427        _ => None,
428    });
429
430    // B2: @id cannot appear on an optional field — primary keys are NOT NULL.
431    if is_id && field_def.field_type.is_optional {
432        return Err(CoreError::Validation {
433            message: format!(
434                "Field `{}.{}` is marked `@id` but is optional; primary key columns cannot be NULL.",
435                model_name, field_def.name,
436            ),
437        });
438    }
439
440    // B3: @default(autoincrement()) only applies to integer scalars.
441    if matches!(default, Some(ast::DefaultValue::AutoIncrement)) {
442        let is_int_scalar = matches!(
443            field_type,
444            FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt)
445        );
446        if !is_int_scalar {
447            return Err(CoreError::InvalidDefault {
448                model_name: model_name.to_string(),
449                field_name: field_def.name.clone(),
450                message: format!(
451                    "`@default(autoincrement())` requires an integer field, got `{type_name}`",
452                ),
453            });
454        }
455    }
456
457    // B5: @relation `fields` and `references` lists must have the same length.
458    for attr in &field_def.attributes {
459        if let ast::FieldAttribute::Relation(rel) = attr
460            && rel.fields.len() != rel.references.len()
461        {
462            return Err(CoreError::InvalidRelationFields {
463                model_name: model_name.to_string(),
464                field_name: field_def.name.clone(),
465                message: format!(
466                    "`@relation` `fields` (length {}) and `references` (length {}) must have the same length",
467                    rel.fields.len(),
468                    rel.references.len(),
469                ),
470            });
471        }
472    }
473
474    // Resolve @map
475    let db_name = field_def
476        .attributes
477        .iter()
478        .find_map(|a| match a {
479            ast::FieldAttribute::Map(name) => Some(name.clone()),
480            _ => None,
481        })
482        .unwrap_or_else(|| to_snake_case(&field_def.name));
483
484    // Resolve @relation
485    let relation = field_def.attributes.iter().find_map(|a| match a {
486        ast::FieldAttribute::Relation(rel) => {
487            let relation_type = if field_def.field_type.is_list {
488                RelationType::OneToMany
489            } else if field_def.field_type.is_optional {
490                RelationType::OneToOne
491            } else {
492                RelationType::ManyToOne
493            };
494
495            Some(ResolvedRelation {
496                name: rel.name.clone(),
497                related_model: type_name.clone(),
498                relation_type,
499                fields: rel.fields.clone(),
500                references: rel.references.clone(),
501                on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
502                on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
503            })
504        }
505        _ => None,
506    });
507
508    // Resolve @db.* type hint (e.g. @db.BigInt)
509    let db_type = field_def.attributes.iter().find_map(|a| match a {
510        ast::FieldAttribute::DbType(ty, args) => Some((ty.clone(), args.clone())),
511        _ => None,
512    });
513
514    Ok(Field {
515        name: field_def.name.clone(),
516        db_name,
517        field_type,
518        is_optional: field_def.field_type.is_optional,
519        is_list: field_def.field_type.is_list,
520        is_id,
521        is_unique,
522        is_updated_at,
523        default,
524        relation,
525        db_type,
526    })
527}
528
529#[cfg(test)]
530#[allow(clippy::pedantic)]
531mod tests {
532    use super::*;
533    use crate::parser::parse;
534    use ferriorm_core::utils::to_snake_case;
535
536    #[test]
537    fn test_validate_basic_schema() {
538        let source = r#"
539datasource db {
540  provider = "postgresql"
541  url      = env("DATABASE_URL")
542}
543
544generator client {
545  output = "./src/generated"
546}
547
548enum Role {
549  User
550  Admin
551}
552
553model User {
554  id    String @id @default(uuid())
555  email String @unique
556  name  String?
557  role  Role   @default(User)
558
559  @@map("users")
560}
561"#;
562
563        let ast = parse(source).expect("parse");
564        let schema = validate(&ast).expect("validate");
565
566        assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
567        assert_eq!(schema.enums.len(), 1);
568        assert_eq!(schema.enums[0].name, "Role");
569        assert_eq!(schema.enums[0].db_name, "role");
570
571        let user = &schema.models[0];
572        assert_eq!(user.name, "User");
573        assert_eq!(user.db_name, "users");
574        assert_eq!(user.primary_key.fields, vec!["id"]);
575
576        let id_field = &user.fields[0];
577        assert!(id_field.is_id);
578        assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
579
580        let name_field = &user.fields[2];
581        assert!(name_field.is_optional);
582        assert_eq!(name_field.db_name, "name");
583
584        let role_field = &user.fields[3];
585        assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
586    }
587
588    #[test]
589    fn test_validate_missing_primary_key() {
590        let source = r#"
591datasource db {
592  provider = "postgresql"
593  url      = "postgres://localhost/test"
594}
595
596model User {
597  email String
598  name  String
599}
600"#;
601
602        let ast = parse(source).expect("parse");
603        let err = validate(&ast).unwrap_err();
604        assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
605    }
606
607    #[test]
608    fn test_validate_unknown_type() {
609        let source = r#"
610datasource db {
611  provider = "postgresql"
612  url      = "postgres://localhost/test"
613}
614
615model User {
616  id   String @id
617  role Nonexistent
618}
619"#;
620
621        let ast = parse(source).expect("parse");
622        let err = validate(&ast).unwrap_err();
623        assert!(matches!(err, CoreError::UnknownType { .. }));
624    }
625
626    #[test]
627    fn test_validate_composite_primary_key() {
628        let source = r#"
629datasource db {
630  provider = "sqlite"
631  url      = "file:./dev.db"
632}
633
634model PostTag {
635  postId String
636  tagId  String
637
638  @@id([postId, tagId])
639}
640"#;
641
642        let ast = parse(source).expect("parse");
643        let schema = validate(&ast).expect("validate");
644        let model = &schema.models[0];
645        assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
646        assert!(model.primary_key.is_composite());
647    }
648
649    #[test]
650    fn test_snake_case() {
651        assert_eq!(to_snake_case("User"), "user");
652        assert_eq!(to_snake_case("PostTag"), "post_tag");
653        assert_eq!(to_snake_case("createdAt"), "created_at");
654        assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
655    }
656
657    #[test]
658    fn test_validate_auto_table_name() {
659        let source = r#"
660datasource db {
661  provider = "postgresql"
662  url      = "postgres://localhost/test"
663}
664
665model BlogPost {
666  id String @id
667}
668"#;
669
670        let ast = parse(source).expect("parse");
671        let schema = validate(&ast).expect("validate");
672        // Auto-generated: snake_case + "s"
673        assert_eq!(schema.models[0].db_name, "blog_posts");
674    }
675}