Skip to main content

ferriorm_parser/
parser.rs

1//! PEG-based parser that turns a `.ferriorm` schema string into a raw AST.
2//!
3//! Uses the `pest` parser generator with the grammar defined in
4//! `grammar.pest`. The public entry point is [`parse`], which returns an
5//! [`ferriorm_core::ast::SchemaFile`] on success or a [`ParseError`] on failure.
6//!
7//! This module only handles syntactic parsing. Semantic validation (type
8//! resolution, constraint checking) is performed by [`crate::validator`].
9
10use ferriorm_core::ast::{
11    BlockAttribute, DefaultValue, EnumDef, FieldAttribute, FieldDef, FieldType, Generator,
12    IndexAttribute, LiteralValue, ModelDef, ReferentialAction, RelationAttribute, SchemaFile, Span,
13    StringOrEnv,
14};
15use pest::Parser;
16use pest_derive::Parser;
17
18use crate::error::ParseError;
19
20#[derive(Parser)]
21#[grammar = "grammar.pest"]
22struct FerriormParser;
23
24/// Parse a `.ferriorm` schema string into an AST.
25///
26/// # Errors
27///
28/// Returns a [`ParseError`] if the source does not conform to the grammar.
29///
30/// # Panics
31///
32/// Panics if the PEG grammar produces no top-level pair, which indicates
33/// a bug in the grammar definition.
34pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
35    let pairs = FerriormParser::parse(Rule::schema, source)
36        .map_err(|e| ParseError::Syntax(e.to_string()))?;
37
38    let mut schema = SchemaFile {
39        datasource: None,
40        generators: Vec::new(),
41        enums: Vec::new(),
42        models: Vec::new(),
43    };
44
45    // The top-level parse result contains a single `schema` pair; iterate its inner pairs.
46    let schema_pair = pairs.into_iter().next().unwrap();
47    for pair in schema_pair.into_inner() {
48        match pair.as_rule() {
49            Rule::datasource_block => {
50                schema.datasource = Some(parse_datasource(pair)?);
51            }
52            Rule::generator_block => {
53                schema.generators.push(parse_generator(pair));
54            }
55            Rule::enum_block => {
56                schema.enums.push(parse_enum(pair));
57            }
58            Rule::model_block => {
59                schema.models.push(parse_model(pair)?);
60            }
61            _ => {}
62        }
63    }
64
65    Ok(schema)
66}
67
68fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
69    let span = pair.as_span();
70    Span {
71        start: span.start(),
72        end: span.end(),
73    }
74}
75
76fn parse_datasource(
77    pair: pest::iterators::Pair<'_, Rule>,
78) -> Result<ferriorm_core::ast::Datasource, ParseError> {
79    let span = span_from(&pair);
80    let mut inner = pair.into_inner();
81    let name = inner.next().unwrap().as_str().to_string();
82
83    let mut provider = String::new();
84    let mut url = StringOrEnv::Literal(String::new());
85
86    for kv in inner {
87        if kv.as_rule() != Rule::kv_pair {
88            continue;
89        }
90        let mut kv_inner = kv.into_inner();
91        let key = kv_inner.next().unwrap().as_str();
92        let value_pair = kv_inner.next().unwrap();
93
94        match key {
95            "provider" => {
96                provider = parse_string_value(&value_pair);
97            }
98            "url" => {
99                url = parse_string_or_env(&value_pair)?;
100            }
101            _ => {}
102        }
103    }
104
105    Ok(ferriorm_core::ast::Datasource {
106        name,
107        provider,
108        url,
109        span,
110    })
111}
112
113fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Generator {
114    let span = span_from(&pair);
115    let mut inner = pair.into_inner();
116    let name = inner.next().unwrap().as_str().to_string();
117
118    let mut output = None;
119
120    for kv in inner {
121        if kv.as_rule() != Rule::kv_pair {
122            continue;
123        }
124        let mut kv_inner = kv.into_inner();
125        let key = kv_inner.next().unwrap().as_str();
126        let value_pair = kv_inner.next().unwrap();
127
128        if key == "output" {
129            output = Some(parse_string_value(&value_pair));
130        }
131    }
132
133    Generator { name, output, span }
134}
135
136fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> EnumDef {
137    let span = span_from(&pair);
138    let mut inner = pair.into_inner();
139    let name = inner.next().unwrap().as_str().to_string();
140
141    let mut variants = Vec::new();
142    let mut db_name = None;
143    for member in inner {
144        match member.as_rule() {
145            Rule::enum_variant => {
146                let variant_name = member.into_inner().next().unwrap().as_str().to_string();
147                variants.push(variant_name);
148            }
149            Rule::enum_block_attr_map => {
150                let s = member.into_inner().next().unwrap().as_str();
151                db_name = Some(unquote(s));
152            }
153            _ => {}
154        }
155    }
156
157    EnumDef {
158        name,
159        variants,
160        db_name,
161        span,
162    }
163}
164
165fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
166    let span = span_from(&pair);
167    let mut inner = pair.into_inner();
168    let name = inner.next().unwrap().as_str().to_string();
169
170    let mut fields = Vec::new();
171    let mut attributes = Vec::new();
172
173    for member in inner {
174        match member.as_rule() {
175            Rule::field_def => {
176                fields.push(parse_field(member)?);
177            }
178            Rule::block_attr_index => {
179                attributes.push(BlockAttribute::Index(parse_index_attribute(member)));
180            }
181            Rule::block_attr_unique => {
182                attributes.push(BlockAttribute::Unique(parse_index_attribute(member)));
183            }
184            Rule::block_attr_map => {
185                let s = member.into_inner().next().unwrap().as_str();
186                attributes.push(BlockAttribute::Map(unquote(s)));
187            }
188            Rule::block_attr_id => {
189                attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
190            }
191            _ => {}
192        }
193    }
194
195    Ok(ModelDef {
196        name,
197        fields,
198        attributes,
199        span,
200    })
201}
202
203fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
204    let span = span_from(&pair);
205    let mut inner = pair.into_inner();
206
207    let name = inner.next().unwrap().as_str().to_string();
208    let field_type_pair = inner.next().unwrap();
209    let field_type = parse_field_type(field_type_pair);
210
211    let mut attributes = Vec::new();
212    for attr_pair in inner {
213        if let Some(attr) = parse_field_attribute(attr_pair)? {
214            attributes.push(attr);
215        }
216    }
217
218    Ok(FieldDef {
219        name,
220        field_type,
221        attributes,
222        span,
223    })
224}
225
226fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
227    let mut inner = pair.into_inner();
228    let name = inner.next().unwrap().as_str().to_string();
229
230    let mut is_list = false;
231    let mut is_optional = false;
232
233    for modifier in inner {
234        match modifier.as_rule() {
235            Rule::list_modifier => is_list = true,
236            Rule::optional_modifier => is_optional = true,
237            _ => {}
238        }
239    }
240
241    FieldType {
242        name,
243        is_list,
244        is_optional,
245    }
246}
247
248fn parse_field_attribute(
249    pair: pest::iterators::Pair<'_, Rule>,
250) -> Result<Option<FieldAttribute>, ParseError> {
251    match pair.as_rule() {
252        Rule::attr_id => Ok(Some(FieldAttribute::Id)),
253        Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
254        Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
255        Rule::attr_default => {
256            let value_pair = pair.into_inner().next().unwrap();
257            let default = parse_default_value(value_pair)?;
258            Ok(Some(FieldAttribute::Default(default)))
259        }
260        Rule::attr_map => {
261            let s = pair.into_inner().next().unwrap().as_str();
262            Ok(Some(FieldAttribute::Map(unquote(s))))
263        }
264        Rule::attr_relation => {
265            let relation = parse_relation_attribute(pair);
266            Ok(Some(FieldAttribute::Relation(relation)))
267        }
268        Rule::attr_db_type => {
269            let mut inner = pair.into_inner();
270            let type_name = inner.next().unwrap().as_str().to_string();
271            let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
272            Ok(Some(FieldAttribute::DbType(type_name, args)))
273        }
274        _ => Ok(None),
275    }
276}
277
278fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
279    match pair.as_rule() {
280        Rule::func_call => {
281            let mut inner = pair.into_inner();
282            let func_name = inner.next().unwrap().as_str();
283            match func_name {
284                "uuid" => Ok(DefaultValue::Uuid),
285                "cuid" => Ok(DefaultValue::Cuid),
286                "autoincrement" => Ok(DefaultValue::AutoIncrement),
287                "now" => Ok(DefaultValue::Now),
288                other => Err(ParseError::Syntax(format!(
289                    "Unknown default function: {other}()"
290                ))),
291            }
292        }
293        Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
294            pair.as_str(),
295        )))),
296        Rule::number_literal => {
297            let s = pair.as_str();
298            if s.contains('.') {
299                Ok(DefaultValue::Literal(LiteralValue::Float(
300                    s.parse()
301                        .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
302                )))
303            } else {
304                Ok(DefaultValue::Literal(LiteralValue::Int(
305                    s.parse()
306                        .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
307                )))
308            }
309        }
310        Rule::boolean_literal => {
311            let b = pair.as_str() == "true";
312            Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
313        }
314        Rule::identifier_value => {
315            let name = pair.into_inner().next().unwrap().as_str().to_string();
316            Ok(DefaultValue::EnumVariant(name))
317        }
318        _ => Err(ParseError::Syntax(format!(
319            "Unexpected default value: {:?}",
320            pair.as_rule()
321        ))),
322    }
323}
324
325fn parse_relation_attribute(pair: pest::iterators::Pair<'_, Rule>) -> RelationAttribute {
326    let args_pair = pair.into_inner().next().unwrap(); // relation_args
327    let mut fields = Vec::new();
328    let mut references = Vec::new();
329    let mut on_delete = None;
330    let mut on_update = None;
331    let mut name = None;
332
333    for arg in args_pair.into_inner() {
334        match arg.as_rule() {
335            // Positional name as first arg: @relation("Authored", ...)
336            Rule::string_literal => {
337                name = Some(parse_string_value(&arg));
338                continue;
339            }
340            Rule::relation_arg | Rule::named_arg => {}
341            _ => continue,
342        }
343
344        // relation_arg = { named_arg }
345        let named_arg = if arg.as_rule() == Rule::relation_arg {
346            arg.into_inner().next().unwrap()
347        } else {
348            arg
349        };
350
351        // named_arg = { identifier ~ ":" ~ (field_list | value) }
352        let mut named = named_arg.into_inner();
353        let key = named.next().unwrap().as_str();
354        let value_pair = named.next().unwrap();
355
356        match key {
357            "fields" => fields = parse_field_list(&value_pair),
358            "references" => references = parse_field_list(&value_pair),
359            "onDelete" => on_delete = parse_referential_action(&value_pair),
360            "onUpdate" => on_update = parse_referential_action(&value_pair),
361            "name" => name = Some(parse_string_value(&value_pair)),
362            _ => {}
363        }
364    }
365
366    RelationAttribute {
367        name,
368        fields,
369        references,
370        on_delete,
371        on_update,
372    }
373}
374
375fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
376    let s = pair.as_str().trim_matches('"');
377    match s {
378        "Cascade" => Some(ReferentialAction::Cascade),
379        "Restrict" => Some(ReferentialAction::Restrict),
380        "NoAction" => Some(ReferentialAction::NoAction),
381        "SetNull" => Some(ReferentialAction::SetNull),
382        "SetDefault" => Some(ReferentialAction::SetDefault),
383        _ => None,
384    }
385}
386
387fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
388    pair.clone()
389        .into_inner()
390        .filter(|p| p.as_rule() == Rule::identifier)
391        .map(|p| p.as_str().to_string())
392        .collect()
393}
394
395fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
396    let field_list = pair.into_inner().next().unwrap();
397    parse_field_list(&field_list)
398}
399
400/// Parse a `@@index` / `@@unique` block attribute body:
401/// the leading field list followed by zero or more named args.
402/// Currently only `name: "..."` is consumed.
403fn parse_index_attribute(pair: pest::iterators::Pair<'_, Rule>) -> IndexAttribute {
404    let mut inner = pair.into_inner();
405    let field_list = inner.next().unwrap();
406    let fields = parse_field_list(&field_list);
407    let mut name = None;
408
409    for arg in inner {
410        if arg.as_rule() != Rule::named_arg {
411            continue;
412        }
413        let mut named = arg.into_inner();
414        let key = named.next().unwrap().as_str();
415        let value_pair = named.next().unwrap();
416        if key == "name" || key == "map" {
417            name = Some(parse_string_value(&value_pair));
418        }
419    }
420
421    IndexAttribute { fields, name }
422}
423
424fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
425    match pair.as_rule() {
426        Rule::func_call => {
427            let mut inner = pair.clone().into_inner();
428            let func_name = inner.next().unwrap().as_str();
429            if func_name == "env" {
430                let arg = inner
431                    .next()
432                    .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
433                Ok(StringOrEnv::Env(unquote(arg.as_str())))
434            } else {
435                Err(ParseError::Syntax(format!(
436                    "Expected env(), got {func_name}()"
437                )))
438            }
439        }
440        Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
441        _ => Err(ParseError::Syntax(format!(
442            "Expected string or env(), got {:?}",
443            pair.as_rule()
444        ))),
445    }
446}
447
448fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
449    unquote(pair.as_str())
450}
451
452fn unquote(s: &str) -> String {
453    if s.starts_with('"') && s.ends_with('"') {
454        s[1..s.len() - 1].to_string()
455    } else {
456        s.to_string()
457    }
458}
459
460#[cfg(test)]
461#[allow(clippy::pedantic)]
462mod tests {
463    use super::*;
464
465    const BASIC_SCHEMA: &str = r#"
466datasource db {
467  provider = "postgresql"
468  url      = env("DATABASE_URL")
469}
470
471generator client {
472  output = "./src/generated"
473}
474
475enum Role {
476  User
477  Admin
478  Moderator
479}
480
481model User {
482  id        String   @id @default(uuid())
483  email     String   @unique
484  name      String?
485  role      Role     @default(User)
486  createdAt DateTime @default(now())
487  updatedAt DateTime @updatedAt
488
489  @@index([email])
490  @@map("users")
491}
492"#;
493
494    #[test]
495    fn test_parse_basic_schema() {
496        let schema = parse(BASIC_SCHEMA).expect("should parse");
497
498        // Datasource
499        let ds = schema.datasource.expect("should have datasource");
500        assert_eq!(ds.name, "db");
501        assert_eq!(ds.provider, "postgresql");
502        match &ds.url {
503            StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
504            _ => panic!("expected env()"),
505        }
506
507        // Generator
508        assert_eq!(schema.generators.len(), 1);
509        assert_eq!(schema.generators[0].name, "client");
510        assert_eq!(
511            schema.generators[0].output.as_deref(),
512            Some("./src/generated")
513        );
514
515        // Enum
516        assert_eq!(schema.enums.len(), 1);
517        assert_eq!(schema.enums[0].name, "Role");
518        assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
519
520        // Model
521        assert_eq!(schema.models.len(), 1);
522        let user = &schema.models[0];
523        assert_eq!(user.name, "User");
524        assert_eq!(user.fields.len(), 6);
525
526        // id field
527        let id_field = &user.fields[0];
528        assert_eq!(id_field.name, "id");
529        assert_eq!(id_field.field_type.name, "String");
530        assert!(!id_field.field_type.is_optional);
531        assert!(
532            id_field
533                .attributes
534                .iter()
535                .any(|a| matches!(a, FieldAttribute::Id))
536        );
537        assert!(
538            id_field
539                .attributes
540                .iter()
541                .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
542        );
543
544        // name field is optional
545        let name_field = &user.fields[2];
546        assert_eq!(name_field.name, "name");
547        assert!(name_field.field_type.is_optional);
548
549        // role field has enum default
550        let role_field = &user.fields[3];
551        assert_eq!(role_field.name, "role");
552        assert!(role_field.attributes.iter().any(
553            |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
554        ));
555
556        // updatedAt has @updatedAt
557        let updated_field = &user.fields[5];
558        assert_eq!(updated_field.name, "updatedAt");
559        assert!(
560            updated_field
561                .attributes
562                .iter()
563                .any(|a| matches!(a, FieldAttribute::UpdatedAt))
564        );
565
566        // Block attributes
567        assert_eq!(user.attributes.len(), 2);
568        assert!(
569            user.attributes
570                .iter()
571                .any(|a| matches!(a, BlockAttribute::Index(idx) if idx.fields == ["email"]))
572        );
573        assert!(
574            user.attributes
575                .iter()
576                .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
577        );
578    }
579
580    #[test]
581    fn test_parse_multiple_models() {
582        let schema_str = r#"
583datasource db {
584  provider = "postgresql"
585  url      = "postgres://localhost/test"
586}
587
588model User {
589  id    String @id @default(uuid())
590  email String @unique
591  posts Post[]
592}
593
594model Post {
595  id       String  @id @default(uuid())
596  title    String
597  content  String?
598  author   User    @relation(fields: [authorId], references: [id])
599  authorId String
600
601  @@index([authorId])
602}
603"#;
604
605        let schema = parse(schema_str).expect("should parse");
606        assert_eq!(schema.models.len(), 2);
607        assert_eq!(schema.models[0].name, "User");
608        assert_eq!(schema.models[1].name, "Post");
609
610        // Check relation attribute on Post.author
611        let author_field = &schema.models[1].fields[3];
612        assert_eq!(author_field.name, "author");
613        let rel = author_field.attributes.iter().find_map(|a| match a {
614            FieldAttribute::Relation(r) => Some(r),
615            _ => None,
616        });
617        let rel = rel.expect("should have @relation");
618        assert_eq!(rel.fields, vec!["authorId"]);
619        assert_eq!(rel.references, vec!["id"]);
620
621        // Check Post[] is a list
622        let posts_field = &schema.models[0].fields[2];
623        assert_eq!(posts_field.name, "posts");
624        assert!(posts_field.field_type.is_list);
625    }
626
627    #[test]
628    fn test_parse_composite_id() {
629        let schema_str = r#"
630datasource db {
631  provider = "sqlite"
632  url      = "file:./dev.db"
633}
634
635model PostTag {
636  postId String
637  tagId  String
638
639  @@id([postId, tagId])
640}
641"#;
642
643        let schema = parse(schema_str).expect("should parse");
644        let model = &schema.models[0];
645        assert!(
646            model
647                .attributes
648                .iter()
649                .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
650        );
651    }
652
653    #[test]
654    fn test_parse_error_invalid_syntax() {
655        let bad = "model { broken }";
656        assert!(parse(bad).is_err());
657    }
658
659    #[test]
660    fn test_parse_with_comments() {
661        let schema_str = r#"
662// This is a comment
663datasource db {
664  provider = "postgresql" // inline comment
665  url      = env("DATABASE_URL")
666}
667
668// Another comment
669model User {
670  id String @id @default(uuid())
671  // A commented field
672  name String?
673}
674"#;
675
676        let schema = parse(schema_str).expect("should parse with comments");
677        assert!(schema.datasource.is_some());
678        assert_eq!(schema.models[0].fields.len(), 2);
679    }
680}