easy_sqlx_core/sync/
mod.rs

1use std::io;
2
3use heck::ToSnakeCase;
4use sqlx::{ColumnIndex, Database, Decode, Encode, Executor, IntoArguments};
5
6use crate::sql::{
7    dialects::{
8        self,
9        schema::schema::{self, Schema},
10    },
11    schema::table::TableSchema,
12};
13
14pub mod table;
15
16// pub async fn execute<'a, DB: Database, T>(self, executor: T) -> Result<DB::QueryResult, Error>
17// where
18// for<'e> &'e mut T: Executor<'e, Database = Sqlite>,
19//     E: Executor<'a, Database = DB>,
20//     <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
21// {
22//     sqlx::query("").execute(executor).await
23// }
24
25#[cfg(feature = "postgres")]
26pub async fn sync_tables<C, DB: Database>(conn: &mut C, tables: Vec<TableSchema>) -> io::Result<()>
27// where
28//     for<'e> &'e mut T: Executor<'e, Database = Postgres>,
29where
30    for<'e> &'e mut C: Executor<'e, Database = DB>,
31    for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
32    for<'a> &'a str: ColumnIndex<DB::Row>,
33    for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
34    for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
35    for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
36    for<'a> i64: Encode<'a, DB>,
37    for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
38{
39    sync_tables_with_schema(conn, tables, "").await
40}
41
42pub async fn sync_tables_with_schema<C, DB: Database>(
43    conn: &mut C,
44    tables: Vec<TableSchema>,
45    default_schema: &str,
46) -> io::Result<()>
47// where
48//     for<'e> &'e mut T: Executor<'e, Database = Postgres>,
49where
50    for<'e> &'e mut C: Executor<'e, Database = DB>,
51    for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
52    for<'a> &'a str: ColumnIndex<DB::Row>,
53    for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
54    for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
55    for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
56    for<'a> i64: Encode<'a, DB>,
57    for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
58{
59    // 查询数据库中表
60    let s = &dialects::schema::new(default_schema.to_string());
61
62    // 删除含有 recreate 控制字段的表
63    check_recreate(&tables, &mut *conn, s).await?;
64
65    let mut db_tables = s.get_tables(&mut *conn).await?;
66
67    // 遍历程序中定义的表
68    for table in tables {
69        if table.query_only {
70            // 只做查询操作,不同步表结构
71            continue;
72        }
73        if let Some(src_table) = &table.from {
74            let src_table_name = src_table.clone().to_snake_case();
75            if src_table_name != table.name {
76                // table 从 src_table 重命名而来
77                let mut src = table.clone();
78                src.name = src_table_name;
79                if db_tables
80                    .iter()
81                    .find(|t| s.is_table_name_equal(&table, t))
82                    .is_none()
83                {
84                    // 数据库中不存在 table 表
85                    if let Some(db_table) = db_tables
86                        .iter_mut()
87                        .find(|t| s.is_table_name_equal(&src, t))
88                    {
89                        // 数据库中存在 src 表
90                        // 重命名表
91                        let sql = s.sql_table_rename(&db_table, &table.name);
92                        s.execute_sql(&mut *conn, &sql).await?;
93                        // 修改 db_table 表名称为新的表名称
94                        db_table.name = table.name.clone();
95                    }
96                }
97            }
98        }
99
100        if let Some(db_table) = db_tables.iter().find(|t| s.is_table_name_equal(&table, t)) {
101            // 数据库中已经存在此表
102            // 检查字段差异
103            for col in &table.columns {
104                if let Some(db_col) = db_table.columns.iter().find(|c| col.is_name_equal(&c)) {
105                    // 列存在,检查列差异
106                    let sqls = s.sql_alter_column(&table, &db_col, &col)?;
107                    for sql in sqls {
108                        s.execute_sql(&mut *conn, &sql).await?;
109                        // println!("column: {sql}");
110                        // println!("column: {:?}", db_col);
111                    }
112                } else {
113                    // 列不存在,添加列
114                    let sql = s.sql_add_column(db_table, &col);
115                    s.execute_sql(&mut *conn, &sql).await?;
116                }
117            }
118            if table.trim_columns {
119                // 清理未定义的列
120                for db_col in &db_table.columns {
121                    if !table.columns.iter().any(|c| c.is_name_equal(db_col)) {
122                        // 定义中没有该列,删除数据库中的列
123                        let sql = s.sql_drop_column(&table, db_col);
124                        s.execute_sql(&mut *conn, &sql).await?;
125                    }
126                }
127            }
128
129            // 检查索引变化
130            if let Some(new_indexes) = &table.indexes {
131                for index in new_indexes {
132                    if let Some(olds) = &db_table.indexes {
133                        // 存在旧索引
134                        // 查找旧索引
135                        if let Some(old) = olds.iter().find(|idx| idx.is_name_equal(index)) {
136                            //  检查索引是否发生变化
137                            if old.is_columns_equal(index) {
138                                // 列没有发生变化
139                                continue;
140                            }
141                            // 索引列发生变化
142                            // 删除旧索引
143                            let sql = s.sql_drop_index(db_table, old);
144                            s.execute_sql(&mut *conn, &sql).await?;
145                        }
146                    }
147                    // 没有旧索引,创建索引
148                    if let Some(sql) = &s.sql_create_index(&table, &index) {
149                        s.execute_sql(&mut *conn, &sql).await?;
150                    }
151                }
152            }
153            if table.trim_indexes {
154                // 清理未定义的索引
155                if let Some(old) = &db_table.indexes {
156                    for oidx in old {
157                        if let Some(idxs) = &table.indexes {
158                            if idxs.iter().any(|idx| idx.is_name_equal(oidx)) {
159                                // 该索引已定义,不要删除
160                                continue;
161                            }
162                        }
163                        // 删除索引
164                        let sql = s.sql_drop_index(db_table, oidx);
165                        s.execute_sql(&mut *conn, &sql).await?;
166                    }
167                }
168            }
169        } else {
170            // 数据库中不存在此表,创建 table
171            let table_sqls = s.sql_create_table(&table)?;
172            for sql in table_sqls.iter() {
173                s.execute_sql(&mut *conn, &sql).await?;
174            }
175
176            // 创建索引
177            let index_sqls = s.sql_create_indexes(&table);
178            for sql in index_sqls.iter() {
179                s.execute_sql(&mut *conn, &sql).await?;
180            }
181        }
182    }
183
184    Ok(())
185}
186
187async fn check_recreate<'c, T, DB: Database>(
188    tables: &Vec<TableSchema>,
189    conn: &mut T,
190    s: &impl schema::Schema,
191) -> io::Result<()>
192where
193    for<'e> &'e mut T: Executor<'e, Database = DB>,
194    for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
195    for<'a> &'a str: ColumnIndex<DB::Row>,
196    for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
197    for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
198    for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
199    for<'a> i64: Encode<'a, DB>,
200    for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
201{
202    for table in tables {
203        if table.query_only {
204            // 只做查询操作,不生成表结构
205            continue;
206        }
207        if let Some(value) = &table.recreate {
208            // 有 recreate 定义
209            let table_name = &s.table_name_with_schema(table);
210            // tracing::info!("recreate {table_name}");
211            let values = s
212                .query_upgrade_tags(&mut *conn, table_name, &"recreate".to_string())
213                .await?;
214            // tracing::info!("join -----");
215            let found = values.iter().any(|v| v == value);
216            if !found {
217                // 删除表
218                let sql = s.sql_drop_table(table);
219                // tracing::info!("删除表 ----- {sql}");
220                s.execute_sql(&mut *conn, &sql).await?;
221                // tracing::info!("删除表{table_name}");
222                // 添加记录
223                s.insert_upgrade_tag(&mut *conn, table_name, &"recreate".to_string(), value)
224                    .await?;
225            }
226        }
227    }
228    Ok(())
229}