1use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5use smol_str::SmolStr;
6
7use super::{CompositeType, Enum, Model, Relation, ServerGroup, View};
8
9#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
11pub struct Schema {
12 pub models: IndexMap<SmolStr, Model>,
14 pub enums: IndexMap<SmolStr, Enum>,
16 pub types: IndexMap<SmolStr, CompositeType>,
18 pub views: IndexMap<SmolStr, View>,
20 pub server_groups: IndexMap<SmolStr, ServerGroup>,
22 pub raw_sql: Vec<RawSql>,
24 pub relations: Vec<Relation>,
26}
27
28impl Schema {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn add_model(&mut self, model: Model) {
36 self.models.insert(model.name.name.clone(), model);
37 }
38
39 pub fn add_enum(&mut self, e: Enum) {
41 self.enums.insert(e.name.name.clone(), e);
42 }
43
44 pub fn add_type(&mut self, t: CompositeType) {
46 self.types.insert(t.name.name.clone(), t);
47 }
48
49 pub fn add_view(&mut self, v: View) {
51 self.views.insert(v.name.name.clone(), v);
52 }
53
54 pub fn add_server_group(&mut self, sg: ServerGroup) {
56 self.server_groups.insert(sg.name.name.clone(), sg);
57 }
58
59 pub fn add_raw_sql(&mut self, sql: RawSql) {
61 self.raw_sql.push(sql);
62 }
63
64 pub fn get_model(&self, name: &str) -> Option<&Model> {
66 self.models.get(name)
67 }
68
69 pub fn get_model_mut(&mut self, name: &str) -> Option<&mut Model> {
71 self.models.get_mut(name)
72 }
73
74 pub fn get_enum(&self, name: &str) -> Option<&Enum> {
76 self.enums.get(name)
77 }
78
79 pub fn get_type(&self, name: &str) -> Option<&CompositeType> {
81 self.types.get(name)
82 }
83
84 pub fn get_view(&self, name: &str) -> Option<&View> {
86 self.views.get(name)
87 }
88
89 pub fn get_server_group(&self, name: &str) -> Option<&ServerGroup> {
91 self.server_groups.get(name)
92 }
93
94 pub fn server_group_names(&self) -> impl Iterator<Item = &str> {
96 self.server_groups.keys().map(|s| s.as_str())
97 }
98
99 pub fn type_exists(&self, name: &str) -> bool {
101 self.models.contains_key(name)
102 || self.enums.contains_key(name)
103 || self.types.contains_key(name)
104 || self.views.contains_key(name)
105 }
106
107 pub fn model_names(&self) -> impl Iterator<Item = &str> {
109 self.models.keys().map(|s| s.as_str())
110 }
111
112 pub fn enum_names(&self) -> impl Iterator<Item = &str> {
114 self.enums.keys().map(|s| s.as_str())
115 }
116
117 pub fn relations_for(&self, model: &str) -> Vec<&Relation> {
119 self.relations
120 .iter()
121 .filter(|r| r.from_model == model || r.to_model == model)
122 .collect()
123 }
124
125 pub fn relations_from(&self, model: &str) -> Vec<&Relation> {
127 self.relations
128 .iter()
129 .filter(|r| r.from_model == model)
130 .collect()
131 }
132
133 pub fn merge(&mut self, other: Schema) {
135 self.models.extend(other.models);
136 self.enums.extend(other.enums);
137 self.types.extend(other.types);
138 self.views.extend(other.views);
139 self.server_groups.extend(other.server_groups);
140 self.raw_sql.extend(other.raw_sql);
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub struct RawSql {
147 pub name: SmolStr,
149 pub sql: String,
151}
152
153impl RawSql {
154 pub fn new(name: impl Into<SmolStr>, sql: impl Into<String>) -> Self {
156 Self {
157 name: name.into(),
158 sql: sql.into(),
159 }
160 }
161}
162
163#[derive(Debug, Clone, Default)]
165pub struct SchemaStats {
166 pub model_count: usize,
168 pub enum_count: usize,
170 pub type_count: usize,
172 pub view_count: usize,
174 pub server_group_count: usize,
176 pub field_count: usize,
178 pub relation_count: usize,
180}
181
182impl Schema {
183 pub fn stats(&self) -> SchemaStats {
185 SchemaStats {
186 model_count: self.models.len(),
187 enum_count: self.enums.len(),
188 type_count: self.types.len(),
189 view_count: self.views.len(),
190 server_group_count: self.server_groups.len(),
191 field_count: self.models.values().map(|m| m.fields.len()).sum(),
192 relation_count: self.relations.len(),
193 }
194 }
195}
196
197impl std::fmt::Display for Schema {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 let stats = self.stats();
200 write!(
201 f,
202 "Schema({} models, {} enums, {} types, {} views, {} server groups, {} fields, {} relations)",
203 stats.model_count,
204 stats.enum_count,
205 stats.type_count,
206 stats.view_count,
207 stats.server_group_count,
208 stats.field_count,
209 stats.relation_count
210 )
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::ast::{
218 Attribute, EnumVariant, Field, FieldType, Ident, RelationType, ScalarType, Span,
219 TypeModifier,
220 };
221
222 fn make_span() -> Span {
223 Span::new(0, 10)
224 }
225
226 fn make_ident(name: &str) -> Ident {
227 Ident::new(name, make_span())
228 }
229
230 fn make_model(name: &str) -> Model {
231 let mut model = Model::new(make_ident(name), make_span());
232 let id_field = make_id_field();
233 model.add_field(id_field);
234 model
235 }
236
237 fn make_id_field() -> Field {
238 let mut field = Field::new(
239 make_ident("id"),
240 FieldType::Scalar(ScalarType::Int),
241 TypeModifier::Required,
242 vec![],
243 make_span(),
244 );
245 field
246 .attributes
247 .push(Attribute::simple(make_ident("id"), make_span()));
248 field
249 }
250
251 fn make_field(name: &str, field_type: FieldType) -> Field {
252 Field::new(
253 make_ident(name),
254 field_type,
255 TypeModifier::Required,
256 vec![],
257 make_span(),
258 )
259 }
260
261 fn make_enum(name: &str, variants: &[&str]) -> Enum {
262 let mut e = Enum::new(make_ident(name), make_span());
263 for v in variants {
264 e.add_variant(EnumVariant::new(make_ident(v), make_span()));
265 }
266 e
267 }
268
269 #[test]
272 fn test_schema_new() {
273 let schema = Schema::new();
274 assert!(schema.models.is_empty());
275 assert!(schema.enums.is_empty());
276 assert!(schema.types.is_empty());
277 assert!(schema.views.is_empty());
278 assert!(schema.raw_sql.is_empty());
279 assert!(schema.relations.is_empty());
280 }
281
282 #[test]
283 fn test_schema_default() {
284 let schema = Schema::default();
285 assert!(schema.models.is_empty());
286 }
287
288 #[test]
289 fn test_schema_add_model() {
290 let mut schema = Schema::new();
291 let model = make_model("User");
292
293 schema.add_model(model);
294
295 assert_eq!(schema.models.len(), 1);
296 assert!(schema.models.contains_key("User"));
297 }
298
299 #[test]
300 fn test_schema_add_multiple_models() {
301 let mut schema = Schema::new();
302 schema.add_model(make_model("User"));
303 schema.add_model(make_model("Post"));
304 schema.add_model(make_model("Comment"));
305
306 assert_eq!(schema.models.len(), 3);
307 }
308
309 #[test]
310 fn test_schema_add_enum() {
311 let mut schema = Schema::new();
312 let e = make_enum("Role", &["User", "Admin"]);
313
314 schema.add_enum(e);
315
316 assert_eq!(schema.enums.len(), 1);
317 assert!(schema.enums.contains_key("Role"));
318 }
319
320 #[test]
321 fn test_schema_add_type() {
322 let mut schema = Schema::new();
323 let ct = CompositeType::new(make_ident("Address"), make_span());
324
325 schema.add_type(ct);
326
327 assert_eq!(schema.types.len(), 1);
328 assert!(schema.types.contains_key("Address"));
329 }
330
331 #[test]
332 fn test_schema_add_view() {
333 let mut schema = Schema::new();
334 let view = View::new(make_ident("UserStats"), make_span());
335
336 schema.add_view(view);
337
338 assert_eq!(schema.views.len(), 1);
339 assert!(schema.views.contains_key("UserStats"));
340 }
341
342 #[test]
343 fn test_schema_add_raw_sql() {
344 let mut schema = Schema::new();
345 let sql = RawSql::new("migration_1", "CREATE TABLE test ();");
346
347 schema.add_raw_sql(sql);
348
349 assert_eq!(schema.raw_sql.len(), 1);
350 }
351
352 #[test]
353 fn test_schema_get_model() {
354 let mut schema = Schema::new();
355 schema.add_model(make_model("User"));
356
357 let model = schema.get_model("User");
358 assert!(model.is_some());
359 assert_eq!(model.unwrap().name(), "User");
360
361 assert!(schema.get_model("NonExistent").is_none());
362 }
363
364 #[test]
365 fn test_schema_get_model_mut() {
366 let mut schema = Schema::new();
367 schema.add_model(make_model("User"));
368
369 let model = schema.get_model_mut("User");
370 assert!(model.is_some());
371
372 let model = model.unwrap();
374 model.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
375
376 assert_eq!(schema.get_model("User").unwrap().fields.len(), 2);
378 }
379
380 #[test]
381 fn test_schema_get_enum() {
382 let mut schema = Schema::new();
383 schema.add_enum(make_enum("Role", &["User", "Admin"]));
384
385 let e = schema.get_enum("Role");
386 assert!(e.is_some());
387 assert_eq!(e.unwrap().name(), "Role");
388
389 assert!(schema.get_enum("NonExistent").is_none());
390 }
391
392 #[test]
393 fn test_schema_get_type() {
394 let mut schema = Schema::new();
395 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
396
397 let ct = schema.get_type("Address");
398 assert!(ct.is_some());
399
400 assert!(schema.get_type("NonExistent").is_none());
401 }
402
403 #[test]
404 fn test_schema_get_view() {
405 let mut schema = Schema::new();
406 schema.add_view(View::new(make_ident("Stats"), make_span()));
407
408 let v = schema.get_view("Stats");
409 assert!(v.is_some());
410
411 assert!(schema.get_view("NonExistent").is_none());
412 }
413
414 #[test]
415 fn test_schema_type_exists() {
416 let mut schema = Schema::new();
417 schema.add_model(make_model("User"));
418 schema.add_enum(make_enum("Role", &["User"]));
419 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
420 schema.add_view(View::new(make_ident("Stats"), make_span()));
421
422 assert!(schema.type_exists("User")); assert!(schema.type_exists("Role")); assert!(schema.type_exists("Address")); assert!(schema.type_exists("Stats")); assert!(!schema.type_exists("NonExistent"));
427 }
428
429 #[test]
430 fn test_schema_model_names() {
431 let mut schema = Schema::new();
432 schema.add_model(make_model("User"));
433 schema.add_model(make_model("Post"));
434
435 let names: Vec<_> = schema.model_names().collect();
436 assert_eq!(names.len(), 2);
437 assert!(names.contains(&"User"));
438 assert!(names.contains(&"Post"));
439 }
440
441 #[test]
442 fn test_schema_enum_names() {
443 let mut schema = Schema::new();
444 schema.add_enum(make_enum("Role", &["User"]));
445 schema.add_enum(make_enum("Status", &["Active"]));
446
447 let names: Vec<_> = schema.enum_names().collect();
448 assert_eq!(names.len(), 2);
449 assert!(names.contains(&"Role"));
450 assert!(names.contains(&"Status"));
451 }
452
453 #[test]
454 fn test_schema_relations_for() {
455 let mut schema = Schema::new();
456 schema.relations.push(Relation::new(
457 "Post",
458 "author",
459 "User",
460 RelationType::ManyToOne,
461 ));
462 schema.relations.push(Relation::new(
463 "Comment",
464 "user",
465 "User",
466 RelationType::ManyToOne,
467 ));
468 schema.relations.push(Relation::new(
469 "Post",
470 "tags",
471 "Tag",
472 RelationType::ManyToMany,
473 ));
474
475 let user_relations = schema.relations_for("User");
476 assert_eq!(user_relations.len(), 2);
477
478 let post_relations = schema.relations_for("Post");
479 assert_eq!(post_relations.len(), 2);
480
481 let tag_relations = schema.relations_for("Tag");
482 assert_eq!(tag_relations.len(), 1);
483 }
484
485 #[test]
486 fn test_schema_relations_from() {
487 let mut schema = Schema::new();
488 schema.relations.push(Relation::new(
489 "Post",
490 "author",
491 "User",
492 RelationType::ManyToOne,
493 ));
494 schema.relations.push(Relation::new(
495 "Post",
496 "tags",
497 "Tag",
498 RelationType::ManyToMany,
499 ));
500 schema.relations.push(Relation::new(
501 "User",
502 "posts",
503 "Post",
504 RelationType::OneToMany,
505 ));
506
507 let post_relations = schema.relations_from("Post");
508 assert_eq!(post_relations.len(), 2);
509
510 let user_relations = schema.relations_from("User");
511 assert_eq!(user_relations.len(), 1);
512
513 let tag_relations = schema.relations_from("Tag");
514 assert_eq!(tag_relations.len(), 0);
515 }
516
517 #[test]
518 fn test_schema_merge() {
519 let mut schema1 = Schema::new();
520 schema1.add_model(make_model("User"));
521 schema1.add_enum(make_enum("Role", &["User"]));
522
523 let mut schema2 = Schema::new();
524 schema2.add_model(make_model("Post"));
525 schema2.add_enum(make_enum("Status", &["Active"]));
526 schema2.add_raw_sql(RawSql::new("init", "-- init"));
527
528 schema1.merge(schema2);
529
530 assert_eq!(schema1.models.len(), 2);
531 assert_eq!(schema1.enums.len(), 2);
532 assert_eq!(schema1.raw_sql.len(), 1);
533 }
534
535 #[test]
536 fn test_schema_stats() {
537 let mut schema = Schema::new();
538
539 let mut user = make_model("User");
540 user.add_field(make_field("email", FieldType::Scalar(ScalarType::String)));
541 user.add_field(make_field("name", FieldType::Scalar(ScalarType::String)));
542 schema.add_model(user);
543
544 let mut post = make_model("Post");
545 post.add_field(make_field("title", FieldType::Scalar(ScalarType::String)));
546 schema.add_model(post);
547
548 schema.add_enum(make_enum("Role", &["User", "Admin"]));
549 schema.add_type(CompositeType::new(make_ident("Address"), make_span()));
550 schema.add_view(View::new(make_ident("Stats"), make_span()));
551 schema.relations.push(Relation::new(
552 "Post",
553 "author",
554 "User",
555 RelationType::ManyToOne,
556 ));
557
558 let stats = schema.stats();
559 assert_eq!(stats.model_count, 2);
560 assert_eq!(stats.enum_count, 1);
561 assert_eq!(stats.type_count, 1);
562 assert_eq!(stats.view_count, 1);
563 assert_eq!(stats.field_count, 5); assert_eq!(stats.relation_count, 1);
565 }
566
567 #[test]
568 fn test_schema_display() {
569 let mut schema = Schema::new();
570 schema.add_model(make_model("User"));
571 schema.add_enum(make_enum("Role", &["User"]));
572
573 let display = format!("{}", schema);
574 assert!(display.contains("1 models"));
575 assert!(display.contains("1 enums"));
576 }
577
578 #[test]
579 fn test_schema_equality() {
580 let schema1 = Schema::new();
581 let schema2 = Schema::new();
582 assert_eq!(schema1, schema2);
583 }
584
585 #[test]
586 fn test_schema_clone() {
587 let mut schema = Schema::new();
588 schema.add_model(make_model("User"));
589
590 let cloned = schema.clone();
591 assert_eq!(cloned.models.len(), 1);
592 }
593
594 #[test]
597 fn test_raw_sql_new() {
598 let sql = RawSql::new("create_users", "CREATE TABLE users ();");
599
600 assert_eq!(sql.name.as_str(), "create_users");
601 assert_eq!(sql.sql, "CREATE TABLE users ();");
602 }
603
604 #[test]
605 fn test_raw_sql_from_strings() {
606 let name = String::from("migration");
607 let content = String::from("ALTER TABLE users ADD COLUMN age INT;");
608 let sql = RawSql::new(name, content);
609
610 assert_eq!(sql.name.as_str(), "migration");
611 }
612
613 #[test]
614 fn test_raw_sql_equality() {
615 let sql1 = RawSql::new("test", "SELECT 1;");
616 let sql2 = RawSql::new("test", "SELECT 1;");
617 let sql3 = RawSql::new("test", "SELECT 2;");
618
619 assert_eq!(sql1, sql2);
620 assert_ne!(sql1, sql3);
621 }
622
623 #[test]
624 fn test_raw_sql_clone() {
625 let sql = RawSql::new("test", "SELECT 1;");
626 let cloned = sql.clone();
627 assert_eq!(sql, cloned);
628 }
629
630 #[test]
633 fn test_schema_stats_default() {
634 let stats = SchemaStats::default();
635 assert_eq!(stats.model_count, 0);
636 assert_eq!(stats.enum_count, 0);
637 assert_eq!(stats.type_count, 0);
638 assert_eq!(stats.view_count, 0);
639 assert_eq!(stats.field_count, 0);
640 assert_eq!(stats.relation_count, 0);
641 }
642
643 #[test]
644 fn test_schema_stats_debug() {
645 let stats = SchemaStats::default();
646 let debug = format!("{:?}", stats);
647 assert!(debug.contains("SchemaStats"));
648 }
649
650 #[test]
651 fn test_schema_stats_clone() {
652 let stats = SchemaStats {
653 model_count: 5,
654 enum_count: 2,
655 type_count: 1,
656 view_count: 3,
657 server_group_count: 2,
658 field_count: 25,
659 relation_count: 10,
660 };
661 let cloned = stats.clone();
662 assert_eq!(cloned.model_count, 5);
663 assert_eq!(cloned.field_count, 25);
664 }
665}