1use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{CompositeType, Datasource, Enum, Model, Policy, Relation, ServerGroup, View};
8
9#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
11pub struct Schema {
12 pub datasource: Option<Datasource>,
14 pub models: IndexMap<SmolStr, Model>,
16 pub enums: IndexMap<SmolStr, Enum>,
18 pub types: IndexMap<SmolStr, CompositeType>,
20 pub views: IndexMap<SmolStr, View>,
22 pub server_groups: IndexMap<SmolStr, ServerGroup>,
24 pub policies: Vec<Policy>,
26 pub raw_sql: Vec<RawSql>,
28 pub relations: Vec<Relation>,
30}
31
32impl Schema {
33 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn set_datasource(&mut self, datasource: Datasource) {
40 self.datasource = Some(datasource);
41 }
42
43 pub fn datasource(&self) -> Option<&Datasource> {
45 self.datasource.as_ref()
46 }
47
48 pub fn has_vector_support(&self) -> bool {
50 self.datasource
51 .as_ref()
52 .is_some_and(|ds| ds.has_vector_support())
53 }
54
55 pub fn required_extensions(&self) -> Vec<&super::PostgresExtension> {
57 self.datasource
58 .as_ref()
59 .map(|ds| ds.extensions.iter().collect())
60 .unwrap_or_default()
61 }
62
63 pub fn add_model(&mut self, model: Model) {
65 self.models.insert(model.name.name.clone(), model);
66 }
67
68 pub fn add_enum(&mut self, e: Enum) {
70 self.enums.insert(e.name.name.clone(), e);
71 }
72
73 pub fn add_type(&mut self, t: CompositeType) {
75 self.types.insert(t.name.name.clone(), t);
76 }
77
78 pub fn add_view(&mut self, v: View) {
80 self.views.insert(v.name.name.clone(), v);
81 }
82
83 pub fn add_server_group(&mut self, sg: ServerGroup) {
85 self.server_groups.insert(sg.name.name.clone(), sg);
86 }
87
88 pub fn add_policy(&mut self, policy: Policy) {
90 self.policies.push(policy);
91 }
92
93 pub fn add_raw_sql(&mut self, sql: RawSql) {
95 self.raw_sql.push(sql);
96 }
97
98 pub fn get_model(&self, name: &str) -> Option<&Model> {
100 self.models.get(name)
101 }
102
103 pub fn get_model_mut(&mut self, name: &str) -> Option<&mut Model> {
105 self.models.get_mut(name)
106 }
107
108 pub fn get_enum(&self, name: &str) -> Option<&Enum> {
110 self.enums.get(name)
111 }
112
113 pub fn get_type(&self, name: &str) -> Option<&CompositeType> {
115 self.types.get(name)
116 }
117
118 pub fn get_view(&self, name: &str) -> Option<&View> {
120 self.views.get(name)
121 }
122
123 pub fn get_server_group(&self, name: &str) -> Option<&ServerGroup> {
125 self.server_groups.get(name)
126 }
127
128 pub fn server_group_names(&self) -> impl Iterator<Item = &str> {
130 self.server_groups.keys().map(|s| s.as_str())
131 }
132
133 pub fn get_policy(&self, name: &str) -> Option<&Policy> {
135 self.policies.iter().find(|p| p.name() == name)
136 }
137
138 pub fn policies_for(&self, model: &str) -> Vec<&Policy> {
140 self.policies
141 .iter()
142 .filter(|p| p.table() == model)
143 .collect()
144 }
145
146 pub fn has_policies(&self, model: &str) -> bool {
148 self.policies.iter().any(|p| p.table() == model)
149 }
150
151 pub fn policy_names(&self) -> impl Iterator<Item = &str> {
153 self.policies.iter().map(|p| p.name())
154 }
155
156 pub fn type_exists(&self, name: &str) -> bool {
158 self.models.contains_key(name)
159 || self.enums.contains_key(name)
160 || self.types.contains_key(name)
161 || self.views.contains_key(name)
162 }
163
164 pub fn model_names(&self) -> impl Iterator<Item = &str> {
166 self.models.keys().map(|s| s.as_str())
167 }
168
169 pub fn enum_names(&self) -> impl Iterator<Item = &str> {
171 self.enums.keys().map(|s| s.as_str())
172 }
173
174 pub fn relations_for(&self, model: &str) -> Vec<&Relation> {
176 self.relations
177 .iter()
178 .filter(|r| r.from_model == model || r.to_model == model)
179 .collect()
180 }
181
182 pub fn relations_from(&self, model: &str) -> Vec<&Relation> {
184 self.relations
185 .iter()
186 .filter(|r| r.from_model == model)
187 .collect()
188 }
189
190 pub fn merge(&mut self, other: Schema) {
192 self.models.extend(other.models);
193 self.enums.extend(other.enums);
194 self.types.extend(other.types);
195 self.views.extend(other.views);
196 self.server_groups.extend(other.server_groups);
197 self.policies.extend(other.policies);
198 self.raw_sql.extend(other.raw_sql);
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub struct RawSql {
205 pub name: SmolStr,
207 pub sql: String,
209}
210
211impl RawSql {
212 pub fn new(name: impl Into<SmolStr>, sql: impl Into<String>) -> Self {
214 Self {
215 name: name.into(),
216 sql: sql.into(),
217 }
218 }
219}
220
221#[derive(Debug, Clone, Default)]
223pub struct SchemaStats {
224 pub model_count: usize,
226 pub enum_count: usize,
228 pub type_count: usize,
230 pub view_count: usize,
232 pub server_group_count: usize,
234 pub policy_count: usize,
236 pub field_count: usize,
238 pub relation_count: usize,
240}
241
242impl Schema {
243 pub fn stats(&self) -> SchemaStats {
245 SchemaStats {
246 model_count: self.models.len(),
247 enum_count: self.enums.len(),
248 type_count: self.types.len(),
249 view_count: self.views.len(),
250 server_group_count: self.server_groups.len(),
251 policy_count: self.policies.len(),
252 field_count: self.models.values().map(|m| m.fields.len()).sum(),
253 relation_count: self.relations.len(),
254 }
255 }
256}
257
258impl std::fmt::Display for Schema {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 let stats = self.stats();
261 write!(
262 f,
263 "Schema({} models, {} enums, {} types, {} views, {} server groups, {} policies, {} fields, {} relations)",
264 stats.model_count,
265 stats.enum_count,
266 stats.type_count,
267 stats.view_count,
268 stats.server_group_count,
269 stats.policy_count,
270 stats.field_count,
271 stats.relation_count
272 )
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::ast::{
280 Attribute, EnumVariant, Field, FieldType, Ident, Policy, RelationType, ScalarType, Span,
281 TypeModifier,
282 };
283
284 fn make_span() -> Span {
285 Span::new(0, 10)
286 }
287
288 fn make_ident(name: &str) -> Ident {
289 Ident::new(name, make_span())
290 }
291
292 fn make_model(name: &str) -> Model {
293 let mut model = Model::new(make_ident(name), make_span());
294 let id_field = make_id_field();
295 model.add_field(id_field);
296 model
297 }
298
299 fn make_id_field() -> Field {
300 let mut field = Field::new(
301 make_ident("id"),
302 FieldType::Scalar(ScalarType::Int),
303 TypeModifier::Required,
304 vec![],
305 make_span(),
306 );
307 field
308 .attributes
309 .push(Attribute::simple(make_ident("id"), make_span()));
310 field
311 }
312
313 fn make_field(name: &str, field_type: FieldType) -> Field {
314 Field::new(
315 make_ident(name),
316 field_type,
317 TypeModifier::Required,
318 vec![],
319 make_span(),
320 )
321 }
322
323 fn make_enum(name: &str, variants: &[&str]) -> Enum {
324 let mut e = Enum::new(make_ident(name), make_span());
325 for v in variants {
326 e.add_variant(EnumVariant::new(make_ident(v), make_span()));
327 }
328 e
329 }
330
331 #[test]
334 fn test_schema_new() {
335 let schema = Schema::new();
336 assert!(schema.models.is_empty());
337 assert!(schema.enums.is_empty());
338 assert!(schema.types.is_empty());
339 assert!(schema.views.is_empty());
340 assert!(schema.policies.is_empty());
341 assert!(schema.raw_sql.is_empty());
342 assert!(schema.relations.is_empty());
343 }
344
345 #[test]
346 fn test_schema_default() {
347 let schema = Schema::default();
348 assert!(schema.models.is_empty());
349 }
350
351 #[test]
352 fn test_schema_add_model() {
353 let mut schema = Schema::new();
354 let model = make_model("User");
355
356 schema.add_model(model);
357
358 assert_eq!(schema.models.len(), 1);
359 assert!(schema.models.contains_key("User"));
360 }
361
362 #[test]
363 fn test_schema_add_multiple_models() {
364 let mut schema = Schema::new();
365 schema.add_model(make_model("User"));
366 schema.add_model(make_model("Post"));
367 schema.add_model(make_model("Comment"));
368
369 assert_eq!(schema.models.len(), 3);
370 }
371
372 #[test]
373 fn test_schema_add_enum() {
374 let mut schema = Schema::new();
375 let e = make_enum("Role", &["User", "Admin"]);
376
377 schema.add_enum(e);
378
379 assert_eq!(schema.enums.len(), 1);
380 assert!(schema.enums.contains_key("Role"));
381 }
382
383 #[test]
384 fn test_schema_add_type() {
385 let mut schema = Schema::new();
386 let ct = CompositeType::new(make_ident("Address"), make_span());
387
388 schema.add_type(ct);
389
390 assert_eq!(schema.types.len(), 1);
391 assert!(schema.types.contains_key("Address"));
392 }
393
394 #[test]
395 fn test_schema_add_view() {
396 let mut schema = Schema::new();
397 let view = View::new(make_ident("UserStats"), make_span());
398
399 schema.add_view(view);
400
401 assert_eq!(schema.views.len(), 1);
402 assert!(schema.views.contains_key("UserStats"));
403 }
404
405 #[test]
406 fn test_schema_add_raw_sql() {
407 let mut schema = Schema::new();
408 let sql = RawSql::new("migration_1", "CREATE TABLE test ();");
409
410 schema.add_raw_sql(sql);
411
412 assert_eq!(schema.raw_sql.len(), 1);
413 }
414
415 #[test]
416 fn test_schema_get_model() {
417 let mut schema = Schema::new();
418 schema.add_model(make_model("User"));
419
420 let model = schema.get_model("User");
421 assert!(model.is_some());
422 assert_eq!(model.unwrap().name(), "User");
423
424 assert!(schema.get_model("NonExistent").is_none());
425 }
426
427 #[test]
428 fn test_schema_get_model_mut() {
429 let mut schema = Schema::new();
430 schema.add_model(make_model("User"));
431
432 let model = schema.get_model_mut("User");
433 assert!(model.is_some());
434
435 let model = model.unwrap();
437 model.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
438
439 assert_eq!(schema.get_model("User").unwrap().fields.len(), 2);
441 }
442
443 #[test]
444 fn test_schema_get_enum() {
445 let mut schema = Schema::new();
446 schema.add_enum(make_enum("Role", &["User", "Admin"]));
447
448 let e = schema.get_enum("Role");
449 assert!(e.is_some());
450 assert_eq!(e.unwrap().name(), "Role");
451
452 assert!(schema.get_enum("NonExistent").is_none());
453 }
454
455 #[test]
456 fn test_schema_get_type() {
457 let mut schema = Schema::new();
458 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
459
460 let ct = schema.get_type("Address");
461 assert!(ct.is_some());
462
463 assert!(schema.get_type("NonExistent").is_none());
464 }
465
466 #[test]
467 fn test_schema_get_view() {
468 let mut schema = Schema::new();
469 schema.add_view(View::new(make_ident("Stats"), make_span()));
470
471 let v = schema.get_view("Stats");
472 assert!(v.is_some());
473
474 assert!(schema.get_view("NonExistent").is_none());
475 }
476
477 #[test]
478 fn test_schema_type_exists() {
479 let mut schema = Schema::new();
480 schema.add_model(make_model("User"));
481 schema.add_enum(make_enum("Role", &["User"]));
482 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
483 schema.add_view(View::new(make_ident("Stats"), make_span()));
484
485 assert!(schema.type_exists("User")); assert!(schema.type_exists("Role")); assert!(schema.type_exists("Address")); assert!(schema.type_exists("Stats")); assert!(!schema.type_exists("NonExistent"));
490 }
491
492 #[test]
493 fn test_schema_model_names() {
494 let mut schema = Schema::new();
495 schema.add_model(make_model("User"));
496 schema.add_model(make_model("Post"));
497
498 let names: Vec<_> = schema.model_names().collect();
499 assert_eq!(names.len(), 2);
500 assert!(names.contains(&"User"));
501 assert!(names.contains(&"Post"));
502 }
503
504 #[test]
505 fn test_schema_enum_names() {
506 let mut schema = Schema::new();
507 schema.add_enum(make_enum("Role", &["User"]));
508 schema.add_enum(make_enum("Status", &["Active"]));
509
510 let names: Vec<_> = schema.enum_names().collect();
511 assert_eq!(names.len(), 2);
512 assert!(names.contains(&"Role"));
513 assert!(names.contains(&"Status"));
514 }
515
516 #[test]
517 fn test_schema_relations_for() {
518 let mut schema = Schema::new();
519 schema.relations.push(Relation::new(
520 "Post",
521 "author",
522 "User",
523 RelationType::ManyToOne,
524 ));
525 schema.relations.push(Relation::new(
526 "Comment",
527 "user",
528 "User",
529 RelationType::ManyToOne,
530 ));
531 schema.relations.push(Relation::new(
532 "Post",
533 "tags",
534 "Tag",
535 RelationType::ManyToMany,
536 ));
537
538 let user_relations = schema.relations_for("User");
539 assert_eq!(user_relations.len(), 2);
540
541 let post_relations = schema.relations_for("Post");
542 assert_eq!(post_relations.len(), 2);
543
544 let tag_relations = schema.relations_for("Tag");
545 assert_eq!(tag_relations.len(), 1);
546 }
547
548 #[test]
549 fn test_schema_relations_from() {
550 let mut schema = Schema::new();
551 schema.relations.push(Relation::new(
552 "Post",
553 "author",
554 "User",
555 RelationType::ManyToOne,
556 ));
557 schema.relations.push(Relation::new(
558 "Post",
559 "tags",
560 "Tag",
561 RelationType::ManyToMany,
562 ));
563 schema.relations.push(Relation::new(
564 "User",
565 "posts",
566 "Post",
567 RelationType::OneToMany,
568 ));
569
570 let post_relations = schema.relations_from("Post");
571 assert_eq!(post_relations.len(), 2);
572
573 let user_relations = schema.relations_from("User");
574 assert_eq!(user_relations.len(), 1);
575
576 let tag_relations = schema.relations_from("Tag");
577 assert_eq!(tag_relations.len(), 0);
578 }
579
580 #[test]
581 fn test_schema_merge() {
582 let mut schema1 = Schema::new();
583 schema1.add_model(make_model("User"));
584 schema1.add_enum(make_enum("Role", &["User"]));
585
586 let mut schema2 = Schema::new();
587 schema2.add_model(make_model("Post"));
588 schema2.add_enum(make_enum("Status", &["Active"]));
589 schema2.add_raw_sql(RawSql::new("init", "-- init"));
590
591 schema1.merge(schema2);
592
593 assert_eq!(schema1.models.len(), 2);
594 assert_eq!(schema1.enums.len(), 2);
595 assert_eq!(schema1.raw_sql.len(), 1);
596 }
597
598 #[test]
599 fn test_schema_stats() {
600 let mut schema = Schema::new();
601
602 let mut user = make_model("User");
603 user.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
604 user.add_field(make_field("name", FieldType::Scalar(ScalarType::String)));
605 schema.add_model(user);
606
607 let mut post = make_model("Post");
608 post.add_field(make_field("title", FieldType::Scalar(ScalarType::String)));
609 schema.add_model(post);
610
611 schema.add_enum(make_enum("Role", &["User", "Admin"]));
612 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
613 schema.add_view(View::new(make_ident("Stats"), make_span()));
614 schema.relations.push(Relation::new(
615 "Post",
616 "author",
617 "User",
618 RelationType::ManyToOne,
619 ));
620
621 let stats = schema.stats();
622 assert_eq!(stats.model_count, 2);
623 assert_eq!(stats.enum_count, 1);
624 assert_eq!(stats.type_count, 1);
625 assert_eq!(stats.view_count, 1);
626 assert_eq!(stats.field_count, 5); assert_eq!(stats.relation_count, 1);
628 }
629
630 #[test]
631 fn test_schema_display() {
632 let mut schema = Schema::new();
633 schema.add_model(make_model("User"));
634 schema.add_enum(make_enum("Role", &["User"]));
635
636 let display = format!("{}", schema);
637 assert!(display.contains("1 models"));
638 assert!(display.contains("1 enums"));
639 assert!(display.contains("0 policies"));
640 }
641
642 #[test]
643 fn test_schema_equality() {
644 let schema1 = Schema::new();
645 let schema2 = Schema::new();
646 assert_eq!(schema1, schema2);
647 }
648
649 #[test]
650 fn test_schema_clone() {
651 let mut schema = Schema::new();
652 schema.add_model(make_model("User"));
653
654 let cloned = schema.clone();
655 assert_eq!(cloned.models.len(), 1);
656 }
657
658 #[test]
661 fn test_raw_sql_new() {
662 let sql = RawSql::new("create_users", "CREATE TABLE users ();");
663
664 assert_eq!(sql.name.as_str(), "create_users");
665 assert_eq!(sql.sql, "CREATE TABLE users ();");
666 }
667
668 #[test]
669 fn test_raw_sql_from_strings() {
670 let name = String::from("migration");
671 let content = String::from("ALTER TABLE users ADD COLUMN age INT;");
672 let sql = RawSql::new(name, content);
673
674 assert_eq!(sql.name.as_str(), "migration");
675 }
676
677 #[test]
678 fn test_raw_sql_equality() {
679 let sql1 = RawSql::new("test", "SELECT 1;");
680 let sql2 = RawSql::new("test", "SELECT 1;");
681 let sql3 = RawSql::new("test", "SELECT 2;");
682
683 assert_eq!(sql1, sql2);
684 assert_ne!(sql1, sql3);
685 }
686
687 #[test]
688 fn test_raw_sql_clone() {
689 let sql = RawSql::new("test", "SELECT 1;");
690 let cloned = sql.clone();
691 assert_eq!(sql, cloned);
692 }
693
694 #[test]
697 fn test_schema_stats_default() {
698 let stats = SchemaStats::default();
699 assert_eq!(stats.model_count, 0);
700 assert_eq!(stats.enum_count, 0);
701 assert_eq!(stats.type_count, 0);
702 assert_eq!(stats.view_count, 0);
703 assert_eq!(stats.policy_count, 0);
704 assert_eq!(stats.field_count, 0);
705 assert_eq!(stats.relation_count, 0);
706 }
707
708 #[test]
709 fn test_schema_stats_debug() {
710 let stats = SchemaStats::default();
711 let debug = format!("{:?}", stats);
712 assert!(debug.contains("SchemaStats"));
713 }
714
715 #[test]
716 fn test_schema_stats_clone() {
717 let stats = SchemaStats {
718 model_count: 5,
719 enum_count: 2,
720 type_count: 1,
721 view_count: 3,
722 server_group_count: 2,
723 policy_count: 4,
724 field_count: 25,
725 relation_count: 10,
726 };
727 let cloned = stats.clone();
728 assert_eq!(cloned.model_count, 5);
729 assert_eq!(cloned.field_count, 25);
730 assert_eq!(cloned.policy_count, 4);
731 }
732
733 #[test]
736 fn test_schema_add_policy() {
737 let mut schema = Schema::new();
738 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
739
740 schema.add_policy(policy);
741
742 assert_eq!(schema.policies.len(), 1);
743 }
744
745 #[test]
746 fn test_schema_get_policy() {
747 let mut schema = Schema::new();
748 schema.add_policy(Policy::new(
749 make_ident("read_own"),
750 make_ident("User"),
751 make_span(),
752 ));
753
754 let policy = schema.get_policy("read_own");
755 assert!(policy.is_some());
756 assert_eq!(policy.unwrap().name(), "read_own");
757
758 assert!(schema.get_policy("nonexistent").is_none());
759 }
760
761 #[test]
762 fn test_schema_policies_for_model() {
763 let mut schema = Schema::new();
764 schema.add_policy(Policy::new(
765 make_ident("user_read"),
766 make_ident("User"),
767 make_span(),
768 ));
769 schema.add_policy(Policy::new(
770 make_ident("user_write"),
771 make_ident("User"),
772 make_span(),
773 ));
774 schema.add_policy(Policy::new(
775 make_ident("post_read"),
776 make_ident("Post"),
777 make_span(),
778 ));
779
780 let user_policies = schema.policies_for("User");
781 assert_eq!(user_policies.len(), 2);
782
783 let post_policies = schema.policies_for("Post");
784 assert_eq!(post_policies.len(), 1);
785
786 let comment_policies = schema.policies_for("Comment");
787 assert!(comment_policies.is_empty());
788 }
789
790 #[test]
791 fn test_schema_has_policies() {
792 let mut schema = Schema::new();
793 schema.add_policy(Policy::new(
794 make_ident("test"),
795 make_ident("User"),
796 make_span(),
797 ));
798
799 assert!(schema.has_policies("User"));
800 assert!(!schema.has_policies("Post"));
801 }
802
803 #[test]
804 fn test_schema_policy_names() {
805 let mut schema = Schema::new();
806 schema.add_policy(Policy::new(
807 make_ident("policy1"),
808 make_ident("User"),
809 make_span(),
810 ));
811 schema.add_policy(Policy::new(
812 make_ident("policy2"),
813 make_ident("Post"),
814 make_span(),
815 ));
816
817 let names: Vec<_> = schema.policy_names().collect();
818 assert_eq!(names.len(), 2);
819 assert!(names.contains(&"policy1"));
820 assert!(names.contains(&"policy2"));
821 }
822
823 #[test]
824 fn test_schema_merge_with_policies() {
825 let mut schema1 = Schema::new();
826 schema1.add_policy(Policy::new(
827 make_ident("policy1"),
828 make_ident("User"),
829 make_span(),
830 ));
831
832 let mut schema2 = Schema::new();
833 schema2.add_policy(Policy::new(
834 make_ident("policy2"),
835 make_ident("Post"),
836 make_span(),
837 ));
838
839 schema1.merge(schema2);
840
841 assert_eq!(schema1.policies.len(), 2);
842 }
843
844 #[test]
845 fn test_schema_stats_with_policies() {
846 let mut schema = Schema::new();
847 schema.add_model(make_model("User"));
848 schema.add_policy(Policy::new(
849 make_ident("policy1"),
850 make_ident("User"),
851 make_span(),
852 ));
853 schema.add_policy(Policy::new(
854 make_ident("policy2"),
855 make_ident("User"),
856 make_span(),
857 ));
858
859 let stats = schema.stats();
860 assert_eq!(stats.model_count, 1);
861 assert_eq!(stats.policy_count, 2);
862 }
863}