1use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::{
21 DatasourceConfig, Enum, Field, FieldKind, GeneratorConfig, Index, Model, PrimaryKey,
22 RelationType, ResolvedRelation, Schema, UniqueConstraint,
23};
24use ferriorm_core::types::{DatabaseProvider, ScalarType};
25use ferriorm_core::utils::to_snake_case;
26
27pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
34 let datasource = validate_datasource(ast)?;
35 let generators = validate_generators(ast)?;
36 let enums = validate_enums(ast)?;
37 let models = validate_models(ast, &enums)?;
38
39 Ok(Schema {
40 datasource,
41 generators,
42 enums,
43 models,
44 })
45}
46
47fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
48 let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
49 message: "Missing datasource block".into(),
50 })?;
51
52 let provider =
53 ds.provider
54 .parse::<DatabaseProvider>()
55 .map_err(|_| CoreError::UnknownProvider {
56 provider: ds.provider.clone(),
57 })?;
58
59 let url = match &ds.url {
60 ast::StringOrEnv::Literal(s) => s.clone(),
61 ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
62 };
63
64 Ok(DatasourceConfig {
65 name: ds.name.clone(),
66 provider,
67 url,
68 })
69}
70
71fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
72 ast.generators
73 .iter()
74 .map(|g| {
75 Ok(GeneratorConfig {
76 name: g.name.clone(),
77 output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
78 })
79 })
80 .collect()
81}
82
83fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
84 let mut names = HashSet::new();
85 let mut result = Vec::new();
86
87 for e in &ast.enums {
88 if !names.insert(&e.name) {
89 return Err(CoreError::DuplicateName {
90 name: e.name.clone(),
91 kind: "enum",
92 });
93 }
94
95 result.push(Enum {
96 name: e.name.clone(),
97 db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
98 variants: e.variants.clone(),
99 });
100 }
101
102 Ok(result)
103}
104
105fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
106 let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
107 let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
108 let mut seen_names = HashSet::new();
109
110 let mut result = Vec::new();
111
112 for model_def in &ast.models {
113 if !seen_names.insert(&model_def.name) {
114 return Err(CoreError::DuplicateName {
115 name: model_def.name.clone(),
116 kind: "model",
117 });
118 }
119
120 if enum_names.contains(model_def.name.as_str()) {
122 return Err(CoreError::DuplicateName {
123 name: model_def.name.clone(),
124 kind: "model/enum",
125 });
126 }
127
128 let model = validate_model(model_def, &enum_names, &model_names)?;
129 result.push(model);
130 }
131
132 Ok(result)
133}
134
135fn validate_model(
136 model_def: &ast::ModelDef,
137 enum_names: &HashSet<&str>,
138 model_names: &HashSet<&str>,
139) -> Result<Model, CoreError> {
140 let db_name = model_def
142 .attributes
143 .iter()
144 .find_map(|a| match a {
145 ast::BlockAttribute::Map(name) => Some(name.clone()),
146 _ => None,
147 })
148 .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
149
150 let mut fields = Vec::new();
151 let mut has_id_field = false;
152
153 for field_def in &model_def.fields {
154 let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
155 if field.is_id {
156 has_id_field = true;
157 }
158 fields.push(field);
159 }
160
161 let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
163 ast::BlockAttribute::Id(fields) => Some(fields.clone()),
164 _ => None,
165 });
166
167 if !has_id_field && composite_id.is_none() {
168 return Err(CoreError::MissingPrimaryKey {
169 model_name: model_def.name.clone(),
170 });
171 }
172
173 let primary_key = if let Some(composite_fields) = composite_id {
174 PrimaryKey {
175 fields: composite_fields,
176 }
177 } else {
178 let id_fields: Vec<String> = fields
179 .iter()
180 .filter(|f| f.is_id)
181 .map(|f| f.name.clone())
182 .collect();
183 PrimaryKey { fields: id_fields }
184 };
185
186 let indexes = model_def
188 .attributes
189 .iter()
190 .filter_map(|a| match a {
191 ast::BlockAttribute::Index(fields) => Some(Index {
192 fields: fields.clone(),
193 }),
194 _ => None,
195 })
196 .collect();
197
198 let unique_constraints = model_def
200 .attributes
201 .iter()
202 .filter_map(|a| match a {
203 ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
204 fields: fields.clone(),
205 }),
206 _ => None,
207 })
208 .collect();
209
210 Ok(Model {
211 name: model_def.name.clone(),
212 db_name,
213 fields,
214 primary_key,
215 indexes,
216 unique_constraints,
217 })
218}
219
220fn validate_field(
221 field_def: &ast::FieldDef,
222 model_name: &str,
223 enum_names: &HashSet<&str>,
224 model_names: &HashSet<&str>,
225) -> Result<Field, CoreError> {
226 let type_name = &field_def.field_type.name;
227
228 let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
229 FieldKind::Scalar(scalar)
230 } else if enum_names.contains(type_name.as_str()) {
231 FieldKind::Enum(type_name.clone())
232 } else if model_names.contains(type_name.as_str()) {
233 FieldKind::Model(type_name.clone())
234 } else {
235 return Err(CoreError::UnknownType {
236 model_name: model_name.to_string(),
237 field_name: field_def.name.clone(),
238 type_name: type_name.clone(),
239 });
240 };
241
242 let is_id = field_def
243 .attributes
244 .iter()
245 .any(|a| matches!(a, ast::FieldAttribute::Id));
246 let is_unique = field_def
247 .attributes
248 .iter()
249 .any(|a| matches!(a, ast::FieldAttribute::Unique));
250 let is_updated_at = field_def
251 .attributes
252 .iter()
253 .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
254 let default = field_def.attributes.iter().find_map(|a| match a {
255 ast::FieldAttribute::Default(d) => Some(d.clone()),
256 _ => None,
257 });
258
259 let db_name = field_def
261 .attributes
262 .iter()
263 .find_map(|a| match a {
264 ast::FieldAttribute::Map(name) => Some(name.clone()),
265 _ => None,
266 })
267 .unwrap_or_else(|| to_snake_case(&field_def.name));
268
269 let relation = field_def.attributes.iter().find_map(|a| match a {
271 ast::FieldAttribute::Relation(rel) => {
272 let relation_type = if field_def.field_type.is_list {
273 RelationType::OneToMany
274 } else if field_def.field_type.is_optional {
275 RelationType::OneToOne
276 } else {
277 RelationType::ManyToOne
278 };
279
280 Some(ResolvedRelation {
281 related_model: type_name.clone(),
282 relation_type,
283 fields: rel.fields.clone(),
284 references: rel.references.clone(),
285 on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
286 on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
287 })
288 }
289 _ => None,
290 });
291
292 Ok(Field {
293 name: field_def.name.clone(),
294 db_name,
295 field_type,
296 is_optional: field_def.field_type.is_optional,
297 is_list: field_def.field_type.is_list,
298 is_id,
299 is_unique,
300 is_updated_at,
301 default,
302 relation,
303 })
304}
305
306#[cfg(test)]
307#[allow(clippy::pedantic)]
308mod tests {
309 use super::*;
310 use crate::parser::parse;
311 use ferriorm_core::utils::to_snake_case;
312
313 #[test]
314 fn test_validate_basic_schema() {
315 let source = r#"
316datasource db {
317 provider = "postgresql"
318 url = env("DATABASE_URL")
319}
320
321generator client {
322 output = "./src/generated"
323}
324
325enum Role {
326 User
327 Admin
328}
329
330model User {
331 id String @id @default(uuid())
332 email String @unique
333 name String?
334 role Role @default(User)
335
336 @@map("users")
337}
338"#;
339
340 let ast = parse(source).expect("parse");
341 let schema = validate(&ast).expect("validate");
342
343 assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
344 assert_eq!(schema.enums.len(), 1);
345 assert_eq!(schema.enums[0].name, "Role");
346 assert_eq!(schema.enums[0].db_name, "role");
347
348 let user = &schema.models[0];
349 assert_eq!(user.name, "User");
350 assert_eq!(user.db_name, "users");
351 assert_eq!(user.primary_key.fields, vec!["id"]);
352
353 let id_field = &user.fields[0];
354 assert!(id_field.is_id);
355 assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
356
357 let name_field = &user.fields[2];
358 assert!(name_field.is_optional);
359 assert_eq!(name_field.db_name, "name");
360
361 let role_field = &user.fields[3];
362 assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
363 }
364
365 #[test]
366 fn test_validate_missing_primary_key() {
367 let source = r#"
368datasource db {
369 provider = "postgresql"
370 url = "postgres://localhost/test"
371}
372
373model User {
374 email String
375 name String
376}
377"#;
378
379 let ast = parse(source).expect("parse");
380 let err = validate(&ast).unwrap_err();
381 assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
382 }
383
384 #[test]
385 fn test_validate_unknown_type() {
386 let source = r#"
387datasource db {
388 provider = "postgresql"
389 url = "postgres://localhost/test"
390}
391
392model User {
393 id String @id
394 role Nonexistent
395}
396"#;
397
398 let ast = parse(source).expect("parse");
399 let err = validate(&ast).unwrap_err();
400 assert!(matches!(err, CoreError::UnknownType { .. }));
401 }
402
403 #[test]
404 fn test_validate_composite_primary_key() {
405 let source = r#"
406datasource db {
407 provider = "sqlite"
408 url = "file:./dev.db"
409}
410
411model PostTag {
412 postId String
413 tagId String
414
415 @@id([postId, tagId])
416}
417"#;
418
419 let ast = parse(source).expect("parse");
420 let schema = validate(&ast).expect("validate");
421 let model = &schema.models[0];
422 assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
423 assert!(model.primary_key.is_composite());
424 }
425
426 #[test]
427 fn test_snake_case() {
428 assert_eq!(to_snake_case("User"), "user");
429 assert_eq!(to_snake_case("PostTag"), "post_tag");
430 assert_eq!(to_snake_case("createdAt"), "created_at");
431 assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
432 }
433
434 #[test]
435 fn test_validate_auto_table_name() {
436 let source = r#"
437datasource db {
438 provider = "postgresql"
439 url = "postgres://localhost/test"
440}
441
442model BlogPost {
443 id String @id
444}
445"#;
446
447 let ast = parse(source).expect("parse");
448 let schema = validate(&ast).expect("validate");
449 assert_eq!(schema.models[0].db_name, "blog_posts");
451 }
452}