awto_compile/
database.rs

1use std::{borrow::Cow, env, fmt::Write, io};
2
3use awto::{
4    database::{DatabaseColumn, DatabaseDefault, DatabaseTable, DatabaseType},
5    schema::{Model, Role},
6};
7use proc_macro2::Literal;
8use quote::{format_ident, quote};
9use sqlx::{Executor, PgPool};
10use tokio_stream::StreamExt;
11
12use crate::{
13    error::Error,
14    util::{is_ty_option, is_ty_vec, strip_ty_option},
15};
16
17const COMPILED_RUST_FILE: &str = "app.rs";
18
19#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
20pub struct CompileDatabaseResult {
21    pub queries_executed: usize,
22    pub rows_affected: u64,
23}
24
25#[cfg(feature = "async")]
26pub async fn compile_database(
27    uri: &str,
28    models: Vec<Model>,
29) -> Result<CompileDatabaseResult, Box<dyn std::error::Error>> {
30    use tokio::fs;
31
32    let out_dir = env::var("OUT_DIR").unwrap();
33    let pool = PgPool::connect(uri).await?;
34    let compiler = DatabaseCompiler::from_pool(&pool, models);
35
36    let generated_code = compiler.compile_generated_code();
37    if !generated_code.is_empty() {
38        let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
39        fs::write(rs_path, generated_code).await?;
40    }
41
42    compiler.append_sea_orm_models().await?;
43
44    let sql = compiler.compile().await?;
45    if !sql.is_empty() {
46        let results = pool
47            .execute_many(sql.as_str())
48            .collect::<Result<Vec<_>, _>>()
49            .await?;
50        let queries_executed = results.len();
51        let rows_affected = results
52            .iter()
53            .fold(0, |acc, result| result.rows_affected() + acc);
54
55        Ok(CompileDatabaseResult {
56            queries_executed,
57            rows_affected,
58        })
59    } else {
60        Ok(CompileDatabaseResult::default())
61    }
62}
63
64#[cfg(not(feature = "async"))]
65pub async fn compile_database(
66    uri: &str,
67    models: Vec<Model>,
68) -> Result<CompileDatabaseResult, Box<dyn std::error::Error>> {
69    use std::fs;
70
71    let out_dir = env::var("OUT_DIR").unwrap();
72    let pool = PgPool::connect(uri).await?;
73    let compiler = DatabaseCompiler::from_pool(&pool, models);
74
75    let generated_code = compiler.compile_generated_code();
76    if !generated_code.is_empty() {
77        let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
78        fs::write(rs_path, generated_code)?;
79    }
80
81    compiler.append_sea_orm_models()?;
82
83    let sql = compiler.compile().await?;
84    if !sql.is_empty() {
85        let results = pool
86            .execute_many(sql.as_str())
87            .collect::<Result<Vec<_>, _>>()
88            .await?;
89        let queries_executed = results.len();
90        let rows_affected = results
91            .iter()
92            .fold(0, |acc, result| result.rows_affected() + acc);
93
94        Ok(CompileDatabaseResult {
95            queries_executed,
96            rows_affected,
97        })
98    } else {
99        Ok(CompileDatabaseResult::default())
100    }
101}
102
103pub struct DatabaseCompiler<'pool> {
104    pool: Cow<'pool, PgPool>,
105    models: Vec<Model>,
106}
107
108impl<'pool> DatabaseCompiler<'pool> {
109    pub async fn connect(
110        uri: &str,
111        models: Vec<Model>,
112    ) -> Result<DatabaseCompiler<'_>, sqlx::Error> {
113        let pool = sqlx::PgPool::connect(uri).await?;
114
115        Ok(DatabaseCompiler {
116            pool: Cow::Owned(pool),
117            models,
118        })
119    }
120
121    pub fn from_pool(pool: &'pool PgPool, models: Vec<Model>) -> DatabaseCompiler<'pool> {
122        DatabaseCompiler {
123            pool: Cow::Borrowed(pool),
124            models,
125        }
126    }
127
128    pub async fn compile(&self) -> Result<String, Error> {
129        let mut sql = String::new();
130
131        for (_, table) in self.database_tables() {
132            let db_columns = self.fetch_table(table).await?;
133
134            match db_columns {
135                Some(db_columns) => {
136                    writeln!(sql, "{}", self.write_sync_sql(table, &db_columns).await).unwrap();
137                }
138                None => {
139                    writeln!(sql, "{}", self.write_table_create_sql(table)).unwrap();
140                }
141            }
142        }
143
144        Ok(sql.trim().to_string())
145    }
146
147    /// Compiles generated Rust code from schemas and services.
148    pub fn compile_generated_code(&self) -> String {
149        let mut code = String::new();
150
151        for (model, table) in self.database_tables() {
152            let ident = format_ident!("{}", model.name);
153            let db_module_ident = format_ident!("{}", table.name);
154
155            let mut from_schema_fields = Vec::new();
156            let mut from_db_fields = Vec::new();
157
158            for field in &model.fields {
159                let field_ident = format_ident!("{}", field.name);
160
161                let ty = strip_ty_option(&field.ty);
162
163                if is_ty_vec(ty) {
164                    from_schema_fields.push(
165                        quote!(#field_ident: val.#field_ident.into_iter().map(|v| v.into()).collect()),
166                    );
167                    from_db_fields.push(
168                        quote!(#field_ident: val.#field_ident.into_iter().map(|v| v.into()).collect()),
169                    );
170                } else {
171                    from_schema_fields.push(quote!(#field_ident: val.#field_ident.into()));
172                    from_db_fields.push(quote!(#field_ident: val.#field_ident.into()));
173                }
174            }
175
176            let expanded = quote!(
177                impl ::std::convert::From<crate::#db_module_ident::Model> for ::schema::#ident {
178                    #[allow(unused_variables)]
179                    fn from(val: crate::#db_module_ident::Model) -> Self {
180                        Self {
181                            #( #from_schema_fields, )*
182                        }
183                    }
184                }
185
186                impl ::std::convert::From<::schema::#ident> for crate::#db_module_ident::Model {
187                    #[allow(unused_variables)]
188                    fn from(val: ::schema::#ident) -> Self {
189                        Self {
190                            #( #from_db_fields, )*
191                        }
192                    }
193                }
194            );
195
196            write!(code, "{}", expanded.to_string()).unwrap();
197        }
198
199        for (model, table) in self.database_sub_tables() {
200            let ident = format_ident!("{}", model.name);
201            let db_module_ident = format_ident!("{}", table.name);
202
203            let active_values = model.fields.iter().map(|field| {
204                let field_ident = format_ident!("{}", field.name);
205                
206                let self_field = if is_ty_option(&field.ty) {
207                    let db_field = table.columns.iter().find(|column| column.name == field.name).unwrap();
208                    match &db_field.default {
209                        Some(DatabaseDefault::Bool(b)) => quote!(self.#field_ident.unwrap_or(#b)),
210                        Some(DatabaseDefault::Float(f)) => {
211                            let f = Literal::i64_unsuffixed(*f);
212                            quote!(self.#field_ident.unwrap_or(#f))
213                        },
214                        Some(DatabaseDefault::Int(i)) => {
215                            let i = Literal::u64_unsuffixed(*i);
216                            quote!(self.#field_ident.unwrap_or(#i))
217                        },
218                        Some(DatabaseDefault::String(s)) => quote!(self.#field_ident.unwrap_or(#s)),
219                        _ => quote!(self.#field_ident),
220                    }
221                } else {
222                    quote!(self.#field_ident)
223                };
224
225                quote!(
226                    #field_ident: ::sea_orm::entity::IntoActiveValue::into_active_value(#self_field).into()
227                )
228            });
229            
230            let expanded = quote!(
231                impl ::sea_orm::entity::IntoActiveModel<crate::#db_module_ident::ActiveModel> for ::schema::#ident {
232                    fn into_active_model(self) -> crate::#db_module_ident::ActiveModel {
233                        crate::#db_module_ident::ActiveModel {
234                            #( #active_values, )*
235                            ..Default::default()
236                        }
237                    }
238                }
239            );
240
241            write!(code, "{}", expanded.to_string()).unwrap();
242        }
243
244        code.trim().to_string()
245    }
246
247    #[cfg(feature = "async")]
248    async fn append_sea_orm_models(&self) -> Result<(), io::Error> {
249        use tokio::fs;
250        use tokio::io::AsyncWriteExt;
251
252        let out_dir = env::var("OUT_DIR").unwrap();
253        let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
254
255        let mut file = fs::OpenOptions::new().append(true).open(&rs_path).await?;
256
257        for (_, table) in self.database_tables() {
258            file.write(format!("pub mod {} {{", table.name).as_bytes()).await?;
259            file.write(format!(r#"    sea_orm::include_model!("{}");"#, table.name).as_bytes()).await?;
260            file.write(b"}").await?;
261        }
262
263        Ok(())
264    }
265    
266    #[cfg(not(feature = "async"))]
267    fn append_sea_orm_models(&self) -> Result<(), io::Error> {
268        use std::fs;
269        use std::io::Write;
270
271        let out_dir = env::var("OUT_DIR").unwrap();
272        let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
273
274        let mut file = fs::OpenOptions::new().append(true).open(&rs_path)?;
275
276        for (_, table) in self.database_tables() {
277            write!(file, "pub mod {} {{", table.name).unwrap();
278            write!(file, r#"    sea_orm::include_model!("{}");"#, table.name).unwrap();
279            write!(file, "}}").unwrap();
280        }
281
282        Ok(())
283    }
284
285    fn database_tables(&self) -> Vec<(&Model, &DatabaseTable)> {
286        self.models.iter().fold(Vec::new(), |mut acc, model| {
287            let roles = model
288                .roles
289                .iter()
290                .filter_map(|role| match role {
291                    Role::DatabaseTable(database_table) => Some((model, database_table)),
292                    _ => None,
293                })
294                .collect::<Vec<_>>();
295
296            acc.extend(roles);
297
298            acc
299        })
300    }
301
302    fn database_sub_tables(&self) -> Vec<(&Model, &DatabaseTable)> {
303        self.models.iter().fold(Vec::new(), |mut acc, model| {
304            let roles = model
305                .roles
306                .iter()
307                .filter_map(|role| match role {
308                    Role::DatabaseSubTable(database_sub_table) => Some((model, database_sub_table)),
309                    _ => None,
310                })
311                .collect::<Vec<_>>();
312
313            acc.extend(roles);
314
315            acc
316        })
317    }
318
319    async fn fetch_table(
320        &self,
321        table: &DatabaseTable,
322    ) -> Result<Option<Vec<DatabaseColumn>>, Error> {
323        #[derive(Debug, sqlx::FromRow)]
324        struct ColumnsQuery {
325            column_name: String,
326            column_default: Option<String>,
327            is_nullable: String,
328            data_type: String,
329            character_maximum_length: Option<i32>,
330            is_primary_key: bool,
331            is_unique: bool,
332            reference: Option<String>,
333        }
334
335        let raw_columns: Vec<ColumnsQuery> = sqlx::query_as(FETCH_TABLE_QUERY)
336            .bind("public")
337            .bind(&table.name)
338            .fetch_all(&*self.pool)
339            .await
340            .map_err(Error::Sqlx)?;
341
342        if raw_columns.is_empty() {
343            return Ok(None);
344        }
345
346        let columns: Vec<DatabaseColumn> = raw_columns
347            .into_iter()
348            .map(|col| {
349                let column_name = col.column_name;
350                let character_maximum_length = col.character_maximum_length;
351
352                Ok(DatabaseColumn {
353                    name: column_name.clone(),
354                    ty: col
355                        .data_type
356                        .parse::<DatabaseType>()
357                        .map(|database_type| {
358                            if let Some(max_len) = character_maximum_length {
359                                if matches!(database_type, DatabaseType::Text(None)) {
360                                    return DatabaseType::Text(Some(max_len));
361                                }
362                            }
363
364                            database_type
365                        })
366                        .map_err(|_| Error::UnsupportedType(table.name.clone(), column_name))?,
367                    nullable: col.is_nullable == "YES",
368                    default: col.column_default.map(|def| {
369                        if def.starts_with('\'') {
370                            let s = def
371                                .strip_prefix('\'')
372                                .unwrap()
373                                .splitn(2, '\'')
374                                .next()
375                                .unwrap()
376                                .to_string();
377                            DatabaseDefault::String(s)
378                        } else if def == "true" {
379                            DatabaseDefault::Bool(true)
380                        } else if def == "false" {
381                            DatabaseDefault::Bool(false)
382                        } else if let Ok(num) = def.parse::<u64>() {
383                            DatabaseDefault::Int(num)
384                        } else if let Ok(num) = def.parse::<i64>() {
385                            DatabaseDefault::Float(num)
386                        } else {
387                            DatabaseDefault::Raw(def)
388                        }
389                    }),
390                    unique: col.is_unique,
391                    constraint: None,
392                    primary_key: col.is_primary_key,
393                    references: if let Some(references) = col.reference {
394                        let mut parts = references.splitn(2, ':');
395                        if let Some(references_table) = parts.next() {
396                            parts.next().map(|references_column| {
397                                (references_table.to_string(), references_column.to_string())
398                            })
399                        } else {
400                            None
401                        }
402                    } else {
403                        None
404                    },
405                })
406            })
407            .collect::<Result<_, _>>()?;
408
409        Ok(Some(columns))
410    }
411
412    fn write_table_create_sql(&self, table: &DatabaseTable) -> String {
413        let mut sql = String::new();
414
415        writeln!(sql, "CREATE TABLE IF NOT EXISTS {} (", table.name).unwrap();
416
417        for (i, column) in table.columns.iter().enumerate() {
418            write!(sql, "  {}", self.write_column_sql(column)).unwrap();
419
420            if i < table.columns.len() - 1 {
421                writeln!(sql, ",").unwrap();
422            } else {
423                writeln!(sql).unwrap();
424            }
425        }
426
427        writeln!(sql, ");").unwrap();
428
429        sql
430    }
431
432    fn write_column_sql(&self, column: &DatabaseColumn) -> String {
433        let mut sql = String::new();
434
435        write!(sql, "{} {}", column.name, column.ty,).unwrap();
436
437        if !column.nullable {
438            write!(sql, " NOT NULL",).unwrap();
439        }
440
441        if let Some(default) = &column.default {
442            write!(sql, " DEFAULT {}", default).unwrap();
443        }
444
445        if let Some(constraint) = &column.constraint {
446            write!(sql, " CHECK ({})", constraint).unwrap();
447        }
448
449        if column.primary_key {
450            write!(sql, " PRIMARY KEY").unwrap();
451        }
452
453        if let Some((table, col)) = &column.references {
454            write!(sql, " REFERENCES {}({})", table, col).unwrap();
455        }
456
457        sql
458    }
459
460    async fn write_sync_sql(&self, table: &DatabaseTable, db_columns: &[DatabaseColumn]) -> String {
461        let mut sql = String::new();
462
463        for schema_col in &table.columns {
464            let db_col = match db_columns
465                .iter()
466                .find(|db_col| db_col.name == schema_col.name)
467            {
468                Some(db_col) => db_col,
469                None => {
470                    // Column does not exist in DB
471                    writeln!(
472                        sql,
473                        "ALTER TABLE {} ADD COLUMN {};",
474                        table.name,
475                        self.write_column_sql(schema_col)
476                    )
477                    .unwrap();
478                    continue;
479                }
480            };
481
482            // Check for type mismatch
483            if schema_col.ty != db_col.ty {
484                writeln!(
485                    sql,
486                    "ALTER TABLE {table} ALTER COLUMN {column} TYPE {ty} USING {column}::{ty};",
487                    table = table.name,
488                    column = schema_col.name,
489                    ty = schema_col.ty.to_string(),
490                )
491                .unwrap();
492            }
493
494            // Check for nullable mismatch
495            if schema_col.nullable != db_col.nullable {
496                if db_col.nullable {
497                    writeln!(
498                        sql,
499                        "ALTER TABLE {table} ALTER COLUMN {column} SET NOT NULL;",
500                        table = table.name,
501                        column = schema_col.name
502                    )
503                    .unwrap();
504                } else {
505                    writeln!(
506                        sql,
507                        "ALTER TABLE {table} ALTER COLUMN {column} DROP NOT NULL;",
508                        table = table.name,
509                        column = schema_col.name
510                    )
511                    .unwrap();
512                }
513            }
514
515            // Check for default mismatch
516            if schema_col.default != db_col.default {
517                if let Some(default) = &schema_col.default {
518                    writeln!(
519                        sql,
520                        "ALTER TABLE {table} ALTER COLUMN {column} SET DEFAULT {default};",
521                        table = table.name,
522                        column = schema_col.name,
523                        default = default
524                    )
525                    .unwrap();
526                } else {
527                    writeln!(
528                        sql,
529                        "ALTER TABLE {table} ALTER COLUMN {column} DROP DEFAULT;",
530                        table = table.name,
531                        column = schema_col.name
532                    )
533                    .unwrap();
534                }
535            }
536
537            // Check for unique mismatch
538            if schema_col.unique != db_col.unique {
539                if db_col.unique {
540                    writeln!(
541                        sql,
542                        "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_key;",
543                        table = table.name,
544                        column = schema_col.name
545                    )
546                    .unwrap();
547                } else {
548                    writeln!(
549                        sql,
550                        "ALTER TABLE {table} ADD CONSTRAINT {table}_{column}_key UNIQUE ({column});",
551                        table = table.name,
552                        column = schema_col.name
553                    )
554                    .unwrap();
555                }
556            }
557
558            // Check for references mismatch
559            if schema_col.references != db_col.references {
560                if let Some(references) = &schema_col.references {
561                    if db_col.references.is_some() {
562                        writeln!(
563                            sql,
564                            "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_fkey;",
565                            table = table.name,
566                            column = schema_col.name
567                        )
568                        .unwrap();
569                    }
570                    writeln!(
571                        sql,
572                        "ALTER TABLE {table} ADD CONSTRAINT {table}_{column}_fkey FOREIGN KEY ({column}) REFERENCES {reference_table} ({reference_column});",
573                        table = table.name,
574                        column = schema_col.name,
575                        reference_table = references.0,
576                        reference_column = references.1,
577                    )
578                    .unwrap();
579                } else {
580                    writeln!(
581                        sql,
582                        "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_fkey;",
583                        table = table.name,
584                        column = schema_col.name
585                    )
586                    .unwrap();
587                }
588            }
589        }
590
591        // Delete columns that exist in db but don't exist in schema
592        db_columns
593            .iter()
594            .filter(|db_col| {
595                table
596                    .columns
597                    .iter()
598                    .all(|schema_col| schema_col.name != db_col.name)
599            })
600            .for_each(|db_col| {
601                writeln!(
602                    sql,
603                    "ALTER TABLE {table} DROP COLUMN {column};",
604                    table = table.name,
605                    column = db_col.name
606                )
607                .unwrap();
608            });
609
610        sql
611    }
612}
613
614const FETCH_TABLE_QUERY: &str = "
615SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
616(
617    SELECT
618        COUNT(*) > 0
619    FROM information_schema.table_constraints tco
620    JOIN information_schema.key_column_usage kcu 
621    ON kcu.constraint_name = tco.constraint_name
622    AND kcu.constraint_schema = tco.constraint_schema
623    AND kcu.constraint_name = tco.constraint_name
624    WHERE
625        tco.constraint_type = 'PRIMARY KEY' AND
626        kcu.table_schema = $1 AND
627        kcu.table_name = $2 AND
628        kcu.column_name = information_schema.columns.column_name
629) as is_primary_key,
630(
631    SELECT
632        COUNT(*) > 0
633    FROM information_schema.table_constraints tco
634    JOIN information_schema.key_column_usage kcu 
635    ON kcu.constraint_name = tco.constraint_name
636    AND kcu.constraint_schema = tco.constraint_schema
637    AND kcu.constraint_name = tco.constraint_name
638    WHERE
639        tco.constraint_type = 'UNIQUE' AND
640        kcu.table_schema = $1 AND
641        kcu.table_name = $2 AND
642        kcu.column_name = information_schema.columns.column_name
643) as is_unique,
644(
645    SELECT CONCAT(
646        rel_tco.table_name,
647        ':',
648        (
649            SELECT u.column_name
650            FROM information_schema.constraint_column_usage u
651            INNER JOIN information_schema.referential_constraints fk
652            ON
653                u.constraint_catalog = fk.unique_constraint_catalog AND
654                u.constraint_schema = fk.unique_constraint_schema AND
655                u.constraint_name = fk.unique_constraint_name
656            INNER JOIN information_schema.key_column_usage r
657            ON
658                r.constraint_catalog = fk.constraint_catalog AND
659                r.constraint_schema = fk.constraint_schema AND
660                r.constraint_name = fk.constraint_name
661            WHERE
662                fk.constraint_name = kcu.constraint_name AND
663                u.table_schema = kcu.table_schema AND
664                u.table_name = rel_tco.table_name
665        )
666    )
667    FROM information_schema.table_constraints tco
668    JOIN information_schema.key_column_usage kcu
669    ON
670        tco.constraint_schema = kcu.constraint_schema AND
671        tco.constraint_name = kcu.constraint_name
672    JOIN information_schema.referential_constraints rco
673    ON
674        tco.constraint_schema = rco.constraint_schema AND
675        tco.constraint_name = rco.constraint_name
676    JOIN information_schema.table_constraints rel_tco
677    ON
678        rco.unique_constraint_schema = rel_tco.constraint_schema AND
679        rco.unique_constraint_name = rel_tco.constraint_name
680    WHERE
681        tco.constraint_type = 'FOREIGN KEY' AND
682        kcu.table_name = $2 AND
683        kcu.column_name = information_schema.columns.column_name
684    GROUP BY
685        kcu.table_schema,
686        kcu.table_name,
687        rel_tco.table_name,
688        rel_tco.table_schema,
689        kcu.constraint_name
690    ORDER BY
691        kcu.table_schema,
692        kcu.table_name
693) as reference
694FROM information_schema.columns
695WHERE table_schema = $1
696AND table_name = $2;
697";