1use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21 DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22 RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34 let datasource = validate_datasource(ast)?;
35 let generators = validate_generators(ast)?;
36 let enums = validate_enums(ast)?;
37 let models = validate_models(ast, &enums)?;
38
39 validate_unique_db_names(&models)?;
40 validate_relation_disambiguation(&models)?;
41
42 Ok(Schema {
43 datasource,
44 generators,
45 enums,
46 models,
47 })
48}
49
50fn validate_unique_db_names(models: &[Model]) -> Result<(), CoreError> {
54 use std::collections::HashMap;
55 let mut seen: HashMap<&str, &str> = HashMap::new();
56 for m in models {
57 if let Some(existing) = seen.get(m.db_name.as_str()) {
58 return Err(CoreError::Validation {
59 message: format!(
60 "Duplicate table name `{}` (used by models `{}` and `{}`). \
61 Each model must map to a distinct table; use `@@map(\"...\")` to disambiguate.",
62 m.db_name, existing, m.name,
63 ),
64 });
65 }
66 seen.insert(&m.db_name, &m.name);
67 }
68 Ok(())
69}
70
71fn is_rust_keyword(s: &str) -> bool {
74 matches!(
75 s,
76 "as" | "break" | "const" | "continue" | "crate" | "else" | "enum" | "extern"
78 | "false" | "fn" | "for" | "if" | "impl" | "in" | "let" | "loop" | "match"
79 | "mod" | "move" | "mut" | "pub" | "ref" | "return" | "self" | "Self"
80 | "static" | "struct" | "super" | "trait" | "true" | "type" | "unsafe"
81 | "use" | "where" | "while"
82 | "async" | "await" | "dyn"
84 | "abstract" | "become" | "box" | "do" | "final" | "macro" | "override"
86 | "priv" | "typeof" | "unsized" | "virtual" | "yield" | "try"
87 )
88}
89
90fn validate_relation_disambiguation(models: &[Model]) -> Result<(), CoreError> {
97 use std::collections::{HashMap, HashSet};
98
99 for model in models {
100 let mut groups: HashMap<(&str, bool), Vec<&Field>> = HashMap::new();
102
103 for field in &model.fields {
104 let target = match &field.field_type {
105 FieldKind::Model(name) => name.as_str(),
106 _ => continue,
107 };
108 let is_fk_owner = field
109 .relation
110 .as_ref()
111 .is_some_and(|r| !r.fields.is_empty());
112 groups.entry((target, is_fk_owner)).or_default().push(field);
113 }
114
115 for ((target, _), group) in &groups {
116 if group.len() < 2 {
117 continue;
118 }
119
120 let mut seen_names: HashSet<&str> = HashSet::new();
121 for field in group {
122 let name = field.relation.as_ref().and_then(|r| r.name.as_deref());
123 let Some(n) = name else {
124 return Err(CoreError::Validation {
125 message: format!(
126 "Multiple relations from `{}` to `{}` require disambiguation. \
127 Add `@relation(\"<Name>\", ...)` to each related field on both sides.",
128 model.name, target,
129 ),
130 });
131 };
132 if !seen_names.insert(n) {
133 return Err(CoreError::Validation {
134 message: format!(
135 "Duplicate relation name `{}` between `{}` and `{}`. \
136 Each relation between the same pair of models must have a unique name.",
137 n, model.name, target,
138 ),
139 });
140 }
141 }
142 }
143 }
144 Ok(())
145}
146
147fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
148 let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
149 message: "Missing datasource block".into(),
150 })?;
151
152 let provider =
153 ds.provider
154 .parse::<DatabaseProvider>()
155 .map_err(|_| CoreError::UnknownProvider {
156 provider: ds.provider.clone(),
157 })?;
158
159 let url = match &ds.url {
160 ast::StringOrEnv::Literal(s) => s.clone(),
161 ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
162 };
163
164 Ok(DatasourceConfig {
165 name: ds.name.clone(),
166 provider,
167 url,
168 })
169}
170
171fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
172 ast.generators
173 .iter()
174 .map(|g| {
175 Ok(GeneratorConfig {
176 name: g.name.clone(),
177 output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
178 })
179 })
180 .collect()
181}
182
183fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
184 let mut names = HashSet::new();
185 let mut result = Vec::new();
186
187 for e in &ast.enums {
188 if !names.insert(&e.name) {
189 return Err(CoreError::DuplicateName {
190 name: e.name.clone(),
191 kind: "enum",
192 });
193 }
194
195 result.push(Enum {
196 name: e.name.clone(),
197 db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
198 variants: e.variants.clone(),
199 });
200 }
201
202 Ok(result)
203}
204
205fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
206 let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
207 let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
208 let mut seen_names = HashSet::new();
209
210 let mut result = Vec::new();
211
212 for model_def in &ast.models {
213 if !seen_names.insert(&model_def.name) {
214 return Err(CoreError::DuplicateName {
215 name: model_def.name.clone(),
216 kind: "model",
217 });
218 }
219
220 if enum_names.contains(model_def.name.as_str()) {
222 return Err(CoreError::DuplicateName {
223 name: model_def.name.clone(),
224 kind: "model/enum",
225 });
226 }
227
228 let model = validate_model(model_def, &enum_names, &model_names)?;
229 result.push(model);
230 }
231
232 Ok(result)
233}
234
235#[allow(clippy::too_many_lines)] fn validate_model(
237 model_def: &ast::ModelDef,
238 enum_names: &HashSet<&str>,
239 model_names: &HashSet<&str>,
240) -> Result<Model, CoreError> {
241 let db_name = model_def
243 .attributes
244 .iter()
245 .find_map(|a| match a {
246 ast::BlockAttribute::Map(name) => Some(name.clone()),
247 _ => None,
248 })
249 .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
250
251 let mut fields = Vec::new();
252 let mut has_id_field = false;
253
254 for field_def in &model_def.fields {
255 let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
256 if field.is_id {
257 has_id_field = true;
258 }
259 fields.push(field);
260 }
261
262 let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
264 ast::BlockAttribute::Id(fields) => Some(fields.clone()),
265 _ => None,
266 });
267
268 if !has_id_field && composite_id.is_none() {
269 return Err(CoreError::MissingPrimaryKey {
270 model_name: model_def.name.clone(),
271 });
272 }
273
274 let field_name_set: HashSet<&str> = fields.iter().map(|f| f.name.as_str()).collect();
277 let field_db_set: HashSet<&str> = fields.iter().map(|f| f.db_name.as_str()).collect();
278 let field_resolver = |needle: &str| -> Option<&Field> {
279 fields
280 .iter()
281 .find(|f| f.name == needle || f.db_name == needle || to_snake_case(&f.name) == needle)
282 };
283
284 let primary_key = if let Some(composite_fields) = composite_id {
285 for f in &composite_fields {
288 let Some(resolved) = field_resolver(f) else {
289 return Err(CoreError::Validation {
290 message: format!(
291 "`@@id` on model `{}` references unknown field `{}`",
292 model_def.name, f,
293 ),
294 });
295 };
296 if matches!(resolved.field_type, FieldKind::Scalar(ScalarType::Json)) {
297 return Err(CoreError::Validation {
298 message: format!(
299 "Field `{}.{}` of type `Json` cannot be part of a composite primary key.",
300 model_def.name, resolved.name,
301 ),
302 });
303 }
304 }
305 PrimaryKey {
306 fields: composite_fields,
307 }
308 } else {
309 let id_fields: Vec<String> = fields
310 .iter()
311 .filter(|f| f.is_id)
312 .map(|f| f.name.clone())
313 .collect();
314 PrimaryKey { fields: id_fields }
315 };
316
317 for attr in &model_def.attributes {
321 let (kind, fs) = match attr {
322 ast::BlockAttribute::Index(idx) => ("@@index", &idx.fields),
323 ast::BlockAttribute::Unique(idx) => ("@@unique", &idx.fields),
324 _ => continue,
325 };
326 for f in fs {
327 if !field_name_set.contains(f.as_str())
328 && !field_db_set.contains(f.as_str())
329 && field_resolver(f).is_none()
330 {
331 return Err(CoreError::Validation {
332 message: format!(
333 "`{}` on model `{}` references unknown field `{}`",
334 kind, model_def.name, f,
335 ),
336 });
337 }
338 }
339 }
340
341 let indexes = model_def
343 .attributes
344 .iter()
345 .filter_map(|a| match a {
346 ast::BlockAttribute::Index(idx) => Some(Index {
347 fields: idx.fields.clone(),
348 name: idx.name.clone(),
349 }),
350 _ => None,
351 })
352 .collect();
353
354 let unique_constraints = model_def
356 .attributes
357 .iter()
358 .filter_map(|a| match a {
359 ast::BlockAttribute::Unique(idx) => Some(UniqueConstraint {
360 fields: idx.fields.clone(),
361 name: idx.name.clone(),
362 }),
363 _ => None,
364 })
365 .collect();
366
367 Ok(Model {
368 name: model_def.name.clone(),
369 db_name,
370 fields,
371 primary_key,
372 indexes,
373 unique_constraints,
374 })
375}
376
377#[allow(clippy::too_many_lines)] fn validate_field(
379 field_def: &ast::FieldDef,
380 model_name: &str,
381 enum_names: &HashSet<&str>,
382 model_names: &HashSet<&str>,
383) -> Result<Field, CoreError> {
384 let type_name = &field_def.field_type.name;
385
386 if is_rust_keyword(&field_def.name) {
390 return Err(CoreError::Validation {
391 message: format!(
392 "Field name `{}.{}` is a Rust keyword and cannot be used as a struct field. \
393 Rename the field and use `@map(\"{}\")` if you need that database column name.",
394 model_name, field_def.name, field_def.name,
395 ),
396 });
397 }
398
399 let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
400 FieldKind::Scalar(scalar)
401 } else if enum_names.contains(type_name.as_str()) {
402 FieldKind::Enum(type_name.clone())
403 } else if model_names.contains(type_name.as_str()) {
404 FieldKind::Model(type_name.clone())
405 } else {
406 return Err(CoreError::UnknownType {
407 model_name: model_name.to_string(),
408 field_name: field_def.name.clone(),
409 type_name: type_name.clone(),
410 });
411 };
412
413 let is_id = field_def
414 .attributes
415 .iter()
416 .any(|a| matches!(a, ast::FieldAttribute::Id));
417 let is_unique = field_def
418 .attributes
419 .iter()
420 .any(|a| matches!(a, ast::FieldAttribute::Unique));
421 let is_updated_at = field_def
422 .attributes
423 .iter()
424 .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
425 let default = field_def.attributes.iter().find_map(|a| match a {
426 ast::FieldAttribute::Default(d) => Some(d.clone()),
427 _ => None,
428 });
429
430 if is_id && field_def.field_type.is_optional {
432 return Err(CoreError::Validation {
433 message: format!(
434 "Field `{}.{}` is marked `@id` but is optional; primary key columns cannot be NULL.",
435 model_name, field_def.name,
436 ),
437 });
438 }
439
440 if matches!(default, Some(ast::DefaultValue::AutoIncrement)) {
442 let is_int_scalar = matches!(
443 field_type,
444 FieldKind::Scalar(ScalarType::Int | ScalarType::BigInt)
445 );
446 if !is_int_scalar {
447 return Err(CoreError::InvalidDefault {
448 model_name: model_name.to_string(),
449 field_name: field_def.name.clone(),
450 message: format!(
451 "`@default(autoincrement())` requires an integer field, got `{type_name}`",
452 ),
453 });
454 }
455 }
456
457 for attr in &field_def.attributes {
459 if let ast::FieldAttribute::Relation(rel) = attr
460 && rel.fields.len() != rel.references.len()
461 {
462 return Err(CoreError::InvalidRelationFields {
463 model_name: model_name.to_string(),
464 field_name: field_def.name.clone(),
465 message: format!(
466 "`@relation` `fields` (length {}) and `references` (length {}) must have the same length",
467 rel.fields.len(),
468 rel.references.len(),
469 ),
470 });
471 }
472 }
473
474 let db_name = field_def
476 .attributes
477 .iter()
478 .find_map(|a| match a {
479 ast::FieldAttribute::Map(name) => Some(name.clone()),
480 _ => None,
481 })
482 .unwrap_or_else(|| to_snake_case(&field_def.name));
483
484 let relation = field_def.attributes.iter().find_map(|a| match a {
486 ast::FieldAttribute::Relation(rel) => {
487 let relation_type = if field_def.field_type.is_list {
488 RelationType::OneToMany
489 } else if field_def.field_type.is_optional {
490 RelationType::OneToOne
491 } else {
492 RelationType::ManyToOne
493 };
494
495 Some(ResolvedRelation {
496 name: rel.name.clone(),
497 related_model: type_name.clone(),
498 relation_type,
499 fields: rel.fields.clone(),
500 references: rel.references.clone(),
501 on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
502 on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
503 })
504 }
505 _ => None,
506 });
507
508 let db_type = field_def.attributes.iter().find_map(|a| match a {
510 ast::FieldAttribute::DbType(ty, args) => Some((ty.clone(), args.clone())),
511 _ => None,
512 });
513
514 Ok(Field {
515 name: field_def.name.clone(),
516 db_name,
517 field_type,
518 is_optional: field_def.field_type.is_optional,
519 is_list: field_def.field_type.is_list,
520 is_id,
521 is_unique,
522 is_updated_at,
523 default,
524 relation,
525 db_type,
526 })
527}
528
529#[cfg(test)]
530#[allow(clippy::pedantic)]
531mod tests {
532 use super::*;
533 use crate::parser::parse;
534 use ferriorm_core::utils::to_snake_case;
535
536 #[test]
537 fn test_validate_basic_schema() {
538 let source = r#"
539datasource db {
540 provider = "postgresql"
541 url = env("DATABASE_URL")
542}
543
544generator client {
545 output = "./src/generated"
546}
547
548enum Role {
549 User
550 Admin
551}
552
553model User {
554 id String @id @default(uuid())
555 email String @unique
556 name String?
557 role Role @default(User)
558
559 @@map("users")
560}
561"#;
562
563 let ast = parse(source).expect("parse");
564 let schema = validate(&ast).expect("validate");
565
566 assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
567 assert_eq!(schema.enums.len(), 1);
568 assert_eq!(schema.enums[0].name, "Role");
569 assert_eq!(schema.enums[0].db_name, "role");
570
571 let user = &schema.models[0];
572 assert_eq!(user.name, "User");
573 assert_eq!(user.db_name, "users");
574 assert_eq!(user.primary_key.fields, vec!["id"]);
575
576 let id_field = &user.fields[0];
577 assert!(id_field.is_id);
578 assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
579
580 let name_field = &user.fields[2];
581 assert!(name_field.is_optional);
582 assert_eq!(name_field.db_name, "name");
583
584 let role_field = &user.fields[3];
585 assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
586 }
587
588 #[test]
589 fn test_validate_missing_primary_key() {
590 let source = r#"
591datasource db {
592 provider = "postgresql"
593 url = "postgres://localhost/test"
594}
595
596model User {
597 email String
598 name String
599}
600"#;
601
602 let ast = parse(source).expect("parse");
603 let err = validate(&ast).unwrap_err();
604 assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
605 }
606
607 #[test]
608 fn test_validate_unknown_type() {
609 let source = r#"
610datasource db {
611 provider = "postgresql"
612 url = "postgres://localhost/test"
613}
614
615model User {
616 id String @id
617 role Nonexistent
618}
619"#;
620
621 let ast = parse(source).expect("parse");
622 let err = validate(&ast).unwrap_err();
623 assert!(matches!(err, CoreError::UnknownType { .. }));
624 }
625
626 #[test]
627 fn test_validate_composite_primary_key() {
628 let source = r#"
629datasource db {
630 provider = "sqlite"
631 url = "file:./dev.db"
632}
633
634model PostTag {
635 postId String
636 tagId String
637
638 @@id([postId, tagId])
639}
640"#;
641
642 let ast = parse(source).expect("parse");
643 let schema = validate(&ast).expect("validate");
644 let model = &schema.models[0];
645 assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
646 assert!(model.primary_key.is_composite());
647 }
648
649 #[test]
650 fn test_snake_case() {
651 assert_eq!(to_snake_case("User"), "user");
652 assert_eq!(to_snake_case("PostTag"), "post_tag");
653 assert_eq!(to_snake_case("createdAt"), "created_at");
654 assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
655 }
656
657 #[test]
658 fn test_validate_auto_table_name() {
659 let source = r#"
660datasource db {
661 provider = "postgresql"
662 url = "postgres://localhost/test"
663}
664
665model BlogPost {
666 id String @id
667}
668"#;
669
670 let ast = parse(source).expect("parse");
671 let schema = validate(&ast).expect("validate");
672 assert_eq!(schema.models[0].db_name, "blog_posts");
674 }
675}