Skip to main content

ferriorm_parser/
validator.rs

1//! Semantic validation of the raw AST into the resolved Schema IR.
2//!
3//! The validator walks the [`ferriorm_core::ast::SchemaFile`] produced by the
4//! parser and performs the following:
5//!
6//! - Resolves field type names to scalars, enums, or model references.
7//! - Infers database table and column names from `@@map`/`@map` or snake_case
8//!   conventions.
9//! - Checks that every model has a primary key (`@id` or `@@id`).
10//! - Detects duplicate model/enum names and unknown type references.
11//! - Resolves relation cardinality and referential actions.
12//!
13//! The output is an [`ferriorm_core::schema::Schema`], the canonical IR consumed
14//! by codegen and the migration engine.
15
16use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::*;
21use ferriorm_core::types::{DatabaseProvider, ScalarType};
22use ferriorm_core::utils::to_snake_case;
23
24/// Validate a parsed AST and produce a resolved Schema IR.
25pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
26    let datasource = validate_datasource(ast)?;
27    let generators = validate_generators(ast)?;
28    let enums = validate_enums(ast)?;
29    let models = validate_models(ast, &enums)?;
30
31    Ok(Schema {
32        datasource,
33        generators,
34        enums,
35        models,
36    })
37}
38
39fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
40    let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
41        message: "Missing datasource block".into(),
42    })?;
43
44    let provider =
45        ds.provider
46            .parse::<DatabaseProvider>()
47            .map_err(|_| CoreError::UnknownProvider {
48                provider: ds.provider.clone(),
49            })?;
50
51    let url = match &ds.url {
52        ast::StringOrEnv::Literal(s) => s.clone(),
53        ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
54    };
55
56    Ok(DatasourceConfig {
57        name: ds.name.clone(),
58        provider,
59        url,
60    })
61}
62
63fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
64    ast.generators
65        .iter()
66        .map(|g| {
67            Ok(GeneratorConfig {
68                name: g.name.clone(),
69                output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
70            })
71        })
72        .collect()
73}
74
75fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
76    let mut names = HashSet::new();
77    let mut result = Vec::new();
78
79    for e in &ast.enums {
80        if !names.insert(&e.name) {
81            return Err(CoreError::DuplicateName {
82                name: e.name.clone(),
83                kind: "enum",
84            });
85        }
86
87        result.push(Enum {
88            name: e.name.clone(),
89            db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
90            variants: e.variants.clone(),
91        });
92    }
93
94    Ok(result)
95}
96
97fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
98    let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
99    let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
100    let mut seen_names = HashSet::new();
101
102    let mut result = Vec::new();
103
104    for model_def in &ast.models {
105        if !seen_names.insert(&model_def.name) {
106            return Err(CoreError::DuplicateName {
107                name: model_def.name.clone(),
108                kind: "model",
109            });
110        }
111
112        // Check for name collision with enums
113        if enum_names.contains(model_def.name.as_str()) {
114            return Err(CoreError::DuplicateName {
115                name: model_def.name.clone(),
116                kind: "model/enum",
117            });
118        }
119
120        let model = validate_model(model_def, &enum_names, &model_names)?;
121        result.push(model);
122    }
123
124    Ok(result)
125}
126
127fn validate_model(
128    model_def: &ast::ModelDef,
129    enum_names: &HashSet<&str>,
130    model_names: &HashSet<&str>,
131) -> Result<Model, CoreError> {
132    // Resolve @@map
133    let db_name = model_def
134        .attributes
135        .iter()
136        .find_map(|a| match a {
137            ast::BlockAttribute::Map(name) => Some(name.clone()),
138            _ => None,
139        })
140        .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
141
142    let mut fields = Vec::new();
143    let mut has_id_field = false;
144
145    for field_def in &model_def.fields {
146        let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
147        if field.is_id {
148            has_id_field = true;
149        }
150        fields.push(field);
151    }
152
153    // Check @@id for composite primary key
154    let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
155        ast::BlockAttribute::Id(fields) => Some(fields.clone()),
156        _ => None,
157    });
158
159    if !has_id_field && composite_id.is_none() {
160        return Err(CoreError::MissingPrimaryKey {
161            model_name: model_def.name.clone(),
162        });
163    }
164
165    let primary_key = if let Some(composite_fields) = composite_id {
166        PrimaryKey {
167            fields: composite_fields,
168        }
169    } else {
170        let id_fields: Vec<String> = fields
171            .iter()
172            .filter(|f| f.is_id)
173            .map(|f| f.name.clone())
174            .collect();
175        PrimaryKey { fields: id_fields }
176    };
177
178    // Indexes
179    let indexes = model_def
180        .attributes
181        .iter()
182        .filter_map(|a| match a {
183            ast::BlockAttribute::Index(fields) => Some(Index {
184                fields: fields.clone(),
185            }),
186            _ => None,
187        })
188        .collect();
189
190    // Unique constraints (from @@unique)
191    let unique_constraints = model_def
192        .attributes
193        .iter()
194        .filter_map(|a| match a {
195            ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
196                fields: fields.clone(),
197            }),
198            _ => None,
199        })
200        .collect();
201
202    Ok(Model {
203        name: model_def.name.clone(),
204        db_name,
205        fields,
206        primary_key,
207        indexes,
208        unique_constraints,
209    })
210}
211
212fn validate_field(
213    field_def: &ast::FieldDef,
214    model_name: &str,
215    enum_names: &HashSet<&str>,
216    model_names: &HashSet<&str>,
217) -> Result<Field, CoreError> {
218    let type_name = &field_def.field_type.name;
219
220    let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
221        FieldKind::Scalar(scalar)
222    } else if enum_names.contains(type_name.as_str()) {
223        FieldKind::Enum(type_name.clone())
224    } else if model_names.contains(type_name.as_str()) {
225        FieldKind::Model(type_name.clone())
226    } else {
227        return Err(CoreError::UnknownType {
228            model_name: model_name.to_string(),
229            field_name: field_def.name.clone(),
230            type_name: type_name.clone(),
231        });
232    };
233
234    let is_id = field_def
235        .attributes
236        .iter()
237        .any(|a| matches!(a, ast::FieldAttribute::Id));
238    let is_unique = field_def
239        .attributes
240        .iter()
241        .any(|a| matches!(a, ast::FieldAttribute::Unique));
242    let is_updated_at = field_def
243        .attributes
244        .iter()
245        .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
246    let default = field_def.attributes.iter().find_map(|a| match a {
247        ast::FieldAttribute::Default(d) => Some(d.clone()),
248        _ => None,
249    });
250
251    // Resolve @map
252    let db_name = field_def
253        .attributes
254        .iter()
255        .find_map(|a| match a {
256            ast::FieldAttribute::Map(name) => Some(name.clone()),
257            _ => None,
258        })
259        .unwrap_or_else(|| to_snake_case(&field_def.name));
260
261    // Resolve @relation
262    let relation = field_def.attributes.iter().find_map(|a| match a {
263        ast::FieldAttribute::Relation(rel) => {
264            let relation_type = if field_def.field_type.is_list {
265                RelationType::OneToMany
266            } else if field_def.field_type.is_optional {
267                RelationType::OneToOne
268            } else {
269                RelationType::ManyToOne
270            };
271
272            Some(ResolvedRelation {
273                related_model: type_name.clone(),
274                relation_type,
275                fields: rel.fields.clone(),
276                references: rel.references.clone(),
277                on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
278                on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
279            })
280        }
281        _ => None,
282    });
283
284    Ok(Field {
285        name: field_def.name.clone(),
286        db_name,
287        field_type,
288        is_optional: field_def.field_type.is_optional,
289        is_list: field_def.field_type.is_list,
290        is_id,
291        is_unique,
292        is_updated_at,
293        default,
294        relation,
295    })
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::parser::parse;
302    use ferriorm_core::utils::to_snake_case;
303
304    #[test]
305    fn test_validate_basic_schema() {
306        let source = r#"
307datasource db {
308  provider = "postgresql"
309  url      = env("DATABASE_URL")
310}
311
312generator client {
313  output = "./src/generated"
314}
315
316enum Role {
317  User
318  Admin
319}
320
321model User {
322  id    String @id @default(uuid())
323  email String @unique
324  name  String?
325  role  Role   @default(User)
326
327  @@map("users")
328}
329"#;
330
331        let ast = parse(source).expect("parse");
332        let schema = validate(&ast).expect("validate");
333
334        assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
335        assert_eq!(schema.enums.len(), 1);
336        assert_eq!(schema.enums[0].name, "Role");
337        assert_eq!(schema.enums[0].db_name, "role");
338
339        let user = &schema.models[0];
340        assert_eq!(user.name, "User");
341        assert_eq!(user.db_name, "users");
342        assert_eq!(user.primary_key.fields, vec!["id"]);
343
344        let id_field = &user.fields[0];
345        assert!(id_field.is_id);
346        assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
347
348        let name_field = &user.fields[2];
349        assert!(name_field.is_optional);
350        assert_eq!(name_field.db_name, "name");
351
352        let role_field = &user.fields[3];
353        assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
354    }
355
356    #[test]
357    fn test_validate_missing_primary_key() {
358        let source = r#"
359datasource db {
360  provider = "postgresql"
361  url      = "postgres://localhost/test"
362}
363
364model User {
365  email String
366  name  String
367}
368"#;
369
370        let ast = parse(source).expect("parse");
371        let err = validate(&ast).unwrap_err();
372        assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
373    }
374
375    #[test]
376    fn test_validate_unknown_type() {
377        let source = r#"
378datasource db {
379  provider = "postgresql"
380  url      = "postgres://localhost/test"
381}
382
383model User {
384  id   String @id
385  role Nonexistent
386}
387"#;
388
389        let ast = parse(source).expect("parse");
390        let err = validate(&ast).unwrap_err();
391        assert!(matches!(err, CoreError::UnknownType { .. }));
392    }
393
394    #[test]
395    fn test_validate_composite_primary_key() {
396        let source = r#"
397datasource db {
398  provider = "sqlite"
399  url      = "file:./dev.db"
400}
401
402model PostTag {
403  postId String
404  tagId  String
405
406  @@id([postId, tagId])
407}
408"#;
409
410        let ast = parse(source).expect("parse");
411        let schema = validate(&ast).expect("validate");
412        let model = &schema.models[0];
413        assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
414        assert!(model.primary_key.is_composite());
415    }
416
417    #[test]
418    fn test_snake_case() {
419        assert_eq!(to_snake_case("User"), "user");
420        assert_eq!(to_snake_case("PostTag"), "post_tag");
421        assert_eq!(to_snake_case("createdAt"), "created_at");
422        assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
423    }
424
425    #[test]
426    fn test_validate_auto_table_name() {
427        let source = r#"
428datasource db {
429  provider = "postgresql"
430  url      = "postgres://localhost/test"
431}
432
433model BlogPost {
434  id String @id
435}
436"#;
437
438        let ast = parse(source).expect("parse");
439        let schema = validate(&ast).expect("validate");
440        // Auto-generated: snake_case + "s"
441        assert_eq!(schema.models[0].db_name, "blog_posts");
442    }
443}