1use futures::future::BoxFuture;
12use heck::ToSnakeCase;
13use sqlx::{any::AnyArguments, Any, AnyPool, Row, Arguments};
14use std::sync::Arc;
15
16use crate::{migration::Migrator, Error, Model, QueryBuilder};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Drivers {
29 Postgres,
31 MySQL,
33 SQLite,
35}
36
37#[derive(Debug, Clone)]
49pub struct Database {
50 pub(crate) pool: AnyPool,
52 pub(crate) driver: Drivers,
54}
55
56impl Database {
61 pub fn builder() -> DatabaseBuilder {
63 DatabaseBuilder::new()
64 }
65
66 pub async fn connect(url: &str) -> Result<Self, Error> {
68 DatabaseBuilder::new().connect(url).await
69 }
70
71 pub fn migrator(&self) -> Migrator {
73 Migrator::new(self)
74 }
75
76 pub fn model<T: Model + Send + Sync + Unpin>(&self) -> QueryBuilder<T, Self> {
78 let active_columns = T::active_columns();
79 let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
80
81 for col in active_columns {
82 columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
83 }
84
85 QueryBuilder::new(self.clone(), self.driver, T::table_name(), T::columns(), columns)
86 }
87
88 pub fn raw<'a>(&self, sql: &'a str) -> RawQuery<'a, Self> {
90 RawQuery::new(self.clone(), sql)
91 }
92
93 pub async fn begin(&self) -> Result<crate::transaction::Transaction<'_>, Error> {
95 let tx = self.pool.begin().await?;
96 Ok(crate::transaction::Transaction {
97 tx: Arc::new(tokio::sync::Mutex::new(Some(tx))),
98 driver: self.driver,
99 })
100 }
101
102 pub async fn table_exists(&self, table_name: &str) -> Result<bool, Error> {
104 let table_name_snake = table_name.to_snake_case();
105 let query = match self.driver {
106 Drivers::Postgres => {
107 "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public')".to_string()
108 }
109 Drivers::MySQL => {
110 "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ? AND table_schema = DATABASE())".to_string()
111 }
112 Drivers::SQLite => {
113 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?".to_string()
114 }
115 };
116
117 let row = sqlx::query(&query).bind(&table_name_snake).fetch_one(&self.pool).await?;
118
119 match self.driver {
120 Drivers::SQLite => {
121 let count: i64 = row.try_get(0)?;
122 Ok(count > 0)
123 }
124 _ => {
125 let exists: bool = row.try_get(0)?;
126 Ok(exists)
127 }
128 }
129 }
130
131 pub async fn create_table<T: Model>(&self) -> Result<(), Error> {
133 let table_name = T::table_name().to_snake_case();
134 let columns = T::columns();
135
136 let mut query = format!("CREATE TABLE IF NOT EXISTS \"{}\" (", table_name);
137 let mut column_defs = Vec::new();
138 let mut indexes = Vec::new();
139
140 for col in columns {
141 let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
142 let mut def = format!("\"{}\" {}", col_name_clean, col.sql_type);
143
144 if col.is_primary_key {
145 def.push_str(" PRIMARY KEY");
146 } else if !col.is_nullable {
147 def.push_str(" NOT NULL");
148 }
149
150 if col.unique && !col.is_primary_key {
151 def.push_str(" UNIQUE");
152 }
153
154 if col.index && !col.is_primary_key && !col.unique {
155 indexes.push(format!(
156 "CREATE INDEX IF NOT EXISTS \"idx_{}_{}\" ON \"{}\" (\"{}\")",
157 table_name, col_name_clean, table_name, col_name_clean
158 ));
159 }
160
161 column_defs.push(def);
162 }
163
164 query.push_str(&column_defs.join(", "));
165 query.push(')');
166
167 sqlx::query(&query).execute(&self.pool).await?;
168
169 for idx_query in indexes {
170 sqlx::query(&idx_query).execute(&self.pool).await?;
171 }
172
173 Ok(())
174 }
175
176 pub async fn sync_table<T: Model>(&self) -> Result<(), Error> {
178 if !self.table_exists(T::table_name()).await? {
179 return self.create_table::<T>().await;
180 }
181
182 let table_name = T::table_name().to_snake_case();
183 let model_columns = T::columns();
184 let existing_columns = self.get_table_columns(&table_name).await?;
185
186 for col in model_columns {
187 let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
188 if !existing_columns.contains(&col_name_clean) {
189 let mut alter_query = format!("ALTER TABLE \"{}\" ADD COLUMN \"{}\" {}", table_name, col_name_clean, col.sql_type);
190 if !col.is_nullable {
191 alter_query.push_str(" DEFAULT ");
192 match col.sql_type {
193 "INTEGER" | "INT" | "BIGINT" => alter_query.push('0'),
194 "BOOLEAN" | "BOOL" => alter_query.push_str("FALSE"),
195 _ => alter_query.push_str("''"),
196 }
197 }
198 sqlx::query(&alter_query).execute(&self.pool).await?;
199 }
200
201 if col.index || col.unique {
202 let existing_indexes = self.get_table_indexes(&table_name).await?;
203 let idx_name = format!("idx_{}_{}", table_name, col_name_clean);
204 let uniq_name = format!("unique_{}_{}", table_name, col_name_clean);
205
206 if col.unique && !existing_indexes.contains(&uniq_name) {
207 let mut query = format!("CREATE UNIQUE INDEX \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
208 if matches!(self.driver, Drivers::SQLite) {
209 query = format!("CREATE UNIQUE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
210 }
211 sqlx::query(&query).execute(&self.pool).await?;
212 } else if col.index && !existing_indexes.contains(&idx_name) && !col.unique {
213 let mut query = format!("CREATE INDEX \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
214 if matches!(self.driver, Drivers::SQLite) {
215 query = format!("CREATE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
216 }
217 sqlx::query(&query).execute(&self.pool).await?;
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 pub async fn get_table_columns(&self, table_name: &str) -> Result<Vec<String>, Error> {
227 let table_name_snake = table_name.to_snake_case();
228 let query = match self.driver {
229 Drivers::Postgres => "SELECT column_name::TEXT FROM information_schema.columns WHERE table_name = $1 AND table_schema = 'public'".to_string(),
230 Drivers::MySQL => "SELECT column_name FROM information_schema.columns WHERE table_name = ? AND table_schema = DATABASE()".to_string(),
231 Drivers::SQLite => format!("PRAGMA table_info(\"{}\")", table_name_snake),
232 };
233
234 let rows = if let Drivers::SQLite = self.driver {
235 sqlx::query(&query).fetch_all(&self.pool).await?
236 } else {
237 sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
238 };
239
240 let mut columns = Vec::new();
241 for row in rows {
242 let col_name: String = if let Drivers::SQLite = self.driver {
243 row.try_get("name")?
244 } else {
245 row.try_get(0)?
246 };
247 columns.push(col_name);
248 }
249 Ok(columns)
250 }
251
252 pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<String>, Error> {
254 let table_name_snake = table_name.to_snake_case();
255 let query = match self.driver {
256 Drivers::Postgres => "SELECT indexname::TEXT FROM pg_indexes WHERE tablename = $1 AND schemaname = 'public'".to_string(),
257 Drivers::MySQL => "SELECT INDEX_NAME FROM information_schema.STATISTICS WHERE TABLE_NAME = ? AND TABLE_SCHEMA = DATABASE()".to_string(),
258 Drivers::SQLite => format!("PRAGMA index_list(\"{}\")", table_name_snake),
259 };
260
261 let rows = if let Drivers::SQLite = self.driver {
262 sqlx::query(&query).fetch_all(&self.pool).await?
263 } else {
264 sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
265 };
266
267 let mut indexes = Vec::new();
268 for row in rows {
269 let idx_name: String = if let Drivers::SQLite = self.driver {
270 row.try_get("name")?
271 } else {
272 row.try_get(0)?
273 };
274 indexes.push(idx_name);
275 }
276 Ok(indexes)
277 }
278
279 pub async fn assign_foreign_keys<T: Model>(&self) -> Result<(), Error> {
281 let table_name = T::table_name().to_snake_case();
282 let columns = T::columns();
283
284 for col in columns {
285 if let (Some(f_table), Some(f_key)) = (col.foreign_table, col.foreign_key) {
286 if matches!(self.driver, Drivers::SQLite) { continue; }
287 let constraint_name = format!("fk_{}_{}_{}", table_name, f_table.to_snake_case(), col.name.to_snake_case());
288 let query = format!(
289 "ALTER TABLE \"{}\" ADD CONSTRAINT \"{}\" FOREIGN KEY (\"{}\") REFERENCES \"{}\"(\"{}\")",
290 table_name, constraint_name, col.name.to_snake_case(), f_table.to_snake_case(), f_key.to_snake_case()
291 );
292 let _ = sqlx::query(&query).execute(&self.pool).await;
293 }
294 }
295 Ok(())
296 }
297}
298
299pub struct DatabaseBuilder {
304 max_connections: u32,
305}
306
307impl DatabaseBuilder {
308 pub fn new() -> Self { Self { max_connections: 5 } }
309 pub fn max_connections(mut self, max: u32) -> Self { self.max_connections = max; self }
310 pub async fn connect(self, url: &str) -> Result<Database, Error> {
311 let pool = sqlx::any::AnyPoolOptions::new().max_connections(self.max_connections).connect(url).await?;
312 let driver = if url.starts_with("postgres") { Drivers::Postgres }
313 else if url.starts_with("mysql") { Drivers::MySQL }
314 else { Drivers::SQLite };
315 Ok(Database { pool, driver })
316 }
317}
318
319pub trait Connection: Send + Sync {
324 fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>>;
325 fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>>;
326 fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>>;
327 fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>>;
328}
329
330impl Connection for Database {
331 fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
332 Box::pin(async move { sqlx::query_with(sql, args).execute(&self.pool).await })
333 }
334 fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
335 Box::pin(async move { sqlx::query_with(sql, args).fetch_all(&self.pool).await })
336 }
337 fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
338 Box::pin(async move { sqlx::query_with(sql, args).fetch_one(&self.pool).await })
339 }
340 fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
341 Box::pin(async move { sqlx::query_with(sql, args).fetch_optional(&self.pool).await })
342 }
343}
344
345pub struct RawQuery<'a, C> {
350 conn: C,
351 sql: &'a str,
352 args: AnyArguments<'a>,
353}
354
355impl<'a, C> RawQuery<'a, C> where C: Connection {
356 pub(crate) fn new(conn: C, sql: &'a str) -> Self {
357 Self { conn, sql, args: AnyArguments::default() }
358 }
359 pub fn bind<T>(mut self, value: T) -> Self where T: 'a + sqlx::Encode<'a, sqlx::Any> + sqlx::Type<sqlx::Any> + Send + Sync {
360 let _ = self.args.add(value);
361 self
362 }
363 pub async fn fetch_all<T>(self) -> Result<Vec<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
364 let rows = self.conn.fetch_all(self.sql, self.args).await?;
365 Ok(rows.iter().map(|r| T::from_row(r).unwrap()).collect())
366 }
367 pub async fn fetch_one<T>(self) -> Result<T, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
368 let row = self.conn.fetch_one(self.sql, self.args).await?;
369 Ok(T::from_row(&row)?)
370 }
371 pub async fn fetch_optional<T>(self) -> Result<Option<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
372 let row = self.conn.fetch_optional(self.sql, self.args).await?;
373 Ok(row.map(|r| T::from_row(&r).unwrap()))
374 }
375 pub async fn execute(self) -> Result<u64, Error> {
376 let result = self.conn.execute(self.sql, self.args).await?;
377 Ok(result.rows_affected())
378 }
379}