1use std::{borrow::Cow, env, fmt::Write, io};
2
3use awto::{
4 database::{DatabaseColumn, DatabaseDefault, DatabaseTable, DatabaseType},
5 schema::{Model, Role},
6};
7use proc_macro2::Literal;
8use quote::{format_ident, quote};
9use sqlx::{Executor, PgPool};
10use tokio_stream::StreamExt;
11
12use crate::{
13 error::Error,
14 util::{is_ty_option, is_ty_vec, strip_ty_option},
15};
16
17const COMPILED_RUST_FILE: &str = "app.rs";
18
19#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
20pub struct CompileDatabaseResult {
21 pub queries_executed: usize,
22 pub rows_affected: u64,
23}
24
25#[cfg(feature = "async")]
26pub async fn compile_database(
27 uri: &str,
28 models: Vec<Model>,
29) -> Result<CompileDatabaseResult, Box<dyn std::error::Error>> {
30 use tokio::fs;
31
32 let out_dir = env::var("OUT_DIR").unwrap();
33 let pool = PgPool::connect(uri).await?;
34 let compiler = DatabaseCompiler::from_pool(&pool, models);
35
36 let generated_code = compiler.compile_generated_code();
37 if !generated_code.is_empty() {
38 let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
39 fs::write(rs_path, generated_code).await?;
40 }
41
42 compiler.append_sea_orm_models().await?;
43
44 let sql = compiler.compile().await?;
45 if !sql.is_empty() {
46 let results = pool
47 .execute_many(sql.as_str())
48 .collect::<Result<Vec<_>, _>>()
49 .await?;
50 let queries_executed = results.len();
51 let rows_affected = results
52 .iter()
53 .fold(0, |acc, result| result.rows_affected() + acc);
54
55 Ok(CompileDatabaseResult {
56 queries_executed,
57 rows_affected,
58 })
59 } else {
60 Ok(CompileDatabaseResult::default())
61 }
62}
63
64#[cfg(not(feature = "async"))]
65pub async fn compile_database(
66 uri: &str,
67 models: Vec<Model>,
68) -> Result<CompileDatabaseResult, Box<dyn std::error::Error>> {
69 use std::fs;
70
71 let out_dir = env::var("OUT_DIR").unwrap();
72 let pool = PgPool::connect(uri).await?;
73 let compiler = DatabaseCompiler::from_pool(&pool, models);
74
75 let generated_code = compiler.compile_generated_code();
76 if !generated_code.is_empty() {
77 let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
78 fs::write(rs_path, generated_code)?;
79 }
80
81 compiler.append_sea_orm_models()?;
82
83 let sql = compiler.compile().await?;
84 if !sql.is_empty() {
85 let results = pool
86 .execute_many(sql.as_str())
87 .collect::<Result<Vec<_>, _>>()
88 .await?;
89 let queries_executed = results.len();
90 let rows_affected = results
91 .iter()
92 .fold(0, |acc, result| result.rows_affected() + acc);
93
94 Ok(CompileDatabaseResult {
95 queries_executed,
96 rows_affected,
97 })
98 } else {
99 Ok(CompileDatabaseResult::default())
100 }
101}
102
103pub struct DatabaseCompiler<'pool> {
104 pool: Cow<'pool, PgPool>,
105 models: Vec<Model>,
106}
107
108impl<'pool> DatabaseCompiler<'pool> {
109 pub async fn connect(
110 uri: &str,
111 models: Vec<Model>,
112 ) -> Result<DatabaseCompiler<'_>, sqlx::Error> {
113 let pool = sqlx::PgPool::connect(uri).await?;
114
115 Ok(DatabaseCompiler {
116 pool: Cow::Owned(pool),
117 models,
118 })
119 }
120
121 pub fn from_pool(pool: &'pool PgPool, models: Vec<Model>) -> DatabaseCompiler<'pool> {
122 DatabaseCompiler {
123 pool: Cow::Borrowed(pool),
124 models,
125 }
126 }
127
128 pub async fn compile(&self) -> Result<String, Error> {
129 let mut sql = String::new();
130
131 for (_, table) in self.database_tables() {
132 let db_columns = self.fetch_table(table).await?;
133
134 match db_columns {
135 Some(db_columns) => {
136 writeln!(sql, "{}", self.write_sync_sql(table, &db_columns).await).unwrap();
137 }
138 None => {
139 writeln!(sql, "{}", self.write_table_create_sql(table)).unwrap();
140 }
141 }
142 }
143
144 Ok(sql.trim().to_string())
145 }
146
147 pub fn compile_generated_code(&self) -> String {
149 let mut code = String::new();
150
151 for (model, table) in self.database_tables() {
152 let ident = format_ident!("{}", model.name);
153 let db_module_ident = format_ident!("{}", table.name);
154
155 let mut from_schema_fields = Vec::new();
156 let mut from_db_fields = Vec::new();
157
158 for field in &model.fields {
159 let field_ident = format_ident!("{}", field.name);
160
161 let ty = strip_ty_option(&field.ty);
162
163 if is_ty_vec(ty) {
164 from_schema_fields.push(
165 quote!(#field_ident: val.#field_ident.into_iter().map(|v| v.into()).collect()),
166 );
167 from_db_fields.push(
168 quote!(#field_ident: val.#field_ident.into_iter().map(|v| v.into()).collect()),
169 );
170 } else {
171 from_schema_fields.push(quote!(#field_ident: val.#field_ident.into()));
172 from_db_fields.push(quote!(#field_ident: val.#field_ident.into()));
173 }
174 }
175
176 let expanded = quote!(
177 impl ::std::convert::From<crate::#db_module_ident::Model> for ::schema::#ident {
178 #[allow(unused_variables)]
179 fn from(val: crate::#db_module_ident::Model) -> Self {
180 Self {
181 #( #from_schema_fields, )*
182 }
183 }
184 }
185
186 impl ::std::convert::From<::schema::#ident> for crate::#db_module_ident::Model {
187 #[allow(unused_variables)]
188 fn from(val: ::schema::#ident) -> Self {
189 Self {
190 #( #from_db_fields, )*
191 }
192 }
193 }
194 );
195
196 write!(code, "{}", expanded.to_string()).unwrap();
197 }
198
199 for (model, table) in self.database_sub_tables() {
200 let ident = format_ident!("{}", model.name);
201 let db_module_ident = format_ident!("{}", table.name);
202
203 let active_values = model.fields.iter().map(|field| {
204 let field_ident = format_ident!("{}", field.name);
205
206 let self_field = if is_ty_option(&field.ty) {
207 let db_field = table.columns.iter().find(|column| column.name == field.name).unwrap();
208 match &db_field.default {
209 Some(DatabaseDefault::Bool(b)) => quote!(self.#field_ident.unwrap_or(#b)),
210 Some(DatabaseDefault::Float(f)) => {
211 let f = Literal::i64_unsuffixed(*f);
212 quote!(self.#field_ident.unwrap_or(#f))
213 },
214 Some(DatabaseDefault::Int(i)) => {
215 let i = Literal::u64_unsuffixed(*i);
216 quote!(self.#field_ident.unwrap_or(#i))
217 },
218 Some(DatabaseDefault::String(s)) => quote!(self.#field_ident.unwrap_or(#s)),
219 _ => quote!(self.#field_ident),
220 }
221 } else {
222 quote!(self.#field_ident)
223 };
224
225 quote!(
226 #field_ident: ::sea_orm::entity::IntoActiveValue::into_active_value(#self_field).into()
227 )
228 });
229
230 let expanded = quote!(
231 impl ::sea_orm::entity::IntoActiveModel<crate::#db_module_ident::ActiveModel> for ::schema::#ident {
232 fn into_active_model(self) -> crate::#db_module_ident::ActiveModel {
233 crate::#db_module_ident::ActiveModel {
234 #( #active_values, )*
235 ..Default::default()
236 }
237 }
238 }
239 );
240
241 write!(code, "{}", expanded.to_string()).unwrap();
242 }
243
244 code.trim().to_string()
245 }
246
247 #[cfg(feature = "async")]
248 async fn append_sea_orm_models(&self) -> Result<(), io::Error> {
249 use tokio::fs;
250 use tokio::io::AsyncWriteExt;
251
252 let out_dir = env::var("OUT_DIR").unwrap();
253 let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
254
255 let mut file = fs::OpenOptions::new().append(true).open(&rs_path).await?;
256
257 for (_, table) in self.database_tables() {
258 file.write(format!("pub mod {} {{", table.name).as_bytes()).await?;
259 file.write(format!(r#" sea_orm::include_model!("{}");"#, table.name).as_bytes()).await?;
260 file.write(b"}").await?;
261 }
262
263 Ok(())
264 }
265
266 #[cfg(not(feature = "async"))]
267 fn append_sea_orm_models(&self) -> Result<(), io::Error> {
268 use std::fs;
269 use std::io::Write;
270
271 let out_dir = env::var("OUT_DIR").unwrap();
272 let rs_path = format!("{}/{}", out_dir, COMPILED_RUST_FILE);
273
274 let mut file = fs::OpenOptions::new().append(true).open(&rs_path)?;
275
276 for (_, table) in self.database_tables() {
277 write!(file, "pub mod {} {{", table.name).unwrap();
278 write!(file, r#" sea_orm::include_model!("{}");"#, table.name).unwrap();
279 write!(file, "}}").unwrap();
280 }
281
282 Ok(())
283 }
284
285 fn database_tables(&self) -> Vec<(&Model, &DatabaseTable)> {
286 self.models.iter().fold(Vec::new(), |mut acc, model| {
287 let roles = model
288 .roles
289 .iter()
290 .filter_map(|role| match role {
291 Role::DatabaseTable(database_table) => Some((model, database_table)),
292 _ => None,
293 })
294 .collect::<Vec<_>>();
295
296 acc.extend(roles);
297
298 acc
299 })
300 }
301
302 fn database_sub_tables(&self) -> Vec<(&Model, &DatabaseTable)> {
303 self.models.iter().fold(Vec::new(), |mut acc, model| {
304 let roles = model
305 .roles
306 .iter()
307 .filter_map(|role| match role {
308 Role::DatabaseSubTable(database_sub_table) => Some((model, database_sub_table)),
309 _ => None,
310 })
311 .collect::<Vec<_>>();
312
313 acc.extend(roles);
314
315 acc
316 })
317 }
318
319 async fn fetch_table(
320 &self,
321 table: &DatabaseTable,
322 ) -> Result<Option<Vec<DatabaseColumn>>, Error> {
323 #[derive(Debug, sqlx::FromRow)]
324 struct ColumnsQuery {
325 column_name: String,
326 column_default: Option<String>,
327 is_nullable: String,
328 data_type: String,
329 character_maximum_length: Option<i32>,
330 is_primary_key: bool,
331 is_unique: bool,
332 reference: Option<String>,
333 }
334
335 let raw_columns: Vec<ColumnsQuery> = sqlx::query_as(FETCH_TABLE_QUERY)
336 .bind("public")
337 .bind(&table.name)
338 .fetch_all(&*self.pool)
339 .await
340 .map_err(Error::Sqlx)?;
341
342 if raw_columns.is_empty() {
343 return Ok(None);
344 }
345
346 let columns: Vec<DatabaseColumn> = raw_columns
347 .into_iter()
348 .map(|col| {
349 let column_name = col.column_name;
350 let character_maximum_length = col.character_maximum_length;
351
352 Ok(DatabaseColumn {
353 name: column_name.clone(),
354 ty: col
355 .data_type
356 .parse::<DatabaseType>()
357 .map(|database_type| {
358 if let Some(max_len) = character_maximum_length {
359 if matches!(database_type, DatabaseType::Text(None)) {
360 return DatabaseType::Text(Some(max_len));
361 }
362 }
363
364 database_type
365 })
366 .map_err(|_| Error::UnsupportedType(table.name.clone(), column_name))?,
367 nullable: col.is_nullable == "YES",
368 default: col.column_default.map(|def| {
369 if def.starts_with('\'') {
370 let s = def
371 .strip_prefix('\'')
372 .unwrap()
373 .splitn(2, '\'')
374 .next()
375 .unwrap()
376 .to_string();
377 DatabaseDefault::String(s)
378 } else if def == "true" {
379 DatabaseDefault::Bool(true)
380 } else if def == "false" {
381 DatabaseDefault::Bool(false)
382 } else if let Ok(num) = def.parse::<u64>() {
383 DatabaseDefault::Int(num)
384 } else if let Ok(num) = def.parse::<i64>() {
385 DatabaseDefault::Float(num)
386 } else {
387 DatabaseDefault::Raw(def)
388 }
389 }),
390 unique: col.is_unique,
391 constraint: None,
392 primary_key: col.is_primary_key,
393 references: if let Some(references) = col.reference {
394 let mut parts = references.splitn(2, ':');
395 if let Some(references_table) = parts.next() {
396 parts.next().map(|references_column| {
397 (references_table.to_string(), references_column.to_string())
398 })
399 } else {
400 None
401 }
402 } else {
403 None
404 },
405 })
406 })
407 .collect::<Result<_, _>>()?;
408
409 Ok(Some(columns))
410 }
411
412 fn write_table_create_sql(&self, table: &DatabaseTable) -> String {
413 let mut sql = String::new();
414
415 writeln!(sql, "CREATE TABLE IF NOT EXISTS {} (", table.name).unwrap();
416
417 for (i, column) in table.columns.iter().enumerate() {
418 write!(sql, " {}", self.write_column_sql(column)).unwrap();
419
420 if i < table.columns.len() - 1 {
421 writeln!(sql, ",").unwrap();
422 } else {
423 writeln!(sql).unwrap();
424 }
425 }
426
427 writeln!(sql, ");").unwrap();
428
429 sql
430 }
431
432 fn write_column_sql(&self, column: &DatabaseColumn) -> String {
433 let mut sql = String::new();
434
435 write!(sql, "{} {}", column.name, column.ty,).unwrap();
436
437 if !column.nullable {
438 write!(sql, " NOT NULL",).unwrap();
439 }
440
441 if let Some(default) = &column.default {
442 write!(sql, " DEFAULT {}", default).unwrap();
443 }
444
445 if let Some(constraint) = &column.constraint {
446 write!(sql, " CHECK ({})", constraint).unwrap();
447 }
448
449 if column.primary_key {
450 write!(sql, " PRIMARY KEY").unwrap();
451 }
452
453 if let Some((table, col)) = &column.references {
454 write!(sql, " REFERENCES {}({})", table, col).unwrap();
455 }
456
457 sql
458 }
459
460 async fn write_sync_sql(&self, table: &DatabaseTable, db_columns: &[DatabaseColumn]) -> String {
461 let mut sql = String::new();
462
463 for schema_col in &table.columns {
464 let db_col = match db_columns
465 .iter()
466 .find(|db_col| db_col.name == schema_col.name)
467 {
468 Some(db_col) => db_col,
469 None => {
470 writeln!(
472 sql,
473 "ALTER TABLE {} ADD COLUMN {};",
474 table.name,
475 self.write_column_sql(schema_col)
476 )
477 .unwrap();
478 continue;
479 }
480 };
481
482 if schema_col.ty != db_col.ty {
484 writeln!(
485 sql,
486 "ALTER TABLE {table} ALTER COLUMN {column} TYPE {ty} USING {column}::{ty};",
487 table = table.name,
488 column = schema_col.name,
489 ty = schema_col.ty.to_string(),
490 )
491 .unwrap();
492 }
493
494 if schema_col.nullable != db_col.nullable {
496 if db_col.nullable {
497 writeln!(
498 sql,
499 "ALTER TABLE {table} ALTER COLUMN {column} SET NOT NULL;",
500 table = table.name,
501 column = schema_col.name
502 )
503 .unwrap();
504 } else {
505 writeln!(
506 sql,
507 "ALTER TABLE {table} ALTER COLUMN {column} DROP NOT NULL;",
508 table = table.name,
509 column = schema_col.name
510 )
511 .unwrap();
512 }
513 }
514
515 if schema_col.default != db_col.default {
517 if let Some(default) = &schema_col.default {
518 writeln!(
519 sql,
520 "ALTER TABLE {table} ALTER COLUMN {column} SET DEFAULT {default};",
521 table = table.name,
522 column = schema_col.name,
523 default = default
524 )
525 .unwrap();
526 } else {
527 writeln!(
528 sql,
529 "ALTER TABLE {table} ALTER COLUMN {column} DROP DEFAULT;",
530 table = table.name,
531 column = schema_col.name
532 )
533 .unwrap();
534 }
535 }
536
537 if schema_col.unique != db_col.unique {
539 if db_col.unique {
540 writeln!(
541 sql,
542 "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_key;",
543 table = table.name,
544 column = schema_col.name
545 )
546 .unwrap();
547 } else {
548 writeln!(
549 sql,
550 "ALTER TABLE {table} ADD CONSTRAINT {table}_{column}_key UNIQUE ({column});",
551 table = table.name,
552 column = schema_col.name
553 )
554 .unwrap();
555 }
556 }
557
558 if schema_col.references != db_col.references {
560 if let Some(references) = &schema_col.references {
561 if db_col.references.is_some() {
562 writeln!(
563 sql,
564 "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_fkey;",
565 table = table.name,
566 column = schema_col.name
567 )
568 .unwrap();
569 }
570 writeln!(
571 sql,
572 "ALTER TABLE {table} ADD CONSTRAINT {table}_{column}_fkey FOREIGN KEY ({column}) REFERENCES {reference_table} ({reference_column});",
573 table = table.name,
574 column = schema_col.name,
575 reference_table = references.0,
576 reference_column = references.1,
577 )
578 .unwrap();
579 } else {
580 writeln!(
581 sql,
582 "ALTER TABLE {table} DROP CONSTRAINT {table}_{column}_fkey;",
583 table = table.name,
584 column = schema_col.name
585 )
586 .unwrap();
587 }
588 }
589 }
590
591 db_columns
593 .iter()
594 .filter(|db_col| {
595 table
596 .columns
597 .iter()
598 .all(|schema_col| schema_col.name != db_col.name)
599 })
600 .for_each(|db_col| {
601 writeln!(
602 sql,
603 "ALTER TABLE {table} DROP COLUMN {column};",
604 table = table.name,
605 column = db_col.name
606 )
607 .unwrap();
608 });
609
610 sql
611 }
612}
613
614const FETCH_TABLE_QUERY: &str = "
615SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
616(
617 SELECT
618 COUNT(*) > 0
619 FROM information_schema.table_constraints tco
620 JOIN information_schema.key_column_usage kcu
621 ON kcu.constraint_name = tco.constraint_name
622 AND kcu.constraint_schema = tco.constraint_schema
623 AND kcu.constraint_name = tco.constraint_name
624 WHERE
625 tco.constraint_type = 'PRIMARY KEY' AND
626 kcu.table_schema = $1 AND
627 kcu.table_name = $2 AND
628 kcu.column_name = information_schema.columns.column_name
629) as is_primary_key,
630(
631 SELECT
632 COUNT(*) > 0
633 FROM information_schema.table_constraints tco
634 JOIN information_schema.key_column_usage kcu
635 ON kcu.constraint_name = tco.constraint_name
636 AND kcu.constraint_schema = tco.constraint_schema
637 AND kcu.constraint_name = tco.constraint_name
638 WHERE
639 tco.constraint_type = 'UNIQUE' AND
640 kcu.table_schema = $1 AND
641 kcu.table_name = $2 AND
642 kcu.column_name = information_schema.columns.column_name
643) as is_unique,
644(
645 SELECT CONCAT(
646 rel_tco.table_name,
647 ':',
648 (
649 SELECT u.column_name
650 FROM information_schema.constraint_column_usage u
651 INNER JOIN information_schema.referential_constraints fk
652 ON
653 u.constraint_catalog = fk.unique_constraint_catalog AND
654 u.constraint_schema = fk.unique_constraint_schema AND
655 u.constraint_name = fk.unique_constraint_name
656 INNER JOIN information_schema.key_column_usage r
657 ON
658 r.constraint_catalog = fk.constraint_catalog AND
659 r.constraint_schema = fk.constraint_schema AND
660 r.constraint_name = fk.constraint_name
661 WHERE
662 fk.constraint_name = kcu.constraint_name AND
663 u.table_schema = kcu.table_schema AND
664 u.table_name = rel_tco.table_name
665 )
666 )
667 FROM information_schema.table_constraints tco
668 JOIN information_schema.key_column_usage kcu
669 ON
670 tco.constraint_schema = kcu.constraint_schema AND
671 tco.constraint_name = kcu.constraint_name
672 JOIN information_schema.referential_constraints rco
673 ON
674 tco.constraint_schema = rco.constraint_schema AND
675 tco.constraint_name = rco.constraint_name
676 JOIN information_schema.table_constraints rel_tco
677 ON
678 rco.unique_constraint_schema = rel_tco.constraint_schema AND
679 rco.unique_constraint_name = rel_tco.constraint_name
680 WHERE
681 tco.constraint_type = 'FOREIGN KEY' AND
682 kcu.table_name = $2 AND
683 kcu.column_name = information_schema.columns.column_name
684 GROUP BY
685 kcu.table_schema,
686 kcu.table_name,
687 rel_tco.table_name,
688 rel_tco.table_schema,
689 kcu.constraint_name
690 ORDER BY
691 kcu.table_schema,
692 kcu.table_name
693) as reference
694FROM information_schema.columns
695WHERE table_schema = $1
696AND table_name = $2;
697";