1use std::io;
2
3use heck::ToSnakeCase;
4use sqlx::{ColumnIndex, Database, Decode, Encode, Executor, IntoArguments};
5
6use crate::sql::{
7 dialects::{
8 self,
9 schema::schema::{self, Schema},
10 },
11 schema::table::TableSchema,
12};
13
14pub mod table;
15
16#[cfg(feature = "postgres")]
26pub async fn sync_tables<C, DB: Database>(conn: &mut C, tables: Vec<TableSchema>) -> io::Result<()>
27where
30 for<'e> &'e mut C: Executor<'e, Database = DB>,
31 for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
32 for<'a> &'a str: ColumnIndex<DB::Row>,
33 for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
34 for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
35 for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
36 for<'a> i64: Encode<'a, DB>,
37 for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
38{
39 sync_tables_with_schema(conn, tables, "").await
40}
41
42pub async fn sync_tables_with_schema<C, DB: Database>(
43 conn: &mut C,
44 tables: Vec<TableSchema>,
45 default_schema: &str,
46) -> io::Result<()>
47where
50 for<'e> &'e mut C: Executor<'e, Database = DB>,
51 for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
52 for<'a> &'a str: ColumnIndex<DB::Row>,
53 for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
54 for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
55 for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
56 for<'a> i64: Encode<'a, DB>,
57 for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
58{
59 let s = &dialects::schema::new(default_schema.to_string());
61
62 check_recreate(&tables, &mut *conn, s).await?;
64
65 let mut db_tables = s.get_tables(&mut *conn).await?;
66
67 for table in tables {
69 if table.query_only {
70 continue;
72 }
73 if let Some(src_table) = &table.from {
74 let src_table_name = src_table.clone().to_snake_case();
75 if src_table_name != table.name {
76 let mut src = table.clone();
78 src.name = src_table_name;
79 if db_tables
80 .iter()
81 .find(|t| s.is_table_name_equal(&table, t))
82 .is_none()
83 {
84 if let Some(db_table) = db_tables
86 .iter_mut()
87 .find(|t| s.is_table_name_equal(&src, t))
88 {
89 let sql = s.sql_table_rename(&db_table, &table.name);
92 s.execute_sql(&mut *conn, &sql).await?;
93 db_table.name = table.name.clone();
95 }
96 }
97 }
98 }
99
100 if let Some(db_table) = db_tables.iter().find(|t| s.is_table_name_equal(&table, t)) {
101 for col in &table.columns {
104 if let Some(db_col) = db_table.columns.iter().find(|c| col.is_name_equal(&c)) {
105 let sqls = s.sql_alter_column(&table, &db_col, &col)?;
107 for sql in sqls {
108 s.execute_sql(&mut *conn, &sql).await?;
109 }
112 } else {
113 let sql = s.sql_add_column(db_table, &col);
115 s.execute_sql(&mut *conn, &sql).await?;
116 }
117 }
118 if table.trim_columns {
119 for db_col in &db_table.columns {
121 if !table.columns.iter().any(|c| c.is_name_equal(db_col)) {
122 let sql = s.sql_drop_column(&table, db_col);
124 s.execute_sql(&mut *conn, &sql).await?;
125 }
126 }
127 }
128
129 if let Some(new_indexes) = &table.indexes {
131 for index in new_indexes {
132 if let Some(olds) = &db_table.indexes {
133 if let Some(old) = olds.iter().find(|idx| idx.is_name_equal(index)) {
136 if old.is_columns_equal(index) {
138 continue;
140 }
141 let sql = s.sql_drop_index(db_table, old);
144 s.execute_sql(&mut *conn, &sql).await?;
145 }
146 }
147 if let Some(sql) = &s.sql_create_index(&table, &index) {
149 s.execute_sql(&mut *conn, &sql).await?;
150 }
151 }
152 }
153 if table.trim_indexes {
154 if let Some(old) = &db_table.indexes {
156 for oidx in old {
157 if let Some(idxs) = &table.indexes {
158 if idxs.iter().any(|idx| idx.is_name_equal(oidx)) {
159 continue;
161 }
162 }
163 let sql = s.sql_drop_index(db_table, oidx);
165 s.execute_sql(&mut *conn, &sql).await?;
166 }
167 }
168 }
169 } else {
170 let table_sqls = s.sql_create_table(&table)?;
172 for sql in table_sqls.iter() {
173 s.execute_sql(&mut *conn, &sql).await?;
174 }
175
176 let index_sqls = s.sql_create_indexes(&table);
178 for sql in index_sqls.iter() {
179 s.execute_sql(&mut *conn, &sql).await?;
180 }
181 }
182 }
183
184 Ok(())
185}
186
187async fn check_recreate<'c, T, DB: Database>(
188 tables: &Vec<TableSchema>,
189 conn: &mut T,
190 s: &impl schema::Schema,
191) -> io::Result<()>
192where
193 for<'e> &'e mut T: Executor<'e, Database = DB>,
194 for<'a> DB::Arguments<'a>: IntoArguments<'a, DB>,
195 for<'a> &'a str: ColumnIndex<DB::Row>,
196 for<'a> bool: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
197 for<'a> i32: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
198 for<'a> i64: sqlx::Decode<'a, DB> + sqlx::Type<DB>,
199 for<'a> i64: Encode<'a, DB>,
200 for<'a> std::string::String: Decode<'a, DB> + Encode<'a, DB> + sqlx::Type<DB>,
201{
202 for table in tables {
203 if table.query_only {
204 continue;
206 }
207 if let Some(value) = &table.recreate {
208 let table_name = &s.table_name_with_schema(table);
210 let values = s
212 .query_upgrade_tags(&mut *conn, table_name, &"recreate".to_string())
213 .await?;
214 let found = values.iter().any(|v| v == value);
216 if !found {
217 let sql = s.sql_drop_table(table);
219 s.execute_sql(&mut *conn, &sql).await?;
221 s.insert_upgrade_tag(&mut *conn, table_name, &"recreate".to_string(), value)
224 .await?;
225 }
226 }
227 }
228 Ok(())
229}