prax_schema/ast/
model.rs

1//! Model definitions for the Prax schema AST.
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{Attribute, Documentation, Field, Ident, Span};
8
9/// A model definition (maps to a database table).
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct Model {
12    /// Model name.
13    pub name: Ident,
14    /// Model fields.
15    pub fields: IndexMap<SmolStr, Field>,
16    /// Model-level attributes (prefixed with `@@`).
17    pub attributes: Vec<Attribute>,
18    /// Documentation comment.
19    pub documentation: Option<Documentation>,
20    /// Source location.
21    pub span: Span,
22}
23
24impl Model {
25    /// Create a new model.
26    pub fn new(name: Ident, span: Span) -> Self {
27        Self {
28            name,
29            fields: IndexMap::new(),
30            attributes: vec![],
31            documentation: None,
32            span,
33        }
34    }
35
36    /// Get the model name as a string.
37    pub fn name(&self) -> &str {
38        self.name.as_str()
39    }
40
41    /// Add a field to the model.
42    pub fn add_field(&mut self, field: Field) {
43        self.fields.insert(field.name.name.clone(), field);
44    }
45
46    /// Get a field by name.
47    pub fn get_field(&self, name: &str) -> Option<&Field> {
48        self.fields.get(name)
49    }
50
51    /// Get the primary key field(s).
52    pub fn id_fields(&self) -> Vec<&Field> {
53        self.fields.values().filter(|f| f.is_id()).collect()
54    }
55
56    /// Get all relation fields.
57    pub fn relation_fields(&self) -> Vec<&Field> {
58        self.fields.values().filter(|f| f.is_relation()).collect()
59    }
60
61    /// Get all scalar (non-relation) fields.
62    pub fn scalar_fields(&self) -> Vec<&Field> {
63        self.fields.values().filter(|f| !f.is_relation()).collect()
64    }
65
66    /// Check if this model has a specific model-level attribute.
67    pub fn has_attribute(&self, name: &str) -> bool {
68        self.attributes.iter().any(|a| a.is(name))
69    }
70
71    /// Get a model-level attribute by name.
72    pub fn get_attribute(&self, name: &str) -> Option<&Attribute> {
73        self.attributes.iter().find(|a| a.is(name))
74    }
75
76    /// Get the database table name (from `@@map` or model name).
77    pub fn table_name(&self) -> &str {
78        self.get_attribute("map")
79            .and_then(|a| a.first_arg())
80            .and_then(|v| v.as_string())
81            .unwrap_or_else(|| self.name())
82    }
83
84    /// Set documentation.
85    pub fn with_documentation(mut self, doc: Documentation) -> Self {
86        self.documentation = Some(doc);
87        self
88    }
89}
90
91/// An enum definition.
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub struct Enum {
94    /// Enum name.
95    pub name: Ident,
96    /// Enum variants.
97    pub variants: Vec<EnumVariant>,
98    /// Enum-level attributes.
99    pub attributes: Vec<Attribute>,
100    /// Documentation comment.
101    pub documentation: Option<Documentation>,
102    /// Source location.
103    pub span: Span,
104}
105
106impl Enum {
107    /// Create a new enum.
108    pub fn new(name: Ident, span: Span) -> Self {
109        Self {
110            name,
111            variants: vec![],
112            attributes: vec![],
113            documentation: None,
114            span,
115        }
116    }
117
118    /// Get the enum name as a string.
119    pub fn name(&self) -> &str {
120        self.name.as_str()
121    }
122
123    /// Add a variant to the enum.
124    pub fn add_variant(&mut self, variant: EnumVariant) {
125        self.variants.push(variant);
126    }
127
128    /// Get a variant by name.
129    pub fn get_variant(&self, name: &str) -> Option<&EnumVariant> {
130        self.variants.iter().find(|v| v.name.as_str() == name)
131    }
132
133    /// Get the database type name (from `@@map` or enum name).
134    pub fn db_name(&self) -> &str {
135        self.attributes
136            .iter()
137            .find(|a| a.is("map"))
138            .and_then(|a| a.first_arg())
139            .and_then(|v| v.as_string())
140            .unwrap_or_else(|| self.name())
141    }
142
143    /// Set documentation.
144    pub fn with_documentation(mut self, doc: Documentation) -> Self {
145        self.documentation = Some(doc);
146        self
147    }
148}
149
150/// An enum variant.
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct EnumVariant {
153    /// Variant name.
154    pub name: Ident,
155    /// Variant-level attributes.
156    pub attributes: Vec<Attribute>,
157    /// Documentation comment.
158    pub documentation: Option<Documentation>,
159    /// Source location.
160    pub span: Span,
161}
162
163impl EnumVariant {
164    /// Create a new enum variant.
165    pub fn new(name: Ident, span: Span) -> Self {
166        Self {
167            name,
168            attributes: vec![],
169            documentation: None,
170            span,
171        }
172    }
173
174    /// Get the variant name as a string.
175    pub fn name(&self) -> &str {
176        self.name.as_str()
177    }
178
179    /// Get the database value (from `@map` or variant name).
180    pub fn db_value(&self) -> &str {
181        self.attributes
182            .iter()
183            .find(|a| a.is("map"))
184            .and_then(|a| a.first_arg())
185            .and_then(|v| v.as_string())
186            .unwrap_or_else(|| self.name())
187    }
188}
189
190/// A composite type definition (for embedded documents / JSON).
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct CompositeType {
193    /// Type name.
194    pub name: Ident,
195    /// Type fields.
196    pub fields: IndexMap<SmolStr, Field>,
197    /// Documentation comment.
198    pub documentation: Option<Documentation>,
199    /// Source location.
200    pub span: Span,
201}
202
203impl CompositeType {
204    /// Create a new composite type.
205    pub fn new(name: Ident, span: Span) -> Self {
206        Self {
207            name,
208            fields: IndexMap::new(),
209            documentation: None,
210            span,
211        }
212    }
213
214    /// Get the type name as a string.
215    pub fn name(&self) -> &str {
216        self.name.as_str()
217    }
218
219    /// Add a field to the type.
220    pub fn add_field(&mut self, field: Field) {
221        self.fields.insert(field.name.name.clone(), field);
222    }
223
224    /// Get a field by name.
225    pub fn get_field(&self, name: &str) -> Option<&Field> {
226        self.fields.get(name)
227    }
228
229    /// Set documentation.
230    pub fn with_documentation(mut self, doc: Documentation) -> Self {
231        self.documentation = Some(doc);
232        self
233    }
234}
235
236/// A view definition (read-only model mapping to a database view).
237#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
238pub struct View {
239    /// View name.
240    pub name: Ident,
241    /// View fields.
242    pub fields: IndexMap<SmolStr, Field>,
243    /// View-level attributes.
244    pub attributes: Vec<Attribute>,
245    /// Documentation comment.
246    pub documentation: Option<Documentation>,
247    /// Source location.
248    pub span: Span,
249}
250
251impl View {
252    /// Create a new view.
253    pub fn new(name: Ident, span: Span) -> Self {
254        Self {
255            name,
256            fields: IndexMap::new(),
257            attributes: vec![],
258            documentation: None,
259            span,
260        }
261    }
262
263    /// Get the view name as a string.
264    pub fn name(&self) -> &str {
265        self.name.as_str()
266    }
267
268    /// Add a field to the view.
269    pub fn add_field(&mut self, field: Field) {
270        self.fields.insert(field.name.name.clone(), field);
271    }
272
273    /// Get the database view name (from `@@map` or view name).
274    pub fn view_name(&self) -> &str {
275        self.attributes
276            .iter()
277            .find(|a| a.is("map"))
278            .and_then(|a| a.first_arg())
279            .and_then(|v| v.as_string())
280            .unwrap_or_else(|| self.name())
281    }
282
283    /// Get the SQL query that defines the view (from `@@sql` attribute).
284    pub fn sql_query(&self) -> Option<&str> {
285        self.attributes
286            .iter()
287            .find(|a| a.is("sql"))
288            .and_then(|a| a.first_arg())
289            .and_then(|v| v.as_string())
290    }
291
292    /// Check if the view is materialized (has `@@materialized` attribute).
293    pub fn is_materialized(&self) -> bool {
294        self.attributes.iter().any(|a| a.is("materialized"))
295    }
296
297    /// Get the refresh interval for materialized views (from `@@refreshInterval`).
298    pub fn refresh_interval(&self) -> Option<&str> {
299        self.attributes
300            .iter()
301            .find(|a| a.is("refreshInterval"))
302            .and_then(|a| a.first_arg())
303            .and_then(|v| v.as_string())
304    }
305
306    /// Set documentation.
307    pub fn with_documentation(mut self, doc: Documentation) -> Self {
308        self.documentation = Some(doc);
309        self
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::ast::{
317        Attribute, AttributeArg, AttributeValue, FieldType, ScalarType, TypeModifier,
318    };
319
320    fn make_span() -> Span {
321        Span::new(0, 10)
322    }
323
324    fn make_ident(name: &str) -> Ident {
325        Ident::new(name, make_span())
326    }
327
328    fn make_field(name: &str, field_type: FieldType, modifier: TypeModifier) -> Field {
329        Field::new(make_ident(name), field_type, modifier, vec![], make_span())
330    }
331
332    fn make_id_field() -> Field {
333        let mut field = make_field(
334            "id",
335            FieldType::Scalar(ScalarType::Int),
336            TypeModifier::Required,
337        );
338        field
339            .attributes
340            .push(Attribute::simple(make_ident("id"), make_span()));
341        field
342            .attributes
343            .push(Attribute::simple(make_ident("auto"), make_span()));
344        field
345    }
346
347    fn make_attribute(name: &str) -> Attribute {
348        Attribute::simple(make_ident(name), make_span())
349    }
350
351    fn make_attribute_with_string(name: &str, value: &str) -> Attribute {
352        Attribute::new(
353            make_ident(name),
354            vec![AttributeArg::positional(
355                AttributeValue::String(value.into()),
356                make_span(),
357            )],
358            make_span(),
359        )
360    }
361
362    // ==================== Model Tests ====================
363
364    #[test]
365    fn test_model_new() {
366        let model = Model::new(make_ident("User"), make_span());
367
368        assert_eq!(model.name(), "User");
369        assert!(model.fields.is_empty());
370        assert!(model.attributes.is_empty());
371        assert!(model.documentation.is_none());
372    }
373
374    #[test]
375    fn test_model_name() {
376        let model = Model::new(make_ident("BlogPost"), make_span());
377        assert_eq!(model.name(), "BlogPost");
378    }
379
380    #[test]
381    fn test_model_add_field() {
382        let mut model = Model::new(make_ident("User"), make_span());
383        let field = make_field(
384            "email",
385            FieldType::Scalar(ScalarType::String),
386            TypeModifier::Required,
387        );
388
389        model.add_field(field);
390
391        assert_eq!(model.fields.len(), 1);
392        assert!(model.fields.contains_key("email"));
393    }
394
395    #[test]
396    fn test_model_add_multiple_fields() {
397        let mut model = Model::new(make_ident("User"), make_span());
398        model.add_field(make_id_field());
399        model.add_field(make_field(
400            "email",
401            FieldType::Scalar(ScalarType::String),
402            TypeModifier::Required,
403        ));
404        model.add_field(make_field(
405            "name",
406            FieldType::Scalar(ScalarType::String),
407            TypeModifier::Optional,
408        ));
409
410        assert_eq!(model.fields.len(), 3);
411    }
412
413    #[test]
414    fn test_model_get_field() {
415        let mut model = Model::new(make_ident("User"), make_span());
416        model.add_field(make_field(
417            "email",
418            FieldType::Scalar(ScalarType::String),
419            TypeModifier::Required,
420        ));
421
422        let field = model.get_field("email");
423        assert!(field.is_some());
424        assert_eq!(field.unwrap().name(), "email");
425
426        assert!(model.get_field("nonexistent").is_none());
427    }
428
429    #[test]
430    fn test_model_id_fields() {
431        let mut model = Model::new(make_ident("User"), make_span());
432        model.add_field(make_id_field());
433        model.add_field(make_field(
434            "email",
435            FieldType::Scalar(ScalarType::String),
436            TypeModifier::Required,
437        ));
438
439        let id_fields = model.id_fields();
440        assert_eq!(id_fields.len(), 1);
441        assert_eq!(id_fields[0].name(), "id");
442    }
443
444    #[test]
445    fn test_model_id_fields_none() {
446        let mut model = Model::new(make_ident("User"), make_span());
447        model.add_field(make_field(
448            "email",
449            FieldType::Scalar(ScalarType::String),
450            TypeModifier::Required,
451        ));
452
453        let id_fields = model.id_fields();
454        assert!(id_fields.is_empty());
455    }
456
457    #[test]
458    fn test_model_relation_fields() {
459        let mut model = Model::new(make_ident("Post"), make_span());
460        model.add_field(make_id_field());
461        model.add_field(make_field(
462            "title",
463            FieldType::Scalar(ScalarType::String),
464            TypeModifier::Required,
465        ));
466        model.add_field(make_field(
467            "author",
468            FieldType::Model("User".into()),
469            TypeModifier::Required,
470        ));
471
472        let rel_fields = model.relation_fields();
473        assert_eq!(rel_fields.len(), 1);
474        assert_eq!(rel_fields[0].name(), "author");
475    }
476
477    #[test]
478    fn test_model_scalar_fields() {
479        let mut model = Model::new(make_ident("Post"), make_span());
480        model.add_field(make_id_field());
481        model.add_field(make_field(
482            "title",
483            FieldType::Scalar(ScalarType::String),
484            TypeModifier::Required,
485        ));
486        model.add_field(make_field(
487            "author",
488            FieldType::Model("User".into()),
489            TypeModifier::Required,
490        ));
491
492        let scalar_fields = model.scalar_fields();
493        assert_eq!(scalar_fields.len(), 2);
494    }
495
496    #[test]
497    fn test_model_has_attribute() {
498        let mut model = Model::new(make_ident("User"), make_span());
499        model.attributes.push(make_attribute("map"));
500
501        assert!(model.has_attribute("map"));
502        assert!(!model.has_attribute("index"));
503    }
504
505    #[test]
506    fn test_model_get_attribute() {
507        let mut model = Model::new(make_ident("User"), make_span());
508        model
509            .attributes
510            .push(make_attribute_with_string("map", "users"));
511
512        let attr = model.get_attribute("map");
513        assert!(attr.is_some());
514        assert!(attr.unwrap().is("map"));
515
516        assert!(model.get_attribute("index").is_none());
517    }
518
519    #[test]
520    fn test_model_table_name_default() {
521        let model = Model::new(make_ident("User"), make_span());
522        assert_eq!(model.table_name(), "User");
523    }
524
525    #[test]
526    fn test_model_table_name_mapped() {
527        let mut model = Model::new(make_ident("User"), make_span());
528        model
529            .attributes
530            .push(make_attribute_with_string("map", "app_users"));
531
532        assert_eq!(model.table_name(), "app_users");
533    }
534
535    #[test]
536    fn test_model_with_documentation() {
537        let model = Model::new(make_ident("User"), make_span())
538            .with_documentation(Documentation::new("Represents a user", make_span()));
539
540        assert!(model.documentation.is_some());
541        assert_eq!(model.documentation.unwrap().text, "Represents a user");
542    }
543
544    // ==================== Enum Tests ====================
545
546    #[test]
547    fn test_enum_new() {
548        let e = Enum::new(make_ident("Role"), make_span());
549
550        assert_eq!(e.name(), "Role");
551        assert!(e.variants.is_empty());
552        assert!(e.attributes.is_empty());
553        assert!(e.documentation.is_none());
554    }
555
556    #[test]
557    fn test_enum_add_variant() {
558        let mut e = Enum::new(make_ident("Role"), make_span());
559        e.add_variant(EnumVariant::new(make_ident("Admin"), make_span()));
560        e.add_variant(EnumVariant::new(make_ident("User"), make_span()));
561
562        assert_eq!(e.variants.len(), 2);
563    }
564
565    #[test]
566    fn test_enum_get_variant() {
567        let mut e = Enum::new(make_ident("Role"), make_span());
568        e.add_variant(EnumVariant::new(make_ident("Admin"), make_span()));
569        e.add_variant(EnumVariant::new(make_ident("User"), make_span()));
570
571        let variant = e.get_variant("Admin");
572        assert!(variant.is_some());
573        assert_eq!(variant.unwrap().name(), "Admin");
574
575        assert!(e.get_variant("Moderator").is_none());
576    }
577
578    #[test]
579    fn test_enum_db_name_default() {
580        let e = Enum::new(make_ident("Role"), make_span());
581        assert_eq!(e.db_name(), "Role");
582    }
583
584    #[test]
585    fn test_enum_db_name_mapped() {
586        let mut e = Enum::new(make_ident("Role"), make_span());
587        e.attributes
588            .push(make_attribute_with_string("map", "user_role"));
589
590        assert_eq!(e.db_name(), "user_role");
591    }
592
593    #[test]
594    fn test_enum_with_documentation() {
595        let e = Enum::new(make_ident("Role"), make_span())
596            .with_documentation(Documentation::new("User roles", make_span()));
597
598        assert!(e.documentation.is_some());
599    }
600
601    // ==================== EnumVariant Tests ====================
602
603    #[test]
604    fn test_enum_variant_new() {
605        let variant = EnumVariant::new(make_ident("Admin"), make_span());
606
607        assert_eq!(variant.name(), "Admin");
608        assert!(variant.attributes.is_empty());
609        assert!(variant.documentation.is_none());
610    }
611
612    #[test]
613    fn test_enum_variant_db_value_default() {
614        let variant = EnumVariant::new(make_ident("Admin"), make_span());
615        assert_eq!(variant.db_value(), "Admin");
616    }
617
618    #[test]
619    fn test_enum_variant_db_value_mapped() {
620        let mut variant = EnumVariant::new(make_ident("Admin"), make_span());
621        variant
622            .attributes
623            .push(make_attribute_with_string("map", "ADMIN_USER"));
624
625        assert_eq!(variant.db_value(), "ADMIN_USER");
626    }
627
628    // ==================== CompositeType Tests ====================
629
630    #[test]
631    fn test_composite_type_new() {
632        let ct = CompositeType::new(make_ident("Address"), make_span());
633
634        assert_eq!(ct.name(), "Address");
635        assert!(ct.fields.is_empty());
636        assert!(ct.documentation.is_none());
637    }
638
639    #[test]
640    fn test_composite_type_add_field() {
641        let mut ct = CompositeType::new(make_ident("Address"), make_span());
642        ct.add_field(make_field(
643            "street",
644            FieldType::Scalar(ScalarType::String),
645            TypeModifier::Required,
646        ));
647        ct.add_field(make_field(
648            "city",
649            FieldType::Scalar(ScalarType::String),
650            TypeModifier::Required,
651        ));
652
653        assert_eq!(ct.fields.len(), 2);
654    }
655
656    #[test]
657    fn test_composite_type_get_field() {
658        let mut ct = CompositeType::new(make_ident("Address"), make_span());
659        ct.add_field(make_field(
660            "city",
661            FieldType::Scalar(ScalarType::String),
662            TypeModifier::Required,
663        ));
664
665        let field = ct.get_field("city");
666        assert!(field.is_some());
667        assert_eq!(field.unwrap().name(), "city");
668
669        assert!(ct.get_field("country").is_none());
670    }
671
672    #[test]
673    fn test_composite_type_with_documentation() {
674        let ct = CompositeType::new(make_ident("Address"), make_span())
675            .with_documentation(Documentation::new("Mailing address", make_span()));
676
677        assert!(ct.documentation.is_some());
678    }
679
680    // ==================== View Tests ====================
681
682    #[test]
683    fn test_view_new() {
684        let view = View::new(make_ident("UserStats"), make_span());
685
686        assert_eq!(view.name(), "UserStats");
687        assert!(view.fields.is_empty());
688        assert!(view.attributes.is_empty());
689        assert!(view.documentation.is_none());
690    }
691
692    #[test]
693    fn test_view_add_field() {
694        let mut view = View::new(make_ident("UserStats"), make_span());
695        view.add_field(make_field(
696            "user_id",
697            FieldType::Scalar(ScalarType::Int),
698            TypeModifier::Required,
699        ));
700        view.add_field(make_field(
701            "post_count",
702            FieldType::Scalar(ScalarType::Int),
703            TypeModifier::Required,
704        ));
705
706        assert_eq!(view.fields.len(), 2);
707    }
708
709    #[test]
710    fn test_view_view_name_default() {
711        let view = View::new(make_ident("UserStats"), make_span());
712        assert_eq!(view.view_name(), "UserStats");
713    }
714
715    #[test]
716    fn test_view_view_name_mapped() {
717        let mut view = View::new(make_ident("UserStats"), make_span());
718        view.attributes
719            .push(make_attribute_with_string("map", "v_user_statistics"));
720
721        assert_eq!(view.view_name(), "v_user_statistics");
722    }
723
724    #[test]
725    fn test_view_with_documentation() {
726        let view = View::new(make_ident("UserStats"), make_span()).with_documentation(
727            Documentation::new("Aggregated user statistics", make_span()),
728        );
729
730        assert!(view.documentation.is_some());
731    }
732
733    // ==================== Equality Tests ====================
734
735    #[test]
736    fn test_model_equality() {
737        let model1 = Model::new(make_ident("User"), make_span());
738        let model2 = Model::new(make_ident("User"), make_span());
739
740        assert_eq!(model1, model2);
741    }
742
743    #[test]
744    fn test_model_inequality() {
745        let model1 = Model::new(make_ident("User"), make_span());
746        let model2 = Model::new(make_ident("Post"), make_span());
747
748        assert_ne!(model1, model2);
749    }
750
751    #[test]
752    fn test_enum_equality() {
753        let enum1 = Enum::new(make_ident("Role"), make_span());
754        let enum2 = Enum::new(make_ident("Role"), make_span());
755
756        assert_eq!(enum1, enum2);
757    }
758
759    #[test]
760    fn test_enum_variant_equality() {
761        let v1 = EnumVariant::new(make_ident("Admin"), make_span());
762        let v2 = EnumVariant::new(make_ident("Admin"), make_span());
763        let v3 = EnumVariant::new(make_ident("User"), make_span());
764
765        assert_eq!(v1, v2);
766        assert_ne!(v1, v3);
767    }
768
769    #[test]
770    fn test_composite_type_equality() {
771        let ct1 = CompositeType::new(make_ident("Address"), make_span());
772        let ct2 = CompositeType::new(make_ident("Address"), make_span());
773
774        assert_eq!(ct1, ct2);
775    }
776
777    #[test]
778    fn test_view_equality() {
779        let v1 = View::new(make_ident("Stats"), make_span());
780        let v2 = View::new(make_ident("Stats"), make_span());
781
782        assert_eq!(v1, v2);
783    }
784}