1use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{
8 CompositeType, Datasource, Enum, Generator, Model, Policy, Relation, ServerGroup, View,
9};
10
11#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
13pub struct Schema {
14 pub datasource: Option<Datasource>,
16 pub generators: IndexMap<SmolStr, Generator>,
18 pub models: IndexMap<SmolStr, Model>,
20 pub enums: IndexMap<SmolStr, Enum>,
22 pub types: IndexMap<SmolStr, CompositeType>,
24 pub views: IndexMap<SmolStr, View>,
26 pub server_groups: IndexMap<SmolStr, ServerGroup>,
28 pub policies: Vec<Policy>,
30 pub raw_sql: Vec<RawSql>,
32 pub relations: Vec<Relation>,
34}
35
36impl Schema {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn set_datasource(&mut self, datasource: Datasource) {
44 self.datasource = Some(datasource);
45 }
46
47 pub fn datasource(&self) -> Option<&Datasource> {
49 self.datasource.as_ref()
50 }
51
52 pub fn has_vector_support(&self) -> bool {
54 self.datasource
55 .as_ref()
56 .is_some_and(|ds| ds.has_vector_support())
57 }
58
59 pub fn required_extensions(&self) -> Vec<&super::PostgresExtension> {
61 self.datasource
62 .as_ref()
63 .map(|ds| ds.extensions.iter().collect())
64 .unwrap_or_default()
65 }
66
67 pub fn add_generator(&mut self, generator: Generator) {
69 self.generators.insert(generator.name.clone(), generator);
70 }
71
72 pub fn get_generator(&self, name: &str) -> Option<&Generator> {
74 self.generators.get(name)
75 }
76
77 pub fn enabled_generators(&self) -> Vec<&Generator> {
79 self.generators
80 .values()
81 .filter(|g| g.is_enabled())
82 .collect()
83 }
84
85 pub fn add_model(&mut self, model: Model) {
87 self.models.insert(model.name.name.clone(), model);
88 }
89
90 pub fn add_enum(&mut self, e: Enum) {
92 self.enums.insert(e.name.name.clone(), e);
93 }
94
95 pub fn add_type(&mut self, t: CompositeType) {
97 self.types.insert(t.name.name.clone(), t);
98 }
99
100 pub fn add_view(&mut self, v: View) {
102 self.views.insert(v.name.name.clone(), v);
103 }
104
105 pub fn add_server_group(&mut self, sg: ServerGroup) {
107 self.server_groups.insert(sg.name.name.clone(), sg);
108 }
109
110 pub fn add_policy(&mut self, policy: Policy) {
112 self.policies.push(policy);
113 }
114
115 pub fn add_raw_sql(&mut self, sql: RawSql) {
117 self.raw_sql.push(sql);
118 }
119
120 pub fn get_model(&self, name: &str) -> Option<&Model> {
122 self.models.get(name)
123 }
124
125 pub fn get_model_mut(&mut self, name: &str) -> Option<&mut Model> {
127 self.models.get_mut(name)
128 }
129
130 pub fn get_enum(&self, name: &str) -> Option<&Enum> {
132 self.enums.get(name)
133 }
134
135 pub fn get_type(&self, name: &str) -> Option<&CompositeType> {
137 self.types.get(name)
138 }
139
140 pub fn get_view(&self, name: &str) -> Option<&View> {
142 self.views.get(name)
143 }
144
145 pub fn get_server_group(&self, name: &str) -> Option<&ServerGroup> {
147 self.server_groups.get(name)
148 }
149
150 pub fn server_group_names(&self) -> impl Iterator<Item = &str> {
152 self.server_groups.keys().map(|s| s.as_str())
153 }
154
155 pub fn get_policy(&self, name: &str) -> Option<&Policy> {
157 self.policies.iter().find(|p| p.name() == name)
158 }
159
160 pub fn policies_for(&self, model: &str) -> Vec<&Policy> {
162 self.policies
163 .iter()
164 .filter(|p| p.table() == model)
165 .collect()
166 }
167
168 pub fn has_policies(&self, model: &str) -> bool {
170 self.policies.iter().any(|p| p.table() == model)
171 }
172
173 pub fn policy_names(&self) -> impl Iterator<Item = &str> {
175 self.policies.iter().map(|p| p.name())
176 }
177
178 pub fn type_exists(&self, name: &str) -> bool {
180 self.models.contains_key(name)
181 || self.enums.contains_key(name)
182 || self.types.contains_key(name)
183 || self.views.contains_key(name)
184 }
185
186 pub fn model_names(&self) -> impl Iterator<Item = &str> {
188 self.models.keys().map(|s| s.as_str())
189 }
190
191 pub fn enum_names(&self) -> impl Iterator<Item = &str> {
193 self.enums.keys().map(|s| s.as_str())
194 }
195
196 pub fn relations_for(&self, model: &str) -> Vec<&Relation> {
198 self.relations
199 .iter()
200 .filter(|r| r.from_model == model || r.to_model == model)
201 .collect()
202 }
203
204 pub fn relations_from(&self, model: &str) -> Vec<&Relation> {
206 self.relations
207 .iter()
208 .filter(|r| r.from_model == model)
209 .collect()
210 }
211
212 pub fn merge(&mut self, other: Schema) {
214 self.models.extend(other.models);
215 self.enums.extend(other.enums);
216 self.types.extend(other.types);
217 self.views.extend(other.views);
218 self.server_groups.extend(other.server_groups);
219 self.policies.extend(other.policies);
220 self.raw_sql.extend(other.raw_sql);
221 }
222}
223
224#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
226pub struct RawSql {
227 pub name: SmolStr,
229 pub sql: String,
231}
232
233impl RawSql {
234 pub fn new(name: impl Into<SmolStr>, sql: impl Into<String>) -> Self {
236 Self {
237 name: name.into(),
238 sql: sql.into(),
239 }
240 }
241}
242
243#[derive(Debug, Clone, Default)]
245pub struct SchemaStats {
246 pub model_count: usize,
248 pub enum_count: usize,
250 pub type_count: usize,
252 pub view_count: usize,
254 pub server_group_count: usize,
256 pub policy_count: usize,
258 pub field_count: usize,
260 pub relation_count: usize,
262}
263
264impl Schema {
265 pub fn stats(&self) -> SchemaStats {
267 SchemaStats {
268 model_count: self.models.len(),
269 enum_count: self.enums.len(),
270 type_count: self.types.len(),
271 view_count: self.views.len(),
272 server_group_count: self.server_groups.len(),
273 policy_count: self.policies.len(),
274 field_count: self.models.values().map(|m| m.fields.len()).sum(),
275 relation_count: self.relations.len(),
276 }
277 }
278}
279
280impl std::fmt::Display for Schema {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 let stats = self.stats();
283 write!(
284 f,
285 "Schema({} models, {} enums, {} types, {} views, {} server groups, {} policies, {} fields, {} relations)",
286 stats.model_count,
287 stats.enum_count,
288 stats.type_count,
289 stats.view_count,
290 stats.server_group_count,
291 stats.policy_count,
292 stats.field_count,
293 stats.relation_count
294 )
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::ast::{
302 Attribute, EnumVariant, Field, FieldType, Ident, Policy, RelationType, ScalarType, Span,
303 TypeModifier,
304 };
305
306 fn make_span() -> Span {
307 Span::new(0, 10)
308 }
309
310 fn make_ident(name: &str) -> Ident {
311 Ident::new(name, make_span())
312 }
313
314 fn make_model(name: &str) -> Model {
315 let mut model = Model::new(make_ident(name), make_span());
316 let id_field = make_id_field();
317 model.add_field(id_field);
318 model
319 }
320
321 fn make_id_field() -> Field {
322 let mut field = Field::new(
323 make_ident("id"),
324 FieldType::Scalar(ScalarType::Int),
325 TypeModifier::Required,
326 vec![],
327 make_span(),
328 );
329 field
330 .attributes
331 .push(Attribute::simple(make_ident("id"), make_span()));
332 field
333 }
334
335 fn make_field(name: &str, field_type: FieldType) -> Field {
336 Field::new(
337 make_ident(name),
338 field_type,
339 TypeModifier::Required,
340 vec![],
341 make_span(),
342 )
343 }
344
345 fn make_enum(name: &str, variants: &[&str]) -> Enum {
346 let mut e = Enum::new(make_ident(name), make_span());
347 for v in variants {
348 e.add_variant(EnumVariant::new(make_ident(v), make_span()));
349 }
350 e
351 }
352
353 #[test]
356 fn test_schema_new() {
357 let schema = Schema::new();
358 assert!(schema.models.is_empty());
359 assert!(schema.enums.is_empty());
360 assert!(schema.types.is_empty());
361 assert!(schema.views.is_empty());
362 assert!(schema.policies.is_empty());
363 assert!(schema.raw_sql.is_empty());
364 assert!(schema.relations.is_empty());
365 }
366
367 #[test]
368 fn test_schema_default() {
369 let schema = Schema::default();
370 assert!(schema.models.is_empty());
371 }
372
373 #[test]
374 fn test_schema_add_model() {
375 let mut schema = Schema::new();
376 let model = make_model("User");
377
378 schema.add_model(model);
379
380 assert_eq!(schema.models.len(), 1);
381 assert!(schema.models.contains_key("User"));
382 }
383
384 #[test]
385 fn test_schema_add_multiple_models() {
386 let mut schema = Schema::new();
387 schema.add_model(make_model("User"));
388 schema.add_model(make_model("Post"));
389 schema.add_model(make_model("Comment"));
390
391 assert_eq!(schema.models.len(), 3);
392 }
393
394 #[test]
395 fn test_schema_add_enum() {
396 let mut schema = Schema::new();
397 let e = make_enum("Role", &["User", "Admin"]);
398
399 schema.add_enum(e);
400
401 assert_eq!(schema.enums.len(), 1);
402 assert!(schema.enums.contains_key("Role"));
403 }
404
405 #[test]
406 fn test_schema_add_type() {
407 let mut schema = Schema::new();
408 let ct = CompositeType::new(make_ident("Address"), make_span());
409
410 schema.add_type(ct);
411
412 assert_eq!(schema.types.len(), 1);
413 assert!(schema.types.contains_key("Address"));
414 }
415
416 #[test]
417 fn test_schema_add_view() {
418 let mut schema = Schema::new();
419 let view = View::new(make_ident("UserStats"), make_span());
420
421 schema.add_view(view);
422
423 assert_eq!(schema.views.len(), 1);
424 assert!(schema.views.contains_key("UserStats"));
425 }
426
427 #[test]
428 fn test_schema_add_raw_sql() {
429 let mut schema = Schema::new();
430 let sql = RawSql::new("migration_1", "CREATE TABLE test ();");
431
432 schema.add_raw_sql(sql);
433
434 assert_eq!(schema.raw_sql.len(), 1);
435 }
436
437 #[test]
438 fn test_schema_get_model() {
439 let mut schema = Schema::new();
440 schema.add_model(make_model("User"));
441
442 let model = schema.get_model("User");
443 assert!(model.is_some());
444 assert_eq!(model.unwrap().name(), "User");
445
446 assert!(schema.get_model("NonExistent").is_none());
447 }
448
449 #[test]
450 fn test_schema_get_model_mut() {
451 let mut schema = Schema::new();
452 schema.add_model(make_model("User"));
453
454 let model = schema.get_model_mut("User");
455 assert!(model.is_some());
456
457 let model = model.unwrap();
459 model.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
460
461 assert_eq!(schema.get_model("User").unwrap().fields.len(), 2);
463 }
464
465 #[test]
466 fn test_schema_get_enum() {
467 let mut schema = Schema::new();
468 schema.add_enum(make_enum("Role", &["User", "Admin"]));
469
470 let e = schema.get_enum("Role");
471 assert!(e.is_some());
472 assert_eq!(e.unwrap().name(), "Role");
473
474 assert!(schema.get_enum("NonExistent").is_none());
475 }
476
477 #[test]
478 fn test_schema_get_type() {
479 let mut schema = Schema::new();
480 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
481
482 let ct = schema.get_type("Address");
483 assert!(ct.is_some());
484
485 assert!(schema.get_type("NonExistent").is_none());
486 }
487
488 #[test]
489 fn test_schema_get_view() {
490 let mut schema = Schema::new();
491 schema.add_view(View::new(make_ident("Stats"), make_span()));
492
493 let v = schema.get_view("Stats");
494 assert!(v.is_some());
495
496 assert!(schema.get_view("NonExistent").is_none());
497 }
498
499 #[test]
500 fn test_schema_type_exists() {
501 let mut schema = Schema::new();
502 schema.add_model(make_model("User"));
503 schema.add_enum(make_enum("Role", &["User"]));
504 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
505 schema.add_view(View::new(make_ident("Stats"), make_span()));
506
507 assert!(schema.type_exists("User")); assert!(schema.type_exists("Role")); assert!(schema.type_exists("Address")); assert!(schema.type_exists("Stats")); assert!(!schema.type_exists("NonExistent"));
512 }
513
514 #[test]
515 fn test_schema_model_names() {
516 let mut schema = Schema::new();
517 schema.add_model(make_model("User"));
518 schema.add_model(make_model("Post"));
519
520 let names: Vec<_> = schema.model_names().collect();
521 assert_eq!(names.len(), 2);
522 assert!(names.contains(&"User"));
523 assert!(names.contains(&"Post"));
524 }
525
526 #[test]
527 fn test_schema_enum_names() {
528 let mut schema = Schema::new();
529 schema.add_enum(make_enum("Role", &["User"]));
530 schema.add_enum(make_enum("Status", &["Active"]));
531
532 let names: Vec<_> = schema.enum_names().collect();
533 assert_eq!(names.len(), 2);
534 assert!(names.contains(&"Role"));
535 assert!(names.contains(&"Status"));
536 }
537
538 #[test]
539 fn test_schema_relations_for() {
540 let mut schema = Schema::new();
541 schema.relations.push(Relation::new(
542 "Post",
543 "author",
544 "User",
545 RelationType::ManyToOne,
546 ));
547 schema.relations.push(Relation::new(
548 "Comment",
549 "user",
550 "User",
551 RelationType::ManyToOne,
552 ));
553 schema.relations.push(Relation::new(
554 "Post",
555 "tags",
556 "Tag",
557 RelationType::ManyToMany,
558 ));
559
560 let user_relations = schema.relations_for("User");
561 assert_eq!(user_relations.len(), 2);
562
563 let post_relations = schema.relations_for("Post");
564 assert_eq!(post_relations.len(), 2);
565
566 let tag_relations = schema.relations_for("Tag");
567 assert_eq!(tag_relations.len(), 1);
568 }
569
570 #[test]
571 fn test_schema_relations_from() {
572 let mut schema = Schema::new();
573 schema.relations.push(Relation::new(
574 "Post",
575 "author",
576 "User",
577 RelationType::ManyToOne,
578 ));
579 schema.relations.push(Relation::new(
580 "Post",
581 "tags",
582 "Tag",
583 RelationType::ManyToMany,
584 ));
585 schema.relations.push(Relation::new(
586 "User",
587 "posts",
588 "Post",
589 RelationType::OneToMany,
590 ));
591
592 let post_relations = schema.relations_from("Post");
593 assert_eq!(post_relations.len(), 2);
594
595 let user_relations = schema.relations_from("User");
596 assert_eq!(user_relations.len(), 1);
597
598 let tag_relations = schema.relations_from("Tag");
599 assert_eq!(tag_relations.len(), 0);
600 }
601
602 #[test]
603 fn test_schema_merge() {
604 let mut schema1 = Schema::new();
605 schema1.add_model(make_model("User"));
606 schema1.add_enum(make_enum("Role", &["User"]));
607
608 let mut schema2 = Schema::new();
609 schema2.add_model(make_model("Post"));
610 schema2.add_enum(make_enum("Status", &["Active"]));
611 schema2.add_raw_sql(RawSql::new("init", "-- init"));
612
613 schema1.merge(schema2);
614
615 assert_eq!(schema1.models.len(), 2);
616 assert_eq!(schema1.enums.len(), 2);
617 assert_eq!(schema1.raw_sql.len(), 1);
618 }
619
620 #[test]
621 fn test_schema_stats() {
622 let mut schema = Schema::new();
623
624 let mut user = make_model("User");
625 user.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
626 user.add_field(make_field("name", FieldType::Scalar(ScalarType::String)));
627 schema.add_model(user);
628
629 let mut post = make_model("Post");
630 post.add_field(make_field("title", FieldType::Scalar(ScalarType::String)));
631 schema.add_model(post);
632
633 schema.add_enum(make_enum("Role", &["User", "Admin"]));
634 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
635 schema.add_view(View::new(make_ident("Stats"), make_span()));
636 schema.relations.push(Relation::new(
637 "Post",
638 "author",
639 "User",
640 RelationType::ManyToOne,
641 ));
642
643 let stats = schema.stats();
644 assert_eq!(stats.model_count, 2);
645 assert_eq!(stats.enum_count, 1);
646 assert_eq!(stats.type_count, 1);
647 assert_eq!(stats.view_count, 1);
648 assert_eq!(stats.field_count, 5); assert_eq!(stats.relation_count, 1);
650 }
651
652 #[test]
653 fn test_schema_display() {
654 let mut schema = Schema::new();
655 schema.add_model(make_model("User"));
656 schema.add_enum(make_enum("Role", &["User"]));
657
658 let display = format!("{}", schema);
659 assert!(display.contains("1 models"));
660 assert!(display.contains("1 enums"));
661 assert!(display.contains("0 policies"));
662 }
663
664 #[test]
665 fn test_schema_equality() {
666 let schema1 = Schema::new();
667 let schema2 = Schema::new();
668 assert_eq!(schema1, schema2);
669 }
670
671 #[test]
672 fn test_schema_clone() {
673 let mut schema = Schema::new();
674 schema.add_model(make_model("User"));
675
676 let cloned = schema.clone();
677 assert_eq!(cloned.models.len(), 1);
678 }
679
680 #[test]
683 fn test_raw_sql_new() {
684 let sql = RawSql::new("create_users", "CREATE TABLE users ();");
685
686 assert_eq!(sql.name.as_str(), "create_users");
687 assert_eq!(sql.sql, "CREATE TABLE users ();");
688 }
689
690 #[test]
691 fn test_raw_sql_from_strings() {
692 let name = String::from("migration");
693 let content = String::from("ALTER TABLE users ADD COLUMN age INT;");
694 let sql = RawSql::new(name, content);
695
696 assert_eq!(sql.name.as_str(), "migration");
697 }
698
699 #[test]
700 fn test_raw_sql_equality() {
701 let sql1 = RawSql::new("test", "SELECT 1;");
702 let sql2 = RawSql::new("test", "SELECT 1;");
703 let sql3 = RawSql::new("test", "SELECT 2;");
704
705 assert_eq!(sql1, sql2);
706 assert_ne!(sql1, sql3);
707 }
708
709 #[test]
710 fn test_raw_sql_clone() {
711 let sql = RawSql::new("test", "SELECT 1;");
712 let cloned = sql.clone();
713 assert_eq!(sql, cloned);
714 }
715
716 #[test]
719 fn test_schema_stats_default() {
720 let stats = SchemaStats::default();
721 assert_eq!(stats.model_count, 0);
722 assert_eq!(stats.enum_count, 0);
723 assert_eq!(stats.type_count, 0);
724 assert_eq!(stats.view_count, 0);
725 assert_eq!(stats.policy_count, 0);
726 assert_eq!(stats.field_count, 0);
727 assert_eq!(stats.relation_count, 0);
728 }
729
730 #[test]
731 fn test_schema_stats_debug() {
732 let stats = SchemaStats::default();
733 let debug = format!("{:?}", stats);
734 assert!(debug.contains("SchemaStats"));
735 }
736
737 #[test]
738 fn test_schema_stats_clone() {
739 let stats = SchemaStats {
740 model_count: 5,
741 enum_count: 2,
742 type_count: 1,
743 view_count: 3,
744 server_group_count: 2,
745 policy_count: 4,
746 field_count: 25,
747 relation_count: 10,
748 };
749 let cloned = stats.clone();
750 assert_eq!(cloned.model_count, 5);
751 assert_eq!(cloned.field_count, 25);
752 assert_eq!(cloned.policy_count, 4);
753 }
754
755 #[test]
758 fn test_schema_add_policy() {
759 let mut schema = Schema::new();
760 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
761
762 schema.add_policy(policy);
763
764 assert_eq!(schema.policies.len(), 1);
765 }
766
767 #[test]
768 fn test_schema_get_policy() {
769 let mut schema = Schema::new();
770 schema.add_policy(Policy::new(
771 make_ident("read_own"),
772 make_ident("User"),
773 make_span(),
774 ));
775
776 let policy = schema.get_policy("read_own");
777 assert!(policy.is_some());
778 assert_eq!(policy.unwrap().name(), "read_own");
779
780 assert!(schema.get_policy("nonexistent").is_none());
781 }
782
783 #[test]
784 fn test_schema_policies_for_model() {
785 let mut schema = Schema::new();
786 schema.add_policy(Policy::new(
787 make_ident("user_read"),
788 make_ident("User"),
789 make_span(),
790 ));
791 schema.add_policy(Policy::new(
792 make_ident("user_write"),
793 make_ident("User"),
794 make_span(),
795 ));
796 schema.add_policy(Policy::new(
797 make_ident("post_read"),
798 make_ident("Post"),
799 make_span(),
800 ));
801
802 let user_policies = schema.policies_for("User");
803 assert_eq!(user_policies.len(), 2);
804
805 let post_policies = schema.policies_for("Post");
806 assert_eq!(post_policies.len(), 1);
807
808 let comment_policies = schema.policies_for("Comment");
809 assert!(comment_policies.is_empty());
810 }
811
812 #[test]
813 fn test_schema_has_policies() {
814 let mut schema = Schema::new();
815 schema.add_policy(Policy::new(
816 make_ident("test"),
817 make_ident("User"),
818 make_span(),
819 ));
820
821 assert!(schema.has_policies("User"));
822 assert!(!schema.has_policies("Post"));
823 }
824
825 #[test]
826 fn test_schema_policy_names() {
827 let mut schema = Schema::new();
828 schema.add_policy(Policy::new(
829 make_ident("policy1"),
830 make_ident("User"),
831 make_span(),
832 ));
833 schema.add_policy(Policy::new(
834 make_ident("policy2"),
835 make_ident("Post"),
836 make_span(),
837 ));
838
839 let names: Vec<_> = schema.policy_names().collect();
840 assert_eq!(names.len(), 2);
841 assert!(names.contains(&"policy1"));
842 assert!(names.contains(&"policy2"));
843 }
844
845 #[test]
846 fn test_schema_merge_with_policies() {
847 let mut schema1 = Schema::new();
848 schema1.add_policy(Policy::new(
849 make_ident("policy1"),
850 make_ident("User"),
851 make_span(),
852 ));
853
854 let mut schema2 = Schema::new();
855 schema2.add_policy(Policy::new(
856 make_ident("policy2"),
857 make_ident("Post"),
858 make_span(),
859 ));
860
861 schema1.merge(schema2);
862
863 assert_eq!(schema1.policies.len(), 2);
864 }
865
866 #[test]
867 fn test_schema_stats_with_policies() {
868 let mut schema = Schema::new();
869 schema.add_model(make_model("User"));
870 schema.add_policy(Policy::new(
871 make_ident("policy1"),
872 make_ident("User"),
873 make_span(),
874 ));
875 schema.add_policy(Policy::new(
876 make_ident("policy2"),
877 make_ident("User"),
878 make_span(),
879 ));
880
881 let stats = schema.stats();
882 assert_eq!(stats.model_count, 1);
883 assert_eq!(stats.policy_count, 2);
884 }
885}