Skip to main content

ferriorm_parser/
parser.rs

1//! PEG-based parser that turns a `.ferriorm` schema string into a raw AST.
2//!
3//! Uses the `pest` parser generator with the grammar defined in
4//! `grammar.pest`. The public entry point is [`parse`], which returns an
5//! [`ferriorm_core::ast::SchemaFile`] on success or a [`ParseError`] on failure.
6//!
7//! This module only handles syntactic parsing. Semantic validation (type
8//! resolution, constraint checking) is performed by [`crate::validator`].
9
10use ferriorm_core::ast::{
11    BlockAttribute, DefaultValue, EnumDef, FieldAttribute, FieldDef, FieldType, Generator,
12    LiteralValue, ModelDef, ReferentialAction, RelationAttribute, SchemaFile, Span, StringOrEnv,
13};
14use pest::Parser;
15use pest_derive::Parser;
16
17use crate::error::ParseError;
18
19#[derive(Parser)]
20#[grammar = "grammar.pest"]
21struct FerriormParser;
22
23/// Parse a `.ferriorm` schema string into an AST.
24///
25/// # Errors
26///
27/// Returns a [`ParseError`] if the source does not conform to the grammar.
28///
29/// # Panics
30///
31/// Panics if the PEG grammar produces no top-level pair, which indicates
32/// a bug in the grammar definition.
33pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
34    let pairs = FerriormParser::parse(Rule::schema, source)
35        .map_err(|e| ParseError::Syntax(e.to_string()))?;
36
37    let mut schema = SchemaFile {
38        datasource: None,
39        generators: Vec::new(),
40        enums: Vec::new(),
41        models: Vec::new(),
42    };
43
44    // The top-level parse result contains a single `schema` pair; iterate its inner pairs.
45    let schema_pair = pairs.into_iter().next().unwrap();
46    for pair in schema_pair.into_inner() {
47        match pair.as_rule() {
48            Rule::datasource_block => {
49                schema.datasource = Some(parse_datasource(pair)?);
50            }
51            Rule::generator_block => {
52                schema.generators.push(parse_generator(pair));
53            }
54            Rule::enum_block => {
55                schema.enums.push(parse_enum(pair));
56            }
57            Rule::model_block => {
58                schema.models.push(parse_model(pair)?);
59            }
60            _ => {}
61        }
62    }
63
64    Ok(schema)
65}
66
67fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
68    let span = pair.as_span();
69    Span {
70        start: span.start(),
71        end: span.end(),
72    }
73}
74
75fn parse_datasource(
76    pair: pest::iterators::Pair<'_, Rule>,
77) -> Result<ferriorm_core::ast::Datasource, ParseError> {
78    let span = span_from(&pair);
79    let mut inner = pair.into_inner();
80    let name = inner.next().unwrap().as_str().to_string();
81
82    let mut provider = String::new();
83    let mut url = StringOrEnv::Literal(String::new());
84
85    for kv in inner {
86        if kv.as_rule() != Rule::kv_pair {
87            continue;
88        }
89        let mut kv_inner = kv.into_inner();
90        let key = kv_inner.next().unwrap().as_str();
91        let value_pair = kv_inner.next().unwrap();
92
93        match key {
94            "provider" => {
95                provider = parse_string_value(&value_pair);
96            }
97            "url" => {
98                url = parse_string_or_env(&value_pair)?;
99            }
100            _ => {}
101        }
102    }
103
104    Ok(ferriorm_core::ast::Datasource {
105        name,
106        provider,
107        url,
108        span,
109    })
110}
111
112fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Generator {
113    let span = span_from(&pair);
114    let mut inner = pair.into_inner();
115    let name = inner.next().unwrap().as_str().to_string();
116
117    let mut output = None;
118
119    for kv in inner {
120        if kv.as_rule() != Rule::kv_pair {
121            continue;
122        }
123        let mut kv_inner = kv.into_inner();
124        let key = kv_inner.next().unwrap().as_str();
125        let value_pair = kv_inner.next().unwrap();
126
127        if key == "output" {
128            output = Some(parse_string_value(&value_pair));
129        }
130    }
131
132    Generator { name, output, span }
133}
134
135fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> EnumDef {
136    let span = span_from(&pair);
137    let mut inner = pair.into_inner();
138    let name = inner.next().unwrap().as_str().to_string();
139
140    let mut variants = Vec::new();
141    for variant_pair in inner {
142        if variant_pair.as_rule() == Rule::enum_variant {
143            let variant_name = variant_pair
144                .into_inner()
145                .next()
146                .unwrap()
147                .as_str()
148                .to_string();
149            variants.push(variant_name);
150        }
151    }
152
153    EnumDef {
154        name,
155        variants,
156        db_name: None,
157        span,
158    }
159}
160
161fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
162    let span = span_from(&pair);
163    let mut inner = pair.into_inner();
164    let name = inner.next().unwrap().as_str().to_string();
165
166    let mut fields = Vec::new();
167    let mut attributes = Vec::new();
168
169    for member in inner {
170        match member.as_rule() {
171            Rule::field_def => {
172                fields.push(parse_field(member)?);
173            }
174            Rule::block_attr_index => {
175                attributes.push(BlockAttribute::Index(parse_field_list_from_block_attr(
176                    member,
177                )));
178            }
179            Rule::block_attr_unique => {
180                attributes.push(BlockAttribute::Unique(parse_field_list_from_block_attr(
181                    member,
182                )));
183            }
184            Rule::block_attr_map => {
185                let s = member.into_inner().next().unwrap().as_str();
186                attributes.push(BlockAttribute::Map(unquote(s)));
187            }
188            Rule::block_attr_id => {
189                attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
190            }
191            _ => {}
192        }
193    }
194
195    Ok(ModelDef {
196        name,
197        fields,
198        attributes,
199        span,
200    })
201}
202
203fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
204    let span = span_from(&pair);
205    let mut inner = pair.into_inner();
206
207    let name = inner.next().unwrap().as_str().to_string();
208    let field_type_pair = inner.next().unwrap();
209    let field_type = parse_field_type(field_type_pair);
210
211    let mut attributes = Vec::new();
212    for attr_pair in inner {
213        if let Some(attr) = parse_field_attribute(attr_pair)? {
214            attributes.push(attr);
215        }
216    }
217
218    Ok(FieldDef {
219        name,
220        field_type,
221        attributes,
222        span,
223    })
224}
225
226fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
227    let mut inner = pair.into_inner();
228    let name = inner.next().unwrap().as_str().to_string();
229
230    let mut is_list = false;
231    let mut is_optional = false;
232
233    for modifier in inner {
234        match modifier.as_rule() {
235            Rule::list_modifier => is_list = true,
236            Rule::optional_modifier => is_optional = true,
237            _ => {}
238        }
239    }
240
241    FieldType {
242        name,
243        is_list,
244        is_optional,
245    }
246}
247
248fn parse_field_attribute(
249    pair: pest::iterators::Pair<'_, Rule>,
250) -> Result<Option<FieldAttribute>, ParseError> {
251    match pair.as_rule() {
252        Rule::attr_id => Ok(Some(FieldAttribute::Id)),
253        Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
254        Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
255        Rule::attr_default => {
256            let value_pair = pair.into_inner().next().unwrap();
257            let default = parse_default_value(value_pair)?;
258            Ok(Some(FieldAttribute::Default(default)))
259        }
260        Rule::attr_map => {
261            let s = pair.into_inner().next().unwrap().as_str();
262            Ok(Some(FieldAttribute::Map(unquote(s))))
263        }
264        Rule::attr_relation => {
265            let relation = parse_relation_attribute(pair);
266            Ok(Some(FieldAttribute::Relation(relation)))
267        }
268        Rule::attr_db_type => {
269            let mut inner = pair.into_inner();
270            let type_name = inner.next().unwrap().as_str().to_string();
271            let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
272            Ok(Some(FieldAttribute::DbType(type_name, args)))
273        }
274        _ => Ok(None),
275    }
276}
277
278fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
279    match pair.as_rule() {
280        Rule::func_call => {
281            let mut inner = pair.into_inner();
282            let func_name = inner.next().unwrap().as_str();
283            match func_name {
284                "uuid" => Ok(DefaultValue::Uuid),
285                "cuid" => Ok(DefaultValue::Cuid),
286                "autoincrement" => Ok(DefaultValue::AutoIncrement),
287                "now" => Ok(DefaultValue::Now),
288                other => Err(ParseError::Syntax(format!(
289                    "Unknown default function: {other}()"
290                ))),
291            }
292        }
293        Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
294            pair.as_str(),
295        )))),
296        Rule::number_literal => {
297            let s = pair.as_str();
298            if s.contains('.') {
299                Ok(DefaultValue::Literal(LiteralValue::Float(
300                    s.parse()
301                        .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
302                )))
303            } else {
304                Ok(DefaultValue::Literal(LiteralValue::Int(
305                    s.parse()
306                        .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
307                )))
308            }
309        }
310        Rule::boolean_literal => {
311            let b = pair.as_str() == "true";
312            Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
313        }
314        Rule::identifier_value => {
315            let name = pair.into_inner().next().unwrap().as_str().to_string();
316            Ok(DefaultValue::EnumVariant(name))
317        }
318        _ => Err(ParseError::Syntax(format!(
319            "Unexpected default value: {:?}",
320            pair.as_rule()
321        ))),
322    }
323}
324
325fn parse_relation_attribute(pair: pest::iterators::Pair<'_, Rule>) -> RelationAttribute {
326    let args_pair = pair.into_inner().next().unwrap(); // relation_args
327    let mut fields = Vec::new();
328    let mut references = Vec::new();
329    let mut on_delete = None;
330    let mut on_update = None;
331    let mut name = None;
332
333    for arg in args_pair.into_inner() {
334        // relation_arg = { named_arg }
335        // Each arg is a relation_arg containing a named_arg
336        let named_arg = if arg.as_rule() == Rule::relation_arg {
337            arg.into_inner().next().unwrap()
338        } else if arg.as_rule() == Rule::named_arg {
339            arg
340        } else {
341            continue;
342        };
343
344        // named_arg = { identifier ~ ":" ~ (field_list | value) }
345        let mut named = named_arg.into_inner();
346        let key = named.next().unwrap().as_str();
347        let value_pair = named.next().unwrap();
348
349        match key {
350            "fields" => fields = parse_field_list(&value_pair),
351            "references" => references = parse_field_list(&value_pair),
352            "onDelete" => on_delete = parse_referential_action(&value_pair),
353            "onUpdate" => on_update = parse_referential_action(&value_pair),
354            "name" => name = Some(parse_string_value(&value_pair)),
355            _ => {}
356        }
357    }
358
359    RelationAttribute {
360        name,
361        fields,
362        references,
363        on_delete,
364        on_update,
365    }
366}
367
368fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
369    let s = pair.as_str().trim_matches('"');
370    match s {
371        "Cascade" => Some(ReferentialAction::Cascade),
372        "Restrict" => Some(ReferentialAction::Restrict),
373        "NoAction" => Some(ReferentialAction::NoAction),
374        "SetNull" => Some(ReferentialAction::SetNull),
375        "SetDefault" => Some(ReferentialAction::SetDefault),
376        _ => None,
377    }
378}
379
380fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
381    pair.clone()
382        .into_inner()
383        .filter(|p| p.as_rule() == Rule::identifier)
384        .map(|p| p.as_str().to_string())
385        .collect()
386}
387
388fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
389    let field_list = pair.into_inner().next().unwrap();
390    parse_field_list(&field_list)
391}
392
393fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
394    match pair.as_rule() {
395        Rule::func_call => {
396            let mut inner = pair.clone().into_inner();
397            let func_name = inner.next().unwrap().as_str();
398            if func_name == "env" {
399                let arg = inner
400                    .next()
401                    .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
402                Ok(StringOrEnv::Env(unquote(arg.as_str())))
403            } else {
404                Err(ParseError::Syntax(format!(
405                    "Expected env(), got {func_name}()"
406                )))
407            }
408        }
409        Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
410        _ => Err(ParseError::Syntax(format!(
411            "Expected string or env(), got {:?}",
412            pair.as_rule()
413        ))),
414    }
415}
416
417fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
418    unquote(pair.as_str())
419}
420
421fn unquote(s: &str) -> String {
422    if s.starts_with('"') && s.ends_with('"') {
423        s[1..s.len() - 1].to_string()
424    } else {
425        s.to_string()
426    }
427}
428
429#[cfg(test)]
430#[allow(clippy::pedantic)]
431mod tests {
432    use super::*;
433
434    const BASIC_SCHEMA: &str = r#"
435datasource db {
436  provider = "postgresql"
437  url      = env("DATABASE_URL")
438}
439
440generator client {
441  output = "./src/generated"
442}
443
444enum Role {
445  User
446  Admin
447  Moderator
448}
449
450model User {
451  id        String   @id @default(uuid())
452  email     String   @unique
453  name      String?
454  role      Role     @default(User)
455  createdAt DateTime @default(now())
456  updatedAt DateTime @updatedAt
457
458  @@index([email])
459  @@map("users")
460}
461"#;
462
463    #[test]
464    fn test_parse_basic_schema() {
465        let schema = parse(BASIC_SCHEMA).expect("should parse");
466
467        // Datasource
468        let ds = schema.datasource.expect("should have datasource");
469        assert_eq!(ds.name, "db");
470        assert_eq!(ds.provider, "postgresql");
471        match &ds.url {
472            StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
473            _ => panic!("expected env()"),
474        }
475
476        // Generator
477        assert_eq!(schema.generators.len(), 1);
478        assert_eq!(schema.generators[0].name, "client");
479        assert_eq!(
480            schema.generators[0].output.as_deref(),
481            Some("./src/generated")
482        );
483
484        // Enum
485        assert_eq!(schema.enums.len(), 1);
486        assert_eq!(schema.enums[0].name, "Role");
487        assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
488
489        // Model
490        assert_eq!(schema.models.len(), 1);
491        let user = &schema.models[0];
492        assert_eq!(user.name, "User");
493        assert_eq!(user.fields.len(), 6);
494
495        // id field
496        let id_field = &user.fields[0];
497        assert_eq!(id_field.name, "id");
498        assert_eq!(id_field.field_type.name, "String");
499        assert!(!id_field.field_type.is_optional);
500        assert!(
501            id_field
502                .attributes
503                .iter()
504                .any(|a| matches!(a, FieldAttribute::Id))
505        );
506        assert!(
507            id_field
508                .attributes
509                .iter()
510                .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
511        );
512
513        // name field is optional
514        let name_field = &user.fields[2];
515        assert_eq!(name_field.name, "name");
516        assert!(name_field.field_type.is_optional);
517
518        // role field has enum default
519        let role_field = &user.fields[3];
520        assert_eq!(role_field.name, "role");
521        assert!(role_field.attributes.iter().any(
522            |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
523        ));
524
525        // updatedAt has @updatedAt
526        let updated_field = &user.fields[5];
527        assert_eq!(updated_field.name, "updatedAt");
528        assert!(
529            updated_field
530                .attributes
531                .iter()
532                .any(|a| matches!(a, FieldAttribute::UpdatedAt))
533        );
534
535        // Block attributes
536        assert_eq!(user.attributes.len(), 2);
537        assert!(
538            user.attributes
539                .iter()
540                .any(|a| matches!(a, BlockAttribute::Index(fields) if fields == &["email"]))
541        );
542        assert!(
543            user.attributes
544                .iter()
545                .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
546        );
547    }
548
549    #[test]
550    fn test_parse_multiple_models() {
551        let schema_str = r#"
552datasource db {
553  provider = "postgresql"
554  url      = "postgres://localhost/test"
555}
556
557model User {
558  id    String @id @default(uuid())
559  email String @unique
560  posts Post[]
561}
562
563model Post {
564  id       String  @id @default(uuid())
565  title    String
566  content  String?
567  author   User    @relation(fields: [authorId], references: [id])
568  authorId String
569
570  @@index([authorId])
571}
572"#;
573
574        let schema = parse(schema_str).expect("should parse");
575        assert_eq!(schema.models.len(), 2);
576        assert_eq!(schema.models[0].name, "User");
577        assert_eq!(schema.models[1].name, "Post");
578
579        // Check relation attribute on Post.author
580        let author_field = &schema.models[1].fields[3];
581        assert_eq!(author_field.name, "author");
582        let rel = author_field.attributes.iter().find_map(|a| match a {
583            FieldAttribute::Relation(r) => Some(r),
584            _ => None,
585        });
586        let rel = rel.expect("should have @relation");
587        assert_eq!(rel.fields, vec!["authorId"]);
588        assert_eq!(rel.references, vec!["id"]);
589
590        // Check Post[] is a list
591        let posts_field = &schema.models[0].fields[2];
592        assert_eq!(posts_field.name, "posts");
593        assert!(posts_field.field_type.is_list);
594    }
595
596    #[test]
597    fn test_parse_composite_id() {
598        let schema_str = r#"
599datasource db {
600  provider = "sqlite"
601  url      = "file:./dev.db"
602}
603
604model PostTag {
605  postId String
606  tagId  String
607
608  @@id([postId, tagId])
609}
610"#;
611
612        let schema = parse(schema_str).expect("should parse");
613        let model = &schema.models[0];
614        assert!(
615            model
616                .attributes
617                .iter()
618                .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
619        );
620    }
621
622    #[test]
623    fn test_parse_error_invalid_syntax() {
624        let bad = "model { broken }";
625        assert!(parse(bad).is_err());
626    }
627
628    #[test]
629    fn test_parse_with_comments() {
630        let schema_str = r#"
631// This is a comment
632datasource db {
633  provider = "postgresql" // inline comment
634  url      = env("DATABASE_URL")
635}
636
637// Another comment
638model User {
639  id String @id @default(uuid())
640  // A commented field
641  name String?
642}
643"#;
644
645        let schema = parse(schema_str).expect("should parse with comments");
646        assert!(schema.datasource.is_some());
647        assert_eq!(schema.models[0].fields.len(), 2);
648    }
649}