1use ferriorm_core::ast::{
11 BlockAttribute, DefaultValue, EnumDef, FieldAttribute, FieldDef, FieldType, Generator,
12 IndexAttribute, LiteralValue, ModelDef, ReferentialAction, RelationAttribute, SchemaFile, Span,
13 StringOrEnv,
14};
15use pest::Parser;
16use pest_derive::Parser;
17
18use crate::error::ParseError;
19
20#[derive(Parser)]
21#[grammar = "grammar.pest"]
22struct FerriormParser;
23
24pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
35 let pairs = FerriormParser::parse(Rule::schema, source)
36 .map_err(|e| ParseError::Syntax(e.to_string()))?;
37
38 let mut schema = SchemaFile {
39 datasource: None,
40 generators: Vec::new(),
41 enums: Vec::new(),
42 models: Vec::new(),
43 };
44
45 let schema_pair = pairs.into_iter().next().unwrap();
47 for pair in schema_pair.into_inner() {
48 match pair.as_rule() {
49 Rule::datasource_block => {
50 schema.datasource = Some(parse_datasource(pair)?);
51 }
52 Rule::generator_block => {
53 schema.generators.push(parse_generator(pair));
54 }
55 Rule::enum_block => {
56 schema.enums.push(parse_enum(pair));
57 }
58 Rule::model_block => {
59 schema.models.push(parse_model(pair)?);
60 }
61 _ => {}
62 }
63 }
64
65 Ok(schema)
66}
67
68fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
69 let span = pair.as_span();
70 Span {
71 start: span.start(),
72 end: span.end(),
73 }
74}
75
76fn parse_datasource(
77 pair: pest::iterators::Pair<'_, Rule>,
78) -> Result<ferriorm_core::ast::Datasource, ParseError> {
79 let span = span_from(&pair);
80 let mut inner = pair.into_inner();
81 let name = inner.next().unwrap().as_str().to_string();
82
83 let mut provider = String::new();
84 let mut url = StringOrEnv::Literal(String::new());
85
86 for kv in inner {
87 if kv.as_rule() != Rule::kv_pair {
88 continue;
89 }
90 let mut kv_inner = kv.into_inner();
91 let key = kv_inner.next().unwrap().as_str();
92 let value_pair = kv_inner.next().unwrap();
93
94 match key {
95 "provider" => {
96 provider = parse_string_value(&value_pair);
97 }
98 "url" => {
99 url = parse_string_or_env(&value_pair)?;
100 }
101 _ => {}
102 }
103 }
104
105 Ok(ferriorm_core::ast::Datasource {
106 name,
107 provider,
108 url,
109 span,
110 })
111}
112
113fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Generator {
114 let span = span_from(&pair);
115 let mut inner = pair.into_inner();
116 let name = inner.next().unwrap().as_str().to_string();
117
118 let mut output = None;
119
120 for kv in inner {
121 if kv.as_rule() != Rule::kv_pair {
122 continue;
123 }
124 let mut kv_inner = kv.into_inner();
125 let key = kv_inner.next().unwrap().as_str();
126 let value_pair = kv_inner.next().unwrap();
127
128 if key == "output" {
129 output = Some(parse_string_value(&value_pair));
130 }
131 }
132
133 Generator { name, output, span }
134}
135
136fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> EnumDef {
137 let span = span_from(&pair);
138 let mut inner = pair.into_inner();
139 let name = inner.next().unwrap().as_str().to_string();
140
141 let mut variants = Vec::new();
142 let mut db_name = None;
143 for member in inner {
144 match member.as_rule() {
145 Rule::enum_variant => {
146 let variant_name = member.into_inner().next().unwrap().as_str().to_string();
147 variants.push(variant_name);
148 }
149 Rule::enum_block_attr_map => {
150 let s = member.into_inner().next().unwrap().as_str();
151 db_name = Some(unquote(s));
152 }
153 _ => {}
154 }
155 }
156
157 EnumDef {
158 name,
159 variants,
160 db_name,
161 span,
162 }
163}
164
165fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
166 let span = span_from(&pair);
167 let mut inner = pair.into_inner();
168 let name = inner.next().unwrap().as_str().to_string();
169
170 let mut fields = Vec::new();
171 let mut attributes = Vec::new();
172
173 for member in inner {
174 match member.as_rule() {
175 Rule::field_def => {
176 fields.push(parse_field(member)?);
177 }
178 Rule::block_attr_index => {
179 attributes.push(BlockAttribute::Index(parse_index_attribute(member)));
180 }
181 Rule::block_attr_unique => {
182 attributes.push(BlockAttribute::Unique(parse_index_attribute(member)));
183 }
184 Rule::block_attr_map => {
185 let s = member.into_inner().next().unwrap().as_str();
186 attributes.push(BlockAttribute::Map(unquote(s)));
187 }
188 Rule::block_attr_id => {
189 attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
190 }
191 _ => {}
192 }
193 }
194
195 Ok(ModelDef {
196 name,
197 fields,
198 attributes,
199 span,
200 })
201}
202
203fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
204 let span = span_from(&pair);
205 let mut inner = pair.into_inner();
206
207 let name = inner.next().unwrap().as_str().to_string();
208 let field_type_pair = inner.next().unwrap();
209 let field_type = parse_field_type(field_type_pair);
210
211 let mut attributes = Vec::new();
212 for attr_pair in inner {
213 if let Some(attr) = parse_field_attribute(attr_pair)? {
214 attributes.push(attr);
215 }
216 }
217
218 Ok(FieldDef {
219 name,
220 field_type,
221 attributes,
222 span,
223 })
224}
225
226fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
227 let mut inner = pair.into_inner();
228 let name = inner.next().unwrap().as_str().to_string();
229
230 let mut is_list = false;
231 let mut is_optional = false;
232
233 for modifier in inner {
234 match modifier.as_rule() {
235 Rule::list_modifier => is_list = true,
236 Rule::optional_modifier => is_optional = true,
237 _ => {}
238 }
239 }
240
241 FieldType {
242 name,
243 is_list,
244 is_optional,
245 }
246}
247
248fn parse_field_attribute(
249 pair: pest::iterators::Pair<'_, Rule>,
250) -> Result<Option<FieldAttribute>, ParseError> {
251 match pair.as_rule() {
252 Rule::attr_id => Ok(Some(FieldAttribute::Id)),
253 Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
254 Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
255 Rule::attr_default => {
256 let value_pair = pair.into_inner().next().unwrap();
257 let default = parse_default_value(value_pair)?;
258 Ok(Some(FieldAttribute::Default(default)))
259 }
260 Rule::attr_map => {
261 let s = pair.into_inner().next().unwrap().as_str();
262 Ok(Some(FieldAttribute::Map(unquote(s))))
263 }
264 Rule::attr_relation => {
265 let relation = parse_relation_attribute(pair);
266 Ok(Some(FieldAttribute::Relation(relation)))
267 }
268 Rule::attr_db_type => {
269 let mut inner = pair.into_inner();
270 let type_name = inner.next().unwrap().as_str().to_string();
271 let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
272 Ok(Some(FieldAttribute::DbType(type_name, args)))
273 }
274 _ => Ok(None),
275 }
276}
277
278fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
279 match pair.as_rule() {
280 Rule::func_call => {
281 let mut inner = pair.into_inner();
282 let func_name = inner.next().unwrap().as_str();
283 match func_name {
284 "uuid" => Ok(DefaultValue::Uuid),
285 "cuid" => Ok(DefaultValue::Cuid),
286 "autoincrement" => Ok(DefaultValue::AutoIncrement),
287 "now" => Ok(DefaultValue::Now),
288 other => Err(ParseError::Syntax(format!(
289 "Unknown default function: {other}()"
290 ))),
291 }
292 }
293 Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
294 pair.as_str(),
295 )))),
296 Rule::number_literal => {
297 let s = pair.as_str();
298 if s.contains('.') {
299 Ok(DefaultValue::Literal(LiteralValue::Float(
300 s.parse()
301 .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
302 )))
303 } else {
304 Ok(DefaultValue::Literal(LiteralValue::Int(
305 s.parse()
306 .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
307 )))
308 }
309 }
310 Rule::boolean_literal => {
311 let b = pair.as_str() == "true";
312 Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
313 }
314 Rule::identifier_value => {
315 let name = pair.into_inner().next().unwrap().as_str().to_string();
316 Ok(DefaultValue::EnumVariant(name))
317 }
318 _ => Err(ParseError::Syntax(format!(
319 "Unexpected default value: {:?}",
320 pair.as_rule()
321 ))),
322 }
323}
324
325fn parse_relation_attribute(pair: pest::iterators::Pair<'_, Rule>) -> RelationAttribute {
326 let args_pair = pair.into_inner().next().unwrap(); let mut fields = Vec::new();
328 let mut references = Vec::new();
329 let mut on_delete = None;
330 let mut on_update = None;
331 let mut name = None;
332
333 for arg in args_pair.into_inner() {
334 match arg.as_rule() {
335 Rule::string_literal => {
337 name = Some(parse_string_value(&arg));
338 continue;
339 }
340 Rule::relation_arg | Rule::named_arg => {}
341 _ => continue,
342 }
343
344 let named_arg = if arg.as_rule() == Rule::relation_arg {
346 arg.into_inner().next().unwrap()
347 } else {
348 arg
349 };
350
351 let mut named = named_arg.into_inner();
353 let key = named.next().unwrap().as_str();
354 let value_pair = named.next().unwrap();
355
356 match key {
357 "fields" => fields = parse_field_list(&value_pair),
358 "references" => references = parse_field_list(&value_pair),
359 "onDelete" => on_delete = parse_referential_action(&value_pair),
360 "onUpdate" => on_update = parse_referential_action(&value_pair),
361 "name" => name = Some(parse_string_value(&value_pair)),
362 _ => {}
363 }
364 }
365
366 RelationAttribute {
367 name,
368 fields,
369 references,
370 on_delete,
371 on_update,
372 }
373}
374
375fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
376 let s = pair.as_str().trim_matches('"');
377 match s {
378 "Cascade" => Some(ReferentialAction::Cascade),
379 "Restrict" => Some(ReferentialAction::Restrict),
380 "NoAction" => Some(ReferentialAction::NoAction),
381 "SetNull" => Some(ReferentialAction::SetNull),
382 "SetDefault" => Some(ReferentialAction::SetDefault),
383 _ => None,
384 }
385}
386
387fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
388 pair.clone()
389 .into_inner()
390 .filter(|p| p.as_rule() == Rule::identifier)
391 .map(|p| p.as_str().to_string())
392 .collect()
393}
394
395fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
396 let field_list = pair.into_inner().next().unwrap();
397 parse_field_list(&field_list)
398}
399
400fn parse_index_attribute(pair: pest::iterators::Pair<'_, Rule>) -> IndexAttribute {
404 let mut inner = pair.into_inner();
405 let field_list = inner.next().unwrap();
406 let fields = parse_field_list(&field_list);
407 let mut name = None;
408
409 for arg in inner {
410 if arg.as_rule() != Rule::named_arg {
411 continue;
412 }
413 let mut named = arg.into_inner();
414 let key = named.next().unwrap().as_str();
415 let value_pair = named.next().unwrap();
416 if key == "name" || key == "map" {
417 name = Some(parse_string_value(&value_pair));
418 }
419 }
420
421 IndexAttribute { fields, name }
422}
423
424fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
425 match pair.as_rule() {
426 Rule::func_call => {
427 let mut inner = pair.clone().into_inner();
428 let func_name = inner.next().unwrap().as_str();
429 if func_name == "env" {
430 let arg = inner
431 .next()
432 .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
433 Ok(StringOrEnv::Env(unquote(arg.as_str())))
434 } else {
435 Err(ParseError::Syntax(format!(
436 "Expected env(), got {func_name}()"
437 )))
438 }
439 }
440 Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
441 _ => Err(ParseError::Syntax(format!(
442 "Expected string or env(), got {:?}",
443 pair.as_rule()
444 ))),
445 }
446}
447
448fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
449 unquote(pair.as_str())
450}
451
452fn unquote(s: &str) -> String {
453 if s.starts_with('"') && s.ends_with('"') {
454 s[1..s.len() - 1].to_string()
455 } else {
456 s.to_string()
457 }
458}
459
460#[cfg(test)]
461#[allow(clippy::pedantic)]
462mod tests {
463 use super::*;
464
465 const BASIC_SCHEMA: &str = r#"
466datasource db {
467 provider = "postgresql"
468 url = env("DATABASE_URL")
469}
470
471generator client {
472 output = "./src/generated"
473}
474
475enum Role {
476 User
477 Admin
478 Moderator
479}
480
481model User {
482 id String @id @default(uuid())
483 email String @unique
484 name String?
485 role Role @default(User)
486 createdAt DateTime @default(now())
487 updatedAt DateTime @updatedAt
488
489 @@index([email])
490 @@map("users")
491}
492"#;
493
494 #[test]
495 fn test_parse_basic_schema() {
496 let schema = parse(BASIC_SCHEMA).expect("should parse");
497
498 let ds = schema.datasource.expect("should have datasource");
500 assert_eq!(ds.name, "db");
501 assert_eq!(ds.provider, "postgresql");
502 match &ds.url {
503 StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
504 _ => panic!("expected env()"),
505 }
506
507 assert_eq!(schema.generators.len(), 1);
509 assert_eq!(schema.generators[0].name, "client");
510 assert_eq!(
511 schema.generators[0].output.as_deref(),
512 Some("./src/generated")
513 );
514
515 assert_eq!(schema.enums.len(), 1);
517 assert_eq!(schema.enums[0].name, "Role");
518 assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
519
520 assert_eq!(schema.models.len(), 1);
522 let user = &schema.models[0];
523 assert_eq!(user.name, "User");
524 assert_eq!(user.fields.len(), 6);
525
526 let id_field = &user.fields[0];
528 assert_eq!(id_field.name, "id");
529 assert_eq!(id_field.field_type.name, "String");
530 assert!(!id_field.field_type.is_optional);
531 assert!(
532 id_field
533 .attributes
534 .iter()
535 .any(|a| matches!(a, FieldAttribute::Id))
536 );
537 assert!(
538 id_field
539 .attributes
540 .iter()
541 .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
542 );
543
544 let name_field = &user.fields[2];
546 assert_eq!(name_field.name, "name");
547 assert!(name_field.field_type.is_optional);
548
549 let role_field = &user.fields[3];
551 assert_eq!(role_field.name, "role");
552 assert!(role_field.attributes.iter().any(
553 |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
554 ));
555
556 let updated_field = &user.fields[5];
558 assert_eq!(updated_field.name, "updatedAt");
559 assert!(
560 updated_field
561 .attributes
562 .iter()
563 .any(|a| matches!(a, FieldAttribute::UpdatedAt))
564 );
565
566 assert_eq!(user.attributes.len(), 2);
568 assert!(
569 user.attributes
570 .iter()
571 .any(|a| matches!(a, BlockAttribute::Index(idx) if idx.fields == ["email"]))
572 );
573 assert!(
574 user.attributes
575 .iter()
576 .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
577 );
578 }
579
580 #[test]
581 fn test_parse_multiple_models() {
582 let schema_str = r#"
583datasource db {
584 provider = "postgresql"
585 url = "postgres://localhost/test"
586}
587
588model User {
589 id String @id @default(uuid())
590 email String @unique
591 posts Post[]
592}
593
594model Post {
595 id String @id @default(uuid())
596 title String
597 content String?
598 author User @relation(fields: [authorId], references: [id])
599 authorId String
600
601 @@index([authorId])
602}
603"#;
604
605 let schema = parse(schema_str).expect("should parse");
606 assert_eq!(schema.models.len(), 2);
607 assert_eq!(schema.models[0].name, "User");
608 assert_eq!(schema.models[1].name, "Post");
609
610 let author_field = &schema.models[1].fields[3];
612 assert_eq!(author_field.name, "author");
613 let rel = author_field.attributes.iter().find_map(|a| match a {
614 FieldAttribute::Relation(r) => Some(r),
615 _ => None,
616 });
617 let rel = rel.expect("should have @relation");
618 assert_eq!(rel.fields, vec!["authorId"]);
619 assert_eq!(rel.references, vec!["id"]);
620
621 let posts_field = &schema.models[0].fields[2];
623 assert_eq!(posts_field.name, "posts");
624 assert!(posts_field.field_type.is_list);
625 }
626
627 #[test]
628 fn test_parse_composite_id() {
629 let schema_str = r#"
630datasource db {
631 provider = "sqlite"
632 url = "file:./dev.db"
633}
634
635model PostTag {
636 postId String
637 tagId String
638
639 @@id([postId, tagId])
640}
641"#;
642
643 let schema = parse(schema_str).expect("should parse");
644 let model = &schema.models[0];
645 assert!(
646 model
647 .attributes
648 .iter()
649 .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
650 );
651 }
652
653 #[test]
654 fn test_parse_error_invalid_syntax() {
655 let bad = "model { broken }";
656 assert!(parse(bad).is_err());
657 }
658
659 #[test]
660 fn test_parse_with_comments() {
661 let schema_str = r#"
662// This is a comment
663datasource db {
664 provider = "postgresql" // inline comment
665 url = env("DATABASE_URL")
666}
667
668// Another comment
669model User {
670 id String @id @default(uuid())
671 // A commented field
672 name String?
673}
674"#;
675
676 let schema = parse(schema_str).expect("should parse with comments");
677 assert!(schema.datasource.is_some());
678 assert_eq!(schema.models[0].fields.len(), 2);
679 }
680}