1use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21 DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22 RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34 let datasource = validate_datasource(ast)?;
35 let generators = validate_generators(ast)?;
36 let enums = validate_enums(ast)?;
37 let models = validate_models(ast, &enums)?;
38
39 Ok(Schema {
40 datasource,
41 generators,
42 enums,
43 models,
44 })
45}
46
47fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
48 let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
49 message: "Missing datasource block".into(),
50 })?;
51
52 let provider =
53 ds.provider
54 .parse::<DatabaseProvider>()
55 .map_err(|_| CoreError::UnknownProvider {
56 provider: ds.provider.clone(),
57 })?;
58
59 let url = match &ds.url {
60 ast::StringOrEnv::Literal(s) => s.clone(),
61 ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
62 };
63
64 Ok(DatasourceConfig {
65 name: ds.name.clone(),
66 provider,
67 url,
68 })
69}
70
71fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
72 ast.generators
73 .iter()
74 .map(|g| {
75 Ok(GeneratorConfig {
76 name: g.name.clone(),
77 output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
78 })
79 })
80 .collect()
81}
82
83fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
84 let mut names = HashSet::new();
85 let mut result = Vec::new();
86
87 for e in &ast.enums {
88 if !names.insert(&e.name) {
89 return Err(CoreError::DuplicateName {
90 name: e.name.clone(),
91 kind: "enum",
92 });
93 }
94
95 result.push(Enum {
96 name: e.name.clone(),
97 db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
98 variants: e.variants.clone(),
99 });
100 }
101
102 Ok(result)
103}
104
105fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
106 let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
107 let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
108 let mut seen_names = HashSet::new();
109
110 let mut result = Vec::new();
111
112 for model_def in &ast.models {
113 if !seen_names.insert(&model_def.name) {
114 return Err(CoreError::DuplicateName {
115 name: model_def.name.clone(),
116 kind: "model",
117 });
118 }
119
120 if enum_names.contains(model_def.name.as_str()) {
122 return Err(CoreError::DuplicateName {
123 name: model_def.name.clone(),
124 kind: "model/enum",
125 });
126 }
127
128 let model = validate_model(model_def, &enum_names, &model_names)?;
129 result.push(model);
130 }
131
132 Ok(result)
133}
134
135fn validate_model(
136 model_def: &ast::ModelDef,
137 enum_names: &HashSet<&str>,
138 model_names: &HashSet<&str>,
139) -> Result<Model, CoreError> {
140 let db_name = model_def
142 .attributes
143 .iter()
144 .find_map(|a| match a {
145 ast::BlockAttribute::Map(name) => Some(name.clone()),
146 _ => None,
147 })
148 .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
149
150 let mut fields = Vec::new();
151 let mut has_id_field = false;
152
153 for field_def in &model_def.fields {
154 let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
155 if field.is_id {
156 has_id_field = true;
157 }
158 fields.push(field);
159 }
160
161 let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
163 ast::BlockAttribute::Id(fields) => Some(fields.clone()),
164 _ => None,
165 });
166
167 if !has_id_field && composite_id.is_none() {
168 return Err(CoreError::MissingPrimaryKey {
169 model_name: model_def.name.clone(),
170 });
171 }
172
173 let primary_key = if let Some(composite_fields) = composite_id {
174 PrimaryKey {
175 fields: composite_fields,
176 }
177 } else {
178 let id_fields: Vec<String> = fields
179 .iter()
180 .filter(|f| f.is_id)
181 .map(|f| f.name.clone())
182 .collect();
183 PrimaryKey { fields: id_fields }
184 };
185
186 let indexes = model_def
188 .attributes
189 .iter()
190 .filter_map(|a| match a {
191 ast::BlockAttribute::Index(fields) => Some(Index {
192 fields: fields.clone(),
193 }),
194 _ => None,
195 })
196 .collect();
197
198 let unique_constraints = model_def
200 .attributes
201 .iter()
202 .filter_map(|a| match a {
203 ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
204 fields: fields.clone(),
205 }),
206 _ => None,
207 })
208 .collect();
209
210 Ok(Model {
211 name: model_def.name.clone(),
212 db_name,
213 fields,
214 primary_key,
215 indexes,
216 unique_constraints,
217 })
218}
219
220fn validate_field(
221 field_def: &ast::FieldDef,
222 model_name: &str,
223 enum_names: &HashSet<&str>,
224 model_names: &HashSet<&str>,
225) -> Result<Field, CoreError> {
226 let type_name = &field_def.field_type.name;
227
228 let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
229 FieldKind::Scalar(scalar)
230 } else if enum_names.contains(type_name.as_str()) {
231 FieldKind::Enum(type_name.clone())
232 } else if model_names.contains(type_name.as_str()) {
233 FieldKind::Model(type_name.clone())
234 } else {
235 return Err(CoreError::UnknownType {
236 model_name: model_name.to_string(),
237 field_name: field_def.name.clone(),
238 type_name: type_name.clone(),
239 });
240 };
241
242 let is_id = field_def
243 .attributes
244 .iter()
245 .any(|a| matches!(a, ast::FieldAttribute::Id));
246 let is_unique = field_def
247 .attributes
248 .iter()
249 .any(|a| matches!(a, ast::FieldAttribute::Unique));
250 let is_updated_at = field_def
251 .attributes
252 .iter()
253 .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
254 let default = field_def.attributes.iter().find_map(|a| match a {
255 ast::FieldAttribute::Default(d) => Some(d.clone()),
256 _ => None,
257 });
258
259 let db_name = field_def
261 .attributes
262 .iter()
263 .find_map(|a| match a {
264 ast::FieldAttribute::Map(name) => Some(name.clone()),
265 _ => None,
266 })
267 .unwrap_or_else(|| to_snake_case(&field_def.name));
268
269 let relation = field_def.attributes.iter().find_map(|a| match a {
271 ast::FieldAttribute::Relation(rel) => {
272 let relation_type = if field_def.field_type.is_list {
273 RelationType::OneToMany
274 } else if field_def.field_type.is_optional {
275 RelationType::OneToOne
276 } else {
277 RelationType::ManyToOne
278 };
279
280 Some(ResolvedRelation {
281 related_model: type_name.clone(),
282 relation_type,
283 fields: rel.fields.clone(),
284 references: rel.references.clone(),
285 on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
286 on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
287 })
288 }
289 _ => None,
290 });
291
292 let db_type = field_def.attributes.iter().find_map(|a| match a {
294 ast::FieldAttribute::DbType(ty, args) => Some((ty.clone(), args.clone())),
295 _ => None,
296 });
297
298 Ok(Field {
299 name: field_def.name.clone(),
300 db_name,
301 field_type,
302 is_optional: field_def.field_type.is_optional,
303 is_list: field_def.field_type.is_list,
304 is_id,
305 is_unique,
306 is_updated_at,
307 default,
308 relation,
309 db_type,
310 })
311}
312
313#[cfg(test)]
314#[allow(clippy::pedantic)]
315mod tests {
316 use super::*;
317 use crate::parser::parse;
318 use ferriorm_core::utils::to_snake_case;
319
320 #[test]
321 fn test_validate_basic_schema() {
322 let source = r#"
323datasource db {
324 provider = "postgresql"
325 url = env("DATABASE_URL")
326}
327
328generator client {
329 output = "./src/generated"
330}
331
332enum Role {
333 User
334 Admin
335}
336
337model User {
338 id String @id @default(uuid())
339 email String @unique
340 name String?
341 role Role @default(User)
342
343 @@map("users")
344}
345"#;
346
347 let ast = parse(source).expect("parse");
348 let schema = validate(&ast).expect("validate");
349
350 assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
351 assert_eq!(schema.enums.len(), 1);
352 assert_eq!(schema.enums[0].name, "Role");
353 assert_eq!(schema.enums[0].db_name, "role");
354
355 let user = &schema.models[0];
356 assert_eq!(user.name, "User");
357 assert_eq!(user.db_name, "users");
358 assert_eq!(user.primary_key.fields, vec!["id"]);
359
360 let id_field = &user.fields[0];
361 assert!(id_field.is_id);
362 assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
363
364 let name_field = &user.fields[2];
365 assert!(name_field.is_optional);
366 assert_eq!(name_field.db_name, "name");
367
368 let role_field = &user.fields[3];
369 assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
370 }
371
372 #[test]
373 fn test_validate_missing_primary_key() {
374 let source = r#"
375datasource db {
376 provider = "postgresql"
377 url = "postgres://localhost/test"
378}
379
380model User {
381 email String
382 name String
383}
384"#;
385
386 let ast = parse(source).expect("parse");
387 let err = validate(&ast).unwrap_err();
388 assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
389 }
390
391 #[test]
392 fn test_validate_unknown_type() {
393 let source = r#"
394datasource db {
395 provider = "postgresql"
396 url = "postgres://localhost/test"
397}
398
399model User {
400 id String @id
401 role Nonexistent
402}
403"#;
404
405 let ast = parse(source).expect("parse");
406 let err = validate(&ast).unwrap_err();
407 assert!(matches!(err, CoreError::UnknownType { .. }));
408 }
409
410 #[test]
411 fn test_validate_composite_primary_key() {
412 let source = r#"
413datasource db {
414 provider = "sqlite"
415 url = "file:./dev.db"
416}
417
418model PostTag {
419 postId String
420 tagId String
421
422 @@id([postId, tagId])
423}
424"#;
425
426 let ast = parse(source).expect("parse");
427 let schema = validate(&ast).expect("validate");
428 let model = &schema.models[0];
429 assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
430 assert!(model.primary_key.is_composite());
431 }
432
433 #[test]
434 fn test_snake_case() {
435 assert_eq!(to_snake_case("User"), "user");
436 assert_eq!(to_snake_case("PostTag"), "post_tag");
437 assert_eq!(to_snake_case("createdAt"), "created_at");
438 assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
439 }
440
441 #[test]
442 fn test_validate_auto_table_name() {
443 let source = r#"
444datasource db {
445 provider = "postgresql"
446 url = "postgres://localhost/test"
447}
448
449model BlogPost {
450 id String @id
451}
452"#;
453
454 let ast = parse(source).expect("parse");
455 let schema = validate(&ast).expect("validate");
456 assert_eq!(schema.models[0].db_name, "blog_posts");
458 }
459}