Skip to main content

ferriorm_parser/
validator.rs

1//! Semantic validation of the raw AST into the resolved Schema IR.
2//!
3//! The validator walks the [`ferriorm_core::ast::SchemaFile`] produced by the
4//! parser and performs the following:
5//!
6//! - Resolves field type names to scalars, enums, or model references.
7//! - Infers database table and column names from `@@map`/`@map` or `snake_case`
8//!   conventions.
9//! - Checks that every model has a primary key (`@id` or `@@id`).
10//! - Detects duplicate model/enum names and unknown type references.
11//! - Resolves relation cardinality and referential actions.
12//!
13//! The output is an [`ferriorm_core::schema::Schema`], the canonical IR consumed
14//! by codegen and the migration engine.
15
16use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21    DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22    RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27/// Validate a parsed AST and produce a resolved Schema IR.
28///
29/// # Errors
30///
31/// Returns a [`CoreError`] if the AST has validation problems such as
32/// missing primary keys, unknown types, or duplicate names.
33pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34    let datasource = validate_datasource(ast)?;
35    let generators = validate_generators(ast)?;
36    let enums = validate_enums(ast)?;
37    let models = validate_models(ast, &enums)?;
38
39    Ok(Schema {
40        datasource,
41        generators,
42        enums,
43        models,
44    })
45}
46
47fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
48    let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
49        message: "Missing datasource block".into(),
50    })?;
51
52    let provider =
53        ds.provider
54            .parse::<DatabaseProvider>()
55            .map_err(|_| CoreError::UnknownProvider {
56                provider: ds.provider.clone(),
57            })?;
58
59    let url = match &ds.url {
60        ast::StringOrEnv::Literal(s) => s.clone(),
61        ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
62    };
63
64    Ok(DatasourceConfig {
65        name: ds.name.clone(),
66        provider,
67        url,
68    })
69}
70
71fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
72    ast.generators
73        .iter()
74        .map(|g| {
75            Ok(GeneratorConfig {
76                name: g.name.clone(),
77                output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
78            })
79        })
80        .collect()
81}
82
83fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
84    let mut names = HashSet::new();
85    let mut result = Vec::new();
86
87    for e in &ast.enums {
88        if !names.insert(&e.name) {
89            return Err(CoreError::DuplicateName {
90                name: e.name.clone(),
91                kind: "enum",
92            });
93        }
94
95        result.push(Enum {
96            name: e.name.clone(),
97            db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
98            variants: e.variants.clone(),
99        });
100    }
101
102    Ok(result)
103}
104
105fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
106    let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
107    let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
108    let mut seen_names = HashSet::new();
109
110    let mut result = Vec::new();
111
112    for model_def in &ast.models {
113        if !seen_names.insert(&model_def.name) {
114            return Err(CoreError::DuplicateName {
115                name: model_def.name.clone(),
116                kind: "model",
117            });
118        }
119
120        // Check for name collision with enums
121        if enum_names.contains(model_def.name.as_str()) {
122            return Err(CoreError::DuplicateName {
123                name: model_def.name.clone(),
124                kind: "model/enum",
125            });
126        }
127
128        let model = validate_model(model_def, &enum_names, &model_names)?;
129        result.push(model);
130    }
131
132    Ok(result)
133}
134
135fn validate_model(
136    model_def: &ast::ModelDef,
137    enum_names: &HashSet<&str>,
138    model_names: &HashSet<&str>,
139) -> Result<Model, CoreError> {
140    // Resolve @@map
141    let db_name = model_def
142        .attributes
143        .iter()
144        .find_map(|a| match a {
145            ast::BlockAttribute::Map(name) => Some(name.clone()),
146            _ => None,
147        })
148        .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
149
150    let mut fields = Vec::new();
151    let mut has_id_field = false;
152
153    for field_def in &model_def.fields {
154        let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
155        if field.is_id {
156            has_id_field = true;
157        }
158        fields.push(field);
159    }
160
161    // Check @@id for composite primary key
162    let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
163        ast::BlockAttribute::Id(fields) => Some(fields.clone()),
164        _ => None,
165    });
166
167    if !has_id_field && composite_id.is_none() {
168        return Err(CoreError::MissingPrimaryKey {
169            model_name: model_def.name.clone(),
170        });
171    }
172
173    let primary_key = if let Some(composite_fields) = composite_id {
174        PrimaryKey {
175            fields: composite_fields,
176        }
177    } else {
178        let id_fields: Vec<String> = fields
179            .iter()
180            .filter(|f| f.is_id)
181            .map(|f| f.name.clone())
182            .collect();
183        PrimaryKey { fields: id_fields }
184    };
185
186    // Indexes
187    let indexes = model_def
188        .attributes
189        .iter()
190        .filter_map(|a| match a {
191            ast::BlockAttribute::Index(fields) => Some(Index {
192                fields: fields.clone(),
193            }),
194            _ => None,
195        })
196        .collect();
197
198    // Unique constraints (from @@unique)
199    let unique_constraints = model_def
200        .attributes
201        .iter()
202        .filter_map(|a| match a {
203            ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
204                fields: fields.clone(),
205            }),
206            _ => None,
207        })
208        .collect();
209
210    Ok(Model {
211        name: model_def.name.clone(),
212        db_name,
213        fields,
214        primary_key,
215        indexes,
216        unique_constraints,
217    })
218}
219
220fn validate_field(
221    field_def: &ast::FieldDef,
222    model_name: &str,
223    enum_names: &HashSet<&str>,
224    model_names: &HashSet<&str>,
225) -> Result<Field, CoreError> {
226    let type_name = &field_def.field_type.name;
227
228    let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
229        FieldKind::Scalar(scalar)
230    } else if enum_names.contains(type_name.as_str()) {
231        FieldKind::Enum(type_name.clone())
232    } else if model_names.contains(type_name.as_str()) {
233        FieldKind::Model(type_name.clone())
234    } else {
235        return Err(CoreError::UnknownType {
236            model_name: model_name.to_string(),
237            field_name: field_def.name.clone(),
238            type_name: type_name.clone(),
239        });
240    };
241
242    let is_id = field_def
243        .attributes
244        .iter()
245        .any(|a| matches!(a, ast::FieldAttribute::Id));
246    let is_unique = field_def
247        .attributes
248        .iter()
249        .any(|a| matches!(a, ast::FieldAttribute::Unique));
250    let is_updated_at = field_def
251        .attributes
252        .iter()
253        .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
254    let default = field_def.attributes.iter().find_map(|a| match a {
255        ast::FieldAttribute::Default(d) => Some(d.clone()),
256        _ => None,
257    });
258
259    // Resolve @map
260    let db_name = field_def
261        .attributes
262        .iter()
263        .find_map(|a| match a {
264            ast::FieldAttribute::Map(name) => Some(name.clone()),
265            _ => None,
266        })
267        .unwrap_or_else(|| to_snake_case(&field_def.name));
268
269    // Resolve @relation
270    let relation = field_def.attributes.iter().find_map(|a| match a {
271        ast::FieldAttribute::Relation(rel) => {
272            let relation_type = if field_def.field_type.is_list {
273                RelationType::OneToMany
274            } else if field_def.field_type.is_optional {
275                RelationType::OneToOne
276            } else {
277                RelationType::ManyToOne
278            };
279
280            Some(ResolvedRelation {
281                related_model: type_name.clone(),
282                relation_type,
283                fields: rel.fields.clone(),
284                references: rel.references.clone(),
285                on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
286                on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
287            })
288        }
289        _ => None,
290    });
291
292    Ok(Field {
293        name: field_def.name.clone(),
294        db_name,
295        field_type,
296        is_optional: field_def.field_type.is_optional,
297        is_list: field_def.field_type.is_list,
298        is_id,
299        is_unique,
300        is_updated_at,
301        default,
302        relation,
303    })
304}
305
306#[cfg(test)]
307#[allow(clippy::pedantic)]
308mod tests {
309    use super::*;
310    use crate::parser::parse;
311    use ferriorm_core::utils::to_snake_case;
312
313    #[test]
314    fn test_validate_basic_schema() {
315        let source = r#"
316datasource db {
317  provider = "postgresql"
318  url      = env("DATABASE_URL")
319}
320
321generator client {
322  output = "./src/generated"
323}
324
325enum Role {
326  User
327  Admin
328}
329
330model User {
331  id    String @id @default(uuid())
332  email String @unique
333  name  String?
334  role  Role   @default(User)
335
336  @@map("users")
337}
338"#;
339
340        let ast = parse(source).expect("parse");
341        let schema = validate(&ast).expect("validate");
342
343        assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
344        assert_eq!(schema.enums.len(), 1);
345        assert_eq!(schema.enums[0].name, "Role");
346        assert_eq!(schema.enums[0].db_name, "role");
347
348        let user = &schema.models[0];
349        assert_eq!(user.name, "User");
350        assert_eq!(user.db_name, "users");
351        assert_eq!(user.primary_key.fields, vec!["id"]);
352
353        let id_field = &user.fields[0];
354        assert!(id_field.is_id);
355        assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
356
357        let name_field = &user.fields[2];
358        assert!(name_field.is_optional);
359        assert_eq!(name_field.db_name, "name");
360
361        let role_field = &user.fields[3];
362        assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
363    }
364
365    #[test]
366    fn test_validate_missing_primary_key() {
367        let source = r#"
368datasource db {
369  provider = "postgresql"
370  url      = "postgres://localhost/test"
371}
372
373model User {
374  email String
375  name  String
376}
377"#;
378
379        let ast = parse(source).expect("parse");
380        let err = validate(&ast).unwrap_err();
381        assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
382    }
383
384    #[test]
385    fn test_validate_unknown_type() {
386        let source = r#"
387datasource db {
388  provider = "postgresql"
389  url      = "postgres://localhost/test"
390}
391
392model User {
393  id   String @id
394  role Nonexistent
395}
396"#;
397
398        let ast = parse(source).expect("parse");
399        let err = validate(&ast).unwrap_err();
400        assert!(matches!(err, CoreError::UnknownType { .. }));
401    }
402
403    #[test]
404    fn test_validate_composite_primary_key() {
405        let source = r#"
406datasource db {
407  provider = "sqlite"
408  url      = "file:./dev.db"
409}
410
411model PostTag {
412  postId String
413  tagId  String
414
415  @@id([postId, tagId])
416}
417"#;
418
419        let ast = parse(source).expect("parse");
420        let schema = validate(&ast).expect("validate");
421        let model = &schema.models[0];
422        assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
423        assert!(model.primary_key.is_composite());
424    }
425
426    #[test]
427    fn test_snake_case() {
428        assert_eq!(to_snake_case("User"), "user");
429        assert_eq!(to_snake_case("PostTag"), "post_tag");
430        assert_eq!(to_snake_case("createdAt"), "created_at");
431        assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
432    }
433
434    #[test]
435    fn test_validate_auto_table_name() {
436        let source = r#"
437datasource db {
438  provider = "postgresql"
439  url      = "postgres://localhost/test"
440}
441
442model BlogPost {
443  id String @id
444}
445"#;
446
447        let ast = parse(source).expect("parse");
448        let schema = validate(&ast).expect("validate");
449        // Auto-generated: snake_case + "s"
450        assert_eq!(schema.models[0].db_name, "blog_posts");
451    }
452}