1#[cfg(target_arch = "x86_64")]
2#[unsafe(no_mangle)]
3pub extern "C" fn __rust_probestack() {}
4
5pub mod backends;
6pub mod config;
7pub mod frontend;
8pub mod ir;
9pub mod lint;
10pub mod passes;
11pub mod prisma;
12pub mod provider;
13pub mod test_runner;
14
15use anyhow::Result;
16use std::path::Path;
18
19use crate::frontend::env::EnvVars;
21pub use ir::{
22 AggregateSpec, CollationSpec, CompositeTypeSpec, Config, DomainSpec, EnumSpec,
23 EventTriggerSpec, ExtensionSpec, FunctionSpec, GrantSpec, MaterializedViewSpec, OutputSpec,
24 PolicySpec, ProcedureSpec, RoleSpec, SchemaSpec, SequenceSpec, TableSpec, TablespaceSpec,
25 TriggerSpec, ViewSpec,
26};
27
28pub trait Loader {
30 fn load(&self, path: &Path) -> Result<String>;
31}
32
33pub fn load_config(root_path: &Path, loader: &dyn Loader, env: EnvVars) -> Result<Config> {
35 frontend::load_root_with_loader(root_path, loader, env)
36}
37
38pub fn validate(cfg: &Config, strict: bool) -> Result<()> {
40 passes::validate(cfg, strict)
41}
42
43fn filter_config_with<F>(cfg: &Config, predicate: F) -> Config
44where
45 F: Fn(crate::config::ResourceKind) -> bool,
46{
47 use crate::config::ResourceKind as R;
48
49 macro_rules! maybe {
50 ($kind:ident, $field:ident) => {
51 predicate(R::$kind)
52 .then(|| cfg.$field.clone())
53 .unwrap_or_default()
54 };
55 }
56
57 Config {
58 functions: maybe!(Functions, functions),
59 procedures: maybe!(Procedures, procedures),
60 aggregates: maybe!(Aggregates, aggregates),
61 operators: maybe!(Operators, operators),
62 triggers: maybe!(Triggers, triggers),
63 rules: maybe!(Rules, rules),
64 event_triggers: maybe!(EventTriggers, event_triggers),
65 extensions: maybe!(Extensions, extensions),
66 collations: maybe!(Collations, collations),
67 sequences: maybe!(Sequences, sequences),
68 schemas: maybe!(Schemas, schemas),
69 enums: maybe!(Enums, enums),
70 domains: maybe!(Domains, domains),
71 types: maybe!(Types, types),
72 tables: maybe!(Tables, tables),
73 indexes: maybe!(Indexes, indexes),
74 statistics: maybe!(Statistics, statistics),
75 views: maybe!(Views, views),
76 materialized: maybe!(Materialized, materialized),
77 policies: maybe!(Policies, policies),
78 roles: maybe!(Roles, roles),
79 tablespaces: maybe!(Tablespaces, tablespaces),
80 grants: maybe!(Grants, grants),
81 foreign_data_wrappers: maybe!(ForeignDataWrappers, foreign_data_wrappers),
82 foreign_servers: maybe!(ForeignServers, foreign_servers),
83 foreign_tables: maybe!(ForeignTables, foreign_tables),
84 text_search_dictionaries: maybe!(TextSearchDictionaries, text_search_dictionaries),
85 text_search_configurations: maybe!(TextSearchConfigurations, text_search_configurations),
86 text_search_templates: maybe!(TextSearchTemplates, text_search_templates),
87 text_search_parsers: maybe!(TextSearchParsers, text_search_parsers),
88 publications: maybe!(Publications, publications),
89 subscriptions: maybe!(Subscriptions, subscriptions),
90 tests: maybe!(Tests, tests),
91 outputs: cfg.outputs.clone(),
92 ..Default::default()
93 }
94}
95
96pub fn apply_filters(
98 cfg: &Config,
99 include: &std::collections::HashSet<crate::config::ResourceKind>,
100 exclude: &std::collections::HashSet<crate::config::ResourceKind>,
101) -> Config {
102 filter_config_with(cfg, |kind| {
103 include.contains(&kind) && !exclude.contains(&kind)
104 })
105}
106
107pub fn apply_resource_filters(cfg: &Config, include: &[String], exclude: &[String]) -> Config {
109 use std::collections::HashSet;
110
111 let include_all = include.is_empty() && exclude.is_empty();
113
114 let include_set: HashSet<String> = include.iter().cloned().collect();
116 let exclude_set: HashSet<String> = exclude.iter().cloned().collect();
117
118 filter_config_with(cfg, |kind| {
119 let resource_type = kind.to_string();
120 if include_all {
121 true
122 } else if !include_set.is_empty() {
123 include_set.contains(resource_type.as_str())
124 } else {
125 !exclude_set.contains(resource_type.as_str())
126 }
127 })
128}
129
130pub fn generate_with_backend(backend: &str, cfg: &Config, strict: bool) -> Result<String> {
131 let be = backends::get_backend(backend)
132 .ok_or_else(|| anyhow::anyhow!(format!("unknown backend '{backend}'")))?;
133 be.generate(cfg, strict)
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::frontend::env::EnvVars;
140 use std::collections::HashMap;
141 use std::path::PathBuf;
142
143 struct MapLoader {
144 files: HashMap<PathBuf, String>,
145 }
146 impl Loader for MapLoader {
147 fn load(&self, path: &Path) -> Result<String> {
148 self.files
149 .get(path)
150 .cloned()
151 .ok_or_else(|| anyhow::anyhow!("missing file: {}", path.display()))
152 }
153 }
154
155 fn p(s: &str) -> PathBuf {
156 PathBuf::from(s)
157 }
158
159 #[test]
160 fn parse_simple_function_and_trigger() {
161 let mut files = HashMap::new();
162 files.insert(
163 p("/root/main.hcl"),
164 r#"
165 function "set_updated_at" {
166 schema = "public"
167 language = "plpgsql"
168 returns = "trigger"
169 body = <<-SQL
170 BEGIN
171 NEW.updated_at = now();
172 RETURN NEW;
173 END;
174 SQL
175 }
176 trigger "users_upd" {
177 schema = "public"
178 table = "users"
179 function = "set_updated_at"
180 events = ["UPDATE"]
181 }
182 "#
183 .to_string(),
184 );
185
186 let loader = MapLoader { files };
187 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
188 assert_eq!(cfg.functions.len(), 1);
189 assert_eq!(cfg.triggers.len(), 1);
190 validate(&cfg, false).unwrap();
191 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
192 assert!(sql.contains("CREATE OR REPLACE FUNCTION \"public\".\"set_updated_at\""));
193 assert!(sql.contains("CREATE TRIGGER \"users_upd\""));
194 }
195
196 #[test]
197 fn parse_simple_event_trigger() {
198 let mut files = HashMap::new();
199 files.insert(
200 p("/root/main.hcl"),
201 r#"
202 function "ddl_logger" {
203 language = "plpgsql"
204 returns = "event_trigger"
205 body = "BEGIN END;"
206 }
207 event_trigger "log_ddl" {
208 event = "ddl_command_start"
209 function = "ddl_logger"
210 }
211 "#
212 .to_string(),
213 );
214
215 let loader = MapLoader { files };
216 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
217 assert_eq!(cfg.event_triggers.len(), 1);
218 validate(&cfg, false).unwrap();
219 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
220 assert!(sql.contains("CREATE EVENT TRIGGER"));
221 }
222
223 #[test]
224 fn parse_with_module_and_vars() {
225 let mut files = HashMap::new();
226 files.insert(
227 p("/root/main.hcl"),
228 r#"
229 variable "schema" { default = "public" }
230 variable "tables" { default = ["users", "orders"] }
231 module "mod1" {
232 source = "/root/mod"
233 schema = var.schema
234 for_each = var.tables
235 table = each.value
236 }
237 "#
238 .to_string(),
239 );
240 files.insert(
241 p("/root/mod/main.hcl"),
242 r#"
243 variable "schema" { default = "public" }
244 variable "table" {}
245 function "f" {
246 schema = var.schema
247 language = "plpgsql"
248 returns = "trigger"
249 body = "BEGIN NEW.updated_at = now(); RETURN NEW; END;"
250 }
251 trigger "t" {
252 schema = var.schema
253 table = var.table
254 function = "f"
255 events = ["INSERT"]
256 }
257 "#
258 .to_string(),
259 );
260
261 let loader = MapLoader { files };
262 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
263 assert!(cfg.functions.len() >= 1);
264 assert_eq!(cfg.triggers.len(), 2);
265 validate(&cfg, false).unwrap();
266 }
267
268 #[test]
269 fn module_outputs_resolved_and_printed() {
270 let mut files = HashMap::new();
271 files.insert(
272 p("/root/main.hcl"),
273 r#"
274 module "child" { source = "/root/mod" }
275 output "answer" { value = module.child.value }
276 "#
277 .to_string(),
278 );
279 files.insert(
280 p("/root/mod/main.hcl"),
281 r#"
282 output "value" { value = 42 }
283 "#
284 .to_string(),
285 );
286
287 let loader = MapLoader { files };
288 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
289 assert_eq!(cfg.outputs.len(), 1);
290 assert_eq!(cfg.outputs[0].name, "answer");
291 assert_eq!(
292 cfg.outputs[0].value,
293 hcl::Value::Number(hcl::Number::from(42))
294 );
295 }
296
297 #[test]
298 fn module_for_each_can_use_data_sources() {
299 let mut files = HashMap::new();
300 files.insert(
301 p("/root/main.hcl"),
302 r#"
303 data "prisma_schema" "app" {
304 file = "/root/schema.prisma"
305 }
306
307 module "mirror" {
308 source = "/root/mod"
309 for_each = data.prisma_schema.app.models
310 name = each.key
311 }
312 "#
313 .to_string(),
314 );
315 files.insert(
316 p("/root/mod/main.hcl"),
317 r#"
318 variable "name" {}
319
320 table "mirror" {
321 comment = var.name
322
323 column "id" {
324 type = "text"
325 }
326 }
327 "#
328 .to_string(),
329 );
330 files.insert(
331 p("/root/schema.prisma"),
332 r#"
333 model User {
334 id Int @id
335 }
336
337 model Post {
338 id Int @id
339 }
340 "#
341 .to_string(),
342 );
343
344 let loader = MapLoader { files };
345 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
346
347 assert_eq!(cfg.tables.len(), 2);
348 let mut comments: Vec<Option<String>> =
349 cfg.tables.iter().map(|t| t.comment.clone()).collect();
350 comments.sort();
351 assert_eq!(
352 comments,
353 vec![Some("Post".to_string()), Some("User".to_string())]
354 );
355 }
356
357 #[test]
358 fn variable_type_and_validation() {
359 use hcl::Value;
360 let mut files = HashMap::new();
361 files.insert(
362 p("/root/main.hcl"),
363 r#"
364 variable "count" {
365 type = "number"
366 validation {
367 condition = var.count > 0
368 error_message = "count must be > 0"
369 }
370 }
371 "#
372 .to_string(),
373 );
374 let loader = MapLoader { files };
375
376 let env = EnvVars {
378 vars: HashMap::from([("count".into(), Value::String("x".into()))]),
379 ..EnvVars::default()
380 };
381 let err = load_config(&p("/root/main.hcl"), &loader, env).unwrap_err();
382 assert!(err.to_string().contains("expected type number"));
383
384 let env = EnvVars {
386 vars: HashMap::from([("count".into(), Value::from(0))]),
387 ..EnvVars::default()
388 };
389 let err = load_config(&p("/root/main.hcl"), &loader, env).unwrap_err();
390 assert!(err.to_string().contains("count must be > 0"));
391
392 let env = EnvVars {
394 vars: HashMap::from([("count".into(), Value::from(2))]),
395 ..EnvVars::default()
396 };
397 load_config(&p("/root/main.hcl"), &loader, env).unwrap();
398 }
399
400 #[test]
401 fn for_each_array_in_trigger_uses_each_value() {
402 let mut files = HashMap::new();
403 files.insert(
404 p("/root/main.hcl"),
405 r#"
406 variable "schema" { default = "public" }
407 variable "tables" { default = ["users", "orders"] }
408
409 function "f" {
410 schema = var.schema
411 language = "plpgsql"
412 returns = "trigger"
413 body = "BEGIN RETURN NEW; END;"
414 }
415
416 trigger "upd" {
417 schema = var.schema
418 for_each = var.tables
419 table = each.value
420 function = "f"
421 events = ["UPDATE"]
422 }
423 "#
424 .to_string(),
425 );
426 let loader = MapLoader { files };
427 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
428 assert_eq!(cfg.triggers.len(), 2);
429 validate(&cfg, false).unwrap();
430 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
431 assert!(sql.contains("\"users\""));
432 assert!(sql.contains("\"orders\""));
433 }
434
435 #[test]
436 fn count_creates_multiple_triggers() {
437 let mut files = HashMap::new();
438 files.insert(
439 p("/root/main.hcl"),
440 r#"
441 function "f" {
442 schema = "public"
443 language = "plpgsql"
444 returns = "trigger"
445 body = "BEGIN RETURN NEW; END;"
446 }
447
448 trigger "t" {
449 count = 2
450 schema = "public"
451 table = "users"
452 function = "f"
453 events = ["INSERT"]
454 name = "t_${count.index}"
455 }
456 "#
457 .to_string(),
458 );
459 let loader = MapLoader { files };
460 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
461 assert_eq!(cfg.triggers.len(), 2);
462 }
463
464 #[test]
465 fn dynamic_block_expands_columns() {
466 let mut files = HashMap::new();
467 files.insert(
468 p("/root/main.hcl"),
469 r#"
470 variable "cols" {
471 default = {
472 id = { type = "serial", nullable = false },
473 name = { type = "text", nullable = true }
474 }
475 }
476
477 table "users" {
478 dynamic "column" {
479 for_each = var.cols
480 labels = [each.key]
481 content {
482 type = each.value.type
483 nullable = each.value.nullable
484 }
485 }
486 }
487 "#
488 .to_string(),
489 );
490 let loader = MapLoader { files };
491 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
492 assert_eq!(cfg.tables.len(), 1);
493 let cols = &cfg.tables[0].columns;
494 assert_eq!(cols.len(), 2);
495 assert_eq!(cols[0].name, "id");
496 assert!(!cols[0].nullable);
497 assert_eq!(cols[1].name, "name");
498 assert!(cols[1].nullable);
499 }
500
501 #[test]
502 fn data_prisma_schema_exposes_models_and_enums() {
503 let mut files = HashMap::new();
504 files.insert(
505 p("/root/main.hcl"),
506 r#"
507 data "prisma_schema" "app" {
508 file = "/root/schema.prisma"
509 }
510
511 table "audit_log" {
512 schema = "public"
513
514 column "user_id" {
515 type = "bigint"
516 nullable = data.prisma_schema.app.models.User.fields.id.type.optional
517 comment = data.prisma_schema.app.models.User.fields.id.type.name
518 }
519
520 column "middle_name" {
521 type = "text"
522 nullable = data.prisma_schema.app.models.User.fields.middleName.type.optional
523 comment = data.prisma_schema.app.models.User.fields.middleName.type.name
524 }
525
526 column "status" {
527 type = "text"
528 comment = data.prisma_schema.app.enums.Status.values.ACTIVE.name
529 }
530
531 column "inactive_label" {
532 type = "text"
533 comment = data.prisma_schema.app.enums.Status.values.INACTIVE.mapped_name
534 }
535 }
536 "#
537 .to_string(),
538 );
539 files.insert(
540 p("/root/schema.prisma"),
541 r#"
542 model User {
543 id Int @id @default(autoincrement())
544 email String @unique
545 middleName String?
546 status Status @default(ACTIVE)
547 }
548
549 enum Status {
550 ACTIVE
551 INACTIVE @map("inactive")
552 }
553 "#
554 .to_string(),
555 );
556
557 let loader = MapLoader { files };
558 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
559 assert_eq!(cfg.tables.len(), 1);
560 let table = &cfg.tables[0];
561 assert_eq!(table.columns.len(), 4);
562 assert_eq!(table.columns[0].nullable, false);
563 assert_eq!(table.columns[0].comment.as_deref(), Some("Int"));
564 assert_eq!(table.columns[1].nullable, true);
565 assert_eq!(table.columns[1].comment.as_deref(), Some("String"));
566 assert_eq!(table.columns[2].comment.as_deref(), Some("ACTIVE"));
567 assert_eq!(table.columns[3].comment.as_deref(), Some("inactive"));
568 }
569
570 #[test]
571 fn clone_prisma_table_with_dynamic_columns() {
572 let mut files = HashMap::new();
573 files.insert(
574 p("/root/main.hcl"),
575 r#"
576 data "prisma_schema" "source" {
577 file = "/root/schema.prisma"
578 }
579
580 table "user_clone" {
581 schema = "public"
582
583 dynamic "column" {
584 for_each = data.prisma_schema.source.models.User.fields
585 labels = [each.key]
586
587 content {
588 type = each.value.type.name == "Int" ? "integer" : each.value.type.name == "String" ? "text" : each.value.type.name == "DateTime" ? "timestamptz" : "text"
589 nullable = each.value.type.optional
590 }
591 }
592
593 primary_key {
594 columns = ["id"]
595 }
596 }
597 "#
598 .to_string(),
599 );
600 files.insert(
601 p("/root/schema.prisma"),
602 r#"
603 model User {
604 id Int @id @default(autoincrement())
605 email String @unique
606 name String?
607 createdAt DateTime @default(now())
608 updatedAt DateTime @updatedAt
609 }
610 "#
611 .to_string(),
612 );
613
614 let loader = MapLoader { files };
615 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
616
617 assert_eq!(cfg.tables.len(), 1);
618 let table = &cfg.tables[0];
619 assert_eq!(table.name, "user_clone");
620
621 assert_eq!(table.columns.len(), 5);
623
624 let col_names: Vec<&str> = table.columns.iter().map(|c| c.name.as_str()).collect();
626 assert!(col_names.contains(&"id"));
627 assert!(col_names.contains(&"email"));
628 assert!(col_names.contains(&"name"));
629 assert!(col_names.contains(&"createdAt"));
630 assert!(col_names.contains(&"updatedAt"));
631
632 let id_col = table.columns.iter().find(|c| c.name == "id").unwrap();
634 assert_eq!(id_col.r#type, "integer");
635 assert!(!id_col.nullable); let email_col = table.columns.iter().find(|c| c.name == "email").unwrap();
638 assert_eq!(email_col.r#type, "text");
639 assert!(!email_col.nullable); let name_col = table.columns.iter().find(|c| c.name == "name").unwrap();
642 assert_eq!(name_col.r#type, "text");
643 assert!(name_col.nullable); let created_at_col = table
646 .columns
647 .iter()
648 .find(|c| c.name == "createdAt")
649 .unwrap();
650 assert_eq!(created_at_col.r#type, "timestamptz");
651 assert!(!created_at_col.nullable);
652
653 let updated_at_col = table
654 .columns
655 .iter()
656 .find(|c| c.name == "updatedAt")
657 .unwrap();
658 assert_eq!(updated_at_col.r#type, "timestamptz");
659 assert!(!updated_at_col.nullable);
660
661 assert!(table.primary_key.is_some());
663 let pk = table.primary_key.as_ref().unwrap();
664 assert_eq!(pk.columns.len(), 1);
665 assert_eq!(pk.columns[0], "id");
666
667 validate(&cfg, false).unwrap();
668 }
669
670 #[test]
671 fn parse_extension_and_generate_sql() {
672 let mut files = HashMap::new();
673 files.insert(
674 p("/root/main.hcl"),
675 r#"
676 extension "pgcrypto" {}
677 "#
678 .to_string(),
679 );
680 let loader = MapLoader { files };
681 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
682 assert_eq!(cfg.extensions.len(), 1);
683 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
684 assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"pgcrypto\";"));
685 }
686
687 #[test]
688 fn generate_json_backend() {
689 let mut files = HashMap::new();
690 files.insert(
691 p("/root/main.hcl"),
692 r#"
693 function "f" {
694 schema = "public"
695 language = "plpgsql"
696 returns = "trigger"
697 body = "BEGIN RETURN NEW; END;"
698 }
699 trigger "t" {
700 schema = "public"
701 table = "users"
702 timing = "BEFORE"
703 events = ["UPDATE"]
704 level = "ROW"
705 function = "f"
706 }
707 extension "pgcrypto" {}
708 "#
709 .to_string(),
710 );
711 let loader = MapLoader { files };
712 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
713 let json = crate::generate_with_backend("json", &cfg, false).unwrap();
714 assert!(json.contains("\"backend\": \"json\""));
715 assert!(json.contains("\"functions\""));
716 assert!(json.contains("\"triggers\""));
717 assert!(json.contains("\"extensions\""));
718 }
719
720 #[test]
721 fn parse_view_and_generate_sql_and_json() {
722 let mut files = HashMap::new();
723 files.insert(
724 p("/root/main.hcl"),
725 r#"
726 view "v_users" {
727 schema = "public"
728 replace = true
729 sql = "SELECT 1 as x"
730 }
731 "#
732 .to_string(),
733 );
734 let loader = MapLoader { files };
735 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
736 assert_eq!(cfg.views.len(), 1);
737 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
738 assert!(sql.contains("CREATE OR REPLACE VIEW \"public\".\"v_users\" AS"));
739 let json = crate::generate_with_backend("json", &cfg, false).unwrap();
740 assert!(json.contains("\"views\""));
741 }
742
743 #[test]
744 fn parse_materialized_and_generate_sql_and_json() {
745 let mut files = HashMap::new();
746 files.insert(
747 p("/root/main.hcl"),
748 r#"
749 materialized "mv" {
750 schema = "public"
751 with_data = false
752 sql = "SELECT 42 as x"
753 }
754 "#
755 .to_string(),
756 );
757 let loader = MapLoader { files };
758 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
759 assert_eq!(cfg.materialized.len(), 1);
760 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
761 assert!(sql.contains("CREATE MATERIALIZED VIEW \"public\".\"mv\" AS"));
762 assert!(sql.contains("WITH NO DATA"));
763 let json = crate::generate_with_backend("json", &cfg, false).unwrap();
764 assert!(json.contains("\"materialized\""));
765 }
766
767 #[test]
768 fn parse_enum_and_generate_sql_json_prisma() {
769 let mut files = HashMap::new();
770 files.insert(
771 p("/root/main.hcl"),
772 r#"
773 enum "status" { values = ["active", "disabled"] }
774 table "users" {
775 column "id" {
776 type = "serial"
777 nullable = false
778 }
779 column "status" {
780 type = "status"
781 nullable = false
782 }
783 primary_key { columns = ["id"] }
784 }
785 "#
786 .to_string(),
787 );
788 let loader = MapLoader { files };
789 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
790 assert_eq!(cfg.enums.len(), 1);
791 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
792 assert!(sql.contains("CREATE TYPE \"public\".\"status\" AS ENUM"));
793 let json = crate::generate_with_backend("json", &cfg, false).unwrap();
794 assert!(json.contains("\"enums\""));
795 let prisma = crate::generate_with_backend("prisma", &cfg, false).unwrap();
796 assert!(prisma.contains("enum status"));
797 assert!(prisma.contains("status status"));
798 }
799
800 #[test]
801 fn prisma_back_reference_relations_have_names() {
802 let mut files = HashMap::new();
803 files.insert(
804 p("/root/main.hcl"),
805 r#"
806 table "blob" {
807 schema = "public"
808 column "blobId" {
809 type = "text"
810 nullable = false
811 }
812 primary_key { columns = ["blobId"] }
813 }
814
815 table "commit" {
816 schema = "public"
817 column "commitId" {
818 type = "text"
819 nullable = false
820 }
821 column "blobId" {
822 type = "text"
823 nullable = false
824 }
825 primary_key { columns = ["commitId"] }
826 foreign_key {
827 name = "blob_fk"
828 columns = ["blobId"]
829 ref {
830 schema = "public"
831 table = "blob"
832 columns = ["blobId"]
833 }
834 back_reference_name = "commits"
835 on_delete = "RESTRICT"
836 }
837 }
838 "#
839 .to_string(),
840 );
841 let loader = MapLoader { files };
842 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
843 let prisma = crate::generate_with_backend("prisma", &cfg, false).unwrap();
844 assert!(
845 prisma.contains("commits Commit[] @relation(name: \"commits\")"),
846 "expected commits relation field to include relation name:\n{prisma}"
847 );
848 assert!(
849 prisma.contains(
850 "blob_fk Blob @relation(name: \"commits\", fields: [blobId], references: [blobId]"
851 ),
852 "expected blob relation field to reference relation name:\n{prisma}"
853 );
854 }
855
856 #[test]
857 fn parse_table_and_generate_sql() {
858 let mut files = HashMap::new();
859 files.insert(
860 p("/root/main.hcl"),
861 r#"
862 table "users" {
863 schema = "public"
864 column "id" {
865 type = "serial"
866 nullable = false
867 }
868 column "email" {
869 type = "text"
870 nullable = false
871 }
872 primary_key { columns = ["id"] }
873 unique "users_email_key" { columns = ["email"] }
874 }
875 "#
876 .to_string(),
877 );
878 let loader = MapLoader { files };
879 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
880 assert_eq!(cfg.tables.len(), 1);
881 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
882 assert!(sql.contains("CREATE TABLE IF NOT EXISTS \"public\".\"users\""));
883 assert!(sql.contains("CREATE UNIQUE INDEX IF NOT EXISTS \"users_email_key\" ON \"public\".\"users\" (\"email\");"));
884 }
885
886 #[test]
887 fn parse_policy_and_generate_sql_and_json() {
888 let mut files = HashMap::new();
889 files.insert(
890 p("/root/main.hcl"),
891 r#"
892 table "users" {
893 schema = "public"
894 column "id" {
895 type = "serial"
896 nullable = false
897 }
898 primary_key { columns = ["id"] }
899 }
900 policy "p_users_select" {
901 schema = "public"
902 table = "users"
903 as = "permissive"
904 command = "select"
905 roles = ["app_user"]
906 using = "true"
907 }
908 "#
909 .to_string(),
910 );
911 let loader = MapLoader { files };
912 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
913 assert_eq!(cfg.policies.len(), 1);
914 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
915 assert!(sql.contains("CREATE POLICY \"p_users_select\" ON \"public\".\"users\""));
916 let json = crate::generate_with_backend("json", &cfg, false).unwrap();
917 assert!(json.contains("\"policies\""));
918 assert!(json.contains("\"p_users_select\""));
919 }
920
921 #[test]
922 fn apply_filters_excludes_resources() {
923 use crate::config::ResourceKind as R;
924
925 let cfg = Config {
926 functions: vec![FunctionSpec {
927 name: "f".into(),
928 alt_name: None,
929 schema: None,
930 language: "sql".into(),
931 parameters: vec![],
932 returns: "void".into(),
933 replace: false,
934 volatility: None,
935 strict: false,
936 security: None,
937 cost: None,
938 body: String::new(),
939 comment: None,
940 }],
941 tables: vec![TableSpec {
942 name: "t".into(),
943 alt_name: None,
944 schema: None,
945 if_not_exists: false,
946 columns: vec![],
947 primary_key: None,
948 indexes: vec![],
949 checks: vec![],
950 foreign_keys: vec![],
951 partition_by: None,
952 partitions: vec![],
953 back_references: vec![],
954 lint_ignore: vec![],
955 comment: None,
956 map: None,
957 }],
958 ..Default::default()
959 };
960
961 let include: std::collections::HashSet<R> =
962 vec![R::Functions, R::Tables].into_iter().collect();
963 let exclude: std::collections::HashSet<R> = vec![R::Functions].into_iter().collect();
964
965 let filtered = apply_filters(&cfg, &include, &exclude);
966 assert_eq!(filtered.functions.len(), 0);
967 assert_eq!(filtered.tables.len(), 1);
968 }
969
970 #[test]
971 fn apply_resource_filters_handles_strings() {
972 let cfg = Config {
973 functions: vec![FunctionSpec {
974 name: "f".into(),
975 alt_name: None,
976 schema: None,
977 language: "sql".into(),
978 parameters: vec![],
979 returns: "void".into(),
980 replace: false,
981 volatility: None,
982 strict: false,
983 security: None,
984 cost: None,
985 body: String::new(),
986 comment: None,
987 }],
988 tables: vec![TableSpec {
989 name: "t".into(),
990 alt_name: None,
991 schema: None,
992 if_not_exists: false,
993 columns: vec![],
994 primary_key: None,
995 indexes: vec![],
996 checks: vec![],
997 foreign_keys: vec![],
998 partition_by: None,
999 partitions: vec![],
1000 back_references: vec![],
1001 lint_ignore: vec![],
1002 comment: None,
1003 map: None,
1004 }],
1005 ..Default::default()
1006 };
1007
1008 let filtered = apply_resource_filters(&cfg, &vec!["tables".into()], &[]);
1009 assert_eq!(filtered.functions.len(), 0);
1010 assert_eq!(filtered.tables.len(), 1);
1011 }
1012
1013 #[test]
1014 fn apply_filters_preserves_extended_resources() {
1015 use crate::config::ResourceKind as R;
1016 use crate::ir::{
1017 ColumnSpec, ForeignDataWrapperSpec, ForeignServerSpec, ForeignTableSpec,
1018 PublicationSpec, PublicationTableSpec, StandaloneIndexSpec, StatisticsSpec,
1019 SubscriptionSpec, TextSearchConfigurationMappingSpec, TextSearchConfigurationSpec,
1020 TextSearchDictionarySpec, TextSearchParserSpec, TextSearchTemplateSpec,
1021 };
1022 use std::collections::HashSet;
1023
1024 let cfg = Config {
1025 indexes: vec![StandaloneIndexSpec {
1026 name: "idx".into(),
1027 table: "t".into(),
1028 schema: Some("public".into()),
1029 columns: vec!["col".into()],
1030 expressions: vec![],
1031 r#where: None,
1032 orders: vec![],
1033 operator_classes: vec![],
1034 unique: false,
1035 }],
1036 statistics: vec![StatisticsSpec {
1037 name: "stats".into(),
1038 alt_name: None,
1039 schema: Some("public".into()),
1040 table: "t".into(),
1041 columns: vec!["col".into()],
1042 kinds: vec![],
1043 comment: None,
1044 }],
1045 foreign_data_wrappers: vec![ForeignDataWrapperSpec {
1046 name: "fdw".into(),
1047 alt_name: None,
1048 handler: None,
1049 validator: None,
1050 options: vec![],
1051 comment: None,
1052 }],
1053 foreign_servers: vec![ForeignServerSpec {
1054 name: "server".into(),
1055 alt_name: None,
1056 wrapper: "fdw".into(),
1057 r#type: None,
1058 version: None,
1059 options: vec![],
1060 comment: None,
1061 }],
1062 foreign_tables: vec![ForeignTableSpec {
1063 name: "foreign_table".into(),
1064 alt_name: None,
1065 schema: Some("public".into()),
1066 server: "server".into(),
1067 columns: vec![ColumnSpec {
1068 name: "col".into(),
1069 r#type: "text".into(),
1070 nullable: true,
1071 default: None,
1072 db_type: None,
1073 lint_ignore: vec![],
1074 comment: None,
1075 count: 0,
1076 }],
1077 options: vec![],
1078 comment: None,
1079 }],
1080 text_search_dictionaries: vec![TextSearchDictionarySpec {
1081 name: "dict".into(),
1082 alt_name: None,
1083 schema: Some("public".into()),
1084 template: "simple".into(),
1085 options: vec![],
1086 comment: None,
1087 }],
1088 text_search_configurations: vec![TextSearchConfigurationSpec {
1089 name: "cfg".into(),
1090 alt_name: None,
1091 schema: Some("public".into()),
1092 parser: "default".into(),
1093 mappings: vec![TextSearchConfigurationMappingSpec {
1094 tokens: vec!["asciiword".into()],
1095 dictionaries: vec!["dict".into()],
1096 }],
1097 comment: None,
1098 }],
1099 text_search_templates: vec![TextSearchTemplateSpec {
1100 name: "tmpl".into(),
1101 alt_name: None,
1102 schema: Some("public".into()),
1103 init: None,
1104 lexize: "lexize".into(),
1105 comment: None,
1106 }],
1107 text_search_parsers: vec![TextSearchParserSpec {
1108 name: "parser".into(),
1109 alt_name: None,
1110 schema: Some("public".into()),
1111 start: "start".into(),
1112 gettoken: "get".into(),
1113 end: "end".into(),
1114 headline: None,
1115 lextypes: "lex".into(),
1116 comment: None,
1117 }],
1118 publications: vec![PublicationSpec {
1119 name: "pub".into(),
1120 alt_name: None,
1121 all_tables: false,
1122 tables: vec![PublicationTableSpec {
1123 schema: Some("public".into()),
1124 table: "t".into(),
1125 }],
1126 publish: vec!["insert".into()],
1127 comment: None,
1128 }],
1129 subscriptions: vec![SubscriptionSpec {
1130 name: "sub".into(),
1131 alt_name: None,
1132 connection: "dbname=app".into(),
1133 publications: vec!["pub".into()],
1134 comment: None,
1135 }],
1136 ..Default::default()
1137 };
1138
1139 use crate::config::ResourceKind;
1140
1141 let include_all = ResourceKind::default_include_set();
1142 let filtered = apply_filters(&cfg, &include_all, &HashSet::new());
1143 assert_eq!(filtered.indexes.len(), 1);
1144 assert_eq!(filtered.statistics.len(), 1);
1145 assert_eq!(filtered.foreign_data_wrappers.len(), 1);
1146 assert_eq!(filtered.foreign_servers.len(), 1);
1147 assert_eq!(filtered.foreign_tables.len(), 1);
1148 assert_eq!(filtered.text_search_dictionaries.len(), 1);
1149 assert_eq!(filtered.text_search_configurations.len(), 1);
1150 assert_eq!(filtered.text_search_templates.len(), 1);
1151 assert_eq!(filtered.text_search_parsers.len(), 1);
1152 assert_eq!(filtered.publications.len(), 1);
1153 assert_eq!(filtered.subscriptions.len(), 1);
1154
1155 let mut only_indexes = HashSet::new();
1156 only_indexes.insert(R::Indexes);
1157 let filtered_only_indexes = apply_filters(&cfg, &only_indexes, &HashSet::new());
1158 assert_eq!(filtered_only_indexes.indexes.len(), 1);
1159 assert_eq!(filtered_only_indexes.statistics.len(), 0);
1160 assert_eq!(filtered_only_indexes.foreign_tables.len(), 0);
1161
1162 let exclude_indexes: HashSet<R> = vec![R::Indexes].into_iter().collect();
1163 let filtered_without_indexes = apply_filters(&cfg, &include_all, &exclude_indexes);
1164 assert_eq!(filtered_without_indexes.indexes.len(), 0);
1165 assert_eq!(filtered_without_indexes.statistics.len(), 1);
1166 }
1167
1168 #[test]
1169 fn parse_role_and_grant() {
1170 let mut files = HashMap::new();
1171 files.insert(
1172 p("/root/main.hcl"),
1173 r#"
1174 role "app" {
1175 login = true
1176 createdb = true
1177 }
1178 grant "g" {
1179 role = "app"
1180 privileges = ["ALL"]
1181 database = "appdb"
1182 }
1183 "#
1184 .to_string(),
1185 );
1186 let loader = MapLoader { files };
1187 let cfg = load_config(&p("/root/main.hcl"), &loader, EnvVars::default()).unwrap();
1188 assert_eq!(cfg.roles.len(), 1);
1189 assert_eq!(cfg.grants.len(), 1);
1190 let sql = generate_with_backend("postgres", &cfg, false).unwrap();
1191 assert!(sql.contains("CREATE ROLE \"app\" LOGIN CREATEDB;"));
1192 assert!(sql.contains("GRANT ALL PRIVILEGES ON DATABASE \"appdb\" TO \"app\";"));
1193 }
1194}