prax_schema/ast/
schema.rs

1//! Top-level schema definition.
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{CompositeType, Enum, Model, Relation, ServerGroup, View};
8
9/// A complete Prax schema.
10#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
11pub struct Schema {
12    /// All models in the schema.
13    pub models: IndexMap<SmolStr, Model>,
14    /// All enums in the schema.
15    pub enums: IndexMap<SmolStr, Enum>,
16    /// All composite types in the schema.
17    pub types: IndexMap<SmolStr, CompositeType>,
18    /// All views in the schema.
19    pub views: IndexMap<SmolStr, View>,
20    /// Server groups for multi-server configurations.
21    pub server_groups: IndexMap<SmolStr, ServerGroup>,
22    /// Raw SQL definitions.
23    pub raw_sql: Vec<RawSql>,
24    /// Resolved relations (populated after validation).
25    pub relations: Vec<Relation>,
26}
27
28impl Schema {
29    /// Create a new empty schema.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Add a model to the schema.
35    pub fn add_model(&mut self, model: Model) {
36        self.models.insert(model.name.name.clone(), model);
37    }
38
39    /// Add an enum to the schema.
40    pub fn add_enum(&mut self, e: Enum) {
41        self.enums.insert(e.name.name.clone(), e);
42    }
43
44    /// Add a composite type to the schema.
45    pub fn add_type(&mut self, t: CompositeType) {
46        self.types.insert(t.name.name.clone(), t);
47    }
48
49    /// Add a view to the schema.
50    pub fn add_view(&mut self, v: View) {
51        self.views.insert(v.name.name.clone(), v);
52    }
53
54    /// Add a server group to the schema.
55    pub fn add_server_group(&mut self, sg: ServerGroup) {
56        self.server_groups.insert(sg.name.name.clone(), sg);
57    }
58
59    /// Add a raw SQL definition.
60    pub fn add_raw_sql(&mut self, sql: RawSql) {
61        self.raw_sql.push(sql);
62    }
63
64    /// Get a model by name.
65    pub fn get_model(&self, name: &str) -> Option<&Model> {
66        self.models.get(name)
67    }
68
69    /// Get a mutable model by name.
70    pub fn get_model_mut(&mut self, name: &str) -> Option<&mut Model> {
71        self.models.get_mut(name)
72    }
73
74    /// Get an enum by name.
75    pub fn get_enum(&self, name: &str) -> Option<&Enum> {
76        self.enums.get(name)
77    }
78
79    /// Get a composite type by name.
80    pub fn get_type(&self, name: &str) -> Option<&CompositeType> {
81        self.types.get(name)
82    }
83
84    /// Get a view by name.
85    pub fn get_view(&self, name: &str) -> Option<&View> {
86        self.views.get(name)
87    }
88
89    /// Get a server group by name.
90    pub fn get_server_group(&self, name: &str) -> Option<&ServerGroup> {
91        self.server_groups.get(name)
92    }
93
94    /// Get all server group names.
95    pub fn server_group_names(&self) -> impl Iterator<Item = &str> {
96        self.server_groups.keys().map(|s| s.as_str())
97    }
98
99    /// Check if a type name exists (model, enum, type, or view).
100    pub fn type_exists(&self, name: &str) -> bool {
101        self.models.contains_key(name)
102            || self.enums.contains_key(name)
103            || self.types.contains_key(name)
104            || self.views.contains_key(name)
105    }
106
107    /// Get all model names.
108    pub fn model_names(&self) -> impl Iterator<Item = &str> {
109        self.models.keys().map(|s| s.as_str())
110    }
111
112    /// Get all enum names.
113    pub fn enum_names(&self) -> impl Iterator<Item = &str> {
114        self.enums.keys().map(|s| s.as_str())
115    }
116
117    /// Get relations for a specific model.
118    pub fn relations_for(&self, model: &str) -> Vec<&Relation> {
119        self.relations
120            .iter()
121            .filter(|r| r.from_model == model || r.to_model == model)
122            .collect()
123    }
124
125    /// Get relations originating from a specific model.
126    pub fn relations_from(&self, model: &str) -> Vec<&Relation> {
127        self.relations
128            .iter()
129            .filter(|r| r.from_model == model)
130            .collect()
131    }
132
133    /// Merge another schema into this one.
134    pub fn merge(&mut self, other: Schema) {
135        self.models.extend(other.models);
136        self.enums.extend(other.enums);
137        self.types.extend(other.types);
138        self.views.extend(other.views);
139        self.server_groups.extend(other.server_groups);
140        self.raw_sql.extend(other.raw_sql);
141    }
142}
143
144/// A raw SQL definition.
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub struct RawSql {
147    /// Name/identifier for the SQL (e.g., view name).
148    pub name: SmolStr,
149    /// The raw SQL content.
150    pub sql: String,
151}
152
153impl RawSql {
154    /// Create a new raw SQL definition.
155    pub fn new(name: impl Into<SmolStr>, sql: impl Into<String>) -> Self {
156        Self {
157            name: name.into(),
158            sql: sql.into(),
159        }
160    }
161}
162
163/// Schema statistics for debugging/info.
164#[derive(Debug, Clone, Default)]
165pub struct SchemaStats {
166    /// Number of models.
167    pub model_count: usize,
168    /// Number of enums.
169    pub enum_count: usize,
170    /// Number of composite types.
171    pub type_count: usize,
172    /// Number of views.
173    pub view_count: usize,
174    /// Number of server groups.
175    pub server_group_count: usize,
176    /// Total number of fields across all models.
177    pub field_count: usize,
178    /// Number of relations.
179    pub relation_count: usize,
180}
181
182impl Schema {
183    /// Get statistics about the schema.
184    pub fn stats(&self) -> SchemaStats {
185        SchemaStats {
186            model_count: self.models.len(),
187            enum_count: self.enums.len(),
188            type_count: self.types.len(),
189            view_count: self.views.len(),
190            server_group_count: self.server_groups.len(),
191            field_count: self.models.values().map(|m| m.fields.len()).sum(),
192            relation_count: self.relations.len(),
193        }
194    }
195}
196
197impl std::fmt::Display for Schema {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        let stats = self.stats();
200        write!(
201            f,
202            "Schema({} models, {} enums, {} types, {} views, {} server groups, {} fields, {} relations)",
203            stats.model_count,
204            stats.enum_count,
205            stats.type_count,
206            stats.view_count,
207            stats.server_group_count,
208            stats.field_count,
209            stats.relation_count
210        )
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::ast::{
218        Attribute, EnumVariant, Field, FieldType, Ident, RelationType, ScalarType, Span,
219        TypeModifier,
220    };
221
222    fn make_span() -> Span {
223        Span::new(0, 10)
224    }
225
226    fn make_ident(name: &str) -> Ident {
227        Ident::new(name, make_span())
228    }
229
230    fn make_model(name: &str) -> Model {
231        let mut model = Model::new(make_ident(name), make_span());
232        let id_field = make_id_field();
233        model.add_field(id_field);
234        model
235    }
236
237    fn make_id_field() -> Field {
238        let mut field = Field::new(
239            make_ident("id"),
240            FieldType::Scalar(ScalarType::Int),
241            TypeModifier::Required,
242            vec![],
243            make_span(),
244        );
245        field
246            .attributes
247            .push(Attribute::simple(make_ident("id"), make_span()));
248        field
249    }
250
251    fn make_field(name: &str, field_type: FieldType) -> Field {
252        Field::new(
253            make_ident(name),
254            field_type,
255            TypeModifier::Required,
256            vec![],
257            make_span(),
258        )
259    }
260
261    fn make_enum(name: &str, variants: &[&str]) -> Enum {
262        let mut e = Enum::new(make_ident(name), make_span());
263        for v in variants {
264            e.add_variant(EnumVariant::new(make_ident(v), make_span()));
265        }
266        e
267    }
268
269    // ==================== Schema Tests ====================
270
271    #[test]
272    fn test_schema_new() {
273        let schema = Schema::new();
274        assert!(schema.models.is_empty());
275        assert!(schema.enums.is_empty());
276        assert!(schema.types.is_empty());
277        assert!(schema.views.is_empty());
278        assert!(schema.raw_sql.is_empty());
279        assert!(schema.relations.is_empty());
280    }
281
282    #[test]
283    fn test_schema_default() {
284        let schema = Schema::default();
285        assert!(schema.models.is_empty());
286    }
287
288    #[test]
289    fn test_schema_add_model() {
290        let mut schema = Schema::new();
291        let model = make_model("User");
292
293        schema.add_model(model);
294
295        assert_eq!(schema.models.len(), 1);
296        assert!(schema.models.contains_key("User"));
297    }
298
299    #[test]
300    fn test_schema_add_multiple_models() {
301        let mut schema = Schema::new();
302        schema.add_model(make_model("User"));
303        schema.add_model(make_model("Post"));
304        schema.add_model(make_model("Comment"));
305
306        assert_eq!(schema.models.len(), 3);
307    }
308
309    #[test]
310    fn test_schema_add_enum() {
311        let mut schema = Schema::new();
312        let e = make_enum("Role", &["User", "Admin"]);
313
314        schema.add_enum(e);
315
316        assert_eq!(schema.enums.len(), 1);
317        assert!(schema.enums.contains_key("Role"));
318    }
319
320    #[test]
321    fn test_schema_add_type() {
322        let mut schema = Schema::new();
323        let ct = CompositeType::new(make_ident("Address"), make_span());
324
325        schema.add_type(ct);
326
327        assert_eq!(schema.types.len(), 1);
328        assert!(schema.types.contains_key("Address"));
329    }
330
331    #[test]
332    fn test_schema_add_view() {
333        let mut schema = Schema::new();
334        let view = View::new(make_ident("UserStats"), make_span());
335
336        schema.add_view(view);
337
338        assert_eq!(schema.views.len(), 1);
339        assert!(schema.views.contains_key("UserStats"));
340    }
341
342    #[test]
343    fn test_schema_add_raw_sql() {
344        let mut schema = Schema::new();
345        let sql = RawSql::new("migration_1", "CREATE TABLE test ();");
346
347        schema.add_raw_sql(sql);
348
349        assert_eq!(schema.raw_sql.len(), 1);
350    }
351
352    #[test]
353    fn test_schema_get_model() {
354        let mut schema = Schema::new();
355        schema.add_model(make_model("User"));
356
357        let model = schema.get_model("User");
358        assert!(model.is_some());
359        assert_eq!(model.unwrap().name(), "User");
360
361        assert!(schema.get_model("NonExistent").is_none());
362    }
363
364    #[test]
365    fn test_schema_get_model_mut() {
366        let mut schema = Schema::new();
367        schema.add_model(make_model("User"));
368
369        let model = schema.get_model_mut("User");
370        assert!(model.is_some());
371
372        // Modify the model
373        let model = model.unwrap();
374        model.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
375
376        // Verify modification persisted
377        assert_eq!(schema.get_model("User").unwrap().fields.len(), 2);
378    }
379
380    #[test]
381    fn test_schema_get_enum() {
382        let mut schema = Schema::new();
383        schema.add_enum(make_enum("Role", &["User", "Admin"]));
384
385        let e = schema.get_enum("Role");
386        assert!(e.is_some());
387        assert_eq!(e.unwrap().name(), "Role");
388
389        assert!(schema.get_enum("NonExistent").is_none());
390    }
391
392    #[test]
393    fn test_schema_get_type() {
394        let mut schema = Schema::new();
395        schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
396
397        let ct = schema.get_type("Address");
398        assert!(ct.is_some());
399
400        assert!(schema.get_type("NonExistent").is_none());
401    }
402
403    #[test]
404    fn test_schema_get_view() {
405        let mut schema = Schema::new();
406        schema.add_view(View::new(make_ident("Stats"), make_span()));
407
408        let v = schema.get_view("Stats");
409        assert!(v.is_some());
410
411        assert!(schema.get_view("NonExistent").is_none());
412    }
413
414    #[test]
415    fn test_schema_type_exists() {
416        let mut schema = Schema::new();
417        schema.add_model(make_model("User"));
418        schema.add_enum(make_enum("Role", &["User"]));
419        schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
420        schema.add_view(View::new(make_ident("Stats"), make_span()));
421
422        assert!(schema.type_exists("User")); // model
423        assert!(schema.type_exists("Role")); // enum
424        assert!(schema.type_exists("Address")); // type
425        assert!(schema.type_exists("Stats")); // view
426        assert!(!schema.type_exists("NonExistent"));
427    }
428
429    #[test]
430    fn test_schema_model_names() {
431        let mut schema = Schema::new();
432        schema.add_model(make_model("User"));
433        schema.add_model(make_model("Post"));
434
435        let names: Vec<_> = schema.model_names().collect();
436        assert_eq!(names.len(), 2);
437        assert!(names.contains(&"User"));
438        assert!(names.contains(&"Post"));
439    }
440
441    #[test]
442    fn test_schema_enum_names() {
443        let mut schema = Schema::new();
444        schema.add_enum(make_enum("Role", &["User"]));
445        schema.add_enum(make_enum("Status", &["Active"]));
446
447        let names: Vec<_> = schema.enum_names().collect();
448        assert_eq!(names.len(), 2);
449        assert!(names.contains(&"Role"));
450        assert!(names.contains(&"Status"));
451    }
452
453    #[test]
454    fn test_schema_relations_for() {
455        let mut schema = Schema::new();
456        schema.relations.push(Relation::new(
457            "Post",
458            "author",
459            "User",
460            RelationType::ManyToOne,
461        ));
462        schema.relations.push(Relation::new(
463            "Comment",
464            "user",
465            "User",
466            RelationType::ManyToOne,
467        ));
468        schema.relations.push(Relation::new(
469            "Post",
470            "tags",
471            "Tag",
472            RelationType::ManyToMany,
473        ));
474
475        let user_relations = schema.relations_for("User");
476        assert_eq!(user_relations.len(), 2);
477
478        let post_relations = schema.relations_for("Post");
479        assert_eq!(post_relations.len(), 2);
480
481        let tag_relations = schema.relations_for("Tag");
482        assert_eq!(tag_relations.len(), 1);
483    }
484
485    #[test]
486    fn test_schema_relations_from() {
487        let mut schema = Schema::new();
488        schema.relations.push(Relation::new(
489            "Post",
490            "author",
491            "User",
492            RelationType::ManyToOne,
493        ));
494        schema.relations.push(Relation::new(
495            "Post",
496            "tags",
497            "Tag",
498            RelationType::ManyToMany,
499        ));
500        schema.relations.push(Relation::new(
501            "User",
502            "posts",
503            "Post",
504            RelationType::OneToMany,
505        ));
506
507        let post_relations = schema.relations_from("Post");
508        assert_eq!(post_relations.len(), 2);
509
510        let user_relations = schema.relations_from("User");
511        assert_eq!(user_relations.len(), 1);
512
513        let tag_relations = schema.relations_from("Tag");
514        assert_eq!(tag_relations.len(), 0);
515    }
516
517    #[test]
518    fn test_schema_merge() {
519        let mut schema1 = Schema::new();
520        schema1.add_model(make_model("User"));
521        schema1.add_enum(make_enum("Role", &["User"]));
522
523        let mut schema2 = Schema::new();
524        schema2.add_model(make_model("Post"));
525        schema2.add_enum(make_enum("Status", &["Active"]));
526        schema2.add_raw_sql(RawSql::new("init", "-- init"));
527
528        schema1.merge(schema2);
529
530        assert_eq!(schema1.models.len(), 2);
531        assert_eq!(schema1.enums.len(), 2);
532        assert_eq!(schema1.raw_sql.len(), 1);
533    }
534
535    #[test]
536    fn test_schema_stats() {
537        let mut schema = Schema::new();
538
539        let mut user = make_model("User");
540        user.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
541        user.add_field(make_field("name", FieldType::Scalar(ScalarType::String)));
542        schema.add_model(user);
543
544        let mut post = make_model("Post");
545        post.add_field(make_field("title", FieldType::Scalar(ScalarType::String)));
546        schema.add_model(post);
547
548        schema.add_enum(make_enum("Role", &["User", "Admin"]));
549        schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
550        schema.add_view(View::new(make_ident("Stats"), make_span()));
551        schema.relations.push(Relation::new(
552            "Post",
553            "author",
554            "User",
555            RelationType::ManyToOne,
556        ));
557
558        let stats = schema.stats();
559        assert_eq!(stats.model_count, 2);
560        assert_eq!(stats.enum_count, 1);
561        assert_eq!(stats.type_count, 1);
562        assert_eq!(stats.view_count, 1);
563        assert_eq!(stats.field_count, 5); // 3 in User + 2 in Post
564        assert_eq!(stats.relation_count, 1);
565    }
566
567    #[test]
568    fn test_schema_display() {
569        let mut schema = Schema::new();
570        schema.add_model(make_model("User"));
571        schema.add_enum(make_enum("Role", &["User"]));
572
573        let display = format!("{}", schema);
574        assert!(display.contains("1 models"));
575        assert!(display.contains("1 enums"));
576    }
577
578    #[test]
579    fn test_schema_equality() {
580        let schema1 = Schema::new();
581        let schema2 = Schema::new();
582        assert_eq!(schema1, schema2);
583    }
584
585    #[test]
586    fn test_schema_clone() {
587        let mut schema = Schema::new();
588        schema.add_model(make_model("User"));
589
590        let cloned = schema.clone();
591        assert_eq!(cloned.models.len(), 1);
592    }
593
594    // ==================== RawSql Tests ====================
595
596    #[test]
597    fn test_raw_sql_new() {
598        let sql = RawSql::new("create_users", "CREATE TABLE users ();");
599
600        assert_eq!(sql.name.as_str(), "create_users");
601        assert_eq!(sql.sql, "CREATE TABLE users ();");
602    }
603
604    #[test]
605    fn test_raw_sql_from_strings() {
606        let name = String::from("migration");
607        let content = String::from("ALTER TABLE users ADD COLUMN age INT;");
608        let sql = RawSql::new(name, content);
609
610        assert_eq!(sql.name.as_str(), "migration");
611    }
612
613    #[test]
614    fn test_raw_sql_equality() {
615        let sql1 = RawSql::new("test", "SELECT 1;");
616        let sql2 = RawSql::new("test", "SELECT 1;");
617        let sql3 = RawSql::new("test", "SELECT 2;");
618
619        assert_eq!(sql1, sql2);
620        assert_ne!(sql1, sql3);
621    }
622
623    #[test]
624    fn test_raw_sql_clone() {
625        let sql = RawSql::new("test", "SELECT 1;");
626        let cloned = sql.clone();
627        assert_eq!(sql, cloned);
628    }
629
630    // ==================== SchemaStats Tests ====================
631
632    #[test]
633    fn test_schema_stats_default() {
634        let stats = SchemaStats::default();
635        assert_eq!(stats.model_count, 0);
636        assert_eq!(stats.enum_count, 0);
637        assert_eq!(stats.type_count, 0);
638        assert_eq!(stats.view_count, 0);
639        assert_eq!(stats.field_count, 0);
640        assert_eq!(stats.relation_count, 0);
641    }
642
643    #[test]
644    fn test_schema_stats_debug() {
645        let stats = SchemaStats::default();
646        let debug = format!("{:?}", stats);
647        assert!(debug.contains("SchemaStats"));
648    }
649
650    #[test]
651    fn test_schema_stats_clone() {
652        let stats = SchemaStats {
653            model_count: 5,
654            enum_count: 2,
655            type_count: 1,
656            view_count: 3,
657            server_group_count: 2,
658            field_count: 25,
659            relation_count: 10,
660        };
661        let cloned = stats.clone();
662        assert_eq!(cloned.model_count, 5);
663        assert_eq!(cloned.field_count, 25);
664    }
665}