Skip to main content

ferriorm_parser/
validator.rs

1//! Semantic validation of the raw AST into the resolved Schema IR.
2//!
3//! The validator walks the [`ferriorm_core::ast::SchemaFile`] produced by the
4//! parser and performs the following:
5//!
6//! - Resolves field type names to scalars, enums, or model references.
7//! - Infers database table and column names from `@@map`/`@map` or `snake_case`
8//!   conventions.
9//! - Checks that every model has a primary key (`@id` or `@@id`).
10//! - Detects duplicate model/enum names and unknown type references.
11//! - Resolves relation cardinality and referential actions.
12//!
13//! The output is an [`ferriorm_core::schema::Schema`], the canonical IR consumed
14//! by codegen and the migration engine.
15
16use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21    DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22    RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27/// Validate a parsed AST and produce a resolved Schema IR.
28///
29/// # Errors
30///
31/// Returns a [`CoreError`] if the AST has validation problems such as
32/// missing primary keys, unknown types, or duplicate names.
33pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34    let datasource = validate_datasource(ast)?;
35    let generators = validate_generators(ast)?;
36    let enums = validate_enums(ast)?;
37    let models = validate_models(ast, &enums)?;
38
39    Ok(Schema {
40        datasource,
41        generators,
42        enums,
43        models,
44    })
45}
46
47fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
48    let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
49        message: "Missing datasource block".into(),
50    })?;
51
52    let provider =
53        ds.provider
54            .parse::<DatabaseProvider>()
55            .map_err(|_| CoreError::UnknownProvider {
56                provider: ds.provider.clone(),
57            })?;
58
59    let url = match &ds.url {
60        ast::StringOrEnv::Literal(s) => s.clone(),
61        ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
62    };
63
64    Ok(DatasourceConfig {
65        name: ds.name.clone(),
66        provider,
67        url,
68    })
69}
70
71fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
72    ast.generators
73        .iter()
74        .map(|g| {
75            Ok(GeneratorConfig {
76                name: g.name.clone(),
77                output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
78            })
79        })
80        .collect()
81}
82
83fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
84    let mut names = HashSet::new();
85    let mut result = Vec::new();
86
87    for e in &ast.enums {
88        if !names.insert(&e.name) {
89            return Err(CoreError::DuplicateName {
90                name: e.name.clone(),
91                kind: "enum",
92            });
93        }
94
95        result.push(Enum {
96            name: e.name.clone(),
97            db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
98            variants: e.variants.clone(),
99        });
100    }
101
102    Ok(result)
103}
104
105fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
106    let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
107    let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
108    let mut seen_names = HashSet::new();
109
110    let mut result = Vec::new();
111
112    for model_def in &ast.models {
113        if !seen_names.insert(&model_def.name) {
114            return Err(CoreError::DuplicateName {
115                name: model_def.name.clone(),
116                kind: "model",
117            });
118        }
119
120        // Check for name collision with enums
121        if enum_names.contains(model_def.name.as_str()) {
122            return Err(CoreError::DuplicateName {
123                name: model_def.name.clone(),
124                kind: "model/enum",
125            });
126        }
127
128        let model = validate_model(model_def, &enum_names, &model_names)?;
129        result.push(model);
130    }
131
132    Ok(result)
133}
134
135fn validate_model(
136    model_def: &ast::ModelDef,
137    enum_names: &HashSet<&str>,
138    model_names: &HashSet<&str>,
139) -> Result<Model, CoreError> {
140    // Resolve @@map
141    let db_name = model_def
142        .attributes
143        .iter()
144        .find_map(|a| match a {
145            ast::BlockAttribute::Map(name) => Some(name.clone()),
146            _ => None,
147        })
148        .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
149
150    let mut fields = Vec::new();
151    let mut has_id_field = false;
152
153    for field_def in &model_def.fields {
154        let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
155        if field.is_id {
156            has_id_field = true;
157        }
158        fields.push(field);
159    }
160
161    // Check @@id for composite primary key
162    let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
163        ast::BlockAttribute::Id(fields) => Some(fields.clone()),
164        _ => None,
165    });
166
167    if !has_id_field && composite_id.is_none() {
168        return Err(CoreError::MissingPrimaryKey {
169            model_name: model_def.name.clone(),
170        });
171    }
172
173    let primary_key = if let Some(composite_fields) = composite_id {
174        PrimaryKey {
175            fields: composite_fields,
176        }
177    } else {
178        let id_fields: Vec<String> = fields
179            .iter()
180            .filter(|f| f.is_id)
181            .map(|f| f.name.clone())
182            .collect();
183        PrimaryKey { fields: id_fields }
184    };
185
186    // Indexes
187    let indexes = model_def
188        .attributes
189        .iter()
190        .filter_map(|a| match a {
191            ast::BlockAttribute::Index(fields) => Some(Index {
192                fields: fields.clone(),
193            }),
194            _ => None,
195        })
196        .collect();
197
198    // Unique constraints (from @@unique)
199    let unique_constraints = model_def
200        .attributes
201        .iter()
202        .filter_map(|a| match a {
203            ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
204                fields: fields.clone(),
205            }),
206            _ => None,
207        })
208        .collect();
209
210    Ok(Model {
211        name: model_def.name.clone(),
212        db_name,
213        fields,
214        primary_key,
215        indexes,
216        unique_constraints,
217    })
218}
219
220fn validate_field(
221    field_def: &ast::FieldDef,
222    model_name: &str,
223    enum_names: &HashSet<&str>,
224    model_names: &HashSet<&str>,
225) -> Result<Field, CoreError> {
226    let type_name = &field_def.field_type.name;
227
228    let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
229        FieldKind::Scalar(scalar)
230    } else if enum_names.contains(type_name.as_str()) {
231        FieldKind::Enum(type_name.clone())
232    } else if model_names.contains(type_name.as_str()) {
233        FieldKind::Model(type_name.clone())
234    } else {
235        return Err(CoreError::UnknownType {
236            model_name: model_name.to_string(),
237            field_name: field_def.name.clone(),
238            type_name: type_name.clone(),
239        });
240    };
241
242    let is_id = field_def
243        .attributes
244        .iter()
245        .any(|a| matches!(a, ast::FieldAttribute::Id));
246    let is_unique = field_def
247        .attributes
248        .iter()
249        .any(|a| matches!(a, ast::FieldAttribute::Unique));
250    let is_updated_at = field_def
251        .attributes
252        .iter()
253        .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
254    let default = field_def.attributes.iter().find_map(|a| match a {
255        ast::FieldAttribute::Default(d) => Some(d.clone()),
256        _ => None,
257    });
258
259    // Resolve @map
260    let db_name = field_def
261        .attributes
262        .iter()
263        .find_map(|a| match a {
264            ast::FieldAttribute::Map(name) => Some(name.clone()),
265            _ => None,
266        })
267        .unwrap_or_else(|| to_snake_case(&field_def.name));
268
269    // Resolve @relation
270    let relation = field_def.attributes.iter().find_map(|a| match a {
271        ast::FieldAttribute::Relation(rel) => {
272            let relation_type = if field_def.field_type.is_list {
273                RelationType::OneToMany
274            } else if field_def.field_type.is_optional {
275                RelationType::OneToOne
276            } else {
277                RelationType::ManyToOne
278            };
279
280            Some(ResolvedRelation {
281                related_model: type_name.clone(),
282                relation_type,
283                fields: rel.fields.clone(),
284                references: rel.references.clone(),
285                on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
286                on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
287            })
288        }
289        _ => None,
290    });
291
292    // Resolve @db.* type hint (e.g. @db.BigInt)
293    let db_type = field_def.attributes.iter().find_map(|a| match a {
294        ast::FieldAttribute::DbType(ty, args) => Some((ty.clone(), args.clone())),
295        _ => None,
296    });
297
298    Ok(Field {
299        name: field_def.name.clone(),
300        db_name,
301        field_type,
302        is_optional: field_def.field_type.is_optional,
303        is_list: field_def.field_type.is_list,
304        is_id,
305        is_unique,
306        is_updated_at,
307        default,
308        relation,
309        db_type,
310    })
311}
312
313#[cfg(test)]
314#[allow(clippy::pedantic)]
315mod tests {
316    use super::*;
317    use crate::parser::parse;
318    use ferriorm_core::utils::to_snake_case;
319
320    #[test]
321    fn test_validate_basic_schema() {
322        let source = r#"
323datasource db {
324  provider = "postgresql"
325  url      = env("DATABASE_URL")
326}
327
328generator client {
329  output = "./src/generated"
330}
331
332enum Role {
333  User
334  Admin
335}
336
337model User {
338  id    String @id @default(uuid())
339  email String @unique
340  name  String?
341  role  Role   @default(User)
342
343  @@map("users")
344}
345"#;
346
347        let ast = parse(source).expect("parse");
348        let schema = validate(&ast).expect("validate");
349
350        assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
351        assert_eq!(schema.enums.len(), 1);
352        assert_eq!(schema.enums[0].name, "Role");
353        assert_eq!(schema.enums[0].db_name, "role");
354
355        let user = &schema.models[0];
356        assert_eq!(user.name, "User");
357        assert_eq!(user.db_name, "users");
358        assert_eq!(user.primary_key.fields, vec!["id"]);
359
360        let id_field = &user.fields[0];
361        assert!(id_field.is_id);
362        assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
363
364        let name_field = &user.fields[2];
365        assert!(name_field.is_optional);
366        assert_eq!(name_field.db_name, "name");
367
368        let role_field = &user.fields[3];
369        assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
370    }
371
372    #[test]
373    fn test_validate_missing_primary_key() {
374        let source = r#"
375datasource db {
376  provider = "postgresql"
377  url      = "postgres://localhost/test"
378}
379
380model User {
381  email String
382  name  String
383}
384"#;
385
386        let ast = parse(source).expect("parse");
387        let err = validate(&ast).unwrap_err();
388        assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
389    }
390
391    #[test]
392    fn test_validate_unknown_type() {
393        let source = r#"
394datasource db {
395  provider = "postgresql"
396  url      = "postgres://localhost/test"
397}
398
399model User {
400  id   String @id
401  role Nonexistent
402}
403"#;
404
405        let ast = parse(source).expect("parse");
406        let err = validate(&ast).unwrap_err();
407        assert!(matches!(err, CoreError::UnknownType { .. }));
408    }
409
410    #[test]
411    fn test_validate_composite_primary_key() {
412        let source = r#"
413datasource db {
414  provider = "sqlite"
415  url      = "file:./dev.db"
416}
417
418model PostTag {
419  postId String
420  tagId  String
421
422  @@id([postId, tagId])
423}
424"#;
425
426        let ast = parse(source).expect("parse");
427        let schema = validate(&ast).expect("validate");
428        let model = &schema.models[0];
429        assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
430        assert!(model.primary_key.is_composite());
431    }
432
433    #[test]
434    fn test_snake_case() {
435        assert_eq!(to_snake_case("User"), "user");
436        assert_eq!(to_snake_case("PostTag"), "post_tag");
437        assert_eq!(to_snake_case("createdAt"), "created_at");
438        assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
439    }
440
441    #[test]
442    fn test_validate_auto_table_name() {
443        let source = r#"
444datasource db {
445  provider = "postgresql"
446  url      = "postgres://localhost/test"
447}
448
449model BlogPost {
450  id String @id
451}
452"#;
453
454        let ast = parse(source).expect("parse");
455        let schema = validate(&ast).expect("validate");
456        // Auto-generated: snake_case + "s"
457        assert_eq!(schema.models[0].db_name, "blog_posts");
458    }
459}