1use sqlx::{
2 any::AnyPoolOptions,
3 AnyPool, Error, Row,
4};
5use heck::ToSnakeCase;
6use crate::{model::Model, query_builder::QueryBuilder, migration::Migrator};
7
8#[derive(Clone, Debug)]
10pub enum Drivers {
11 Postgres,
13 SQLite,
15 MySQL,
17}
18
19#[derive(Clone)]
23pub struct Database {
24 pub(crate) pool: AnyPool,
25 pub(crate) driver: Drivers,
26}
27
28impl Database {
29 pub async fn connect(url: &str) -> Result<Self, Error> {
43 sqlx::any::install_default_drivers();
44 let pool = AnyPoolOptions::new().max_connections(5).connect(url).await?;
45
46 let (driver_str, _) = url.split_once(":").unwrap_or(("sqlite", ""));
47 let driver = match driver_str {
48 "postgresql" | "postgres" => Drivers::Postgres,
49 "mysql" => Drivers::MySQL,
50 _ => Drivers::SQLite,
51 };
52
53 Ok(Self { pool, driver })
54 }
55
56 pub fn migrator(&self) -> Migrator<'_> {
58 Migrator::new(self)
59 }
60
61 pub fn model<T: Model + Send + Sync + Unpin>(&self) -> QueryBuilder<'_, T> {
69 let active_columns = T::active_columns();
70 let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
71 for col in active_columns {
72 columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
73 }
74
75 QueryBuilder::new(self, T::table_name(), T::columns(), columns)
76 }
77
78 pub async fn create_table<T: Model>(&self) -> Result<&Self, Error> {
82 let table_name = T::table_name().to_snake_case();
83 let columns = T::columns();
84
85 let mut column_defs = Vec::new();
86 let mut index_statements = Vec::new();
87
88 for col in &columns {
89 let col_name = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
90 let mut def = format!("\"{}\" {}", col_name, col.sql_type);
91
92 if col.is_primary_key {
93 def.push_str(" PRIMARY KEY");
94 }
95
96 if !col.is_nullable && !col.is_primary_key {
97 def.push_str(" NOT NULL");
98 }
99
100 if col.create_time {
101 def.push_str(" DEFAULT CURRENT_TIMESTAMP");
102 }
103
104 if col.unique {
105 def.push_str(" UNIQUE");
106 }
107
108 column_defs.push(def);
109
110 if col.index {
111 let index_type = if col.unique { "UNIQUE INDEX" } else { "INDEX" };
112 let index_name = format!("idx_{}_{}", table_name, col_name);
113
114 let index_query = format!(
115 "CREATE {} IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\" )",
116 index_type, index_name, table_name, col_name,
117 );
118
119 index_statements.push(index_query);
120 }
121 }
122
123 let create_table_query =
124 format!("CREATE TABLE IF NOT EXISTS \"{}\" ({})", table_name.to_snake_case(), column_defs.join(", "));
125
126 sqlx::query(&create_table_query).execute(&self.pool).await?;
127 for idx_stmt in index_statements {
128 sqlx::query(&idx_stmt).execute(&self.pool).await?;
129 }
130 Ok(self)
131 }
132
133 pub async fn assign_foreign_keys<T: Model>(&self) -> Result<&Self, Error> {
138 let table_name = T::table_name().to_snake_case();
139 let columns = T::columns();
140
141 for col in columns {
142 if let (Some(f_table), Some(f_key)) = (col.foreign_table, col.foreign_key) {
143 let col_name = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
144 let f_table_clean = f_table.to_snake_case();
145 let f_key_clean = f_key.to_snake_case();
146
147 let constraint_name = format!("fk_{}_{}", table_name, col_name);
148
149 let check_query =
150 "SELECT count(*) FROM information_schema.table_constraints WHERE constraint_name = $1";
151 let row = sqlx::query(check_query).bind(&constraint_name).fetch_one(&self.pool).await?;
152 let count: i64 = row.try_get(0).unwrap_or(0);
153 if count > 0 {
154 continue;
155 }
156
157 let alter_query = format!(
158 "ALTER TABLE \"{}\" ADD CONSTRAINT \"{}\" FOREIGN KEY (\"{}\") REFERENCES \"{}\" (\"{}\" )",
159 table_name, constraint_name, col_name, f_table_clean, f_key_clean
160 );
161
162 sqlx::query(&alter_query).execute(&self.pool).await?;
163 }
164 }
165
166 Ok(self)
167 }
168}