Skip to main content

prax_schema/parser/
mod.rs

1//! Schema parser for `.prax` files.
2
3mod grammar;
4
5use std::path::Path;
6
7use pest::Parser;
8use smol_str::SmolStr;
9use tracing::{debug, info};
10
11use crate::ast::*;
12use crate::error::{SchemaError, SchemaResult};
13
14pub use grammar::{PraxParser, Rule};
15
16use crate::ast::{
17    MssqlBlockOperation, Policy, PolicyCommand, PolicyType, Server, ServerGroup, ServerProperty,
18    ServerPropertyValue,
19};
20
21/// Parse a schema from a string.
22pub fn parse_schema(input: &str) -> SchemaResult<Schema> {
23    debug!(input_len = input.len(), "parse_schema() starting");
24    let pairs = PraxParser::parse(Rule::schema, input)
25        .map_err(|e| SchemaError::syntax(input.to_string(), 0, input.len(), e.to_string()))?;
26
27    let mut schema = Schema::new();
28    let mut current_doc: Option<Documentation> = None;
29
30    // The top-level parse result contains a single "schema" rule - get its inner pairs
31    let schema_pair = pairs.into_iter().next().unwrap();
32
33    for pair in schema_pair.into_inner() {
34        match pair.as_rule() {
35            Rule::documentation => {
36                let span = pair.as_span();
37                let text = pair
38                    .into_inner()
39                    .map(|p| p.as_str().trim_start_matches("///").trim())
40                    .collect::<Vec<_>>()
41                    .join("\n");
42                current_doc = Some(Documentation::new(
43                    text,
44                    Span::new(span.start(), span.end()),
45                ));
46            }
47            Rule::model_def => {
48                let mut model = parse_model(pair)?;
49                if let Some(doc) = current_doc.take() {
50                    model = model.with_documentation(doc);
51                }
52                schema.add_model(model);
53            }
54            Rule::enum_def => {
55                let mut e = parse_enum(pair)?;
56                if let Some(doc) = current_doc.take() {
57                    e = e.with_documentation(doc);
58                }
59                schema.add_enum(e);
60            }
61            Rule::type_def => {
62                let mut t = parse_composite_type(pair)?;
63                if let Some(doc) = current_doc.take() {
64                    t = t.with_documentation(doc);
65                }
66                schema.add_type(t);
67            }
68            Rule::view_def => {
69                let mut v = parse_view(pair)?;
70                if let Some(doc) = current_doc.take() {
71                    v = v.with_documentation(doc);
72                }
73                schema.add_view(v);
74            }
75            Rule::raw_sql_def => {
76                let sql = parse_raw_sql(pair)?;
77                schema.add_raw_sql(sql);
78            }
79            Rule::server_group_def => {
80                let mut sg = parse_server_group(pair)?;
81                if let Some(doc) = current_doc.take() {
82                    sg.set_documentation(doc);
83                }
84                schema.add_server_group(sg);
85            }
86            Rule::policy_def => {
87                let mut policy = parse_policy(pair)?;
88                if let Some(doc) = current_doc.take() {
89                    policy = policy.with_documentation(doc);
90                }
91                schema.add_policy(policy);
92            }
93            Rule::datasource_def => {
94                let ds = parse_datasource(pair)?;
95                schema.set_datasource(ds);
96                current_doc = None;
97            }
98            Rule::generator_def => {
99                let generator = parse_generator(pair)?;
100                schema.add_generator(generator);
101                current_doc = None;
102            }
103            Rule::EOI => {}
104            _ => {}
105        }
106    }
107
108    info!(
109        models = schema.models.len(),
110        enums = schema.enums.len(),
111        types = schema.types.len(),
112        views = schema.views.len(),
113        generators = schema.generators.len(),
114        policies = schema.policies.len(),
115        "Schema parsed successfully"
116    );
117    Ok(schema)
118}
119
120/// Parse a schema from a file.
121pub fn parse_schema_file(path: impl AsRef<Path>) -> SchemaResult<Schema> {
122    let path = path.as_ref();
123    info!(path = %path.display(), "Loading schema file");
124    let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
125        path: path.display().to_string(),
126        source: e,
127    })?;
128
129    parse_schema(&content)
130}
131
132/// Parse a model definition.
133fn parse_model(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Model> {
134    let span = pair.as_span();
135    let mut inner = pair.into_inner();
136
137    let name_pair = inner.next().unwrap();
138    let name = Ident::new(
139        name_pair.as_str(),
140        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
141    );
142
143    let mut model = Model::new(name, Span::new(span.start(), span.end()));
144
145    for item in inner {
146        match item.as_rule() {
147            Rule::field_def => {
148                let field = parse_field(item)?;
149                model.add_field(field);
150            }
151            Rule::model_attribute => {
152                let attr = parse_attribute(item)?;
153                model.attributes.push(attr);
154            }
155            Rule::model_body_item => {
156                // Unwrap the model_body_item to get the actual field_def or model_attribute
157                let inner_item = item.into_inner().next().unwrap();
158                match inner_item.as_rule() {
159                    Rule::field_def => {
160                        let field = parse_field(inner_item)?;
161                        model.add_field(field);
162                    }
163                    Rule::model_attribute => {
164                        let attr = parse_attribute(inner_item)?;
165                        model.attributes.push(attr);
166                    }
167                    _ => {}
168                }
169            }
170            _ => {}
171        }
172    }
173
174    Ok(model)
175}
176
177/// Parse an enum definition.
178fn parse_enum(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Enum> {
179    let span = pair.as_span();
180    let mut inner = pair.into_inner();
181
182    let name_pair = inner.next().unwrap();
183    let name = Ident::new(
184        name_pair.as_str(),
185        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
186    );
187
188    let mut e = Enum::new(name, Span::new(span.start(), span.end()));
189
190    for item in inner {
191        match item.as_rule() {
192            Rule::enum_variant => {
193                let variant = parse_enum_variant(item)?;
194                e.add_variant(variant);
195            }
196            Rule::model_attribute => {
197                let attr = parse_attribute(item)?;
198                e.attributes.push(attr);
199            }
200            Rule::enum_body_item => {
201                // Unwrap the enum_body_item to get the actual enum_variant or model_attribute
202                let inner_item = item.into_inner().next().unwrap();
203                match inner_item.as_rule() {
204                    Rule::enum_variant => {
205                        let variant = parse_enum_variant(inner_item)?;
206                        e.add_variant(variant);
207                    }
208                    Rule::model_attribute => {
209                        let attr = parse_attribute(inner_item)?;
210                        e.attributes.push(attr);
211                    }
212                    _ => {}
213                }
214            }
215            _ => {}
216        }
217    }
218
219    Ok(e)
220}
221
222/// Parse an enum variant.
223fn parse_enum_variant(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<EnumVariant> {
224    let span = pair.as_span();
225    let mut inner = pair.into_inner();
226
227    let name_pair = inner.next().unwrap();
228    let name = Ident::new(
229        name_pair.as_str(),
230        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
231    );
232
233    let mut variant = EnumVariant::new(name, Span::new(span.start(), span.end()));
234
235    for item in inner {
236        if item.as_rule() == Rule::field_attribute {
237            let attr = parse_attribute(item)?;
238            variant.attributes.push(attr);
239        }
240    }
241
242    Ok(variant)
243}
244
245/// Parse a composite type definition.
246fn parse_composite_type(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<CompositeType> {
247    let span = pair.as_span();
248    let mut inner = pair.into_inner();
249
250    let name_pair = inner.next().unwrap();
251    let name = Ident::new(
252        name_pair.as_str(),
253        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
254    );
255
256    let mut t = CompositeType::new(name, Span::new(span.start(), span.end()));
257
258    for item in inner {
259        if item.as_rule() == Rule::field_def {
260            let field = parse_field(item)?;
261            t.add_field(field);
262        }
263    }
264
265    Ok(t)
266}
267
268/// Parse a view definition.
269fn parse_view(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<View> {
270    let span = pair.as_span();
271    let mut inner = pair.into_inner();
272
273    let name_pair = inner.next().unwrap();
274    let name = Ident::new(
275        name_pair.as_str(),
276        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
277    );
278
279    let mut v = View::new(name, Span::new(span.start(), span.end()));
280
281    for item in inner {
282        match item.as_rule() {
283            Rule::field_def => {
284                let field = parse_field(item)?;
285                v.add_field(field);
286            }
287            Rule::model_attribute => {
288                let attr = parse_attribute(item)?;
289                v.attributes.push(attr);
290            }
291            Rule::model_body_item => {
292                // Unwrap the model_body_item to get the actual field_def or model_attribute
293                let inner_item = item.into_inner().next().unwrap();
294                match inner_item.as_rule() {
295                    Rule::field_def => {
296                        let field = parse_field(inner_item)?;
297                        v.add_field(field);
298                    }
299                    Rule::model_attribute => {
300                        let attr = parse_attribute(inner_item)?;
301                        v.attributes.push(attr);
302                    }
303                    _ => {}
304                }
305            }
306            _ => {}
307        }
308    }
309
310    Ok(v)
311}
312
313/// Parse a field definition.
314fn parse_field(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Field> {
315    let span = pair.as_span();
316    let mut inner = pair.into_inner();
317
318    let name_pair = inner.next().unwrap();
319    let name = Ident::new(
320        name_pair.as_str(),
321        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
322    );
323
324    let type_pair = inner.next().unwrap();
325    let (field_type, modifier) = parse_field_type(type_pair)?;
326
327    let mut attributes = vec![];
328    for item in inner {
329        if item.as_rule() == Rule::field_attribute {
330            let attr = parse_attribute(item)?;
331            attributes.push(attr);
332        }
333    }
334
335    Ok(Field::new(
336        name,
337        field_type,
338        modifier,
339        attributes,
340        Span::new(span.start(), span.end()),
341    ))
342}
343
344/// Parse a field type with optional modifier.
345fn parse_field_type(
346    pair: pest::iterators::Pair<'_, Rule>,
347) -> SchemaResult<(FieldType, TypeModifier)> {
348    let mut type_name = String::new();
349    let mut modifier = TypeModifier::Required;
350
351    for item in pair.into_inner() {
352        match item.as_rule() {
353            Rule::type_name => {
354                type_name = item.as_str().to_string();
355            }
356            Rule::optional_marker => {
357                modifier = if modifier == TypeModifier::List {
358                    TypeModifier::OptionalList
359                } else {
360                    TypeModifier::Optional
361                };
362            }
363            Rule::list_marker => {
364                modifier = if modifier == TypeModifier::Optional {
365                    TypeModifier::OptionalList
366                } else {
367                    TypeModifier::List
368                };
369            }
370            _ => {}
371        }
372    }
373
374    let field_type = if let Some(scalar) = ScalarType::from_str(&type_name) {
375        FieldType::Scalar(scalar)
376    } else {
377        // Assume it's a reference to a model, enum, or type
378        // This will be validated later
379        FieldType::Model(SmolStr::new(&type_name))
380    };
381
382    Ok((field_type, modifier))
383}
384
385/// Parse an attribute.
386fn parse_attribute(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Attribute> {
387    let span = pair.as_span();
388    let mut inner = pair.into_inner();
389
390    let name_pair = inner.next().unwrap();
391    let name = Ident::new(
392        name_pair.as_str(),
393        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
394    );
395
396    let mut args = vec![];
397    for item in inner {
398        if item.as_rule() == Rule::attribute_args {
399            args = parse_attribute_args(item)?;
400        }
401    }
402
403    Ok(Attribute::new(
404        name,
405        args,
406        Span::new(span.start(), span.end()),
407    ))
408}
409
410/// Parse attribute arguments.
411fn parse_attribute_args(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Vec<AttributeArg>> {
412    let mut args = vec![];
413
414    for item in pair.into_inner() {
415        if item.as_rule() == Rule::attribute_arg {
416            let arg = parse_attribute_arg(item)?;
417            args.push(arg);
418        }
419    }
420
421    Ok(args)
422}
423
424/// Parse a single attribute argument.
425fn parse_attribute_arg(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeArg> {
426    let span = pair.as_span();
427    let mut inner = pair.into_inner();
428
429    let first = inner.next().unwrap();
430
431    // Check if this is a named argument (name: value) or positional
432    if let Some(second) = inner.next() {
433        // Named argument
434        let name = Ident::new(
435            first.as_str(),
436            Span::new(first.as_span().start(), first.as_span().end()),
437        );
438        let value = parse_attribute_value(second)?;
439        Ok(AttributeArg::named(
440            name,
441            value,
442            Span::new(span.start(), span.end()),
443        ))
444    } else {
445        // Positional argument
446        let value = parse_attribute_value(first)?;
447        Ok(AttributeArg::positional(
448            value,
449            Span::new(span.start(), span.end()),
450        ))
451    }
452}
453
454/// Parse an attribute value.
455fn parse_attribute_value(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<AttributeValue> {
456    match pair.as_rule() {
457        Rule::string_literal => {
458            let s = pair.as_str();
459            // Remove quotes
460            let unquoted = &s[1..s.len() - 1];
461            Ok(AttributeValue::String(unquoted.to_string()))
462        }
463        Rule::number_literal => {
464            let s = pair.as_str();
465            if s.contains('.') {
466                Ok(AttributeValue::Float(s.parse().unwrap()))
467            } else {
468                Ok(AttributeValue::Int(s.parse().unwrap()))
469            }
470        }
471        Rule::boolean_literal => Ok(AttributeValue::Boolean(pair.as_str() == "true")),
472        Rule::identifier => Ok(AttributeValue::Ident(SmolStr::new(pair.as_str()))),
473        Rule::dotted_identifier => {
474            // Represent "rel.field" as a String so callers can split on '.'
475            Ok(AttributeValue::String(pair.as_str().to_string()))
476        }
477        Rule::function_call => {
478            let mut inner = pair.into_inner();
479            let name = SmolStr::new(inner.next().unwrap().as_str());
480            let mut args = vec![];
481            for item in inner {
482                args.push(parse_attribute_value(item)?);
483            }
484            Ok(AttributeValue::Function(name, args))
485        }
486        Rule::field_ref_list => {
487            let refs: Vec<SmolStr> = pair
488                .into_inner()
489                .map(|p| SmolStr::new(p.as_str()))
490                .collect();
491            Ok(AttributeValue::FieldRefList(refs))
492        }
493        Rule::array_literal => {
494            let values: Result<Vec<_>, _> = pair.into_inner().map(parse_attribute_value).collect();
495            Ok(AttributeValue::Array(values?))
496        }
497        Rule::attribute_value => {
498            // Unwrap nested attribute_value
499            parse_attribute_value(pair.into_inner().next().unwrap())
500        }
501        _ => {
502            // Fallback: treat as identifier
503            Ok(AttributeValue::Ident(SmolStr::new(pair.as_str())))
504        }
505    }
506}
507
508/// Parse a raw SQL definition.
509fn parse_raw_sql(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<RawSql> {
510    let mut inner = pair.into_inner();
511
512    let name = inner.next().unwrap().as_str();
513    let sql = inner.next().unwrap().as_str();
514
515    // Remove triple quotes
516    let sql_content = sql
517        .trim_start_matches("\"\"\"")
518        .trim_end_matches("\"\"\"")
519        .trim();
520
521    Ok(RawSql::new(name, sql_content))
522}
523
524/// Parse a server group definition.
525fn parse_server_group(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerGroup> {
526    let span = pair.as_span();
527    let mut inner = pair.into_inner();
528
529    let name_pair = inner.next().unwrap();
530    let name = Ident::new(
531        name_pair.as_str(),
532        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
533    );
534
535    let mut server_group = ServerGroup::new(name, Span::new(span.start(), span.end()));
536
537    for item in inner {
538        match item.as_rule() {
539            Rule::server_group_item => {
540                // Unwrap the server_group_item to get the actual server_def or model_attribute
541                let inner_item = item.into_inner().next().unwrap();
542                match inner_item.as_rule() {
543                    Rule::server_def => {
544                        let server = parse_server(inner_item)?;
545                        server_group.add_server(server);
546                    }
547                    Rule::model_attribute => {
548                        let attr = parse_attribute(inner_item)?;
549                        server_group.add_attribute(attr);
550                    }
551                    _ => {}
552                }
553            }
554            Rule::server_def => {
555                let server = parse_server(item)?;
556                server_group.add_server(server);
557            }
558            Rule::model_attribute => {
559                let attr = parse_attribute(item)?;
560                server_group.add_attribute(attr);
561            }
562            _ => {}
563        }
564    }
565
566    Ok(server_group)
567}
568
569/// Parse a server definition within a server group.
570fn parse_server(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Server> {
571    let span = pair.as_span();
572    let mut inner = pair.into_inner();
573
574    let name_pair = inner.next().unwrap();
575    let name = Ident::new(
576        name_pair.as_str(),
577        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
578    );
579
580    let mut server = Server::new(name, Span::new(span.start(), span.end()));
581
582    for item in inner {
583        if item.as_rule() == Rule::server_property {
584            let prop = parse_server_property(item)?;
585            server.add_property(prop);
586        }
587    }
588
589    Ok(server)
590}
591
592/// Parse a server property (key = value).
593fn parse_server_property(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<ServerProperty> {
594    let span = pair.as_span();
595    let mut inner = pair.into_inner();
596
597    let key_pair = inner.next().unwrap();
598    let key = key_pair.as_str();
599
600    let value_pair = inner.next().unwrap();
601    let value = parse_server_property_value(value_pair)?;
602
603    Ok(ServerProperty::new(
604        key,
605        value,
606        Span::new(span.start(), span.end()),
607    ))
608}
609
610/// Parse a generator definition.
611fn parse_generator(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Generator> {
612    let span = pair.as_span();
613    let mut inner = pair.into_inner();
614
615    let name = inner.next().unwrap().as_str();
616    let mut generator = Generator::new(name, Span::new(span.start(), span.end()));
617
618    for prop in inner {
619        if prop.as_rule() == Rule::datasource_property {
620            let mut prop_inner = prop.into_inner();
621            let key = prop_inner.next().unwrap().as_str();
622            let value_pair = prop_inner.next().unwrap();
623
624            match key {
625                "provider" => {
626                    let s = extract_datasource_string(&value_pair);
627                    generator.provider = Some(SmolStr::new(s));
628                }
629                "output" => {
630                    let s = extract_datasource_string(&value_pair);
631                    generator.output = Some(SmolStr::new(s));
632                }
633                "generate" => {
634                    generator.generate = parse_generator_toggle(&value_pair);
635                }
636                _ => {
637                    let val = parse_generator_value(&value_pair);
638                    generator.properties.insert(SmolStr::new(key), val);
639                }
640            }
641        }
642    }
643
644    Ok(generator)
645}
646
647/// Parse a generator toggle value (bool literal or env() call).
648fn parse_generator_toggle(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorToggle {
649    match pair.as_rule() {
650        Rule::env_function => {
651            let env_var = pair
652                .clone()
653                .into_inner()
654                .next()
655                .map(|p| {
656                    let s = p.as_str();
657                    SmolStr::new(&s[1..s.len() - 1])
658                })
659                .unwrap_or_default();
660            GeneratorToggle::Env(env_var)
661        }
662        Rule::datasource_value => {
663            let inner = pair.clone().into_inner().next().unwrap();
664            parse_generator_toggle(&inner)
665        }
666        _ => {
667            let s = pair.as_str().trim().trim_matches('"');
668            match s {
669                "true" => GeneratorToggle::Literal(true),
670                "false" => GeneratorToggle::Literal(false),
671                _ => GeneratorToggle::Literal(false),
672            }
673        }
674    }
675}
676
677/// Parse an arbitrary generator property value.
678fn parse_generator_value(pair: &pest::iterators::Pair<'_, Rule>) -> GeneratorValue {
679    match pair.as_rule() {
680        Rule::env_function => {
681            let env_var = pair
682                .clone()
683                .into_inner()
684                .next()
685                .map(|p| {
686                    let s = p.as_str();
687                    SmolStr::new(&s[1..s.len() - 1])
688                })
689                .unwrap_or_default();
690            GeneratorValue::Env(env_var)
691        }
692        Rule::datasource_value => {
693            let inner = pair.clone().into_inner().next().unwrap();
694            parse_generator_value(&inner)
695        }
696        Rule::string_literal => {
697            let s = pair.as_str();
698            GeneratorValue::String(SmolStr::new(&s[1..s.len() - 1]))
699        }
700        _ => {
701            let s = pair.as_str().trim().trim_matches('"');
702            match s {
703                "true" => GeneratorValue::Bool(true),
704                "false" => GeneratorValue::Bool(false),
705                _ => GeneratorValue::Ident(SmolStr::new(s)),
706            }
707        }
708    }
709}
710
711/// Parse a datasource definition.
712fn parse_datasource(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Datasource> {
713    let span = pair.as_span();
714    let mut inner = pair.into_inner();
715
716    let name_pair = inner.next().unwrap();
717    let name = name_pair.as_str();
718
719    let mut datasource = Datasource::new(
720        name,
721        DatabaseProvider::PostgreSQL,
722        Span::new(span.start(), span.end()),
723    );
724
725    for prop in inner {
726        if prop.as_rule() == Rule::datasource_property {
727            let mut prop_inner = prop.into_inner();
728            let key = prop_inner.next().unwrap().as_str();
729            let value_pair = prop_inner.next().unwrap();
730
731            match key {
732                "provider" => {
733                    let provider_str = extract_datasource_string(&value_pair);
734                    if let Some(provider) = DatabaseProvider::from_str(&provider_str) {
735                        datasource.provider = provider;
736                    }
737                }
738                "url" => {
739                    match value_pair.as_rule() {
740                        Rule::env_function => {
741                            // env("DATABASE_URL")
742                            let env_var = value_pair
743                                .into_inner()
744                                .next()
745                                .map(|p| {
746                                    let s = p.as_str();
747                                    s[1..s.len() - 1].to_string()
748                                })
749                                .unwrap_or_default();
750                            datasource.url_env = Some(SmolStr::new(env_var));
751                        }
752                        Rule::string_literal => {
753                            let s = value_pair.as_str();
754                            let url = &s[1..s.len() - 1];
755                            datasource.url = Some(SmolStr::new(url));
756                        }
757                        _ => {}
758                    }
759                }
760                "extensions" => {
761                    if value_pair.as_rule() == Rule::extension_array {
762                        for ext_item in value_pair.into_inner() {
763                            if ext_item.as_rule() == Rule::extension_item {
764                                let ext = parse_extension_item(
765                                    ext_item,
766                                    Span::new(span.start(), span.end()),
767                                )?;
768                                datasource.add_extension(ext);
769                            }
770                        }
771                    }
772                }
773                _ => {
774                    // Store as additional property
775                    let value_str = extract_datasource_string(&value_pair);
776                    datasource.add_property(key, value_str);
777                }
778            }
779        }
780    }
781
782    Ok(datasource)
783}
784
785/// Parse an extension item from the extensions array.
786fn parse_extension_item(
787    pair: pest::iterators::Pair<'_, Rule>,
788    span: Span,
789) -> SchemaResult<PostgresExtension> {
790    let mut inner = pair.into_inner();
791    let name = inner.next().unwrap().as_str();
792    let mut ext = PostgresExtension::new(name, span);
793
794    // Check for extension args like (schema: "public", version: "0.5.0")
795    if let Some(args_pair) = inner.next()
796        && args_pair.as_rule() == Rule::extension_args
797    {
798        for arg in args_pair.into_inner() {
799            if arg.as_rule() == Rule::extension_arg {
800                let mut arg_inner = arg.into_inner();
801                let arg_key = arg_inner.next().unwrap().as_str();
802                let arg_value_pair = arg_inner.next().unwrap();
803                let arg_value = {
804                    let s = arg_value_pair.as_str();
805                    &s[1..s.len() - 1]
806                };
807
808                match arg_key {
809                    "schema" => {
810                        ext = ext.with_schema(arg_value);
811                    }
812                    "version" => {
813                        ext = ext.with_version(arg_value);
814                    }
815                    _ => {}
816                }
817            }
818        }
819    }
820
821    Ok(ext)
822}
823
824/// Extract a string value from a datasource property value.
825fn extract_datasource_string(pair: &pest::iterators::Pair<'_, Rule>) -> String {
826    match pair.as_rule() {
827        Rule::string_literal => {
828            let s = pair.as_str();
829            s[1..s.len() - 1].to_string()
830        }
831        Rule::identifier => pair.as_str().to_string(),
832        Rule::datasource_value => {
833            if let Some(inner) = pair.clone().into_inner().next() {
834                extract_datasource_string(&inner)
835            } else {
836                pair.as_str().to_string()
837            }
838        }
839        _ => pair.as_str().to_string(),
840    }
841}
842
843/// Extract a string value from a pest pair, handling nesting.
844fn extract_string_from_arg(pair: pest::iterators::Pair<'_, Rule>) -> String {
845    match pair.as_rule() {
846        Rule::string_literal => {
847            let s = pair.as_str();
848            s[1..s.len() - 1].to_string()
849        }
850        Rule::attribute_value => {
851            // Unwrap nested attribute_value
852            if let Some(inner) = pair.into_inner().next() {
853                extract_string_from_arg(inner)
854            } else {
855                String::new()
856            }
857        }
858        _ => pair.as_str().to_string(),
859    }
860}
861
862/// Parse a server property value.
863fn parse_server_property_value(
864    pair: pest::iterators::Pair<'_, Rule>,
865) -> SchemaResult<ServerPropertyValue> {
866    match pair.as_rule() {
867        Rule::string_literal => {
868            let s = pair.as_str();
869            // Remove quotes
870            let unquoted = &s[1..s.len() - 1];
871            Ok(ServerPropertyValue::String(unquoted.to_string()))
872        }
873        Rule::number_literal => {
874            let s = pair.as_str();
875            Ok(ServerPropertyValue::Number(s.parse().unwrap_or(0.0)))
876        }
877        Rule::boolean_literal => Ok(ServerPropertyValue::Boolean(pair.as_str() == "true")),
878        Rule::identifier => Ok(ServerPropertyValue::Identifier(pair.as_str().to_string())),
879        Rule::function_call => {
880            // Handle env("VAR") and other function calls
881            let mut inner = pair.into_inner();
882            let func_name = inner.next().unwrap().as_str();
883            if func_name == "env"
884                && let Some(arg) = inner.next()
885            {
886                let var_name = extract_string_from_arg(arg);
887                return Ok(ServerPropertyValue::EnvVar(var_name));
888            }
889            // For other functions, store as identifier
890            Ok(ServerPropertyValue::Identifier(func_name.to_string()))
891        }
892        Rule::array_literal => {
893            let values: Result<Vec<_>, _> =
894                pair.into_inner().map(parse_server_property_value).collect();
895            Ok(ServerPropertyValue::Array(values?))
896        }
897        Rule::attribute_value => {
898            // Unwrap nested attribute_value
899            parse_server_property_value(pair.into_inner().next().unwrap())
900        }
901        _ => {
902            // Fallback: treat as identifier
903            Ok(ServerPropertyValue::Identifier(pair.as_str().to_string()))
904        }
905    }
906}
907
908/// Parse a PostgreSQL Row-Level Security policy definition.
909fn parse_policy(pair: pest::iterators::Pair<'_, Rule>) -> SchemaResult<Policy> {
910    let span = pair.as_span();
911    let mut inner = pair.into_inner();
912
913    // First identifier is the policy name
914    let name_pair = inner.next().unwrap();
915    let name = Ident::new(
916        name_pair.as_str(),
917        Span::new(name_pair.as_span().start(), name_pair.as_span().end()),
918    );
919
920    // Second identifier is the table name
921    let table_pair = inner.next().unwrap();
922    let table = Ident::new(
923        table_pair.as_str(),
924        Span::new(table_pair.as_span().start(), table_pair.as_span().end()),
925    );
926
927    let mut policy = Policy::new(name, table, Span::new(span.start(), span.end()));
928    // Reset commands to empty - will be set by 'for' clause if present
929    policy.commands = vec![];
930
931    for item in inner {
932        match item.as_rule() {
933            Rule::policy_item => {
934                let inner_item = item.into_inner().next().unwrap();
935                parse_policy_item(&mut policy, inner_item)?;
936            }
937            Rule::policy_for
938            | Rule::policy_to
939            | Rule::policy_as
940            | Rule::policy_using
941            | Rule::policy_check => {
942                parse_policy_item(&mut policy, item)?;
943            }
944            _ => {}
945        }
946    }
947
948    // Default to ALL if no commands specified
949    if policy.commands.is_empty() {
950        policy.commands.push(PolicyCommand::All);
951    }
952
953    Ok(policy)
954}
955
956/// Parse a single policy item (for, to, as, using, check, mssqlSchema, mssqlBlock).
957fn parse_policy_item(
958    policy: &mut Policy,
959    pair: pest::iterators::Pair<'_, Rule>,
960) -> SchemaResult<()> {
961    match pair.as_rule() {
962        Rule::policy_for => {
963            let inner = pair.into_inner().next().unwrap();
964            match inner.as_rule() {
965                Rule::policy_command => {
966                    if let Some(cmd) = PolicyCommand::from_str(inner.as_str()) {
967                        policy.add_command(cmd);
968                    }
969                }
970                Rule::policy_command_list => {
971                    for cmd_pair in inner.into_inner() {
972                        if cmd_pair.as_rule() == Rule::policy_command
973                            && let Some(cmd) = PolicyCommand::from_str(cmd_pair.as_str())
974                        {
975                            policy.add_command(cmd);
976                        }
977                    }
978                }
979                _ => {}
980            }
981        }
982        Rule::policy_to => {
983            let inner = pair.into_inner().next().unwrap();
984            match inner.as_rule() {
985                Rule::identifier => {
986                    policy.add_role(inner.as_str());
987                }
988                Rule::policy_role_list => {
989                    for role_pair in inner.into_inner() {
990                        if role_pair.as_rule() == Rule::identifier {
991                            policy.add_role(role_pair.as_str());
992                        }
993                    }
994                }
995                _ => {}
996            }
997        }
998        Rule::policy_as => {
999            let inner = pair.into_inner().next().unwrap();
1000            if inner.as_rule() == Rule::policy_type
1001                && let Some(policy_type) = PolicyType::from_str(inner.as_str())
1002            {
1003                policy.policy_type = policy_type;
1004            }
1005        }
1006        Rule::policy_using => {
1007            let inner = pair.into_inner().next().unwrap();
1008            let expr = extract_policy_expression(&inner);
1009            policy.using_expr = Some(expr);
1010        }
1011        Rule::policy_check => {
1012            let inner = pair.into_inner().next().unwrap();
1013            let expr = extract_policy_expression(&inner);
1014            policy.check_expr = Some(expr);
1015        }
1016        Rule::policy_mssql_schema => {
1017            let inner = pair.into_inner().next().unwrap();
1018            if inner.as_rule() == Rule::string_literal {
1019                let s = inner.as_str();
1020                let schema = &s[1..s.len() - 1]; // Remove quotes
1021                policy.mssql_schema = Some(SmolStr::new(schema));
1022            }
1023        }
1024        Rule::policy_mssql_block => {
1025            let inner = pair.into_inner().next().unwrap();
1026            match inner.as_rule() {
1027                Rule::mssql_block_op => {
1028                    if let Some(op) = MssqlBlockOperation::from_str(inner.as_str()) {
1029                        policy.add_mssql_block_operation(op);
1030                    }
1031                }
1032                Rule::mssql_block_op_list => {
1033                    for op_pair in inner.into_inner() {
1034                        if op_pair.as_rule() == Rule::mssql_block_op
1035                            && let Some(op) = MssqlBlockOperation::from_str(op_pair.as_str())
1036                        {
1037                            policy.add_mssql_block_operation(op);
1038                        }
1039                    }
1040                }
1041                _ => {}
1042            }
1043        }
1044        _ => {}
1045    }
1046    Ok(())
1047}
1048
1049/// Extract the expression from a string literal or multiline string.
1050fn extract_policy_expression(pair: &pest::iterators::Pair<'_, Rule>) -> String {
1051    let s = pair.as_str();
1052    match pair.as_rule() {
1053        Rule::multiline_string => {
1054            // Remove triple quotes
1055            s.trim_start_matches("\"\"\"")
1056                .trim_end_matches("\"\"\"")
1057                .trim()
1058                .to_string()
1059        }
1060        Rule::string_literal => {
1061            // Remove single quotes
1062            s[1..s.len() - 1].to_string()
1063        }
1064        _ => s.to_string(),
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071
1072    // ==================== Basic Model Parsing ====================
1073
1074    #[test]
1075    fn test_parse_simple_model() {
1076        let schema = parse_schema(
1077            r#"
1078            model User {
1079                id    Int    @id @auto
1080                email String @unique
1081                name  String?
1082            }
1083        "#,
1084        )
1085        .unwrap();
1086
1087        assert_eq!(schema.models.len(), 1);
1088        let user = schema.get_model("User").unwrap();
1089        assert_eq!(user.fields.len(), 3);
1090        assert!(user.get_field("id").unwrap().is_id());
1091        assert!(user.get_field("email").unwrap().is_unique());
1092        assert!(user.get_field("name").unwrap().is_optional());
1093    }
1094
1095    #[test]
1096    fn test_parse_model_name() {
1097        let schema = parse_schema(
1098            r#"
1099            model BlogPost {
1100                id Int @id
1101            }
1102        "#,
1103        )
1104        .unwrap();
1105
1106        assert!(schema.get_model("BlogPost").is_some());
1107    }
1108
1109    #[test]
1110    fn test_parse_multiple_models() {
1111        let schema = parse_schema(
1112            r#"
1113            model User {
1114                id Int @id
1115            }
1116
1117            model Post {
1118                id Int @id
1119            }
1120
1121            model Comment {
1122                id Int @id
1123            }
1124        "#,
1125        )
1126        .unwrap();
1127
1128        assert_eq!(schema.models.len(), 3);
1129        assert!(schema.get_model("User").is_some());
1130        assert!(schema.get_model("Post").is_some());
1131        assert!(schema.get_model("Comment").is_some());
1132    }
1133
1134    // ==================== Field Type Parsing ====================
1135
1136    #[test]
1137    fn test_parse_all_scalar_types() {
1138        let schema = parse_schema(
1139            r#"
1140            model AllTypes {
1141                id       Int      @id
1142                big      BigInt
1143                float_f  Float
1144                decimal  Decimal
1145                str      String
1146                bool     Boolean
1147                datetime DateTime
1148                date     Date
1149                time     Time
1150                json     Json
1151                bytes    Bytes
1152                uuid     Uuid
1153                cuid     Cuid
1154                cuid2    Cuid2
1155                nanoid   NanoId
1156                ulid     Ulid
1157            }
1158        "#,
1159        )
1160        .unwrap();
1161
1162        let model = schema.get_model("AllTypes").unwrap();
1163        assert_eq!(model.fields.len(), 16);
1164
1165        assert!(matches!(
1166            model.get_field("id").unwrap().field_type,
1167            FieldType::Scalar(ScalarType::Int)
1168        ));
1169        assert!(matches!(
1170            model.get_field("big").unwrap().field_type,
1171            FieldType::Scalar(ScalarType::BigInt)
1172        ));
1173        assert!(matches!(
1174            model.get_field("str").unwrap().field_type,
1175            FieldType::Scalar(ScalarType::String)
1176        ));
1177        assert!(matches!(
1178            model.get_field("bool").unwrap().field_type,
1179            FieldType::Scalar(ScalarType::Boolean)
1180        ));
1181        assert!(matches!(
1182            model.get_field("datetime").unwrap().field_type,
1183            FieldType::Scalar(ScalarType::DateTime)
1184        ));
1185        assert!(matches!(
1186            model.get_field("uuid").unwrap().field_type,
1187            FieldType::Scalar(ScalarType::Uuid)
1188        ));
1189        assert!(matches!(
1190            model.get_field("cuid").unwrap().field_type,
1191            FieldType::Scalar(ScalarType::Cuid)
1192        ));
1193        assert!(matches!(
1194            model.get_field("cuid2").unwrap().field_type,
1195            FieldType::Scalar(ScalarType::Cuid2)
1196        ));
1197        assert!(matches!(
1198            model.get_field("nanoid").unwrap().field_type,
1199            FieldType::Scalar(ScalarType::NanoId)
1200        ));
1201        assert!(matches!(
1202            model.get_field("ulid").unwrap().field_type,
1203            FieldType::Scalar(ScalarType::Ulid)
1204        ));
1205    }
1206
1207    #[test]
1208    fn test_parse_optional_field() {
1209        let schema = parse_schema(
1210            r#"
1211            model User {
1212                id   Int     @id
1213                bio  String?
1214                age  Int?
1215            }
1216        "#,
1217        )
1218        .unwrap();
1219
1220        let user = schema.get_model("User").unwrap();
1221        assert!(!user.get_field("id").unwrap().is_optional());
1222        assert!(user.get_field("bio").unwrap().is_optional());
1223        assert!(user.get_field("age").unwrap().is_optional());
1224    }
1225
1226    #[test]
1227    fn test_parse_list_field() {
1228        let schema = parse_schema(
1229            r#"
1230            model User {
1231                id    Int      @id
1232                tags  String[]
1233                posts Post[]
1234            }
1235        "#,
1236        )
1237        .unwrap();
1238
1239        let user = schema.get_model("User").unwrap();
1240        assert!(user.get_field("tags").unwrap().is_list());
1241        assert!(user.get_field("posts").unwrap().is_list());
1242    }
1243
1244    #[test]
1245    fn test_parse_optional_list_field() {
1246        let schema = parse_schema(
1247            r#"
1248            model User {
1249                id       Int       @id
1250                metadata String[]?
1251            }
1252        "#,
1253        )
1254        .unwrap();
1255
1256        let user = schema.get_model("User").unwrap();
1257        let metadata = user.get_field("metadata").unwrap();
1258        assert!(metadata.is_list());
1259        assert!(metadata.is_optional());
1260    }
1261
1262    // ==================== Attribute Parsing ====================
1263
1264    #[test]
1265    fn test_parse_id_attribute() {
1266        let schema = parse_schema(
1267            r#"
1268            model User {
1269                id Int @id
1270            }
1271        "#,
1272        )
1273        .unwrap();
1274
1275        let user = schema.get_model("User").unwrap();
1276        assert!(user.get_field("id").unwrap().is_id());
1277    }
1278
1279    #[test]
1280    fn test_parse_unique_attribute() {
1281        let schema = parse_schema(
1282            r#"
1283            model User {
1284                id    Int    @id
1285                email String @unique
1286            }
1287        "#,
1288        )
1289        .unwrap();
1290
1291        let user = schema.get_model("User").unwrap();
1292        assert!(user.get_field("email").unwrap().is_unique());
1293    }
1294
1295    #[test]
1296    fn test_parse_default_int() {
1297        let schema = parse_schema(
1298            r#"
1299            model Counter {
1300                id    Int @id
1301                count Int @default(0)
1302            }
1303        "#,
1304        )
1305        .unwrap();
1306
1307        let counter = schema.get_model("Counter").unwrap();
1308        let count_field = counter.get_field("count").unwrap();
1309        let attrs = count_field.extract_attributes();
1310        assert!(attrs.default.is_some());
1311        assert_eq!(attrs.default.unwrap().as_int(), Some(0));
1312    }
1313
1314    #[test]
1315    fn test_parse_default_string() {
1316        let schema = parse_schema(
1317            r#"
1318            model User {
1319                id     Int    @id
1320                status String @default("active")
1321            }
1322        "#,
1323        )
1324        .unwrap();
1325
1326        let user = schema.get_model("User").unwrap();
1327        let status = user.get_field("status").unwrap();
1328        let attrs = status.extract_attributes();
1329        assert!(attrs.default.is_some());
1330        assert_eq!(attrs.default.unwrap().as_string(), Some("active"));
1331    }
1332
1333    #[test]
1334    fn test_parse_default_boolean() {
1335        let schema = parse_schema(
1336            r#"
1337            model Post {
1338                id        Int     @id
1339                published Boolean @default(false)
1340            }
1341        "#,
1342        )
1343        .unwrap();
1344
1345        let post = schema.get_model("Post").unwrap();
1346        let published = post.get_field("published").unwrap();
1347        let attrs = published.extract_attributes();
1348        assert!(attrs.default.is_some());
1349        assert_eq!(attrs.default.unwrap().as_bool(), Some(false));
1350    }
1351
1352    #[test]
1353    fn test_parse_default_function() {
1354        let schema = parse_schema(
1355            r#"
1356            model User {
1357                id        Int      @id
1358                createdAt DateTime @default(now())
1359            }
1360        "#,
1361        )
1362        .unwrap();
1363
1364        let user = schema.get_model("User").unwrap();
1365        let created_at = user.get_field("createdAt").unwrap();
1366        let attrs = created_at.extract_attributes();
1367        assert!(attrs.default.is_some());
1368        if let Some(AttributeValue::Function(name, _)) = attrs.default {
1369            assert_eq!(name.as_str(), "now");
1370        } else {
1371            panic!("Expected function default");
1372        }
1373    }
1374
1375    #[test]
1376    fn test_parse_updated_at_attribute() {
1377        let schema = parse_schema(
1378            r#"
1379            model User {
1380                id        Int      @id
1381                updatedAt DateTime @updated_at
1382            }
1383        "#,
1384        )
1385        .unwrap();
1386
1387        let user = schema.get_model("User").unwrap();
1388        let updated_at = user.get_field("updatedAt").unwrap();
1389        let attrs = updated_at.extract_attributes();
1390        assert!(attrs.is_updated_at);
1391    }
1392
1393    #[test]
1394    fn test_parse_map_attribute() {
1395        let schema = parse_schema(
1396            r#"
1397            model User {
1398                id    Int    @id
1399                email String @map("email_address")
1400            }
1401        "#,
1402        )
1403        .unwrap();
1404
1405        let user = schema.get_model("User").unwrap();
1406        let email = user.get_field("email").unwrap();
1407        let attrs = email.extract_attributes();
1408        assert_eq!(attrs.map, Some("email_address".to_string()));
1409    }
1410
1411    #[test]
1412    fn test_parse_multiple_attributes() {
1413        let schema = parse_schema(
1414            r#"
1415            model User {
1416                id    Int    @id @auto
1417                email String @unique @index
1418            }
1419        "#,
1420        )
1421        .unwrap();
1422
1423        let user = schema.get_model("User").unwrap();
1424        let id = user.get_field("id").unwrap();
1425        let email = user.get_field("email").unwrap();
1426
1427        let id_attrs = id.extract_attributes();
1428        assert!(id_attrs.is_id);
1429        assert!(id_attrs.is_auto);
1430
1431        let email_attrs = email.extract_attributes();
1432        assert!(email_attrs.is_unique);
1433        assert!(email_attrs.is_indexed);
1434    }
1435
1436    // ==================== Model Attribute Parsing ====================
1437
1438    #[test]
1439    fn test_parse_model_map_attribute() {
1440        let schema = parse_schema(
1441            r#"
1442            model User {
1443                id Int @id
1444
1445                @@map("app_users")
1446            }
1447        "#,
1448        )
1449        .unwrap();
1450
1451        let user = schema.get_model("User").unwrap();
1452        assert_eq!(user.table_name(), "app_users");
1453    }
1454
1455    #[test]
1456    fn test_parse_model_index_attribute() {
1457        let schema = parse_schema(
1458            r#"
1459            model User {
1460                id    Int    @id
1461                email String
1462                name  String
1463
1464                @@index([email, name])
1465            }
1466        "#,
1467        )
1468        .unwrap();
1469
1470        let user = schema.get_model("User").unwrap();
1471        assert!(user.has_attribute("index"));
1472    }
1473
1474    #[test]
1475    fn test_parse_composite_primary_key() {
1476        let schema = parse_schema(
1477            r#"
1478            model PostTag {
1479                postId Int
1480                tagId  Int
1481
1482                @@id([postId, tagId])
1483            }
1484        "#,
1485        )
1486        .unwrap();
1487
1488        let post_tag = schema.get_model("PostTag").unwrap();
1489        assert!(post_tag.has_attribute("id"));
1490    }
1491
1492    // ==================== Enum Parsing ====================
1493
1494    #[test]
1495    fn test_parse_enum() {
1496        let schema = parse_schema(
1497            r#"
1498            enum Role {
1499                User
1500                Admin
1501                Moderator
1502            }
1503        "#,
1504        )
1505        .unwrap();
1506
1507        assert_eq!(schema.enums.len(), 1);
1508        let role = schema.get_enum("Role").unwrap();
1509        assert_eq!(role.variants.len(), 3);
1510    }
1511
1512    #[test]
1513    fn test_parse_enum_variant_names() {
1514        let schema = parse_schema(
1515            r#"
1516            enum Status {
1517                Pending
1518                Active
1519                Completed
1520                Cancelled
1521            }
1522        "#,
1523        )
1524        .unwrap();
1525
1526        let status = schema.get_enum("Status").unwrap();
1527        assert!(status.get_variant("Pending").is_some());
1528        assert!(status.get_variant("Active").is_some());
1529        assert!(status.get_variant("Completed").is_some());
1530        assert!(status.get_variant("Cancelled").is_some());
1531    }
1532
1533    #[test]
1534    fn test_parse_enum_with_map() {
1535        let schema = parse_schema(
1536            r#"
1537            enum Role {
1538                User  @map("USER")
1539                Admin @map("ADMINISTRATOR")
1540            }
1541        "#,
1542        )
1543        .unwrap();
1544
1545        let role = schema.get_enum("Role").unwrap();
1546        let user_variant = role.get_variant("User").unwrap();
1547        assert_eq!(user_variant.db_value(), "USER");
1548
1549        let admin_variant = role.get_variant("Admin").unwrap();
1550        assert_eq!(admin_variant.db_value(), "ADMINISTRATOR");
1551    }
1552
1553    // ==================== Relation Parsing ====================
1554
1555    #[test]
1556    fn test_parse_one_to_many_relation() {
1557        let schema = parse_schema(
1558            r#"
1559            model User {
1560                id    Int    @id
1561                posts Post[]
1562            }
1563
1564            model Post {
1565                id       Int  @id
1566                authorId Int
1567                author   User @relation(fields: [authorId], references: [id])
1568            }
1569        "#,
1570        )
1571        .unwrap();
1572
1573        let user = schema.get_model("User").unwrap();
1574        let post = schema.get_model("Post").unwrap();
1575
1576        assert!(user.get_field("posts").unwrap().is_list());
1577        assert!(post.get_field("author").unwrap().is_relation());
1578    }
1579
1580    #[test]
1581    fn test_parse_relation_with_actions() {
1582        let schema = parse_schema(
1583            r#"
1584            model Post {
1585                id       Int  @id
1586                authorId Int
1587                author   User @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: Restrict)
1588            }
1589
1590            model User {
1591                id    Int    @id
1592                posts Post[]
1593            }
1594        "#,
1595        )
1596        .unwrap();
1597
1598        let post = schema.get_model("Post").unwrap();
1599        let author = post.get_field("author").unwrap();
1600        let attrs = author.extract_attributes();
1601
1602        assert!(attrs.relation.is_some());
1603        let rel = attrs.relation.unwrap();
1604        assert_eq!(rel.on_delete, Some(ReferentialAction::Cascade));
1605        assert_eq!(rel.on_update, Some(ReferentialAction::Restrict));
1606    }
1607
1608    // ==================== Documentation Parsing ====================
1609
1610    #[test]
1611    fn test_parse_model_documentation() {
1612        let schema = parse_schema(
1613            r#"/// Represents a user in the system
1614model User {
1615    id Int @id
1616}"#,
1617        )
1618        .unwrap();
1619
1620        let user = schema.get_model("User").unwrap();
1621        // Documentation parsing is optional - the model should still parse
1622        // If documentation is present, it should contain "user"
1623        if let Some(doc) = &user.documentation {
1624            assert!(doc.text.contains("user"));
1625        }
1626    }
1627
1628    // ==================== Complete Schema Parsing ====================
1629
1630    #[test]
1631    fn test_parse_complete_schema() {
1632        let schema = parse_schema(
1633            r#"
1634            /// User model
1635            model User {
1636                id        Int      @id @auto
1637                email     String   @unique
1638                name      String?
1639                role      Role     @default(User)
1640                posts     Post[]
1641                profile   Profile?
1642                createdAt DateTime @default(now())
1643                updatedAt DateTime @updated_at
1644
1645                @@map("users")
1646                @@index([email])
1647            }
1648
1649            model Post {
1650                id        Int      @id @auto
1651                title     String
1652                content   String?
1653                published Boolean  @default(false)
1654                authorId  Int
1655                author    User     @relation(fields: [authorId], references: [id])
1656                tags      Tag[]
1657                createdAt DateTime @default(now())
1658
1659                @@index([authorId])
1660            }
1661
1662            model Profile {
1663                id     Int    @id @auto
1664                bio    String?
1665                userId Int    @unique
1666                user   User   @relation(fields: [userId], references: [id])
1667            }
1668
1669            model Tag {
1670                id    Int    @id @auto
1671                name  String @unique
1672                posts Post[]
1673            }
1674
1675            enum Role {
1676                User
1677                Admin
1678                Moderator
1679            }
1680        "#,
1681        )
1682        .unwrap();
1683
1684        // Verify models
1685        assert_eq!(schema.models.len(), 4);
1686        assert!(schema.get_model("User").is_some());
1687        assert!(schema.get_model("Post").is_some());
1688        assert!(schema.get_model("Profile").is_some());
1689        assert!(schema.get_model("Tag").is_some());
1690
1691        // Verify enums
1692        assert_eq!(schema.enums.len(), 1);
1693        assert!(schema.get_enum("Role").is_some());
1694
1695        // Verify User model details
1696        let user = schema.get_model("User").unwrap();
1697        assert_eq!(user.table_name(), "users");
1698        assert_eq!(user.fields.len(), 8);
1699        assert!(user.has_attribute("index"));
1700
1701        // Verify relations
1702        let post = schema.get_model("Post").unwrap();
1703        assert!(post.get_field("author").unwrap().is_relation());
1704    }
1705
1706    // ==================== Error Handling ====================
1707
1708    #[test]
1709    fn test_parse_invalid_syntax() {
1710        let result = parse_schema("model { broken }");
1711        assert!(result.is_err());
1712    }
1713
1714    #[test]
1715    fn test_parse_empty_schema() {
1716        let schema = parse_schema("").unwrap();
1717        assert!(schema.models.is_empty());
1718        assert!(schema.enums.is_empty());
1719    }
1720
1721    #[test]
1722    fn test_parse_whitespace_only() {
1723        let schema = parse_schema("   \n\t   \n   ").unwrap();
1724        assert!(schema.models.is_empty());
1725    }
1726
1727    #[test]
1728    fn test_parse_comments_only() {
1729        let schema = parse_schema(
1730            r#"
1731            // This is a comment
1732            // Another comment
1733        "#,
1734        )
1735        .unwrap();
1736        assert!(schema.models.is_empty());
1737    }
1738
1739    // ==================== Edge Cases ====================
1740
1741    #[test]
1742    fn test_parse_model_with_no_fields() {
1743        // Models with no fields should still parse (might be invalid semantically but syntactically ok)
1744        let result = parse_schema(
1745            r#"
1746            model Empty {
1747            }
1748        "#,
1749        );
1750        // This might error or succeed depending on grammar - just verify it doesn't panic
1751        let _ = result;
1752    }
1753
1754    #[test]
1755    fn test_parse_long_identifier() {
1756        let schema = parse_schema(
1757            r#"
1758            model VeryLongModelNameThatIsStillValid {
1759                someVeryLongFieldNameThatShouldWork Int @id
1760            }
1761        "#,
1762        )
1763        .unwrap();
1764
1765        assert!(
1766            schema
1767                .get_model("VeryLongModelNameThatIsStillValid")
1768                .is_some()
1769        );
1770    }
1771
1772    #[test]
1773    fn test_parse_underscore_identifiers() {
1774        let schema = parse_schema(
1775            r#"
1776            model user_account {
1777                user_id     Int @id
1778                created_at  DateTime
1779            }
1780        "#,
1781        )
1782        .unwrap();
1783
1784        let model = schema.get_model("user_account").unwrap();
1785        assert!(model.get_field("user_id").is_some());
1786        assert!(model.get_field("created_at").is_some());
1787    }
1788
1789    #[test]
1790    fn test_parse_negative_default() {
1791        let schema = parse_schema(
1792            r#"
1793            model Config {
1794                id       Int @id
1795                minValue Int @default(-100)
1796            }
1797        "#,
1798        )
1799        .unwrap();
1800
1801        let config = schema.get_model("Config").unwrap();
1802        let min_value = config.get_field("minValue").unwrap();
1803        let attrs = min_value.extract_attributes();
1804        assert!(attrs.default.is_some());
1805    }
1806
1807    #[test]
1808    fn test_parse_float_default() {
1809        let schema = parse_schema(
1810            r#"
1811            model Product {
1812                id    Int   @id
1813                price Float @default(9.99)
1814            }
1815        "#,
1816        )
1817        .unwrap();
1818
1819        let product = schema.get_model("Product").unwrap();
1820        let price = product.get_field("price").unwrap();
1821        let attrs = price.extract_attributes();
1822        assert!(attrs.default.is_some());
1823    }
1824
1825    // ==================== Server Group Parsing ====================
1826
1827    #[test]
1828    fn test_parse_simple_server_group() {
1829        let schema = parse_schema(
1830            r#"
1831            serverGroup MainCluster {
1832                server primary {
1833                    url = "postgres://localhost/db"
1834                    role = "primary"
1835                }
1836            }
1837        "#,
1838        )
1839        .unwrap();
1840
1841        assert_eq!(schema.server_groups.len(), 1);
1842        let cluster = schema.get_server_group("MainCluster").unwrap();
1843        assert_eq!(cluster.servers.len(), 1);
1844        assert!(cluster.servers.contains_key("primary"));
1845    }
1846
1847    #[test]
1848    fn test_parse_server_group_with_multiple_servers() {
1849        let schema = parse_schema(
1850            r#"
1851            serverGroup ReadReplicas {
1852                server primary {
1853                    url = "postgres://primary.db.com/app"
1854                    role = "primary"
1855                    weight = 1
1856                }
1857
1858                server replica1 {
1859                    url = "postgres://replica1.db.com/app"
1860                    role = "replica"
1861                    weight = 2
1862                }
1863
1864                server replica2 {
1865                    url = "postgres://replica2.db.com/app"
1866                    role = "replica"
1867                    weight = 2
1868                }
1869            }
1870        "#,
1871        )
1872        .unwrap();
1873
1874        let cluster = schema.get_server_group("ReadReplicas").unwrap();
1875        assert_eq!(cluster.servers.len(), 3);
1876
1877        let primary = cluster.servers.get("primary").unwrap();
1878        assert_eq!(primary.role(), Some(ServerRole::Primary));
1879        assert_eq!(primary.weight(), Some(1));
1880
1881        let replica1 = cluster.servers.get("replica1").unwrap();
1882        assert_eq!(replica1.role(), Some(ServerRole::Replica));
1883        assert_eq!(replica1.weight(), Some(2));
1884    }
1885
1886    #[test]
1887    fn test_parse_server_group_with_attributes() {
1888        let schema = parse_schema(
1889            r#"
1890            serverGroup ProductionCluster {
1891                @@strategy(ReadReplica)
1892                @@loadBalance(RoundRobin)
1893
1894                server main {
1895                    url = "postgres://main/db"
1896                    role = "primary"
1897                }
1898            }
1899        "#,
1900        )
1901        .unwrap();
1902
1903        let cluster = schema.get_server_group("ProductionCluster").unwrap();
1904        assert!(cluster.attributes.iter().any(|a| a.name.name == "strategy"));
1905        assert!(
1906            cluster
1907                .attributes
1908                .iter()
1909                .any(|a| a.name.name == "loadBalance")
1910        );
1911    }
1912
1913    #[test]
1914    fn test_parse_server_group_with_env_vars() {
1915        let schema = parse_schema(
1916            r#"
1917            serverGroup EnvCluster {
1918                server db1 {
1919                    url = env("PRIMARY_DB_URL")
1920                    role = "primary"
1921                }
1922            }
1923        "#,
1924        )
1925        .unwrap();
1926
1927        let cluster = schema.get_server_group("EnvCluster").unwrap();
1928        let server = cluster.servers.get("db1").unwrap();
1929
1930        // Check that the URL is stored as an env var reference
1931        if let Some(ServerPropertyValue::EnvVar(var)) = server.get_property("url") {
1932            assert_eq!(var, "PRIMARY_DB_URL");
1933        } else {
1934            panic!("Expected env var for url property");
1935        }
1936    }
1937
1938    #[test]
1939    fn test_parse_server_group_with_boolean_property() {
1940        let schema = parse_schema(
1941            r#"
1942            serverGroup TestCluster {
1943                server replica {
1944                    url = "postgres://replica/db"
1945                    role = "replica"
1946                    readOnly = true
1947                }
1948            }
1949        "#,
1950        )
1951        .unwrap();
1952
1953        let cluster = schema.get_server_group("TestCluster").unwrap();
1954        let server = cluster.servers.get("replica").unwrap();
1955        assert!(server.is_read_only());
1956    }
1957
1958    #[test]
1959    fn test_parse_server_group_with_numeric_properties() {
1960        let schema = parse_schema(
1961            r#"
1962            serverGroup NumericCluster {
1963                server db {
1964                    url = "postgres://localhost/db"
1965                    weight = 5
1966                    priority = 1
1967                    maxConnections = 100
1968                }
1969            }
1970        "#,
1971        )
1972        .unwrap();
1973
1974        let cluster = schema.get_server_group("NumericCluster").unwrap();
1975        let server = cluster.servers.get("db").unwrap();
1976
1977        assert_eq!(server.weight(), Some(5));
1978        assert_eq!(server.priority(), Some(1));
1979        assert_eq!(server.max_connections(), Some(100));
1980    }
1981
1982    #[test]
1983    fn test_parse_server_group_with_region() {
1984        let schema = parse_schema(
1985            r#"
1986            serverGroup GeoCluster {
1987                server usEast {
1988                    url = "postgres://us-east.db.com/app"
1989                    role = "replica"
1990                    region = "us-east-1"
1991                }
1992
1993                server usWest {
1994                    url = "postgres://us-west.db.com/app"
1995                    role = "replica"
1996                    region = "us-west-2"
1997                }
1998            }
1999        "#,
2000        )
2001        .unwrap();
2002
2003        let cluster = schema.get_server_group("GeoCluster").unwrap();
2004
2005        let us_east = cluster.servers.get("usEast").unwrap();
2006        assert_eq!(us_east.region(), Some("us-east-1"));
2007
2008        let us_west = cluster.servers.get("usWest").unwrap();
2009        assert_eq!(us_west.region(), Some("us-west-2"));
2010
2011        // Test region filtering
2012        let us_east_servers = cluster.servers_in_region("us-east-1");
2013        assert_eq!(us_east_servers.len(), 1);
2014    }
2015
2016    #[test]
2017    fn test_parse_multiple_server_groups() {
2018        let schema = parse_schema(
2019            r#"
2020            serverGroup Cluster1 {
2021                server db1 {
2022                    url = "postgres://db1/app"
2023                }
2024            }
2025
2026            serverGroup Cluster2 {
2027                server db2 {
2028                    url = "postgres://db2/app"
2029                }
2030            }
2031
2032            serverGroup Cluster3 {
2033                server db3 {
2034                    url = "postgres://db3/app"
2035                }
2036            }
2037        "#,
2038        )
2039        .unwrap();
2040
2041        assert_eq!(schema.server_groups.len(), 3);
2042        assert!(schema.get_server_group("Cluster1").is_some());
2043        assert!(schema.get_server_group("Cluster2").is_some());
2044        assert!(schema.get_server_group("Cluster3").is_some());
2045    }
2046
2047    #[test]
2048    fn test_parse_schema_with_models_and_server_groups() {
2049        let schema = parse_schema(
2050            r#"
2051            model User {
2052                id    Int    @id @auto
2053                email String @unique
2054            }
2055
2056            serverGroup Database {
2057                @@strategy(ReadReplica)
2058
2059                server primary {
2060                    url = env("DATABASE_URL")
2061                    role = "primary"
2062                }
2063            }
2064
2065            model Post {
2066                id       Int    @id @auto
2067                title    String
2068                authorId Int
2069            }
2070        "#,
2071        )
2072        .unwrap();
2073
2074        assert_eq!(schema.models.len(), 2);
2075        assert!(schema.get_model("User").is_some());
2076        assert!(schema.get_model("Post").is_some());
2077
2078        assert_eq!(schema.server_groups.len(), 1);
2079        assert!(schema.get_server_group("Database").is_some());
2080    }
2081
2082    #[test]
2083    fn test_parse_server_group_with_health_check() {
2084        let schema = parse_schema(
2085            r#"
2086            serverGroup HealthyCluster {
2087                server monitored {
2088                    url = "postgres://localhost/db"
2089                    healthCheck = "/health"
2090                }
2091            }
2092        "#,
2093        )
2094        .unwrap();
2095
2096        let cluster = schema.get_server_group("HealthyCluster").unwrap();
2097        let server = cluster.servers.get("monitored").unwrap();
2098        assert_eq!(server.health_check(), Some("/health"));
2099    }
2100
2101    #[test]
2102    fn test_server_group_failover_order() {
2103        let schema = parse_schema(
2104            r#"
2105            serverGroup FailoverCluster {
2106                server db3 {
2107                    url = "postgres://db3/app"
2108                    priority = 3
2109                }
2110
2111                server db1 {
2112                    url = "postgres://db1/app"
2113                    priority = 1
2114                }
2115
2116                server db2 {
2117                    url = "postgres://db2/app"
2118                    priority = 2
2119                }
2120            }
2121        "#,
2122        )
2123        .unwrap();
2124
2125        let cluster = schema.get_server_group("FailoverCluster").unwrap();
2126        let ordered = cluster.failover_order();
2127
2128        assert_eq!(ordered[0].name.name.as_str(), "db1");
2129        assert_eq!(ordered[1].name.name.as_str(), "db2");
2130        assert_eq!(ordered[2].name.name.as_str(), "db3");
2131    }
2132
2133    #[test]
2134    fn test_server_group_names() {
2135        let schema = parse_schema(
2136            r#"
2137            serverGroup Alpha {
2138                server s1 { url = "pg://a" }
2139            }
2140            serverGroup Beta {
2141                server s2 { url = "pg://b" }
2142            }
2143        "#,
2144        )
2145        .unwrap();
2146
2147        let names: Vec<_> = schema.server_group_names().collect();
2148        assert_eq!(names.len(), 2);
2149        assert!(names.contains(&"Alpha"));
2150        assert!(names.contains(&"Beta"));
2151    }
2152
2153    // ==================== Policy Parsing ====================
2154
2155    #[test]
2156    fn test_parse_simple_policy() {
2157        let schema = parse_schema(
2158            r#"
2159            policy UserReadOwn on User {
2160                for SELECT
2161                using "id = current_user_id()"
2162            }
2163        "#,
2164        )
2165        .unwrap();
2166
2167        assert_eq!(schema.policies.len(), 1);
2168        let policy = schema.get_policy("UserReadOwn").unwrap();
2169        assert_eq!(policy.name(), "UserReadOwn");
2170        assert_eq!(policy.table(), "User");
2171        assert!(policy.applies_to(PolicyCommand::Select));
2172        assert!(!policy.applies_to(PolicyCommand::Insert));
2173        assert_eq!(policy.using_expr.as_deref(), Some("id = current_user_id()"));
2174    }
2175
2176    #[test]
2177    fn test_parse_policy_with_multiple_commands() {
2178        let schema = parse_schema(
2179            r#"
2180            policy UserModify on User {
2181                for [SELECT, UPDATE, DELETE]
2182                using "id = auth.uid()"
2183            }
2184        "#,
2185        )
2186        .unwrap();
2187
2188        let policy = schema.get_policy("UserModify").unwrap();
2189        assert!(policy.applies_to(PolicyCommand::Select));
2190        assert!(policy.applies_to(PolicyCommand::Update));
2191        assert!(policy.applies_to(PolicyCommand::Delete));
2192        assert!(!policy.applies_to(PolicyCommand::Insert));
2193    }
2194
2195    #[test]
2196    fn test_parse_policy_with_all_command() {
2197        let schema = parse_schema(
2198            r#"
2199            policy UserAll on User {
2200                for ALL
2201                using "true"
2202            }
2203        "#,
2204        )
2205        .unwrap();
2206
2207        let policy = schema.get_policy("UserAll").unwrap();
2208        assert!(policy.applies_to(PolicyCommand::Select));
2209        assert!(policy.applies_to(PolicyCommand::Insert));
2210        assert!(policy.applies_to(PolicyCommand::Update));
2211        assert!(policy.applies_to(PolicyCommand::Delete));
2212    }
2213
2214    #[test]
2215    fn test_parse_policy_with_roles() {
2216        let schema = parse_schema(
2217            r#"
2218            policy AuthenticatedRead on Document {
2219                for SELECT
2220                to authenticated
2221                using "true"
2222            }
2223        "#,
2224        )
2225        .unwrap();
2226
2227        let policy = schema.get_policy("AuthenticatedRead").unwrap();
2228        let roles = policy.effective_roles();
2229        assert!(roles.contains(&"authenticated"));
2230    }
2231
2232    #[test]
2233    fn test_parse_policy_with_multiple_roles() {
2234        let schema = parse_schema(
2235            r#"
2236            policy AdminModerator on Post {
2237                for [UPDATE, DELETE]
2238                to [admin, moderator]
2239                using "true"
2240            }
2241        "#,
2242        )
2243        .unwrap();
2244
2245        let policy = schema.get_policy("AdminModerator").unwrap();
2246        let roles = policy.effective_roles();
2247        assert!(roles.contains(&"admin"));
2248        assert!(roles.contains(&"moderator"));
2249    }
2250
2251    #[test]
2252    fn test_parse_policy_restrictive() {
2253        let schema = parse_schema(
2254            r#"
2255            policy OrgRestriction on Document {
2256                as RESTRICTIVE
2257                for SELECT
2258                using "org_id = current_org_id()"
2259            }
2260        "#,
2261        )
2262        .unwrap();
2263
2264        let policy = schema.get_policy("OrgRestriction").unwrap();
2265        assert!(policy.is_restrictive());
2266        assert!(!policy.is_permissive());
2267    }
2268
2269    #[test]
2270    fn test_parse_policy_permissive_explicit() {
2271        let schema = parse_schema(
2272            r#"
2273            policy Permissive on User {
2274                as PERMISSIVE
2275                for SELECT
2276                using "true"
2277            }
2278        "#,
2279        )
2280        .unwrap();
2281
2282        let policy = schema.get_policy("Permissive").unwrap();
2283        assert!(policy.is_permissive());
2284    }
2285
2286    #[test]
2287    fn test_parse_policy_with_check() {
2288        let schema = parse_schema(
2289            r#"
2290            policy InsertOwn on Post {
2291                for INSERT
2292                to authenticated
2293                check "author_id = current_user_id()"
2294            }
2295        "#,
2296        )
2297        .unwrap();
2298
2299        let policy = schema.get_policy("InsertOwn").unwrap();
2300        assert!(policy.applies_to(PolicyCommand::Insert));
2301        assert_eq!(
2302            policy.check_expr.as_deref(),
2303            Some("author_id = current_user_id()")
2304        );
2305        assert!(policy.using_expr.is_none());
2306    }
2307
2308    #[test]
2309    fn test_parse_policy_with_both_expressions() {
2310        let schema = parse_schema(
2311            r#"
2312            policy UpdateOwn on Post {
2313                for UPDATE
2314                using "author_id = current_user_id()"
2315                check "author_id = current_user_id()"
2316            }
2317        "#,
2318        )
2319        .unwrap();
2320
2321        let policy = schema.get_policy("UpdateOwn").unwrap();
2322        assert!(policy.using_expr.is_some());
2323        assert!(policy.check_expr.is_some());
2324    }
2325
2326    #[test]
2327    fn test_parse_policy_multiline_expression() {
2328        let schema = parse_schema(
2329            r#"
2330            policy ComplexCheck on Document {
2331                for SELECT
2332                using """
2333                    (is_public = true)
2334                    OR (owner_id = current_user_id())
2335                    OR (id IN (SELECT document_id FROM shares WHERE user_id = current_user_id()))
2336                """
2337            }
2338        "#,
2339        )
2340        .unwrap();
2341
2342        let policy = schema.get_policy("ComplexCheck").unwrap();
2343        assert!(policy.using_expr.is_some());
2344        let expr = policy.using_expr.as_ref().unwrap();
2345        assert!(expr.contains("is_public = true"));
2346        assert!(expr.contains("owner_id = current_user_id()"));
2347        assert!(expr.contains("SELECT document_id FROM shares"));
2348    }
2349
2350    #[test]
2351    fn test_parse_multiple_policies() {
2352        let schema = parse_schema(
2353            r#"
2354            policy UserRead on User {
2355                for SELECT
2356                using "true"
2357            }
2358
2359            policy UserInsert on User {
2360                for INSERT
2361                check "id = current_user_id()"
2362            }
2363
2364            policy PostRead on Post {
2365                for SELECT
2366                using "published = true OR author_id = current_user_id()"
2367            }
2368        "#,
2369        )
2370        .unwrap();
2371
2372        assert_eq!(schema.policies.len(), 3);
2373        assert!(schema.get_policy("UserRead").is_some());
2374        assert!(schema.get_policy("UserInsert").is_some());
2375        assert!(schema.get_policy("PostRead").is_some());
2376    }
2377
2378    #[test]
2379    fn test_parse_policy_with_model() {
2380        let schema = parse_schema(
2381            r#"
2382            model User {
2383                id    Int    @id @auto
2384                email String @unique
2385            }
2386
2387            policy UserReadOwn on User {
2388                for SELECT
2389                to authenticated
2390                using "id = auth.uid()"
2391            }
2392        "#,
2393        )
2394        .unwrap();
2395
2396        assert_eq!(schema.models.len(), 1);
2397        assert_eq!(schema.policies.len(), 1);
2398
2399        let policies = schema.policies_for("User");
2400        assert_eq!(policies.len(), 1);
2401        assert_eq!(policies[0].name(), "UserReadOwn");
2402    }
2403
2404    #[test]
2405    fn test_parse_policies_for_multiple_models() {
2406        let schema = parse_schema(
2407            r#"
2408            policy UserPolicy1 on User {
2409                for SELECT
2410                using "true"
2411            }
2412
2413            policy UserPolicy2 on User {
2414                for INSERT
2415                check "true"
2416            }
2417
2418            policy PostPolicy on Post {
2419                for SELECT
2420                using "true"
2421            }
2422        "#,
2423        )
2424        .unwrap();
2425
2426        assert_eq!(schema.policies_for("User").len(), 2);
2427        assert_eq!(schema.policies_for("Post").len(), 1);
2428        assert!(schema.has_policies("User"));
2429        assert!(schema.has_policies("Post"));
2430        assert!(!schema.has_policies("Comment"));
2431    }
2432
2433    #[test]
2434    fn test_parse_policy_default_all_command() {
2435        let schema = parse_schema(
2436            r#"
2437            policy DefaultAll on User {
2438                using "id = current_user_id()"
2439            }
2440        "#,
2441        )
2442        .unwrap();
2443
2444        let policy = schema.get_policy("DefaultAll").unwrap();
2445        // When no 'for' clause, should default to ALL
2446        assert!(policy.applies_to(PolicyCommand::All));
2447    }
2448
2449    #[test]
2450    fn test_parse_policy_case_insensitive_keywords() {
2451        let schema = parse_schema(
2452            r#"
2453            policy CaseTest on User {
2454                for select
2455                as permissive
2456                using "true"
2457            }
2458        "#,
2459        )
2460        .unwrap();
2461
2462        let policy = schema.get_policy("CaseTest").unwrap();
2463        assert!(policy.applies_to(PolicyCommand::Select));
2464        assert!(policy.is_permissive());
2465    }
2466
2467    #[test]
2468    fn test_parse_policy_sql_generation() {
2469        let schema = parse_schema(
2470            r#"
2471            model User {
2472                id Int @id
2473
2474                @@map("users")
2475            }
2476
2477            policy ReadOwn on User {
2478                for SELECT
2479                to authenticated
2480                using "id = auth.uid()"
2481            }
2482        "#,
2483        )
2484        .unwrap();
2485
2486        let policy = schema.get_policy("ReadOwn").unwrap();
2487        let sql = policy.to_sql("users");
2488
2489        assert!(sql.contains("CREATE POLICY ReadOwn ON users"));
2490        assert!(sql.contains("FOR SELECT"));
2491        assert!(sql.contains("TO authenticated"));
2492        assert!(sql.contains("USING (id = auth.uid())"));
2493    }
2494
2495    #[test]
2496    fn test_parse_policy_restrictive_sql() {
2497        let schema = parse_schema(
2498            r#"
2499            policy OrgBoundary on Document {
2500                as RESTRICTIVE
2501                for ALL
2502                using "org_id = current_org_id()"
2503            }
2504        "#,
2505        )
2506        .unwrap();
2507
2508        let policy = schema.get_policy("OrgBoundary").unwrap();
2509        let sql = policy.to_sql("documents");
2510
2511        assert!(sql.contains("AS RESTRICTIVE"));
2512    }
2513
2514    #[test]
2515    fn test_parse_policy_with_documentation() {
2516        let schema = parse_schema(
2517            r#"
2518            /// Users can only read their own data
2519            policy UserIsolation on User {
2520                for SELECT
2521                using "id = current_user_id()"
2522            }
2523        "#,
2524        )
2525        .unwrap();
2526
2527        let policy = schema.get_policy("UserIsolation").unwrap();
2528        if let Some(doc) = &policy.documentation {
2529            assert!(doc.text.contains("their own data"));
2530        }
2531    }
2532
2533    #[test]
2534    fn test_parse_complex_rls_schema() {
2535        let schema = parse_schema(
2536            r#"
2537            model Organization {
2538                id   Int    @id @auto
2539                name String
2540            }
2541
2542            model User {
2543                id    Int    @id @auto
2544                orgId Int
2545                email String @unique
2546            }
2547
2548            model Document {
2549                id       Int     @id @auto
2550                title    String
2551                ownerId  Int
2552                orgId    Int
2553                isPublic Boolean @default(false)
2554            }
2555
2556            /// Organization-level isolation
2557            policy OrgIsolation on Document {
2558                as RESTRICTIVE
2559                for ALL
2560                using "org_id = current_setting('app.current_org')::int"
2561            }
2562
2563            /// Users can read public documents
2564            policy PublicRead on Document {
2565                for SELECT
2566                using "is_public = true"
2567            }
2568
2569            /// Users can read their own documents
2570            policy OwnerRead on Document {
2571                for SELECT
2572                to authenticated
2573                using "owner_id = auth.uid()"
2574            }
2575
2576            /// Users can only modify their own documents
2577            policy OwnerModify on Document {
2578                for [UPDATE, DELETE]
2579                to authenticated
2580                using "owner_id = auth.uid()"
2581                check "owner_id = auth.uid()"
2582            }
2583
2584            /// Users can create documents in their org
2585            policy OrgInsert on Document {
2586                for INSERT
2587                to authenticated
2588                check "org_id = current_setting('app.current_org')::int"
2589            }
2590        "#,
2591        )
2592        .unwrap();
2593
2594        assert_eq!(schema.models.len(), 3);
2595        assert_eq!(schema.policies.len(), 5);
2596
2597        // Verify org isolation is restrictive
2598        let org_iso = schema.get_policy("OrgIsolation").unwrap();
2599        assert!(org_iso.is_restrictive());
2600
2601        // Verify all Document policies
2602        let doc_policies = schema.policies_for("Document");
2603        assert_eq!(doc_policies.len(), 5);
2604    }
2605
2606    // ==================== MSSQL Policy Parsing ====================
2607
2608    #[test]
2609    fn test_parse_policy_with_mssql_schema() {
2610        let schema = parse_schema(
2611            r#"
2612            policy UserFilter on User {
2613                for SELECT
2614                using "UserId = @UserId"
2615                mssqlSchema "RLS"
2616            }
2617        "#,
2618        )
2619        .unwrap();
2620
2621        let policy = schema.get_policy("UserFilter").unwrap();
2622        assert_eq!(policy.mssql_schema(), "RLS");
2623    }
2624
2625    #[test]
2626    fn test_parse_policy_with_mssql_block_single() {
2627        let schema = parse_schema(
2628            r#"
2629            policy UserInsert on User {
2630                for INSERT
2631                check "UserId = @UserId"
2632                mssqlBlock AFTER_INSERT
2633            }
2634        "#,
2635        )
2636        .unwrap();
2637
2638        let policy = schema.get_policy("UserInsert").unwrap();
2639        assert_eq!(policy.mssql_block_operations.len(), 1);
2640        assert_eq!(
2641            policy.mssql_block_operations[0],
2642            MssqlBlockOperation::AfterInsert
2643        );
2644    }
2645
2646    #[test]
2647    fn test_parse_policy_with_mssql_block_list() {
2648        let schema = parse_schema(
2649            r#"
2650            policy UserModify on User {
2651                for [INSERT, UPDATE, DELETE]
2652                check "UserId = @UserId"
2653                mssqlBlock [AFTER_INSERT, AFTER_UPDATE, BEFORE_DELETE]
2654            }
2655        "#,
2656        )
2657        .unwrap();
2658
2659        let policy = schema.get_policy("UserModify").unwrap();
2660        assert_eq!(policy.mssql_block_operations.len(), 3);
2661        assert!(
2662            policy
2663                .mssql_block_operations
2664                .contains(&MssqlBlockOperation::AfterInsert)
2665        );
2666        assert!(
2667            policy
2668                .mssql_block_operations
2669                .contains(&MssqlBlockOperation::AfterUpdate)
2670        );
2671        assert!(
2672            policy
2673                .mssql_block_operations
2674                .contains(&MssqlBlockOperation::BeforeDelete)
2675        );
2676    }
2677
2678    #[test]
2679    fn test_parse_policy_full_mssql_config() {
2680        let schema = parse_schema(
2681            r#"
2682            policy TenantIsolation on Order {
2683                for ALL
2684                using "TenantId = @TenantId"
2685                check "TenantId = @TenantId"
2686                mssqlSchema "MultiTenant"
2687                mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2688            }
2689        "#,
2690        )
2691        .unwrap();
2692
2693        let policy = schema.get_policy("TenantIsolation").unwrap();
2694
2695        // Verify standard options
2696        assert!(policy.applies_to(PolicyCommand::All));
2697        assert!(policy.using_expr.is_some());
2698        assert!(policy.check_expr.is_some());
2699
2700        // Verify MSSQL options
2701        assert_eq!(policy.mssql_schema(), "MultiTenant");
2702        assert_eq!(policy.mssql_block_operations.len(), 4);
2703
2704        // Test SQL generation
2705        let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
2706        assert!(mssql.schema_sql.contains("MultiTenant"));
2707        assert!(mssql.function_sql.contains("fn_TenantIsolation_predicate"));
2708    }
2709
2710    #[test]
2711    fn test_parse_policy_mssql_block_case_variants() {
2712        // Test different case variants for block operations
2713        let schema = parse_schema(
2714            r#"
2715            policy Test1 on User {
2716                for INSERT
2717                check "true"
2718                mssqlBlock after_insert
2719            }
2720        "#,
2721        )
2722        .unwrap();
2723
2724        let policy = schema.get_policy("Test1").unwrap();
2725        assert_eq!(policy.mssql_block_operations.len(), 1);
2726        assert_eq!(
2727            policy.mssql_block_operations[0],
2728            MssqlBlockOperation::AfterInsert
2729        );
2730    }
2731
2732    #[test]
2733    fn test_parse_mixed_postgres_mssql_schema() {
2734        let schema = parse_schema(
2735            r#"
2736            model User {
2737                id    Int    @id @auto
2738                email String @unique
2739            }
2740
2741            // PostgreSQL-style policy (works on both, MSSQL uses defaults)
2742            policy UserReadOwn on User {
2743                for SELECT
2744                to authenticated
2745                using "id = current_user_id()"
2746            }
2747
2748            // MSSQL-optimized policy with explicit settings
2749            policy UserModifyOwn on User {
2750                for [INSERT, UPDATE, DELETE]
2751                to authenticated
2752                using "id = current_user_id()"
2753                check "id = current_user_id()"
2754                mssqlSchema "Security"
2755                mssqlBlock [AFTER_INSERT, BEFORE_UPDATE, AFTER_UPDATE, BEFORE_DELETE]
2756            }
2757        "#,
2758        )
2759        .unwrap();
2760
2761        assert_eq!(schema.policies.len(), 2);
2762
2763        // First policy uses defaults for MSSQL
2764        let read_policy = schema.get_policy("UserReadOwn").unwrap();
2765        assert_eq!(read_policy.mssql_schema(), "Security"); // default
2766        assert!(read_policy.mssql_block_operations.is_empty()); // will use auto-generated
2767
2768        // Second policy has explicit MSSQL config
2769        let modify_policy = schema.get_policy("UserModifyOwn").unwrap();
2770        assert_eq!(modify_policy.mssql_schema(), "Security");
2771        assert_eq!(modify_policy.mssql_block_operations.len(), 4);
2772
2773        // Both should generate valid PostgreSQL SQL
2774        let pg_sql = read_policy.to_postgres_sql("users");
2775        assert!(pg_sql.contains("CREATE POLICY UserReadOwn ON users"));
2776
2777        // Both should generate valid MSSQL SQL
2778        let mssql = modify_policy.to_mssql_sql("dbo.Users", "id");
2779        assert!(mssql.policy_sql.contains("Security.UserModifyOwn"));
2780    }
2781}