Skip to main content

prax_schema/ast/
model.rs

1//! Model definitions for the Prax schema AST.
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{Attribute, Documentation, Field, Ident, Span};
8
9/// A model definition (maps to a database table).
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct Model {
12    /// Model name.
13    pub name: Ident,
14    /// Model fields.
15    pub fields: IndexMap<SmolStr, Field>,
16    /// Model-level attributes (prefixed with `@@`).
17    pub attributes: Vec<Attribute>,
18    /// Documentation comment.
19    pub documentation: Option<Documentation>,
20    /// Source location.
21    pub span: Span,
22    /// Source file this model was parsed from (None for single-file path).
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub source_id: Option<crate::loader::SourceId>,
25}
26
27impl Model {
28    /// Create a new model.
29    pub fn new(name: Ident, span: Span) -> Self {
30        Self {
31            name,
32            fields: IndexMap::new(),
33            attributes: vec![],
34            documentation: None,
35            span,
36            source_id: None,
37        }
38    }
39
40    /// Get the model name as a string.
41    pub fn name(&self) -> &str {
42        self.name.as_str()
43    }
44
45    /// Add a field to the model.
46    pub fn add_field(&mut self, field: Field) {
47        self.fields.insert(field.name.name.clone(), field);
48    }
49
50    /// Get a field by name.
51    pub fn get_field(&self, name: &str) -> Option<&Field> {
52        self.fields.get(name)
53    }
54
55    /// Get the primary key field(s).
56    pub fn id_fields(&self) -> Vec<&Field> {
57        self.fields.values().filter(|f| f.is_id()).collect()
58    }
59
60    /// Get all relation fields.
61    pub fn relation_fields(&self) -> Vec<&Field> {
62        self.fields.values().filter(|f| f.is_relation()).collect()
63    }
64
65    /// Get all scalar (non-relation) fields.
66    pub fn scalar_fields(&self) -> Vec<&Field> {
67        self.fields.values().filter(|f| !f.is_relation()).collect()
68    }
69
70    /// Check if this model has a specific model-level attribute.
71    pub fn has_attribute(&self, name: &str) -> bool {
72        self.attributes.iter().any(|a| a.is(name))
73    }
74
75    /// Get a model-level attribute by name.
76    pub fn get_attribute(&self, name: &str) -> Option<&Attribute> {
77        self.attributes.iter().find(|a| a.is(name))
78    }
79
80    /// Get the database table name (from `@@map` or model name).
81    pub fn table_name(&self) -> &str {
82        self.get_attribute("map")
83            .and_then(|a| a.first_arg())
84            .and_then(|v| v.as_string())
85            .unwrap_or_else(|| self.name())
86    }
87
88    /// Set documentation.
89    pub fn with_documentation(mut self, doc: Documentation) -> Self {
90        self.documentation = Some(doc);
91        self
92    }
93}
94
95/// An enum definition.
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub struct Enum {
98    /// Enum name.
99    pub name: Ident,
100    /// Enum variants.
101    pub variants: Vec<EnumVariant>,
102    /// Enum-level attributes.
103    pub attributes: Vec<Attribute>,
104    /// Documentation comment.
105    pub documentation: Option<Documentation>,
106    /// Source location.
107    pub span: Span,
108    /// Source file this enum was parsed from (None for single-file path).
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub source_id: Option<crate::loader::SourceId>,
111}
112
113impl Enum {
114    /// Create a new enum.
115    pub fn new(name: Ident, span: Span) -> Self {
116        Self {
117            name,
118            variants: vec![],
119            attributes: vec![],
120            documentation: None,
121            span,
122            source_id: None,
123        }
124    }
125
126    /// Get the enum name as a string.
127    pub fn name(&self) -> &str {
128        self.name.as_str()
129    }
130
131    /// Add a variant to the enum.
132    pub fn add_variant(&mut self, variant: EnumVariant) {
133        self.variants.push(variant);
134    }
135
136    /// Get a variant by name.
137    pub fn get_variant(&self, name: &str) -> Option<&EnumVariant> {
138        self.variants.iter().find(|v| v.name.as_str() == name)
139    }
140
141    /// Get the database type name (from `@@map` or enum name).
142    pub fn db_name(&self) -> &str {
143        self.attributes
144            .iter()
145            .find(|a| a.is("map"))
146            .and_then(|a| a.first_arg())
147            .and_then(|v| v.as_string())
148            .unwrap_or_else(|| self.name())
149    }
150
151    /// Set documentation.
152    pub fn with_documentation(mut self, doc: Documentation) -> Self {
153        self.documentation = Some(doc);
154        self
155    }
156}
157
158/// An enum variant.
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160pub struct EnumVariant {
161    /// Variant name.
162    pub name: Ident,
163    /// Variant-level attributes.
164    pub attributes: Vec<Attribute>,
165    /// Documentation comment.
166    pub documentation: Option<Documentation>,
167    /// Source location.
168    pub span: Span,
169}
170
171impl EnumVariant {
172    /// Create a new enum variant.
173    pub fn new(name: Ident, span: Span) -> Self {
174        Self {
175            name,
176            attributes: vec![],
177            documentation: None,
178            span,
179        }
180    }
181
182    /// Get the variant name as a string.
183    pub fn name(&self) -> &str {
184        self.name.as_str()
185    }
186
187    /// Get the database value (from `@map` or variant name).
188    pub fn db_value(&self) -> &str {
189        self.attributes
190            .iter()
191            .find(|a| a.is("map"))
192            .and_then(|a| a.first_arg())
193            .and_then(|v| v.as_string())
194            .unwrap_or_else(|| self.name())
195    }
196}
197
198/// A composite type definition (for embedded documents / JSON).
199#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
200pub struct CompositeType {
201    /// Type name.
202    pub name: Ident,
203    /// Type fields.
204    pub fields: IndexMap<SmolStr, Field>,
205    /// Documentation comment.
206    pub documentation: Option<Documentation>,
207    /// Source location.
208    pub span: Span,
209    /// Source file this type was parsed from (None for single-file path).
210    #[serde(default, skip_serializing_if = "Option::is_none")]
211    pub source_id: Option<crate::loader::SourceId>,
212}
213
214impl CompositeType {
215    /// Create a new composite type.
216    pub fn new(name: Ident, span: Span) -> Self {
217        Self {
218            name,
219            fields: IndexMap::new(),
220            documentation: None,
221            span,
222            source_id: None,
223        }
224    }
225
226    /// Get the type name as a string.
227    pub fn name(&self) -> &str {
228        self.name.as_str()
229    }
230
231    /// Add a field to the type.
232    pub fn add_field(&mut self, field: Field) {
233        self.fields.insert(field.name.name.clone(), field);
234    }
235
236    /// Get a field by name.
237    pub fn get_field(&self, name: &str) -> Option<&Field> {
238        self.fields.get(name)
239    }
240
241    /// Set documentation.
242    pub fn with_documentation(mut self, doc: Documentation) -> Self {
243        self.documentation = Some(doc);
244        self
245    }
246}
247
248/// A view definition (read-only model mapping to a database view).
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct View {
251    /// View name.
252    pub name: Ident,
253    /// View fields.
254    pub fields: IndexMap<SmolStr, Field>,
255    /// View-level attributes.
256    pub attributes: Vec<Attribute>,
257    /// Documentation comment.
258    pub documentation: Option<Documentation>,
259    /// Source location.
260    pub span: Span,
261    /// Source file this view was parsed from (None for single-file path).
262    #[serde(default, skip_serializing_if = "Option::is_none")]
263    pub source_id: Option<crate::loader::SourceId>,
264}
265
266impl View {
267    /// Create a new view.
268    pub fn new(name: Ident, span: Span) -> Self {
269        Self {
270            name,
271            fields: IndexMap::new(),
272            attributes: vec![],
273            documentation: None,
274            span,
275            source_id: None,
276        }
277    }
278
279    /// Get the view name as a string.
280    pub fn name(&self) -> &str {
281        self.name.as_str()
282    }
283
284    /// Add a field to the view.
285    pub fn add_field(&mut self, field: Field) {
286        self.fields.insert(field.name.name.clone(), field);
287    }
288
289    /// Get the database view name (from `@@map` or view name).
290    pub fn view_name(&self) -> &str {
291        self.attributes
292            .iter()
293            .find(|a| a.is("map"))
294            .and_then(|a| a.first_arg())
295            .and_then(|v| v.as_string())
296            .unwrap_or_else(|| self.name())
297    }
298
299    /// Get the SQL query that defines the view (from `@@sql` attribute).
300    pub fn sql_query(&self) -> Option<&str> {
301        self.attributes
302            .iter()
303            .find(|a| a.is("sql"))
304            .and_then(|a| a.first_arg())
305            .and_then(|v| v.as_string())
306    }
307
308    /// Check if the view is materialized (has `@@materialized` attribute).
309    pub fn is_materialized(&self) -> bool {
310        self.attributes.iter().any(|a| a.is("materialized"))
311    }
312
313    /// Get the refresh interval for materialized views (from `@@refreshInterval`).
314    pub fn refresh_interval(&self) -> Option<&str> {
315        self.attributes
316            .iter()
317            .find(|a| a.is("refreshInterval"))
318            .and_then(|a| a.first_arg())
319            .and_then(|v| v.as_string())
320    }
321
322    /// Set documentation.
323    pub fn with_documentation(mut self, doc: Documentation) -> Self {
324        self.documentation = Some(doc);
325        self
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::ast::{
333        Attribute, AttributeArg, AttributeValue, FieldType, ScalarType, TypeModifier,
334    };
335
336    fn make_span() -> Span {
337        Span::new(0, 10)
338    }
339
340    fn make_ident(name: &str) -> Ident {
341        Ident::new(name, make_span())
342    }
343
344    fn make_field(name: &str, field_type: FieldType, modifier: TypeModifier) -> Field {
345        Field::new(make_ident(name), field_type, modifier, vec![], make_span())
346    }
347
348    fn make_id_field() -> Field {
349        let mut field = make_field(
350            "id",
351            FieldType::Scalar(ScalarType::Int),
352            TypeModifier::Required,
353        );
354        field
355            .attributes
356            .push(Attribute::simple(make_ident("id"), make_span()));
357        field
358            .attributes
359            .push(Attribute::simple(make_ident("auto"), make_span()));
360        field
361    }
362
363    fn make_attribute(name: &str) -> Attribute {
364        Attribute::simple(make_ident(name), make_span())
365    }
366
367    fn make_attribute_with_string(name: &str, value: &str) -> Attribute {
368        Attribute::new(
369            make_ident(name),
370            vec![AttributeArg::positional(
371                AttributeValue::String(value.into()),
372                make_span(),
373            )],
374            make_span(),
375        )
376    }
377
378    // ==================== Model Tests ====================
379
380    #[test]
381    fn test_model_new() {
382        let model = Model::new(make_ident("User"), make_span());
383
384        assert_eq!(model.name(), "User");
385        assert!(model.fields.is_empty());
386        assert!(model.attributes.is_empty());
387        assert!(model.documentation.is_none());
388    }
389
390    #[test]
391    fn test_model_name() {
392        let model = Model::new(make_ident("BlogPost"), make_span());
393        assert_eq!(model.name(), "BlogPost");
394    }
395
396    #[test]
397    fn test_model_add_field() {
398        let mut model = Model::new(make_ident("User"), make_span());
399        let field = make_field(
400            "email",
401            FieldType::Scalar(ScalarType::String),
402            TypeModifier::Required,
403        );
404
405        model.add_field(field);
406
407        assert_eq!(model.fields.len(), 1);
408        assert!(model.fields.contains_key("email"));
409    }
410
411    #[test]
412    fn test_model_add_multiple_fields() {
413        let mut model = Model::new(make_ident("User"), make_span());
414        model.add_field(make_id_field());
415        model.add_field(make_field(
416            "email",
417            FieldType::Scalar(ScalarType::String),
418            TypeModifier::Required,
419        ));
420        model.add_field(make_field(
421            "name",
422            FieldType::Scalar(ScalarType::String),
423            TypeModifier::Optional,
424        ));
425
426        assert_eq!(model.fields.len(), 3);
427    }
428
429    #[test]
430    fn test_model_get_field() {
431        let mut model = Model::new(make_ident("User"), make_span());
432        model.add_field(make_field(
433            "email",
434            FieldType::Scalar(ScalarType::String),
435            TypeModifier::Required,
436        ));
437
438        let field = model.get_field("email");
439        assert!(field.is_some());
440        assert_eq!(field.unwrap().name(), "email");
441
442        assert!(model.get_field("nonexistent").is_none());
443    }
444
445    #[test]
446    fn test_model_id_fields() {
447        let mut model = Model::new(make_ident("User"), make_span());
448        model.add_field(make_id_field());
449        model.add_field(make_field(
450            "email",
451            FieldType::Scalar(ScalarType::String),
452            TypeModifier::Required,
453        ));
454
455        let id_fields = model.id_fields();
456        assert_eq!(id_fields.len(), 1);
457        assert_eq!(id_fields[0].name(), "id");
458    }
459
460    #[test]
461    fn test_model_id_fields_none() {
462        let mut model = Model::new(make_ident("User"), make_span());
463        model.add_field(make_field(
464            "email",
465            FieldType::Scalar(ScalarType::String),
466            TypeModifier::Required,
467        ));
468
469        let id_fields = model.id_fields();
470        assert!(id_fields.is_empty());
471    }
472
473    #[test]
474    fn test_model_relation_fields() {
475        let mut model = Model::new(make_ident("Post"), make_span());
476        model.add_field(make_id_field());
477        model.add_field(make_field(
478            "title",
479            FieldType::Scalar(ScalarType::String),
480            TypeModifier::Required,
481        ));
482        model.add_field(make_field(
483            "author",
484            FieldType::Model("User".into()),
485            TypeModifier::Required,
486        ));
487
488        let rel_fields = model.relation_fields();
489        assert_eq!(rel_fields.len(), 1);
490        assert_eq!(rel_fields[0].name(), "author");
491    }
492
493    #[test]
494    fn test_model_scalar_fields() {
495        let mut model = Model::new(make_ident("Post"), make_span());
496        model.add_field(make_id_field());
497        model.add_field(make_field(
498            "title",
499            FieldType::Scalar(ScalarType::String),
500            TypeModifier::Required,
501        ));
502        model.add_field(make_field(
503            "author",
504            FieldType::Model("User".into()),
505            TypeModifier::Required,
506        ));
507
508        let scalar_fields = model.scalar_fields();
509        assert_eq!(scalar_fields.len(), 2);
510    }
511
512    #[test]
513    fn test_model_has_attribute() {
514        let mut model = Model::new(make_ident("User"), make_span());
515        model.attributes.push(make_attribute("map"));
516
517        assert!(model.has_attribute("map"));
518        assert!(!model.has_attribute("index"));
519    }
520
521    #[test]
522    fn test_model_get_attribute() {
523        let mut model = Model::new(make_ident("User"), make_span());
524        model
525            .attributes
526            .push(make_attribute_with_string("map", "users"));
527
528        let attr = model.get_attribute("map");
529        assert!(attr.is_some());
530        assert!(attr.unwrap().is("map"));
531
532        assert!(model.get_attribute("index").is_none());
533    }
534
535    #[test]
536    fn test_model_table_name_default() {
537        let model = Model::new(make_ident("User"), make_span());
538        assert_eq!(model.table_name(), "User");
539    }
540
541    #[test]
542    fn test_model_table_name_mapped() {
543        let mut model = Model::new(make_ident("User"), make_span());
544        model
545            .attributes
546            .push(make_attribute_with_string("map", "app_users"));
547
548        assert_eq!(model.table_name(), "app_users");
549    }
550
551    #[test]
552    fn test_model_with_documentation() {
553        let model = Model::new(make_ident("User"), make_span())
554            .with_documentation(Documentation::new("Represents a user", make_span()));
555
556        assert!(model.documentation.is_some());
557        assert_eq!(model.documentation.unwrap().text, "Represents a user");
558    }
559
560    // ==================== Enum Tests ====================
561
562    #[test]
563    fn test_enum_new() {
564        let e = Enum::new(make_ident("Role"), make_span());
565
566        assert_eq!(e.name(), "Role");
567        assert!(e.variants.is_empty());
568        assert!(e.attributes.is_empty());
569        assert!(e.documentation.is_none());
570    }
571
572    #[test]
573    fn test_enum_add_variant() {
574        let mut e = Enum::new(make_ident("Role"), make_span());
575        e.add_variant(EnumVariant::new(make_ident("Admin"), make_span()));
576        e.add_variant(EnumVariant::new(make_ident("User"), make_span()));
577
578        assert_eq!(e.variants.len(), 2);
579    }
580
581    #[test]
582    fn test_enum_get_variant() {
583        let mut e = Enum::new(make_ident("Role"), make_span());
584        e.add_variant(EnumVariant::new(make_ident("Admin"), make_span()));
585        e.add_variant(EnumVariant::new(make_ident("User"), make_span()));
586
587        let variant = e.get_variant("Admin");
588        assert!(variant.is_some());
589        assert_eq!(variant.unwrap().name(), "Admin");
590
591        assert!(e.get_variant("Moderator").is_none());
592    }
593
594    #[test]
595    fn test_enum_db_name_default() {
596        let e = Enum::new(make_ident("Role"), make_span());
597        assert_eq!(e.db_name(), "Role");
598    }
599
600    #[test]
601    fn test_enum_db_name_mapped() {
602        let mut e = Enum::new(make_ident("Role"), make_span());
603        e.attributes
604            .push(make_attribute_with_string("map", "user_role"));
605
606        assert_eq!(e.db_name(), "user_role");
607    }
608
609    #[test]
610    fn test_enum_with_documentation() {
611        let e = Enum::new(make_ident("Role"), make_span())
612            .with_documentation(Documentation::new("User roles", make_span()));
613
614        assert!(e.documentation.is_some());
615    }
616
617    // ==================== EnumVariant Tests ====================
618
619    #[test]
620    fn test_enum_variant_new() {
621        let variant = EnumVariant::new(make_ident("Admin"), make_span());
622
623        assert_eq!(variant.name(), "Admin");
624        assert!(variant.attributes.is_empty());
625        assert!(variant.documentation.is_none());
626    }
627
628    #[test]
629    fn test_enum_variant_db_value_default() {
630        let variant = EnumVariant::new(make_ident("Admin"), make_span());
631        assert_eq!(variant.db_value(), "Admin");
632    }
633
634    #[test]
635    fn test_enum_variant_db_value_mapped() {
636        let mut variant = EnumVariant::new(make_ident("Admin"), make_span());
637        variant
638            .attributes
639            .push(make_attribute_with_string("map", "ADMIN_USER"));
640
641        assert_eq!(variant.db_value(), "ADMIN_USER");
642    }
643
644    // ==================== CompositeType Tests ====================
645
646    #[test]
647    fn test_composite_type_new() {
648        let ct = CompositeType::new(make_ident("Address"), make_span());
649
650        assert_eq!(ct.name(), "Address");
651        assert!(ct.fields.is_empty());
652        assert!(ct.documentation.is_none());
653    }
654
655    #[test]
656    fn test_composite_type_add_field() {
657        let mut ct = CompositeType::new(make_ident("Address"), make_span());
658        ct.add_field(make_field(
659            "street",
660            FieldType::Scalar(ScalarType::String),
661            TypeModifier::Required,
662        ));
663        ct.add_field(make_field(
664            "city",
665            FieldType::Scalar(ScalarType::String),
666            TypeModifier::Required,
667        ));
668
669        assert_eq!(ct.fields.len(), 2);
670    }
671
672    #[test]
673    fn test_composite_type_get_field() {
674        let mut ct = CompositeType::new(make_ident("Address"), make_span());
675        ct.add_field(make_field(
676            "city",
677            FieldType::Scalar(ScalarType::String),
678            TypeModifier::Required,
679        ));
680
681        let field = ct.get_field("city");
682        assert!(field.is_some());
683        assert_eq!(field.unwrap().name(), "city");
684
685        assert!(ct.get_field("country").is_none());
686    }
687
688    #[test]
689    fn test_composite_type_with_documentation() {
690        let ct = CompositeType::new(make_ident("Address"), make_span())
691            .with_documentation(Documentation::new("Mailing address", make_span()));
692
693        assert!(ct.documentation.is_some());
694    }
695
696    // ==================== View Tests ====================
697
698    #[test]
699    fn test_view_new() {
700        let view = View::new(make_ident("UserStats"), make_span());
701
702        assert_eq!(view.name(), "UserStats");
703        assert!(view.fields.is_empty());
704        assert!(view.attributes.is_empty());
705        assert!(view.documentation.is_none());
706    }
707
708    #[test]
709    fn test_view_add_field() {
710        let mut view = View::new(make_ident("UserStats"), make_span());
711        view.add_field(make_field(
712            "user_id",
713            FieldType::Scalar(ScalarType::Int),
714            TypeModifier::Required,
715        ));
716        view.add_field(make_field(
717            "post_count",
718            FieldType::Scalar(ScalarType::Int),
719            TypeModifier::Required,
720        ));
721
722        assert_eq!(view.fields.len(), 2);
723    }
724
725    #[test]
726    fn test_view_view_name_default() {
727        let view = View::new(make_ident("UserStats"), make_span());
728        assert_eq!(view.view_name(), "UserStats");
729    }
730
731    #[test]
732    fn test_view_view_name_mapped() {
733        let mut view = View::new(make_ident("UserStats"), make_span());
734        view.attributes
735            .push(make_attribute_with_string("map", "v_user_statistics"));
736
737        assert_eq!(view.view_name(), "v_user_statistics");
738    }
739
740    #[test]
741    fn test_view_with_documentation() {
742        let view = View::new(make_ident("UserStats"), make_span()).with_documentation(
743            Documentation::new("Aggregated user statistics", make_span()),
744        );
745
746        assert!(view.documentation.is_some());
747    }
748
749    // ==================== Equality Tests ====================
750
751    #[test]
752    fn test_model_equality() {
753        let model1 = Model::new(make_ident("User"), make_span());
754        let model2 = Model::new(make_ident("User"), make_span());
755
756        assert_eq!(model1, model2);
757    }
758
759    #[test]
760    fn test_model_inequality() {
761        let model1 = Model::new(make_ident("User"), make_span());
762        let model2 = Model::new(make_ident("Post"), make_span());
763
764        assert_ne!(model1, model2);
765    }
766
767    #[test]
768    fn test_enum_equality() {
769        let enum1 = Enum::new(make_ident("Role"), make_span());
770        let enum2 = Enum::new(make_ident("Role"), make_span());
771
772        assert_eq!(enum1, enum2);
773    }
774
775    #[test]
776    fn test_enum_variant_equality() {
777        let v1 = EnumVariant::new(make_ident("Admin"), make_span());
778        let v2 = EnumVariant::new(make_ident("Admin"), make_span());
779        let v3 = EnumVariant::new(make_ident("User"), make_span());
780
781        assert_eq!(v1, v2);
782        assert_ne!(v1, v3);
783    }
784
785    #[test]
786    fn test_composite_type_equality() {
787        let ct1 = CompositeType::new(make_ident("Address"), make_span());
788        let ct2 = CompositeType::new(make_ident("Address"), make_span());
789
790        assert_eq!(ct1, ct2);
791    }
792
793    #[test]
794    fn test_view_equality() {
795        let v1 = View::new(make_ident("Stats"), make_span());
796        let v2 = View::new(make_ident("Stats"), make_span());
797
798        assert_eq!(v1, v2);
799    }
800}