db_schema/
pg.rs

1#[cfg(feature = "db-postgres")]
2use paste::paste;
3#[cfg(feature = "db-postgres")]
4use sqlx::PgPool;
5
6/// A struct representing a PostgreSQL schema.
7#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
8pub struct PgSchema {
9    namespace: String,
10}
11
12impl PgSchema {
13    /// Create a new `PgSchema` instance.
14    pub fn new(namespace: impl Into<String>) -> Self {
15        Self {
16            namespace: namespace.into(),
17        }
18    }
19
20    /// Generates a SQL statement for creating all enum types in the schema.
21    pub fn enums(&self) -> String {
22        format!("SELECT
23      'CREATE TYPE ' || n.nspname || '.' || t.typname || ' AS ENUM (' || string_agg(quote_literal(e.enumlabel), ', ') || ');' AS sql
24    FROM
25      pg_catalog.pg_type t
26      JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
27      JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
28    WHERE
29      n.nspname = '{}'
30      AND t.typtype = 'e'
31    GROUP BY
32      n.nspname, t.typname;", self.namespace)
33    }
34
35    /// Generates a SQL statement for creating all composite types in the schema.
36    pub fn types(&self) -> String {
37        format!("SELECT
38      'CREATE TYPE ' || n.nspname || '.' || t.typname || ' AS (' || string_agg(a.attname || ' ' || pg_catalog.format_type(a.atttypid, a.atttypmod), ', ') || ');' AS sql
39    FROM
40      pg_catalog.pg_type t
41      JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
42      JOIN pg_catalog.pg_class c ON t.typrelid = c.oid
43      JOIN pg_catalog.pg_attribute a ON t.typrelid = a.attrelid
44    WHERE
45      n.nspname = '{}'
46      AND t.typtype = 'c'
47      AND c.relkind = 'c'
48    GROUP BY
49      n.nspname, t.typname;", self.namespace)
50    }
51
52    /// Generates a SQL statement for creating all tables in the schema.
53    pub fn tables(&self) -> String {
54        format!("WITH table_columns AS (
55          SELECT
56            n.nspname AS schema_name,
57            c.relname AS table_name,
58            a.attname AS column_name,
59            pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type,
60            a.attnotnull AS is_not_null,
61            a.attnum AS column_position
62          FROM
63            pg_catalog.pg_attribute a
64            JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
65            JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
66          WHERE
67            a.attnum > 0
68            AND NOT a.attisdropped
69            AND n.nspname = '{0}'
70            AND c.relkind = 'r'
71        ),
72        table_constraints AS (
73          SELECT
74            tc.constraint_name,
75            tc.table_schema,
76            tc.table_name,
77            kcu.column_name,
78            tc.constraint_type
79          FROM
80            information_schema.table_constraints tc
81            JOIN information_schema.key_column_usage kcu
82              ON tc.constraint_catalog = kcu.constraint_catalog
83              AND tc.constraint_schema = kcu.constraint_schema
84              AND tc.constraint_name = kcu.constraint_name
85          WHERE
86            tc.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY', 'UNIQUE')
87            AND tc.table_schema = '{0}'
88        ),
89        formatted_columns AS (
90          SELECT
91            tc.schema_name,
92            tc.table_name,
93            tc.column_name,
94            tc.column_type,
95            tc.is_not_null,
96            tc.column_position,
97            STRING_AGG(
98              tcs.constraint_type || ' (' || tc.column_name || ')',
99              ', '
100              ORDER BY tcs.constraint_type DESC
101            ) AS column_constraints
102          FROM
103            table_columns tc
104            LEFT JOIN table_constraints tcs
105              ON tc.schema_name = tcs.table_schema
106              AND tc.table_name = tcs.table_name
107              AND tc.column_name = tcs.column_name
108          GROUP BY
109            tc.schema_name,
110            tc.table_name,
111            tc.column_name,
112            tc.column_type,
113            tc.is_not_null,
114            tc.column_position
115        ),
116        create_table_statements AS (
117          SELECT
118            fc.schema_name,
119            fc.table_name,
120            STRING_AGG(
121              fc.column_name || ' ' || fc.column_type || (CASE WHEN fc.is_not_null THEN ' NOT NULL' ELSE '' END) || COALESCE(' ' || fc.column_constraints, ''),
122              ', '
123              ORDER BY fc.column_position
124            ) AS formatted_columns
125          FROM
126            formatted_columns fc
127          GROUP BY
128            fc.schema_name,
129            fc.table_name
130        )
131        SELECT
132          'CREATE TABLE ' || schema_name || '.' || table_name || ' (' || formatted_columns || ');' AS sql
133        FROM
134          create_table_statements;", self.namespace)
135    }
136
137    /// Generates a SQL statement for creating all views in the schema.
138    pub fn views(&self) -> String {
139        format!(
140            "SELECT
141      'CREATE VIEW ' || n.nspname || '.' || c.relname || ' AS ' || pg_get_viewdef(c.oid) AS sql
142    FROM
143      pg_catalog.pg_class c
144      JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
145    WHERE
146      c.relkind = 'v' -- Select views
147      AND n.nspname = '{}';",
148            self.namespace
149        )
150    }
151
152    /// Generates a SQL statement for creating all materialized views in the schema.
153    pub fn mviews(&self) -> String {
154        format!("SELECT
155        'CREATE MATERIALIZED VIEW ' || n.nspname || '.' || c.relname || ' AS ' || pg_get_viewdef(c.oid) AS sql
156      FROM
157        pg_catalog.pg_class c
158        JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
159      WHERE
160        c.relkind = 'm' -- Select materialized views
161        AND n.nspname = '{}';", self.namespace)
162    }
163
164    /// Generates a SQL statement for creating all functions in the schema.
165    pub fn functions(&self) -> String {
166        format!("SELECT
167      'CREATE OR REPLACE FUNCTION ' || n.nspname || '.' || p.proname || '(' || pg_get_function_arguments(p.oid) || ') RETURNS '
168      || pg_get_function_result(p.oid) || ' AS $function_body$ ' || pg_get_functiondef(p.oid) || '$function_body$ LANGUAGE ' || l.lanname || ';' AS sql
169    FROM
170      pg_catalog.pg_proc p
171      JOIN pg_catalog.pg_namespace n ON p.pronamespace = n.oid
172      JOIN pg_catalog.pg_language l ON p.prolang = l.oid
173    WHERE
174      n.nspname = '{}'
175      AND p.prokind = 'f';", self.namespace)
176    }
177
178    /// Generates a SQL statement for creating all triggers in the schema.
179    pub fn triggers(&self) -> String {
180        format!(
181            "SELECT
182      'CREATE TRIGGER ' || t.tgname
183      || ' ' || CASE
184        WHEN t.tgtype & 2 > 0 THEN 'BEFORE'
185        WHEN t.tgtype & 4 > 0 THEN 'AFTER'
186        WHEN t.tgtype & 64 > 0 THEN 'INSTEAD OF'
187      END
188      || ' ' || CASE
189        WHEN t.tgtype & 8 > 0 THEN 'INSERT'
190        WHEN t.tgtype & 16 > 0 THEN 'DELETE'
191        WHEN t.tgtype & 32 > 0 THEN 'UPDATE'
192      END
193      || ' ON ' || n.nspname || '.' || c.relname
194      || ' FOR EACH ' || CASE WHEN t.tgtype & 1 > 0 THEN 'ROW' ELSE 'STATEMENT' END
195      || ' EXECUTE FUNCTION ' || np.nspname || '.' || p.proname || '();' AS sql
196    FROM
197      pg_catalog.pg_trigger t
198      JOIN pg_catalog.pg_class c ON t.tgrelid = c.oid
199      JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
200      JOIN pg_catalog.pg_proc p ON t.tgfoid = p.oid
201      JOIN pg_catalog.pg_namespace np ON p.pronamespace = np.oid
202    WHERE
203      n.nspname = '{}'
204      AND NOT t.tgisinternal;",
205            self.namespace
206        )
207    }
208
209    /// Generates a SQL statement for creating all indexes in the schema.
210    pub fn indexes(&self) -> String {
211        format!("SELECT indexdef || ';' AS sql FROM pg_indexes WHERE schemaname = '{}' ORDER BY tablename, indexname;", self.namespace)
212    }
213}
214
215#[cfg(feature = "db-postgres")]
216#[derive(sqlx::FromRow)]
217struct SchemaRet {
218    sql: String,
219}
220
221#[cfg(feature = "db-postgres")]
222macro_rules! gen_fn {
223  ($($name:ident),*) => {
224      $(
225        paste! {
226          /// Async function that fetches the SQL statements for $name for the specified schema item.
227          ///
228          /// Example usage:
229          /// ```
230          /// use crate::PgSchema;
231          ///
232          /// let schema = PgSchema::new("my_schema");
233          /// let pool = get_pg_pool(); // Function to get a connection pool
234          /// let sqls = schema.[<get_ $name>](&pool).await.unwrap();
235          /// ```
236          pub async fn [<get_ $name>] (&self, pool: &PgPool) -> Result<Vec<String>, sqlx::Error> {
237              let sql = self.$name();
238              let ret: Vec<SchemaRet> = sqlx::query_as(&sql).fetch_all(pool).await?;
239              Ok(ret.into_iter().map(|r| r.sql).collect())
240          }
241        }
242      )*
243  };
244}
245
246#[cfg(feature = "db-postgres")]
247impl PgSchema {
248    gen_fn!(enums, types, tables, views, mviews, functions, triggers, indexes);
249}
250
251#[cfg(feature = "db-postgres")]
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use anyhow::Result;
256    use sqlx_db_tester::TestPg;
257
258    #[tokio::test]
259    async fn get_tables_should_work() -> Result<()> {
260        let schema = PgSchema::new("gpt");
261        let tdb = TestPg::default();
262        let pool = tdb.get_pool().await;
263        let items = schema.get_tables(&pool).await?;
264        assert_eq!(items.len(), 4);
265        assert_eq!(
266          items[0],
267            "CREATE TABLE gpt.comments (id integer NOT NULL PRIMARY KEY (id), post_id integer NOT NULL FOREIGN KEY (post_id), user_id integer NOT NULL FOREIGN KEY (user_id), content text NOT NULL, created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone NOT NULL);"
268        );
269
270        Ok(())
271    }
272
273    #[tokio::test]
274    async fn get_enums_should_work() -> Result<()> {
275        let schema = PgSchema::new("gpt");
276        let tdb = TestPg::default();
277        let pool = tdb.get_pool().await;
278        let items = schema.get_enums(&pool).await?;
279        assert_eq!(items.len(), 2);
280        assert_eq!(
281            items[0],
282            "CREATE TYPE gpt.login_method AS ENUM ('email', 'google', 'github');"
283        );
284
285        Ok(())
286    }
287
288    #[tokio::test]
289    async fn get_types_should_work() -> Result<()> {
290        let schema = PgSchema::new("gpt");
291        let tdb = TestPg::default();
292        let pool = tdb.get_pool().await;
293        let items = schema.get_types(&pool).await?;
294        assert_eq!(items.len(), 1);
295        assert_eq!(
296            items[0],
297            "CREATE TYPE gpt.address AS (street character varying(255), city character varying(100), state character(2), postal_code character(5));"
298        );
299
300        Ok(())
301    }
302
303    #[tokio::test]
304    async fn get_views_should_work() -> Result<()> {
305        let schema = PgSchema::new("gpt");
306        let tdb = TestPg::default();
307        let pool = tdb.get_pool().await;
308        let items = schema.get_views(&pool).await?;
309        assert_eq!(items.len(), 1);
310        assert_eq!(
311            items[0],
312            "CREATE VIEW gpt.posts_with_comments AS  SELECT p.id,\n    p.user_id,\n    p.title,\n    p.content,\n    p.status,\n    p.published_at,\n    p.created_at,\n    p.updated_at,\n    json_agg(json_build_object('id', c.id, 'user_id', c.user_id, 'content', c.content, 'created_at', c.created_at, 'updated_at', c.updated_at)) AS comments\n   FROM (gpt.posts p\n     LEFT JOIN gpt.comments c ON ((c.post_id = p.id)))\n  GROUP BY p.id;"
313        );
314
315        Ok(())
316    }
317
318    #[tokio::test]
319    async fn get_mviews_should_work() -> Result<()> {
320        let schema = PgSchema::new("gpt");
321        let tdb = TestPg::default();
322        let pool = tdb.get_pool().await;
323        let items = schema.get_mviews(&pool).await?;
324        assert_eq!(items.len(), 1);
325        assert_eq!(
326            items[0],
327            "CREATE MATERIALIZED VIEW gpt.users_with_posts AS  SELECT u.id,\n    u.username,\n    u.email,\n    u.first_name,\n    u.last_name,\n    u.created_at,\n    u.updated_at,\n    json_agg(json_build_object('id', p.id, 'title', p.title, 'content', p.content, 'status', p.status, 'published_at', p.published_at, 'created_at', p.created_at, 'updated_at', p.updated_at)) AS posts\n   FROM (gpt.users u\n     LEFT JOIN gpt.posts p ON ((p.user_id = u.id)))\n  GROUP BY u.id;"
328        );
329
330        Ok(())
331    }
332
333    #[tokio::test]
334    async fn get_functions_should_work() -> Result<()> {
335        let schema = PgSchema::new("gpt");
336        let tdb = TestPg::default();
337        let pool = tdb.get_pool().await;
338        let items = schema.get_functions(&pool).await?;
339        assert_eq!(items.len(), 1);
340        assert_eq!(
341            items[0],
342            "CREATE OR REPLACE FUNCTION gpt.refresh_users_with_posts() RETURNS trigger AS $function_body$ CREATE OR REPLACE FUNCTION gpt.refresh_users_with_posts()\n RETURNS trigger\n LANGUAGE plpgsql\nAS $function$\nBEGIN\n  REFRESH MATERIALIZED VIEW gpt.users_with_posts;\n  RETURN NULL;\nEND;\n$function$\n$function_body$ LANGUAGE plpgsql;"
343        );
344
345        Ok(())
346    }
347
348    #[tokio::test]
349    async fn get_triggers_should_work() -> Result<()> {
350        let schema = PgSchema::new("gpt");
351        let tdb = TestPg::default();
352        let pool = tdb.get_pool().await;
353        let items = schema.get_triggers(&pool).await?;
354        assert_eq!(items.len(), 1);
355        assert_eq!(
356            items[0],
357            "CREATE TRIGGER refresh_users_with_posts AFTER INSERT ON gpt.posts FOR EACH STATEMENT EXECUTE FUNCTION gpt.refresh_users_with_posts();"
358        );
359
360        Ok(())
361    }
362
363    #[tokio::test]
364    async fn get_indexes_should_work() -> Result<()> {
365        let schema = PgSchema::new("gpt");
366        let tdb = TestPg::default();
367        let pool = tdb.get_pool().await;
368        let items = schema.get_indexes(&pool).await?;
369        assert_eq!(items.len(), 8);
370        assert_eq!(
371            items[0],
372            "CREATE UNIQUE INDEX comments_pkey ON gpt.comments USING btree (id);"
373        );
374
375        Ok(())
376    }
377}