1use ferriorm_core::ast::*;
11use pest::Parser;
12use pest_derive::Parser;
13
14use crate::error::ParseError;
15
16#[derive(Parser)]
17#[grammar = "grammar.pest"]
18struct FerriormParser;
19
20pub fn parse(source: &str) -> Result<SchemaFile, ParseError> {
22 let pairs = FerriormParser::parse(Rule::schema, source)
23 .map_err(|e| ParseError::Syntax(e.to_string()))?;
24
25 let mut schema = SchemaFile {
26 datasource: None,
27 generators: Vec::new(),
28 enums: Vec::new(),
29 models: Vec::new(),
30 };
31
32 let schema_pair = pairs.into_iter().next().unwrap();
34 for pair in schema_pair.into_inner() {
35 match pair.as_rule() {
36 Rule::datasource_block => {
37 schema.datasource = Some(parse_datasource(pair)?);
38 }
39 Rule::generator_block => {
40 schema.generators.push(parse_generator(pair)?);
41 }
42 Rule::enum_block => {
43 schema.enums.push(parse_enum(pair)?);
44 }
45 Rule::model_block => {
46 schema.models.push(parse_model(pair)?);
47 }
48 Rule::EOI => {}
49 _ => {}
50 }
51 }
52
53 Ok(schema)
54}
55
56fn span_from(pair: &pest::iterators::Pair<'_, Rule>) -> Span {
57 let span = pair.as_span();
58 Span {
59 start: span.start(),
60 end: span.end(),
61 }
62}
63
64fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> Result<Datasource, ParseError> {
65 let span = span_from(&pair);
66 let mut inner = pair.into_inner();
67 let name = inner.next().unwrap().as_str().to_string();
68
69 let mut provider = String::new();
70 let mut url = StringOrEnv::Literal(String::new());
71
72 for kv in inner {
73 if kv.as_rule() != Rule::kv_pair {
74 continue;
75 }
76 let mut kv_inner = kv.into_inner();
77 let key = kv_inner.next().unwrap().as_str();
78 let value_pair = kv_inner.next().unwrap();
79
80 match key {
81 "provider" => {
82 provider = parse_string_value(&value_pair);
83 }
84 "url" => {
85 url = parse_string_or_env(&value_pair)?;
86 }
87 _ => {}
88 }
89 }
90
91 Ok(Datasource {
92 name,
93 provider,
94 url,
95 span,
96 })
97}
98
99fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> Result<Generator, ParseError> {
100 let span = span_from(&pair);
101 let mut inner = pair.into_inner();
102 let name = inner.next().unwrap().as_str().to_string();
103
104 let mut output = None;
105
106 for kv in inner {
107 if kv.as_rule() != Rule::kv_pair {
108 continue;
109 }
110 let mut kv_inner = kv.into_inner();
111 let key = kv_inner.next().unwrap().as_str();
112 let value_pair = kv_inner.next().unwrap();
113
114 if key == "output" {
115 output = Some(parse_string_value(&value_pair));
116 }
117 }
118
119 Ok(Generator { name, output, span })
120}
121
122fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> Result<EnumDef, ParseError> {
123 let span = span_from(&pair);
124 let mut inner = pair.into_inner();
125 let name = inner.next().unwrap().as_str().to_string();
126
127 let mut variants = Vec::new();
128 for variant_pair in inner {
129 if variant_pair.as_rule() == Rule::enum_variant {
130 let variant_name = variant_pair
131 .into_inner()
132 .next()
133 .unwrap()
134 .as_str()
135 .to_string();
136 variants.push(variant_name);
137 }
138 }
139
140 Ok(EnumDef {
141 name,
142 variants,
143 db_name: None,
144 span,
145 })
146}
147
148fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> Result<ModelDef, ParseError> {
149 let span = span_from(&pair);
150 let mut inner = pair.into_inner();
151 let name = inner.next().unwrap().as_str().to_string();
152
153 let mut fields = Vec::new();
154 let mut attributes = Vec::new();
155
156 for member in inner {
157 match member.as_rule() {
158 Rule::field_def => {
159 fields.push(parse_field(member)?);
160 }
161 Rule::block_attr_index => {
162 attributes.push(BlockAttribute::Index(parse_field_list_from_block_attr(
163 member,
164 )));
165 }
166 Rule::block_attr_unique => {
167 attributes.push(BlockAttribute::Unique(parse_field_list_from_block_attr(
168 member,
169 )));
170 }
171 Rule::block_attr_map => {
172 let s = member.into_inner().next().unwrap().as_str();
173 attributes.push(BlockAttribute::Map(unquote(s)));
174 }
175 Rule::block_attr_id => {
176 attributes.push(BlockAttribute::Id(parse_field_list_from_block_attr(member)));
177 }
178 _ => {}
179 }
180 }
181
182 Ok(ModelDef {
183 name,
184 fields,
185 attributes,
186 span,
187 })
188}
189
190fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> Result<FieldDef, ParseError> {
191 let span = span_from(&pair);
192 let mut inner = pair.into_inner();
193
194 let name = inner.next().unwrap().as_str().to_string();
195 let field_type_pair = inner.next().unwrap();
196 let field_type = parse_field_type(field_type_pair);
197
198 let mut attributes = Vec::new();
199 for attr_pair in inner {
200 if let Some(attr) = parse_field_attribute(attr_pair)? {
201 attributes.push(attr);
202 }
203 }
204
205 Ok(FieldDef {
206 name,
207 field_type,
208 attributes,
209 span,
210 })
211}
212
213fn parse_field_type(pair: pest::iterators::Pair<'_, Rule>) -> FieldType {
214 let mut inner = pair.into_inner();
215 let name = inner.next().unwrap().as_str().to_string();
216
217 let mut is_list = false;
218 let mut is_optional = false;
219
220 for modifier in inner {
221 match modifier.as_rule() {
222 Rule::list_modifier => is_list = true,
223 Rule::optional_modifier => is_optional = true,
224 _ => {}
225 }
226 }
227
228 FieldType {
229 name,
230 is_list,
231 is_optional,
232 }
233}
234
235fn parse_field_attribute(
236 pair: pest::iterators::Pair<'_, Rule>,
237) -> Result<Option<FieldAttribute>, ParseError> {
238 match pair.as_rule() {
239 Rule::attr_id => Ok(Some(FieldAttribute::Id)),
240 Rule::attr_unique => Ok(Some(FieldAttribute::Unique)),
241 Rule::attr_updated_at => Ok(Some(FieldAttribute::UpdatedAt)),
242 Rule::attr_default => {
243 let value_pair = pair.into_inner().next().unwrap();
244 let default = parse_default_value(value_pair)?;
245 Ok(Some(FieldAttribute::Default(default)))
246 }
247 Rule::attr_map => {
248 let s = pair.into_inner().next().unwrap().as_str();
249 Ok(Some(FieldAttribute::Map(unquote(s))))
250 }
251 Rule::attr_relation => {
252 let relation = parse_relation_attribute(pair)?;
253 Ok(Some(FieldAttribute::Relation(relation)))
254 }
255 Rule::attr_db_type => {
256 let mut inner = pair.into_inner();
257 let type_name = inner.next().unwrap().as_str().to_string();
258 let args: Vec<String> = inner.map(|p| parse_string_value(&p)).collect();
259 Ok(Some(FieldAttribute::DbType(type_name, args)))
260 }
261 _ => Ok(None),
262 }
263}
264
265fn parse_default_value(pair: pest::iterators::Pair<'_, Rule>) -> Result<DefaultValue, ParseError> {
266 match pair.as_rule() {
267 Rule::func_call => {
268 let mut inner = pair.into_inner();
269 let func_name = inner.next().unwrap().as_str();
270 match func_name {
271 "uuid" => Ok(DefaultValue::Uuid),
272 "cuid" => Ok(DefaultValue::Cuid),
273 "autoincrement" => Ok(DefaultValue::AutoIncrement),
274 "now" => Ok(DefaultValue::Now),
275 other => Err(ParseError::Syntax(format!(
276 "Unknown default function: {other}()"
277 ))),
278 }
279 }
280 Rule::string_literal => Ok(DefaultValue::Literal(LiteralValue::String(unquote(
281 pair.as_str(),
282 )))),
283 Rule::number_literal => {
284 let s = pair.as_str();
285 if s.contains('.') {
286 Ok(DefaultValue::Literal(LiteralValue::Float(
287 s.parse()
288 .map_err(|e| ParseError::Syntax(format!("Invalid float: {e}")))?,
289 )))
290 } else {
291 Ok(DefaultValue::Literal(LiteralValue::Int(
292 s.parse()
293 .map_err(|e| ParseError::Syntax(format!("Invalid int: {e}")))?,
294 )))
295 }
296 }
297 Rule::boolean_literal => {
298 let b = pair.as_str() == "true";
299 Ok(DefaultValue::Literal(LiteralValue::Bool(b)))
300 }
301 Rule::identifier_value => {
302 let name = pair.into_inner().next().unwrap().as_str().to_string();
303 Ok(DefaultValue::EnumVariant(name))
304 }
305 _ => Err(ParseError::Syntax(format!(
306 "Unexpected default value: {:?}",
307 pair.as_rule()
308 ))),
309 }
310}
311
312fn parse_relation_attribute(
313 pair: pest::iterators::Pair<'_, Rule>,
314) -> Result<RelationAttribute, ParseError> {
315 let args_pair = pair.into_inner().next().unwrap(); let mut fields = Vec::new();
317 let mut references = Vec::new();
318 let mut on_delete = None;
319 let mut on_update = None;
320 let mut name = None;
321
322 for arg in args_pair.into_inner() {
323 let named_arg = if arg.as_rule() == Rule::relation_arg {
326 arg.into_inner().next().unwrap()
327 } else if arg.as_rule() == Rule::named_arg {
328 arg
329 } else {
330 continue;
331 };
332
333 let mut named = named_arg.into_inner();
335 let key = named.next().unwrap().as_str();
336 let value_pair = named.next().unwrap();
337
338 match key {
339 "fields" => fields = parse_field_list(&value_pair),
340 "references" => references = parse_field_list(&value_pair),
341 "onDelete" => on_delete = parse_referential_action(&value_pair),
342 "onUpdate" => on_update = parse_referential_action(&value_pair),
343 "name" => name = Some(parse_string_value(&value_pair)),
344 _ => {}
345 }
346 }
347
348 Ok(RelationAttribute {
349 name,
350 fields,
351 references,
352 on_delete,
353 on_update,
354 })
355}
356
357fn parse_referential_action(pair: &pest::iterators::Pair<'_, Rule>) -> Option<ReferentialAction> {
358 let s = pair.as_str().trim_matches('"');
359 match s {
360 "Cascade" => Some(ReferentialAction::Cascade),
361 "Restrict" => Some(ReferentialAction::Restrict),
362 "NoAction" => Some(ReferentialAction::NoAction),
363 "SetNull" => Some(ReferentialAction::SetNull),
364 "SetDefault" => Some(ReferentialAction::SetDefault),
365 _ => None,
366 }
367}
368
369fn parse_field_list(pair: &pest::iterators::Pair<'_, Rule>) -> Vec<String> {
370 pair.clone()
371 .into_inner()
372 .filter(|p| p.as_rule() == Rule::identifier)
373 .map(|p| p.as_str().to_string())
374 .collect()
375}
376
377fn parse_field_list_from_block_attr(pair: pest::iterators::Pair<'_, Rule>) -> Vec<String> {
378 let field_list = pair.into_inner().next().unwrap();
379 parse_field_list(&field_list)
380}
381
382fn parse_string_or_env(pair: &pest::iterators::Pair<'_, Rule>) -> Result<StringOrEnv, ParseError> {
383 match pair.as_rule() {
384 Rule::func_call => {
385 let mut inner = pair.clone().into_inner();
386 let func_name = inner.next().unwrap().as_str();
387 if func_name == "env" {
388 let arg = inner
389 .next()
390 .ok_or_else(|| ParseError::Syntax("env() requires a string argument".into()))?;
391 Ok(StringOrEnv::Env(unquote(arg.as_str())))
392 } else {
393 Err(ParseError::Syntax(format!(
394 "Expected env(), got {func_name}()"
395 )))
396 }
397 }
398 Rule::string_literal => Ok(StringOrEnv::Literal(unquote(pair.as_str()))),
399 _ => Err(ParseError::Syntax(format!(
400 "Expected string or env(), got {:?}",
401 pair.as_rule()
402 ))),
403 }
404}
405
406fn parse_string_value(pair: &pest::iterators::Pair<'_, Rule>) -> String {
407 unquote(pair.as_str())
408}
409
410fn unquote(s: &str) -> String {
411 if s.starts_with('"') && s.ends_with('"') {
412 s[1..s.len() - 1].to_string()
413 } else {
414 s.to_string()
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 const BASIC_SCHEMA: &str = r#"
423datasource db {
424 provider = "postgresql"
425 url = env("DATABASE_URL")
426}
427
428generator client {
429 output = "./src/generated"
430}
431
432enum Role {
433 User
434 Admin
435 Moderator
436}
437
438model User {
439 id String @id @default(uuid())
440 email String @unique
441 name String?
442 role Role @default(User)
443 createdAt DateTime @default(now())
444 updatedAt DateTime @updatedAt
445
446 @@index([email])
447 @@map("users")
448}
449"#;
450
451 #[test]
452 fn test_parse_basic_schema() {
453 let schema = parse(BASIC_SCHEMA).expect("should parse");
454
455 let ds = schema.datasource.expect("should have datasource");
457 assert_eq!(ds.name, "db");
458 assert_eq!(ds.provider, "postgresql");
459 match &ds.url {
460 StringOrEnv::Env(var) => assert_eq!(var, "DATABASE_URL"),
461 _ => panic!("expected env()"),
462 }
463
464 assert_eq!(schema.generators.len(), 1);
466 assert_eq!(schema.generators[0].name, "client");
467 assert_eq!(
468 schema.generators[0].output.as_deref(),
469 Some("./src/generated")
470 );
471
472 assert_eq!(schema.enums.len(), 1);
474 assert_eq!(schema.enums[0].name, "Role");
475 assert_eq!(schema.enums[0].variants, vec!["User", "Admin", "Moderator"]);
476
477 assert_eq!(schema.models.len(), 1);
479 let user = &schema.models[0];
480 assert_eq!(user.name, "User");
481 assert_eq!(user.fields.len(), 6);
482
483 let id_field = &user.fields[0];
485 assert_eq!(id_field.name, "id");
486 assert_eq!(id_field.field_type.name, "String");
487 assert!(!id_field.field_type.is_optional);
488 assert!(
489 id_field
490 .attributes
491 .iter()
492 .any(|a| matches!(a, FieldAttribute::Id))
493 );
494 assert!(
495 id_field
496 .attributes
497 .iter()
498 .any(|a| matches!(a, FieldAttribute::Default(DefaultValue::Uuid)))
499 );
500
501 let name_field = &user.fields[2];
503 assert_eq!(name_field.name, "name");
504 assert!(name_field.field_type.is_optional);
505
506 let role_field = &user.fields[3];
508 assert_eq!(role_field.name, "role");
509 assert!(role_field.attributes.iter().any(
510 |a| matches!(a, FieldAttribute::Default(DefaultValue::EnumVariant(v)) if v == "User")
511 ));
512
513 let updated_field = &user.fields[5];
515 assert_eq!(updated_field.name, "updatedAt");
516 assert!(
517 updated_field
518 .attributes
519 .iter()
520 .any(|a| matches!(a, FieldAttribute::UpdatedAt))
521 );
522
523 assert_eq!(user.attributes.len(), 2);
525 assert!(
526 user.attributes
527 .iter()
528 .any(|a| matches!(a, BlockAttribute::Index(fields) if fields == &["email"]))
529 );
530 assert!(
531 user.attributes
532 .iter()
533 .any(|a| matches!(a, BlockAttribute::Map(name) if name == "users"))
534 );
535 }
536
537 #[test]
538 fn test_parse_multiple_models() {
539 let schema_str = r#"
540datasource db {
541 provider = "postgresql"
542 url = "postgres://localhost/test"
543}
544
545model User {
546 id String @id @default(uuid())
547 email String @unique
548 posts Post[]
549}
550
551model Post {
552 id String @id @default(uuid())
553 title String
554 content String?
555 author User @relation(fields: [authorId], references: [id])
556 authorId String
557
558 @@index([authorId])
559}
560"#;
561
562 let schema = parse(schema_str).expect("should parse");
563 assert_eq!(schema.models.len(), 2);
564 assert_eq!(schema.models[0].name, "User");
565 assert_eq!(schema.models[1].name, "Post");
566
567 let author_field = &schema.models[1].fields[3];
569 assert_eq!(author_field.name, "author");
570 let rel = author_field.attributes.iter().find_map(|a| match a {
571 FieldAttribute::Relation(r) => Some(r),
572 _ => None,
573 });
574 let rel = rel.expect("should have @relation");
575 assert_eq!(rel.fields, vec!["authorId"]);
576 assert_eq!(rel.references, vec!["id"]);
577
578 let posts_field = &schema.models[0].fields[2];
580 assert_eq!(posts_field.name, "posts");
581 assert!(posts_field.field_type.is_list);
582 }
583
584 #[test]
585 fn test_parse_composite_id() {
586 let schema_str = r#"
587datasource db {
588 provider = "sqlite"
589 url = "file:./dev.db"
590}
591
592model PostTag {
593 postId String
594 tagId String
595
596 @@id([postId, tagId])
597}
598"#;
599
600 let schema = parse(schema_str).expect("should parse");
601 let model = &schema.models[0];
602 assert!(
603 model
604 .attributes
605 .iter()
606 .any(|a| matches!(a, BlockAttribute::Id(fields) if fields == &["postId", "tagId"]))
607 );
608 }
609
610 #[test]
611 fn test_parse_error_invalid_syntax() {
612 let bad = "model { broken }";
613 assert!(parse(bad).is_err());
614 }
615
616 #[test]
617 fn test_parse_with_comments() {
618 let schema_str = r#"
619// This is a comment
620datasource db {
621 provider = "postgresql" // inline comment
622 url = env("DATABASE_URL")
623}
624
625// Another comment
626model User {
627 id String @id @default(uuid())
628 // A commented field
629 name String?
630}
631"#;
632
633 let schema = parse(schema_str).expect("should parse with comments");
634 assert!(schema.datasource.is_some());
635 assert_eq!(schema.models[0].fields.len(), 2);
636 }
637}