1use ferriorm_core::ast::{
11 BlockAttribute, DefaultValue, EnumDef, FieldAttribute, FieldDef, FieldType, Generator,
12 LiteralValue, ModelDef, ReferentialAction, RelationAttribute, SchemaFile, Span, StringOrEnv,
13};
14use pest::Parser;
15use pest_derive::Parser;
16
17use crate::error::ParseError;
18
19#[derive(Parser)]
20#[grammar = "grammar.pest"]
21struct FerriormParser;
22
23pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
34 let pairs = FerriormParser::parse(Rule::schema, source)
35 .map_err(|e| ParseError::Syntax(e.to_string()))?;
36
37 let mut schema = SchemaFile {
38 datasource: None,
39 generators: Vec::new(),
40 enums: Vec::new(),
41 models: Vec::new(),
42 };
43
44 let schema_pair = pairs.into_iter().next().unwrap();
46 for pair in schema_pair.into_inner() {
47 match pair.as_rule() {
48 Rule::datasource_block => {
49 schema.datasource = Some(parse_datasource(pair)?);
50 }
51 Rule::generator_block => {
52 schema.generators.push(parse_generator(pair));
53 }
54 Rule::enum_block => {
55 schema.enums.push(parse_enum(pair));
56 }
57 Rule::model_block => {
58 schema.models.push(parse_model(pair)?);
59 }
60 _ => {}
61 }
62 }
63
64 Ok(schema)
65}
66
67fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
68 let span = pair.as_span();
69 Span {
70 start: span.start(),
71 end: span.end(),
72 }
73}
74
75fn parse_datasource(
76 pair: pest::iterators::Pair<'_, Rule>,
77) -> Result<ferriorm_core::ast::Datasource, ParseError> {
78 let span = span_from(&pair);
79 let mut inner = pair.into_inner();
80 let name = inner.next().unwrap().as_str().to_string();
81
82 let mut provider = String::new();
83 let mut url = StringOrEnv::Literal(String::new());
84
85 for kv in inner {
86 if kv.as_rule() != Rule::kv_pair {
87 continue;
88 }
89 let mut kv_inner = kv.into_inner();
90 let key = kv_inner.next().unwrap().as_str();
91 let value_pair = kv_inner.next().unwrap();
92
93 match key {
94 "provider" => {
95 provider = parse_string_value(&value_pair);
96 }
97 "url" => {
98 url = parse_string_or_env(&value_pair)?;
99 }
100 _ => {}
101 }
102 }
103
104 Ok(ferriorm_core::ast::Datasource {
105 name,
106 provider,
107 url,
108 span,
109 })
110}
111
112fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Generator {
113 let span = span_from(&pair);
114 let mut inner = pair.into_inner();
115 let name = inner.next().unwrap().as_str().to_string();
116
117 let mut output = None;
118
119 for kv in inner {
120 if kv.as_rule() != Rule::kv_pair {
121 continue;
122 }
123 let mut kv_inner = kv.into_inner();
124 let key = kv_inner.next().unwrap().as_str();
125 let value_pair = kv_inner.next().unwrap();
126
127 if key == "output" {
128 output = Some(parse_string_value(&value_pair));
129 }
130 }
131
132 Generator { name, output, span }
133}
134
135fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> EnumDef {
136 let span = span_from(&pair);
137 let mut inner = pair.into_inner();
138 let name = inner.next().unwrap().as_str().to_string();
139
140 let mut variants = Vec::new();
141 for variant_pair in inner {
142 if variant_pair.as_rule() == Rule::enum_variant {
143 let variant_name = variant_pair
144 .into_inner()
145 .next()
146 .unwrap()
147 .as_str()
148 .to_string();
149 variants.push(variant_name);
150 }
151 }
152
153 EnumDef {
154 name,
155 variants,
156 db_name: None,
157 span,
158 }
159}
160
161fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
162 let span = span_from(&pair);
163 let mut inner = pair.into_inner();
164 let name = inner.next().unwrap().as_str().to_string();
165
166 let mut fields = Vec::new();
167 let mut attributes = Vec::new();
168
169 for member in inner {
170 match member.as_rule() {
171 Rule::field_def => {
172 fields.push(parse_field(member)?);
173 }
174 Rule::block_attr_index => {
175 attributes.push(BlockAttribute::Index(parse_field_list_from_block_attr(
176 member,
177 )));
178 }
179 Rule::block_attr_unique => {
180 attributes.push(BlockAttribute::Unique(parse_field_list_from_block_attr(
181 member,
182 )));
183 }
184 Rule::block_attr_map => {
185 let s = member.into_inner().next().unwrap().as_str();
186 attributes.push(BlockAttribute::Map(unquote(s)));
187 }
188 Rule::block_attr_id => {
189 attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
190 }
191 _ => {}
192 }
193 }
194
195 Ok(ModelDef {
196 name,
197 fields,
198 attributes,
199 span,
200 })
201}
202
203fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
204 let span = span_from(&pair);
205 let mut inner = pair.into_inner();
206
207 let name = inner.next().unwrap().as_str().to_string();
208 let field_type_pair = inner.next().unwrap();
209 let field_type = parse_field_type(field_type_pair);
210
211 let mut attributes = Vec::new();
212 for attr_pair in inner {
213 if let Some(attr) = parse_field_attribute(attr_pair)? {
214 attributes.push(attr);
215 }
216 }
217
218 Ok(FieldDef {
219 name,
220 field_type,
221 attributes,
222 span,
223 })
224}
225
226fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
227 let mut inner = pair.into_inner();
228 let name = inner.next().unwrap().as_str().to_string();
229
230 let mut is_list = false;
231 let mut is_optional = false;
232
233 for modifier in inner {
234 match modifier.as_rule() {
235 Rule::list_modifier => is_list = true,
236 Rule::optional_modifier => is_optional = true,
237 _ => {}
238 }
239 }
240
241 FieldType {
242 name,
243 is_list,
244 is_optional,
245 }
246}
247
248fn parse_field_attribute(
249 pair: pest::iterators::Pair<'_, Rule>,
250) -> Result<Option<FieldAttribute>, ParseError> {
251 match pair.as_rule() {
252 Rule::attr_id => Ok(Some(FieldAttribute::Id)),
253 Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
254 Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
255 Rule::attr_default => {
256 let value_pair = pair.into_inner().next().unwrap();
257 let default = parse_default_value(value_pair)?;
258 Ok(Some(FieldAttribute::Default(default)))
259 }
260 Rule::attr_map => {
261 let s = pair.into_inner().next().unwrap().as_str();
262 Ok(Some(FieldAttribute::Map(unquote(s))))
263 }
264 Rule::attr_relation => {
265 let relation = parse_relation_attribute(pair);
266 Ok(Some(FieldAttribute::Relation(relation)))
267 }
268 Rule::attr_db_type => {
269 let mut inner = pair.into_inner();
270 let type_name = inner.next().unwrap().as_str().to_string();
271 let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
272 Ok(Some(FieldAttribute::DbType(type_name, args)))
273 }
274 _ => Ok(None),
275 }
276}
277
278fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
279 match pair.as_rule() {
280 Rule::func_call => {
281 let mut inner = pair.into_inner();
282 let func_name = inner.next().unwrap().as_str();
283 match func_name {
284 "uuid" => Ok(DefaultValue::Uuid),
285 "cuid" => Ok(DefaultValue::Cuid),
286 "autoincrement" => Ok(DefaultValue::AutoIncrement),
287 "now" => Ok(DefaultValue::Now),
288 other => Err(ParseError::Syntax(format!(
289 "Unknown default function: {other}()"
290 ))),
291 }
292 }
293 Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
294 pair.as_str(),
295 )))),
296 Rule::number_literal => {
297 let s = pair.as_str();
298 if s.contains('.') {
299 Ok(DefaultValue::Literal(LiteralValue::Float(
300 s.parse()
301 .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
302 )))
303 } else {
304 Ok(DefaultValue::Literal(LiteralValue::Int(
305 s.parse()
306 .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
307 )))
308 }
309 }
310 Rule::boolean_literal => {
311 let b = pair.as_str() == "true";
312 Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
313 }
314 Rule::identifier_value => {
315 let name = pair.into_inner().next().unwrap().as_str().to_string();
316 Ok(DefaultValue::EnumVariant(name))
317 }
318 _ => Err(ParseError::Syntax(format!(
319 "Unexpected default value: {:?}",
320 pair.as_rule()
321 ))),
322 }
323}
324
325fn parse_relation_attribute(pair: pest::iterators::Pair<'_, Rule>) -> RelationAttribute {
326 let args_pair = pair.into_inner().next().unwrap(); let mut fields = Vec::new();
328 let mut references = Vec::new();
329 let mut on_delete = None;
330 let mut on_update = None;
331 let mut name = None;
332
333 for arg in args_pair.into_inner() {
334 let named_arg = if arg.as_rule() == Rule::relation_arg {
337 arg.into_inner().next().unwrap()
338 } else if arg.as_rule() == Rule::named_arg {
339 arg
340 } else {
341 continue;
342 };
343
344 let mut named = named_arg.into_inner();
346 let key = named.next().unwrap().as_str();
347 let value_pair = named.next().unwrap();
348
349 match key {
350 "fields" => fields = parse_field_list(&value_pair),
351 "references" => references = parse_field_list(&value_pair),
352 "onDelete" => on_delete = parse_referential_action(&value_pair),
353 "onUpdate" => on_update = parse_referential_action(&value_pair),
354 "name" => name = Some(parse_string_value(&value_pair)),
355 _ => {}
356 }
357 }
358
359 RelationAttribute {
360 name,
361 fields,
362 references,
363 on_delete,
364 on_update,
365 }
366}
367
368fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
369 let s = pair.as_str().trim_matches('"');
370 match s {
371 "Cascade" => Some(ReferentialAction::Cascade),
372 "Restrict" => Some(ReferentialAction::Restrict),
373 "NoAction" => Some(ReferentialAction::NoAction),
374 "SetNull" => Some(ReferentialAction::SetNull),
375 "SetDefault" => Some(ReferentialAction::SetDefault),
376 _ => None,
377 }
378}
379
380fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
381 pair.clone()
382 .into_inner()
383 .filter(|p| p.as_rule() == Rule::identifier)
384 .map(|p| p.as_str().to_string())
385 .collect()
386}
387
388fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
389 let field_list = pair.into_inner().next().unwrap();
390 parse_field_list(&field_list)
391}
392
393fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
394 match pair.as_rule() {
395 Rule::func_call => {
396 let mut inner = pair.clone().into_inner();
397 let func_name = inner.next().unwrap().as_str();
398 if func_name == "env" {
399 let arg = inner
400 .next()
401 .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
402 Ok(StringOrEnv::Env(unquote(arg.as_str())))
403 } else {
404 Err(ParseError::Syntax(format!(
405 "Expected env(), got {func_name}()"
406 )))
407 }
408 }
409 Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
410 _ => Err(ParseError::Syntax(format!(
411 "Expected string or env(), got {:?}",
412 pair.as_rule()
413 ))),
414 }
415}
416
417fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
418 unquote(pair.as_str())
419}
420
421fn unquote(s: &str) -> String {
422 if s.starts_with('"') && s.ends_with('"') {
423 s[1..s.len() - 1].to_string()
424 } else {
425 s.to_string()
426 }
427}
428
429#[cfg(test)]
430#[allow(clippy::pedantic)]
431mod tests {
432 use super::*;
433
434 const BASIC_SCHEMA: &str = r#"
435datasource db {
436 provider = "postgresql"
437 url = env("DATABASE_URL")
438}
439
440generator client {
441 output = "./src/generated"
442}
443
444enum Role {
445 User
446 Admin
447 Moderator
448}
449
450model User {
451 id String @id @default(uuid())
452 email String @unique
453 name String?
454 role Role @default(User)
455 createdAt DateTime @default(now())
456 updatedAt DateTime @updatedAt
457
458 @@index([email])
459 @@map("users")
460}
461"#;
462
463 #[test]
464 fn test_parse_basic_schema() {
465 let schema = parse(BASIC_SCHEMA).expect("should parse");
466
467 let ds = schema.datasource.expect("should have datasource");
469 assert_eq!(ds.name, "db");
470 assert_eq!(ds.provider, "postgresql");
471 match &ds.url {
472 StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
473 _ => panic!("expected env()"),
474 }
475
476 assert_eq!(schema.generators.len(), 1);
478 assert_eq!(schema.generators[0].name, "client");
479 assert_eq!(
480 schema.generators[0].output.as_deref(),
481 Some("./src/generated")
482 );
483
484 assert_eq!(schema.enums.len(), 1);
486 assert_eq!(schema.enums[0].name, "Role");
487 assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
488
489 assert_eq!(schema.models.len(), 1);
491 let user = &schema.models[0];
492 assert_eq!(user.name, "User");
493 assert_eq!(user.fields.len(), 6);
494
495 let id_field = &user.fields[0];
497 assert_eq!(id_field.name, "id");
498 assert_eq!(id_field.field_type.name, "String");
499 assert!(!id_field.field_type.is_optional);
500 assert!(
501 id_field
502 .attributes
503 .iter()
504 .any(|a| matches!(a, FieldAttribute::Id))
505 );
506 assert!(
507 id_field
508 .attributes
509 .iter()
510 .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
511 );
512
513 let name_field = &user.fields[2];
515 assert_eq!(name_field.name, "name");
516 assert!(name_field.field_type.is_optional);
517
518 let role_field = &user.fields[3];
520 assert_eq!(role_field.name, "role");
521 assert!(role_field.attributes.iter().any(
522 |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
523 ));
524
525 let updated_field = &user.fields[5];
527 assert_eq!(updated_field.name, "updatedAt");
528 assert!(
529 updated_field
530 .attributes
531 .iter()
532 .any(|a| matches!(a, FieldAttribute::UpdatedAt))
533 );
534
535 assert_eq!(user.attributes.len(), 2);
537 assert!(
538 user.attributes
539 .iter()
540 .any(|a| matches!(a, BlockAttribute::Index(fields) if fields == &["email"]))
541 );
542 assert!(
543 user.attributes
544 .iter()
545 .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
546 );
547 }
548
549 #[test]
550 fn test_parse_multiple_models() {
551 let schema_str = r#"
552datasource db {
553 provider = "postgresql"
554 url = "postgres://localhost/test"
555}
556
557model User {
558 id String @id @default(uuid())
559 email String @unique
560 posts Post[]
561}
562
563model Post {
564 id String @id @default(uuid())
565 title String
566 content String?
567 author User @relation(fields: [authorId], references: [id])
568 authorId String
569
570 @@index([authorId])
571}
572"#;
573
574 let schema = parse(schema_str).expect("should parse");
575 assert_eq!(schema.models.len(), 2);
576 assert_eq!(schema.models[0].name, "User");
577 assert_eq!(schema.models[1].name, "Post");
578
579 let author_field = &schema.models[1].fields[3];
581 assert_eq!(author_field.name, "author");
582 let rel = author_field.attributes.iter().find_map(|a| match a {
583 FieldAttribute::Relation(r) => Some(r),
584 _ => None,
585 });
586 let rel = rel.expect("should have @relation");
587 assert_eq!(rel.fields, vec!["authorId"]);
588 assert_eq!(rel.references, vec!["id"]);
589
590 let posts_field = &schema.models[0].fields[2];
592 assert_eq!(posts_field.name, "posts");
593 assert!(posts_field.field_type.is_list);
594 }
595
596 #[test]
597 fn test_parse_composite_id() {
598 let schema_str = r#"
599datasource db {
600 provider = "sqlite"
601 url = "file:./dev.db"
602}
603
604model PostTag {
605 postId String
606 tagId String
607
608 @@id([postId, tagId])
609}
610"#;
611
612 let schema = parse(schema_str).expect("should parse");
613 let model = &schema.models[0];
614 assert!(
615 model
616 .attributes
617 .iter()
618 .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
619 );
620 }
621
622 #[test]
623 fn test_parse_error_invalid_syntax() {
624 let bad = "model { broken }";
625 assert!(parse(bad).is_err());
626 }
627
628 #[test]
629 fn test_parse_with_comments() {
630 let schema_str = r#"
631// This is a comment
632datasource db {
633 provider = "postgresql" // inline comment
634 url = env("DATABASE_URL")
635}
636
637// Another comment
638model User {
639 id String @id @default(uuid())
640 // A commented field
641 name String?
642}
643"#;
644
645 let schema = parse(schema_str).expect("should parse with comments");
646 assert!(schema.datasource.is_some());
647 assert_eq!(schema.models[0].fields.len(), 2);
648 }
649}