1use std::collections::HashSet;
17
18use ferriorm_core::ast;
19use ferriorm_core::error::CoreError;
20use ferriorm_core::schema::*;
21use ferriorm_core::types::{DatabaseProvider, ScalarType};
22use ferriorm_core::utils::to_snake_case;
23
24pub fn validate(ast: &ast::SchemaFile) -> Result<Schema, CoreError> {
26 let datasource = validate_datasource(ast)?;
27 let generators = validate_generators(ast)?;
28 let enums = validate_enums(ast)?;
29 let models = validate_models(ast, &enums)?;
30
31 Ok(Schema {
32 datasource,
33 generators,
34 enums,
35 models,
36 })
37}
38
39fn validate_datasource(ast: &ast::SchemaFile) -> Result<DatasourceConfig, CoreError> {
40 let ds = ast.datasource.as_ref().ok_or(CoreError::Validation {
41 message: "Missing datasource block".into(),
42 })?;
43
44 let provider =
45 ds.provider
46 .parse::<DatabaseProvider>()
47 .map_err(|_| CoreError::UnknownProvider {
48 provider: ds.provider.clone(),
49 })?;
50
51 let url = match &ds.url {
52 ast::StringOrEnv::Literal(s) => s.clone(),
53 ast::StringOrEnv::Env(var) => format!("${{env:{var}}}"),
54 };
55
56 Ok(DatasourceConfig {
57 name: ds.name.clone(),
58 provider,
59 url,
60 })
61}
62
63fn validate_generators(ast: &ast::SchemaFile) -> Result<Vec<GeneratorConfig>, CoreError> {
64 ast.generators
65 .iter()
66 .map(|g| {
67 Ok(GeneratorConfig {
68 name: g.name.clone(),
69 output: g.output.clone().unwrap_or_else(|| "./src/generated".into()),
70 })
71 })
72 .collect()
73}
74
75fn validate_enums(ast: &ast::SchemaFile) -> Result<Vec<Enum>, CoreError> {
76 let mut names = HashSet::new();
77 let mut result = Vec::new();
78
79 for e in &ast.enums {
80 if !names.insert(&e.name) {
81 return Err(CoreError::DuplicateName {
82 name: e.name.clone(),
83 kind: "enum",
84 });
85 }
86
87 result.push(Enum {
88 name: e.name.clone(),
89 db_name: e.db_name.clone().unwrap_or_else(|| to_snake_case(&e.name)),
90 variants: e.variants.clone(),
91 });
92 }
93
94 Ok(result)
95}
96
97fn validate_models(ast: &ast::SchemaFile, enums: &[Enum]) -> Result<Vec<Model>, CoreError> {
98 let enum_names: HashSet<&str> = enums.iter().map(|e| e.name.as_str()).collect();
99 let model_names: HashSet<&str> = ast.models.iter().map(|m| m.name.as_str()).collect();
100 let mut seen_names = HashSet::new();
101
102 let mut result = Vec::new();
103
104 for model_def in &ast.models {
105 if !seen_names.insert(&model_def.name) {
106 return Err(CoreError::DuplicateName {
107 name: model_def.name.clone(),
108 kind: "model",
109 });
110 }
111
112 if enum_names.contains(model_def.name.as_str()) {
114 return Err(CoreError::DuplicateName {
115 name: model_def.name.clone(),
116 kind: "model/enum",
117 });
118 }
119
120 let model = validate_model(model_def, &enum_names, &model_names)?;
121 result.push(model);
122 }
123
124 Ok(result)
125}
126
127fn validate_model(
128 model_def: &ast::ModelDef,
129 enum_names: &HashSet<&str>,
130 model_names: &HashSet<&str>,
131) -> Result<Model, CoreError> {
132 let db_name = model_def
134 .attributes
135 .iter()
136 .find_map(|a| match a {
137 ast::BlockAttribute::Map(name) => Some(name.clone()),
138 _ => None,
139 })
140 .unwrap_or_else(|| to_snake_case(&model_def.name) + "s");
141
142 let mut fields = Vec::new();
143 let mut has_id_field = false;
144
145 for field_def in &model_def.fields {
146 let field = validate_field(field_def, &model_def.name, enum_names, model_names)?;
147 if field.is_id {
148 has_id_field = true;
149 }
150 fields.push(field);
151 }
152
153 let composite_id: Option<Vec<String>> = model_def.attributes.iter().find_map(|a| match a {
155 ast::BlockAttribute::Id(fields) => Some(fields.clone()),
156 _ => None,
157 });
158
159 if !has_id_field && composite_id.is_none() {
160 return Err(CoreError::MissingPrimaryKey {
161 model_name: model_def.name.clone(),
162 });
163 }
164
165 let primary_key = if let Some(composite_fields) = composite_id {
166 PrimaryKey {
167 fields: composite_fields,
168 }
169 } else {
170 let id_fields: Vec<String> = fields
171 .iter()
172 .filter(|f| f.is_id)
173 .map(|f| f.name.clone())
174 .collect();
175 PrimaryKey { fields: id_fields }
176 };
177
178 let indexes = model_def
180 .attributes
181 .iter()
182 .filter_map(|a| match a {
183 ast::BlockAttribute::Index(fields) => Some(Index {
184 fields: fields.clone(),
185 }),
186 _ => None,
187 })
188 .collect();
189
190 let unique_constraints = model_def
192 .attributes
193 .iter()
194 .filter_map(|a| match a {
195 ast::BlockAttribute::Unique(fields) => Some(UniqueConstraint {
196 fields: fields.clone(),
197 }),
198 _ => None,
199 })
200 .collect();
201
202 Ok(Model {
203 name: model_def.name.clone(),
204 db_name,
205 fields,
206 primary_key,
207 indexes,
208 unique_constraints,
209 })
210}
211
212fn validate_field(
213 field_def: &ast::FieldDef,
214 model_name: &str,
215 enum_names: &HashSet<&str>,
216 model_names: &HashSet<&str>,
217) -> Result<Field, CoreError> {
218 let type_name = &field_def.field_type.name;
219
220 let field_type = if let Ok(scalar) = type_name.parse::<ScalarType>() {
221 FieldKind::Scalar(scalar)
222 } else if enum_names.contains(type_name.as_str()) {
223 FieldKind::Enum(type_name.clone())
224 } else if model_names.contains(type_name.as_str()) {
225 FieldKind::Model(type_name.clone())
226 } else {
227 return Err(CoreError::UnknownType {
228 model_name: model_name.to_string(),
229 field_name: field_def.name.clone(),
230 type_name: type_name.clone(),
231 });
232 };
233
234 let is_id = field_def
235 .attributes
236 .iter()
237 .any(|a| matches!(a, ast::FieldAttribute::Id));
238 let is_unique = field_def
239 .attributes
240 .iter()
241 .any(|a| matches!(a, ast::FieldAttribute::Unique));
242 let is_updated_at = field_def
243 .attributes
244 .iter()
245 .any(|a| matches!(a, ast::FieldAttribute::UpdatedAt));
246 let default = field_def.attributes.iter().find_map(|a| match a {
247 ast::FieldAttribute::Default(d) => Some(d.clone()),
248 _ => None,
249 });
250
251 let db_name = field_def
253 .attributes
254 .iter()
255 .find_map(|a| match a {
256 ast::FieldAttribute::Map(name) => Some(name.clone()),
257 _ => None,
258 })
259 .unwrap_or_else(|| to_snake_case(&field_def.name));
260
261 let relation = field_def.attributes.iter().find_map(|a| match a {
263 ast::FieldAttribute::Relation(rel) => {
264 let relation_type = if field_def.field_type.is_list {
265 RelationType::OneToMany
266 } else if field_def.field_type.is_optional {
267 RelationType::OneToOne
268 } else {
269 RelationType::ManyToOne
270 };
271
272 Some(ResolvedRelation {
273 related_model: type_name.clone(),
274 relation_type,
275 fields: rel.fields.clone(),
276 references: rel.references.clone(),
277 on_delete: rel.on_delete.unwrap_or(ast::ReferentialAction::Restrict),
278 on_update: rel.on_update.unwrap_or(ast::ReferentialAction::Cascade),
279 })
280 }
281 _ => None,
282 });
283
284 Ok(Field {
285 name: field_def.name.clone(),
286 db_name,
287 field_type,
288 is_optional: field_def.field_type.is_optional,
289 is_list: field_def.field_type.is_list,
290 is_id,
291 is_unique,
292 is_updated_at,
293 default,
294 relation,
295 })
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::parser::parse;
302 use ferriorm_core::utils::to_snake_case;
303
304 #[test]
305 fn test_validate_basic_schema() {
306 let source = r#"
307datasource db {
308 provider = "postgresql"
309 url = env("DATABASE_URL")
310}
311
312generator client {
313 output = "./src/generated"
314}
315
316enum Role {
317 User
318 Admin
319}
320
321model User {
322 id String @id @default(uuid())
323 email String @unique
324 name String?
325 role Role @default(User)
326
327 @@map("users")
328}
329"#;
330
331 let ast = parse(source).expect("parse");
332 let schema = validate(&ast).expect("validate");
333
334 assert_eq!(schema.datasource.provider, DatabaseProvider::PostgreSQL);
335 assert_eq!(schema.enums.len(), 1);
336 assert_eq!(schema.enums[0].name, "Role");
337 assert_eq!(schema.enums[0].db_name, "role");
338
339 let user = &schema.models[0];
340 assert_eq!(user.name, "User");
341 assert_eq!(user.db_name, "users");
342 assert_eq!(user.primary_key.fields, vec!["id"]);
343
344 let id_field = &user.fields[0];
345 assert!(id_field.is_id);
346 assert_eq!(id_field.field_type, FieldKind::Scalar(ScalarType::String));
347
348 let name_field = &user.fields[2];
349 assert!(name_field.is_optional);
350 assert_eq!(name_field.db_name, "name");
351
352 let role_field = &user.fields[3];
353 assert_eq!(role_field.field_type, FieldKind::Enum("Role".into()));
354 }
355
356 #[test]
357 fn test_validate_missing_primary_key() {
358 let source = r#"
359datasource db {
360 provider = "postgresql"
361 url = "postgres://localhost/test"
362}
363
364model User {
365 email String
366 name String
367}
368"#;
369
370 let ast = parse(source).expect("parse");
371 let err = validate(&ast).unwrap_err();
372 assert!(matches!(err, CoreError::MissingPrimaryKey { .. }));
373 }
374
375 #[test]
376 fn test_validate_unknown_type() {
377 let source = r#"
378datasource db {
379 provider = "postgresql"
380 url = "postgres://localhost/test"
381}
382
383model User {
384 id String @id
385 role Nonexistent
386}
387"#;
388
389 let ast = parse(source).expect("parse");
390 let err = validate(&ast).unwrap_err();
391 assert!(matches!(err, CoreError::UnknownType { .. }));
392 }
393
394 #[test]
395 fn test_validate_composite_primary_key() {
396 let source = r#"
397datasource db {
398 provider = "sqlite"
399 url = "file:./dev.db"
400}
401
402model PostTag {
403 postId String
404 tagId String
405
406 @@id([postId, tagId])
407}
408"#;
409
410 let ast = parse(source).expect("parse");
411 let schema = validate(&ast).expect("validate");
412 let model = &schema.models[0];
413 assert_eq!(model.primary_key.fields, vec!["postId", "tagId"]);
414 assert!(model.primary_key.is_composite());
415 }
416
417 #[test]
418 fn test_snake_case() {
419 assert_eq!(to_snake_case("User"), "user");
420 assert_eq!(to_snake_case("PostTag"), "post_tag");
421 assert_eq!(to_snake_case("createdAt"), "created_at");
422 assert_eq!(to_snake_case("HTMLParser"), "h_t_m_l_parser");
423 }
424
425 #[test]
426 fn test_validate_auto_table_name() {
427 let source = r#"
428datasource db {
429 provider = "postgresql"
430 url = "postgres://localhost/test"
431}
432
433model BlogPost {
434 id String @id
435}
436"#;
437
438 let ast = parse(source).expect("parse");
439 let schema = validate(&ast).expect("validate");
440 assert_eq!(schema.models[0].db_name, "blog_posts");
442 }
443}