Skip to main content

ferriorm_parser/
parser.rs

1//! PEG-based parser that turns a `.ferriorm` schema string into a raw AST.
2//!
3//! Uses the `pest` parser generator with the grammar defined in
4//! `grammar.pest`. The public entry point is [`parse`], which returns an
5//! [`ferriorm_core::ast::SchemaFile`] on success or a [`ParseError`] on failure.
6//!
7//! This module only handles syntactic parsing. Semantic validation (type
8//! resolution, constraint checking) is performed by [`crate::validator`].
9
10use ferriorm_core::ast::*;
11use pest::Parser;
12use pest_derive::Parser;
13
14use crate::error::ParseError;
15
16#[derive(Parser)]
17#[grammar = "grammar.pest"]
18struct FerriormParser;
19
20/// Parse a `.ferriorm` schema string into an AST.
21pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
22    let pairs = FerriormParser::parse(Rule::schema, source)
23        .map_err(|e| ParseError::Syntax(e.to_string()))?;
24
25    let mut schema = SchemaFile {
26        datasource: None,
27        generators: Vec::new(),
28        enums: Vec::new(),
29        models: Vec::new(),
30    };
31
32    // The top-level parse result contains a single `schema` pair; iterate its inner pairs.
33    let schema_pair = pairs.into_iter().next().unwrap();
34    for pair in schema_pair.into_inner() {
35        match pair.as_rule() {
36            Rule::datasource_block => {
37                schema.datasource = Some(parse_datasource(pair)?);
38            }
39            Rule::generator_block => {
40                schema.generators.push(parse_generator(pair)?);
41            }
42            Rule::enum_block => {
43                schema.enums.push(parse_enum(pair)?);
44            }
45            Rule::model_block => {
46                schema.models.push(parse_model(pair)?);
47            }
48            Rule::EOI => {}
49            _ => {}
50        }
51    }
52
53    Ok(schema)
54}
55
56fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
57    let span = pair.as_span();
58    Span {
59        start: span.start(),
60        end: span.end(),
61    }
62}
63
64fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> Result<Datasource, ParseError> {
65    let span = span_from(&pair);
66    let mut inner = pair.into_inner();
67    let name = inner.next().unwrap().as_str().to_string();
68
69    let mut provider = String::new();
70    let mut url = StringOrEnv::Literal(String::new());
71
72    for kv in inner {
73        if kv.as_rule() != Rule::kv_pair {
74            continue;
75        }
76        let mut kv_inner = kv.into_inner();
77        let key = kv_inner.next().unwrap().as_str();
78        let value_pair = kv_inner.next().unwrap();
79
80        match key {
81            "provider" => {
82                provider = parse_string_value(&value_pair);
83            }
84            "url" => {
85                url = parse_string_or_env(&value_pair)?;
86            }
87            _ => {}
88        }
89    }
90
91    Ok(Datasource {
92        name,
93        provider,
94        url,
95        span,
96    })
97}
98
99fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Result<Generator, ParseError> {
100    let span = span_from(&pair);
101    let mut inner = pair.into_inner();
102    let name = inner.next().unwrap().as_str().to_string();
103
104    let mut output = None;
105
106    for kv in inner {
107        if kv.as_rule() != Rule::kv_pair {
108            continue;
109        }
110        let mut kv_inner = kv.into_inner();
111        let key = kv_inner.next().unwrap().as_str();
112        let value_pair = kv_inner.next().unwrap();
113
114        if key == "output" {
115            output = Some(parse_string_value(&value_pair));
116        }
117    }
118
119    Ok(Generator { name, output, span })
120}
121
122fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> Result<EnumDef, ParseError> {
123    let span = span_from(&pair);
124    let mut inner = pair.into_inner();
125    let name = inner.next().unwrap().as_str().to_string();
126
127    let mut variants = Vec::new();
128    for variant_pair in inner {
129        if variant_pair.as_rule() == Rule::enum_variant {
130            let variant_name = variant_pair
131                .into_inner()
132                .next()
133                .unwrap()
134                .as_str()
135                .to_string();
136            variants.push(variant_name);
137        }
138    }
139
140    Ok(EnumDef {
141        name,
142        variants,
143        db_name: None,
144        span,
145    })
146}
147
148fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
149    let span = span_from(&pair);
150    let mut inner = pair.into_inner();
151    let name = inner.next().unwrap().as_str().to_string();
152
153    let mut fields = Vec::new();
154    let mut attributes = Vec::new();
155
156    for member in inner {
157        match member.as_rule() {
158            Rule::field_def => {
159                fields.push(parse_field(member)?);
160            }
161            Rule::block_attr_index => {
162                attributes.push(BlockAttribute::Index(parse_field_list_from_block_attr(
163                    member,
164                )));
165            }
166            Rule::block_attr_unique => {
167                attributes.push(BlockAttribute::Unique(parse_field_list_from_block_attr(
168                    member,
169                )));
170            }
171            Rule::block_attr_map => {
172                let s = member.into_inner().next().unwrap().as_str();
173                attributes.push(BlockAttribute::Map(unquote(s)));
174            }
175            Rule::block_attr_id => {
176                attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
177            }
178            _ => {}
179        }
180    }
181
182    Ok(ModelDef {
183        name,
184        fields,
185        attributes,
186        span,
187    })
188}
189
190fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
191    let span = span_from(&pair);
192    let mut inner = pair.into_inner();
193
194    let name = inner.next().unwrap().as_str().to_string();
195    let field_type_pair = inner.next().unwrap();
196    let field_type = parse_field_type(field_type_pair);
197
198    let mut attributes = Vec::new();
199    for attr_pair in inner {
200        if let Some(attr) = parse_field_attribute(attr_pair)? {
201            attributes.push(attr);
202        }
203    }
204
205    Ok(FieldDef {
206        name,
207        field_type,
208        attributes,
209        span,
210    })
211}
212
213fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
214    let mut inner = pair.into_inner();
215    let name = inner.next().unwrap().as_str().to_string();
216
217    let mut is_list = false;
218    let mut is_optional = false;
219
220    for modifier in inner {
221        match modifier.as_rule() {
222            Rule::list_modifier => is_list = true,
223            Rule::optional_modifier => is_optional = true,
224            _ => {}
225        }
226    }
227
228    FieldType {
229        name,
230        is_list,
231        is_optional,
232    }
233}
234
235fn parse_field_attribute(
236    pair: pest::iterators::Pair<'_, Rule>,
237) -> Result<Option<FieldAttribute>, ParseError> {
238    match pair.as_rule() {
239        Rule::attr_id => Ok(Some(FieldAttribute::Id)),
240        Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
241        Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
242        Rule::attr_default => {
243            let value_pair = pair.into_inner().next().unwrap();
244            let default = parse_default_value(value_pair)?;
245            Ok(Some(FieldAttribute::Default(default)))
246        }
247        Rule::attr_map => {
248            let s = pair.into_inner().next().unwrap().as_str();
249            Ok(Some(FieldAttribute::Map(unquote(s))))
250        }
251        Rule::attr_relation => {
252            let relation = parse_relation_attribute(pair)?;
253            Ok(Some(FieldAttribute::Relation(relation)))
254        }
255        Rule::attr_db_type => {
256            let mut inner = pair.into_inner();
257            let type_name = inner.next().unwrap().as_str().to_string();
258            let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
259            Ok(Some(FieldAttribute::DbType(type_name, args)))
260        }
261        _ => Ok(None),
262    }
263}
264
265fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
266    match pair.as_rule() {
267        Rule::func_call => {
268            let mut inner = pair.into_inner();
269            let func_name = inner.next().unwrap().as_str();
270            match func_name {
271                "uuid" => Ok(DefaultValue::Uuid),
272                "cuid" => Ok(DefaultValue::Cuid),
273                "autoincrement" => Ok(DefaultValue::AutoIncrement),
274                "now" => Ok(DefaultValue::Now),
275                other => Err(ParseError::Syntax(format!(
276                    "Unknown default function: {other}()"
277                ))),
278            }
279        }
280        Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
281            pair.as_str(),
282        )))),
283        Rule::number_literal => {
284            let s = pair.as_str();
285            if s.contains('.') {
286                Ok(DefaultValue::Literal(LiteralValue::Float(
287                    s.parse()
288                        .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
289                )))
290            } else {
291                Ok(DefaultValue::Literal(LiteralValue::Int(
292                    s.parse()
293                        .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
294                )))
295            }
296        }
297        Rule::boolean_literal => {
298            let b = pair.as_str() == "true";
299            Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
300        }
301        Rule::identifier_value => {
302            let name = pair.into_inner().next().unwrap().as_str().to_string();
303            Ok(DefaultValue::EnumVariant(name))
304        }
305        _ => Err(ParseError::Syntax(format!(
306            "Unexpected default value: {:?}",
307            pair.as_rule()
308        ))),
309    }
310}
311
312fn parse_relation_attribute(
313    pair: pest::iterators::Pair<'_, Rule>,
314) -> Result<RelationAttribute, ParseError> {
315    let args_pair = pair.into_inner().next().unwrap(); // relation_args
316    let mut fields = Vec::new();
317    let mut references = Vec::new();
318    let mut on_delete = None;
319    let mut on_update = None;
320    let mut name = None;
321
322    for arg in args_pair.into_inner() {
323        // relation_arg = { named_arg }
324        // Each arg is a relation_arg containing a named_arg
325        let named_arg = if arg.as_rule() == Rule::relation_arg {
326            arg.into_inner().next().unwrap()
327        } else if arg.as_rule() == Rule::named_arg {
328            arg
329        } else {
330            continue;
331        };
332
333        // named_arg = { identifier ~ ":" ~ (field_list | value) }
334        let mut named = named_arg.into_inner();
335        let key = named.next().unwrap().as_str();
336        let value_pair = named.next().unwrap();
337
338        match key {
339            "fields" => fields = parse_field_list(&value_pair),
340            "references" => references = parse_field_list(&value_pair),
341            "onDelete" => on_delete = parse_referential_action(&value_pair),
342            "onUpdate" => on_update = parse_referential_action(&value_pair),
343            "name" => name = Some(parse_string_value(&value_pair)),
344            _ => {}
345        }
346    }
347
348    Ok(RelationAttribute {
349        name,
350        fields,
351        references,
352        on_delete,
353        on_update,
354    })
355}
356
357fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
358    let s = pair.as_str().trim_matches('"');
359    match s {
360        "Cascade" => Some(ReferentialAction::Cascade),
361        "Restrict" => Some(ReferentialAction::Restrict),
362        "NoAction" => Some(ReferentialAction::NoAction),
363        "SetNull" => Some(ReferentialAction::SetNull),
364        "SetDefault" => Some(ReferentialAction::SetDefault),
365        _ => None,
366    }
367}
368
369fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
370    pair.clone()
371        .into_inner()
372        .filter(|p| p.as_rule() == Rule::identifier)
373        .map(|p| p.as_str().to_string())
374        .collect()
375}
376
377fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
378    let field_list = pair.into_inner().next().unwrap();
379    parse_field_list(&field_list)
380}
381
382fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
383    match pair.as_rule() {
384        Rule::func_call => {
385            let mut inner = pair.clone().into_inner();
386            let func_name = inner.next().unwrap().as_str();
387            if func_name == "env" {
388                let arg = inner
389                    .next()
390                    .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
391                Ok(StringOrEnv::Env(unquote(arg.as_str())))
392            } else {
393                Err(ParseError::Syntax(format!(
394                    "Expected env(), got {func_name}()"
395                )))
396            }
397        }
398        Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
399        _ => Err(ParseError::Syntax(format!(
400            "Expected string or env(), got {:?}",
401            pair.as_rule()
402        ))),
403    }
404}
405
406fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
407    unquote(pair.as_str())
408}
409
410fn unquote(s: &str) -> String {
411    if s.starts_with('"') && s.ends_with('"') {
412        s[1..s.len() - 1].to_string()
413    } else {
414        s.to_string()
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    const BASIC_SCHEMA: &str = r#"
423datasource db {
424  provider = "postgresql"
425  url      = env("DATABASE_URL")
426}
427
428generator client {
429  output = "./src/generated"
430}
431
432enum Role {
433  User
434  Admin
435  Moderator
436}
437
438model User {
439  id        String   @id @default(uuid())
440  email     String   @unique
441  name      String?
442  role      Role     @default(User)
443  createdAt DateTime @default(now())
444  updatedAt DateTime @updatedAt
445
446  @@index([email])
447  @@map("users")
448}
449"#;
450
451    #[test]
452    fn test_parse_basic_schema() {
453        let schema = parse(BASIC_SCHEMA).expect("should parse");
454
455        // Datasource
456        let ds = schema.datasource.expect("should have datasource");
457        assert_eq!(ds.name, "db");
458        assert_eq!(ds.provider, "postgresql");
459        match &ds.url {
460            StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
461            _ => panic!("expected env()"),
462        }
463
464        // Generator
465        assert_eq!(schema.generators.len(), 1);
466        assert_eq!(schema.generators[0].name, "client");
467        assert_eq!(
468            schema.generators[0].output.as_deref(),
469            Some("./src/generated")
470        );
471
472        // Enum
473        assert_eq!(schema.enums.len(), 1);
474        assert_eq!(schema.enums[0].name, "Role");
475        assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
476
477        // Model
478        assert_eq!(schema.models.len(), 1);
479        let user = &schema.models[0];
480        assert_eq!(user.name, "User");
481        assert_eq!(user.fields.len(), 6);
482
483        // id field
484        let id_field = &user.fields[0];
485        assert_eq!(id_field.name, "id");
486        assert_eq!(id_field.field_type.name, "String");
487        assert!(!id_field.field_type.is_optional);
488        assert!(
489            id_field
490                .attributes
491                .iter()
492                .any(|a| matches!(a, FieldAttribute::Id))
493        );
494        assert!(
495            id_field
496                .attributes
497                .iter()
498                .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
499        );
500
501        // name field is optional
502        let name_field = &user.fields[2];
503        assert_eq!(name_field.name, "name");
504        assert!(name_field.field_type.is_optional);
505
506        // role field has enum default
507        let role_field = &user.fields[3];
508        assert_eq!(role_field.name, "role");
509        assert!(role_field.attributes.iter().any(
510            |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
511        ));
512
513        // updatedAt has @updatedAt
514        let updated_field = &user.fields[5];
515        assert_eq!(updated_field.name, "updatedAt");
516        assert!(
517            updated_field
518                .attributes
519                .iter()
520                .any(|a| matches!(a, FieldAttribute::UpdatedAt))
521        );
522
523        // Block attributes
524        assert_eq!(user.attributes.len(), 2);
525        assert!(
526            user.attributes
527                .iter()
528                .any(|a| matches!(a, BlockAttribute::Index(fields) if fields == &["email"]))
529        );
530        assert!(
531            user.attributes
532                .iter()
533                .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
534        );
535    }
536
537    #[test]
538    fn test_parse_multiple_models() {
539        let schema_str = r#"
540datasource db {
541  provider = "postgresql"
542  url      = "postgres://localhost/test"
543}
544
545model User {
546  id    String @id @default(uuid())
547  email String @unique
548  posts Post[]
549}
550
551model Post {
552  id       String  @id @default(uuid())
553  title    String
554  content  String?
555  author   User    @relation(fields: [authorId], references: [id])
556  authorId String
557
558  @@index([authorId])
559}
560"#;
561
562        let schema = parse(schema_str).expect("should parse");
563        assert_eq!(schema.models.len(), 2);
564        assert_eq!(schema.models[0].name, "User");
565        assert_eq!(schema.models[1].name, "Post");
566
567        // Check relation attribute on Post.author
568        let author_field = &schema.models[1].fields[3];
569        assert_eq!(author_field.name, "author");
570        let rel = author_field.attributes.iter().find_map(|a| match a {
571            FieldAttribute::Relation(r) => Some(r),
572            _ => None,
573        });
574        let rel = rel.expect("should have @relation");
575        assert_eq!(rel.fields, vec!["authorId"]);
576        assert_eq!(rel.references, vec!["id"]);
577
578        // Check Post[] is a list
579        let posts_field = &schema.models[0].fields[2];
580        assert_eq!(posts_field.name, "posts");
581        assert!(posts_field.field_type.is_list);
582    }
583
584    #[test]
585    fn test_parse_composite_id() {
586        let schema_str = r#"
587datasource db {
588  provider = "sqlite"
589  url      = "file:./dev.db"
590}
591
592model PostTag {
593  postId String
594  tagId  String
595
596  @@id([postId, tagId])
597}
598"#;
599
600        let schema = parse(schema_str).expect("should parse");
601        let model = &schema.models[0];
602        assert!(
603            model
604                .attributes
605                .iter()
606                .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
607        );
608    }
609
610    #[test]
611    fn test_parse_error_invalid_syntax() {
612        let bad = "model { broken }";
613        assert!(parse(bad).is_err());
614    }
615
616    #[test]
617    fn test_parse_with_comments() {
618        let schema_str = r#"
619// This is a comment
620datasource db {
621  provider = "postgresql" // inline comment
622  url      = env("DATABASE_URL")
623}
624
625// Another comment
626model User {
627  id String @id @default(uuid())
628  // A commented field
629  name String?
630}
631"#;
632
633        let schema = parse(schema_str).expect("should parse with comments");
634        assert!(schema.datasource.is_some());
635        assert_eq!(schema.models[0].fields.len(), 2);
636    }
637}