Skip to main content

bottle_orm/
database.rs

1//! # Database Module
2//!
3//! This module provides the core database connection and management functionality for Bottle ORM.
4//! It handles connection pooling, driver detection, table creation, and foreign key management
5//! across PostgreSQL, MySQL, and SQLite.
6
7// ============================================================================
8// External Crate Imports
9// ============================================================================
10
11use futures::future::BoxFuture;
12use heck::ToSnakeCase;
13use sqlx::{any::AnyArguments, AnyPool, Row, Arguments};
14use std::sync::Arc;
15
16// ============================================================================
17// Internal Crate Imports
18// ============================================================================
19
20use crate::{migration::Migrator, Error, Model, QueryBuilder};
21
22// ============================================================================
23// Database Driver Enum
24// ============================================================================
25
26/// Supported database drivers for Bottle ORM.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Drivers {
29    /// PostgreSQL driver
30    Postgres,
31    /// MySQL driver
32    MySQL,
33    /// SQLite driver
34    SQLite,
35}
36
37// ============================================================================
38// Database Struct
39// ============================================================================
40
41/// The main entry point for Bottle ORM database operations.
42///
43/// `Database` manages a connection pool and provides methods for starting
44/// transactions, creating tables, and building queries for models.
45///
46/// It is designed to be thread-safe and easily shared across an application
47/// (internally uses an `Arc` for the connection pool).
48#[derive(Debug, Clone)]
49pub struct Database {
50    /// The underlying SQLx connection pool
51    pub(crate) pool: AnyPool,
52    /// The detected database driver
53    pub(crate) driver: Drivers,
54}
55
56// ============================================================================
57// Database Implementation
58// ============================================================================
59
60impl Database {
61    /// Creates a new DatabaseBuilder for configuring the connection.
62    pub fn builder() -> DatabaseBuilder {
63        DatabaseBuilder::new()
64    }
65
66    /// Connects to a database using the provided connection string.
67    ///
68    /// This is a convenience method that uses default builder settings.
69    ///
70    /// # Arguments
71    ///
72    /// * `url` - A database connection URL (e.g., "postgres://user:pass@localhost/db")
73    pub async fn connect(url: &str) -> Result<Self, Error> {
74        DatabaseBuilder::new().connect(url).await
75    }
76
77    /// Returns a new Migrator instance for managing schema changes.
78    pub fn migrator(&self) -> Migrator<'_> {
79        Migrator::new(self)
80    }
81
82    /// Starts building a query for the specified model.
83    ///
84    /// # Type Parameters
85    ///
86    /// * `T` - The Model type to query.
87    pub fn model<T: Model + Send + Sync + Unpin + crate::AnyImpl>(&self) -> QueryBuilder<T, Self> {
88        let active_columns = T::active_columns();
89        let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
90
91        for col in active_columns {
92            columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
93        }
94
95        QueryBuilder::new(self.clone(), self.driver, T::table_name(), <T as Model>::columns(), columns)
96    }
97
98    /// Creates a raw SQL query builder.
99    pub fn raw<'a>(&self, sql: &'a str) -> RawQuery<'a, Self> {
100        RawQuery::new(self.clone(), sql)
101    }
102    
103    /// This function should have been here a long time ago.
104    /// Retrieve the connection pool.
105    pub fn get_pool(&self) -> AnyPool {
106    	self.pool.clone()
107    }
108
109    /// Starts a new database transaction.
110    pub async fn begin(&self) -> Result<crate::transaction::Transaction<'_>, Error> {
111        let tx = self.pool.begin().await?;
112        Ok(crate::transaction::Transaction {
113            tx: Arc::new(tokio::sync::Mutex::new(Some(tx))),
114            pool: self.pool.clone(),
115            driver: self.driver,
116        })
117    }
118
119    /// Checks if a table exists in the database.
120    pub async fn table_exists(&self, table_name: &str) -> Result<bool, Error> {
121        let table_name_snake = table_name.to_snake_case();
122        let query = match self.driver {
123            Drivers::Postgres => {
124                "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public')".to_string()
125            }
126            Drivers::MySQL => {
127                "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ? AND table_schema = DATABASE())".to_string()
128            }
129            Drivers::SQLite => {
130                "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?".to_string()
131            }
132        };
133
134        let row = sqlx::query(&query).bind(&table_name_snake).fetch_one(&self.pool).await?;
135
136        match self.driver {
137            Drivers::SQLite => {
138                let count: i64 = row.try_get(0)?;
139                Ok(count > 0)
140            }
141            _ => {
142                let exists: bool = row.try_get(0)?;
143                Ok(exists)
144            }
145        }
146    }
147
148    /// Creates a table based on the provided Model metadata.
149    pub async fn create_table<T: Model>(&self) -> Result<(), Error> {
150        let table_name = T::table_name().to_snake_case();
151        let columns = T::columns();
152
153        let mut query = format!("CREATE TABLE IF NOT EXISTS \"{}\" (", table_name);
154        let mut column_defs = Vec::new();
155        let mut indexes = Vec::new();
156
157        // Identify primary key columns
158        let pk_columns: Vec<String> = columns.iter()
159            .filter(|c| c.is_primary_key)
160            .map(|c| format!("\"{}\"", c.name.strip_prefix("r#").unwrap_or(c.name).to_snake_case()))
161            .collect();
162
163        for col in columns {
164            let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
165            let mut def = format!("\"{}\" {}", col_name_clean, col.sql_type);
166
167            // If it's a single primary key, we can keep it inline for simplicity
168            // If it's composite, we MUST define it as a table constraint
169            if col.is_primary_key && pk_columns.len() == 1 {
170                def.push_str(" PRIMARY KEY");
171            } else if !col.is_nullable || col.is_primary_key {
172                def.push_str(" NOT NULL");
173            }
174
175            if col.unique && !col.is_primary_key {
176                def.push_str(" UNIQUE");
177            }
178
179            if col.index && !col.is_primary_key && !col.unique {
180                indexes.push(format!(
181                    "CREATE INDEX IF NOT EXISTS \"idx_{}_{}\" ON \"{}\" (\"{}\")",
182                    table_name, col_name_clean, table_name, col_name_clean
183                ));
184            }
185
186            column_defs.push(def);
187        }
188
189        // Add composite primary key if multiple columns are specified
190        if pk_columns.len() > 1 {
191            column_defs.push(format!("PRIMARY KEY ({})", pk_columns.join(", ")));
192        }
193
194        query.push_str(&column_defs.join(", "));
195        query.push(')');
196
197        sqlx::query(&query).execute(&self.pool).await?;
198
199        for idx_query in indexes {
200            sqlx::query(&idx_query).execute(&self.pool).await?;
201        }
202
203        Ok(())
204    }
205
206    /// Synchronizes a table schema by adding missing columns or indexes.
207    pub async fn sync_table<T: Model>(&self) -> Result<(), Error> {
208        if !self.table_exists(T::table_name()).await? {
209            return self.create_table::<T>().await;
210        }
211
212        let table_name = T::table_name().to_snake_case();
213        let model_columns = T::columns();
214        let existing_columns = self.get_table_columns(&table_name).await?;
215
216        for col in model_columns {
217            let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
218            if !existing_columns.contains(&col_name_clean) {
219                let mut alter_query = format!("ALTER TABLE \"{}\" ADD COLUMN \"{}\" {}", table_name, col_name_clean, col.sql_type);
220                if !col.is_nullable {
221                    alter_query.push_str(" DEFAULT ");
222                    match col.sql_type {
223                        "INTEGER" | "INT" | "BIGINT" => alter_query.push('0'),
224                        "BOOLEAN" | "BOOL" => alter_query.push_str("FALSE"),
225                        _ => alter_query.push_str("''"),
226                    }
227                }
228                sqlx::query(&alter_query).execute(&self.pool).await?;
229            }
230
231            if col.index || col.unique {
232                let existing_indexes = self.get_table_indexes(&table_name).await?;
233                let idx_name = format!("idx_{}_{}", table_name, col_name_clean);
234                let uniq_name = format!("unique_{}_{}", table_name, col_name_clean);
235
236                if col.unique && !existing_indexes.contains(&uniq_name) {
237                    let mut query = format!("CREATE UNIQUE INDEX \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
238                    if matches!(self.driver, Drivers::SQLite) {
239                        query = format!("CREATE UNIQUE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
240                    }
241                    sqlx::query(&query).execute(&self.pool).await?;
242                } else if col.index && !existing_indexes.contains(&idx_name) && !col.unique {
243                    let mut query = format!("CREATE INDEX \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
244                    if matches!(self.driver, Drivers::SQLite) {
245                        query = format!("CREATE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
246                    }
247                    sqlx::query(&query).execute(&self.pool).await?;
248                }
249            }
250        }
251
252        Ok(())
253    }
254
255    /// Returns the current columns of a table.
256    pub async fn get_table_columns(&self, table_name: &str) -> Result<Vec<String>, Error> {
257        let table_name_snake = table_name.to_snake_case();
258        let query = match self.driver {
259            Drivers::Postgres => "SELECT column_name::TEXT FROM information_schema.columns WHERE table_name = $1 AND table_schema = 'public'".to_string(),
260            Drivers::MySQL => "SELECT column_name FROM information_schema.columns WHERE table_name = ? AND table_schema = DATABASE()".to_string(),
261            Drivers::SQLite => format!("PRAGMA table_info(\"{}\")", table_name_snake),
262        };
263
264        let rows = if let Drivers::SQLite = self.driver {
265            sqlx::query(&query).fetch_all(&self.pool).await?
266        } else {
267            sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
268        };
269
270        let mut columns = Vec::new();
271        for row in rows {
272            let col_name: String = if let Drivers::SQLite = self.driver {
273                row.try_get("name")?
274            } else {
275                row.try_get(0)?
276            };
277            columns.push(col_name);
278        }
279        Ok(columns)
280    }
281
282    /// Returns the current indexes of a table.
283    pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<String>, Error> {
284        let table_name_snake = table_name.to_snake_case();
285        let query = match self.driver {
286            Drivers::Postgres => "SELECT indexname::TEXT FROM pg_indexes WHERE tablename = $1 AND schemaname = 'public'".to_string(),
287            Drivers::MySQL => "SELECT INDEX_NAME FROM information_schema.STATISTICS WHERE TABLE_NAME = ? AND TABLE_SCHEMA = DATABASE()".to_string(),
288            Drivers::SQLite => format!("PRAGMA index_list(\"{}\")", table_name_snake),
289        };
290
291        let rows = if let Drivers::SQLite = self.driver {
292            sqlx::query(&query).fetch_all(&self.pool).await?
293        } else {
294            sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
295        };
296
297        let mut indexes = Vec::new();
298        for row in rows {
299            let idx_name: String = if let Drivers::SQLite = self.driver {
300                row.try_get("name")?
301            } else {
302                row.try_get(0)?
303            };
304            indexes.push(idx_name);
305        }
306        Ok(indexes)
307    }
308
309    /// Assigns foreign keys to a table.
310    pub async fn assign_foreign_keys<T: Model>(&self) -> Result<(), Error> {
311        let table_name = T::table_name().to_snake_case();
312        let columns = T::columns();
313
314        for col in columns {
315            if let (Some(f_table), Some(f_key)) = (col.foreign_table, col.foreign_key) {
316                if matches!(self.driver, Drivers::SQLite) { continue; }
317                let constraint_name = format!("fk_{}_{}_{}", table_name, f_table.to_snake_case(), col.name.to_snake_case());
318                let query = format!(
319                    "ALTER TABLE \"{}\" ADD CONSTRAINT \"{}\" FOREIGN KEY (\"{}\") REFERENCES \"{}\"(\"{}\")",
320                    table_name, constraint_name, col.name.to_snake_case(), f_table.to_snake_case(), f_key.to_snake_case()
321                );
322                let _ = sqlx::query(&query).execute(&self.pool).await;
323            }
324        }
325        Ok(())
326    }
327}
328
329// ============================================================================
330// DatabaseBuilder Struct
331// ============================================================================
332
333pub struct DatabaseBuilder {
334    max_connections: u32,
335}
336
337impl DatabaseBuilder {
338    /// Creates a new DatabaseBuilder with default settings.
339    ///
340    /// # Example
341    ///
342    /// ```rust,ignore
343    /// let builder = DatabaseBuilder::new();
344    /// ```
345    pub fn new() -> Self { Self { max_connections: 5 } }
346
347    /// Sets the maximum number of connections for the database pool.
348    ///
349    /// # Arguments
350    ///
351    /// * `max` - The maximum number of connections.
352    ///
353    /// # Example
354    ///
355    /// ```rust,ignore
356    /// let db = Database::builder()
357    ///     .max_connections(10)
358    ///     .connect("sqlite::memory:")
359    ///     .await?;
360    /// ```
361    pub fn max_connections(mut self, max: u32) -> Self { self.max_connections = max; self }
362
363    /// Connects to the database using the configured settings.
364    ///
365    /// # Arguments
366    ///
367    /// * `url` - The database connection string.
368    ///
369    /// # Example
370    ///
371    /// ```rust,ignore
372    /// let db = Database::builder()
373    ///     .connect("sqlite::memory:")
374    ///     .await?;
375    /// ```
376    pub async fn connect(self, url: &str) -> Result<Database, Error> {
377        // Ensure sqlx drivers are registered for Any driver support
378        let _ = sqlx::any::install_default_drivers();
379
380        let pool = sqlx::any::AnyPoolOptions::new().max_connections(self.max_connections).connect(url).await?;
381        let driver = if url.starts_with("postgres") { Drivers::Postgres }
382                    else if url.starts_with("mysql") { Drivers::MySQL }
383                    else { Drivers::SQLite };
384        Ok(Database { pool, driver })
385    }
386}
387
388// ============================================================================
389// Connection Trait
390// ============================================================================
391
392pub trait Connection: Send + Sync {
393    fn driver(&self) -> Drivers;
394    fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>>;
395    fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>>;
396    fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>>;
397    fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>>;
398    fn clone_db(&self) -> Database;
399}
400
401impl Connection for Database {
402    fn driver(&self) -> Drivers { self.driver }
403    fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
404        Box::pin(async move { sqlx::query_with(sql, args).execute(&self.pool).await })
405    }
406    fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
407        Box::pin(async move { sqlx::query_with(sql, args).fetch_all(&self.pool).await })
408    }
409    fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
410        Box::pin(async move { sqlx::query_with(sql, args).fetch_one(&self.pool).await })
411    }
412    fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
413        Box::pin(async move { sqlx::query_with(sql, args).fetch_optional(&self.pool).await })
414    }
415    fn clone_db(&self) -> Database { self.clone() }
416}
417
418// ============================================================================
419// Raw SQL Query Builder
420// ============================================================================
421
422pub struct RawQuery<'a, C> {
423    conn: C,
424    sql: &'a str,
425    args: AnyArguments<'a>,
426}
427
428impl<'a, C> RawQuery<'a, C> where C: Connection {
429    pub(crate) fn new(conn: C, sql: &'a str) -> Self {
430        Self { conn, sql, args: AnyArguments::default() }
431    }
432
433    /// Binds a value to the SQL query.
434    ///
435    /// # Type Parameters
436    ///
437    /// * `T` - The type of the value to bind.
438    ///
439    /// # Example
440    ///
441    /// ```rust,ignore
442    /// let user: User = db.raw("SELECT * FROM users WHERE id = ?")
443    ///     .bind(1)
444    ///     .fetch_one()
445    ///     .await?;
446    /// ```
447    pub fn bind<T>(mut self, value: T) -> Self where T: 'a + sqlx::Encode<'a, sqlx::Any> + sqlx::Type<sqlx::Any> + Send + Sync {
448        let _ = self.args.add(value);
449        self
450    }
451
452    /// Executes the query and returns all matching rows.
453    ///
454    /// # Type Parameters
455    ///
456    /// * `T` - The type to map the rows to.
457    ///
458    /// # Example
459    ///
460    /// ```rust,ignore
461    /// let users: Vec<User> = db.raw("SELECT * FROM users")
462    ///     .fetch_all()
463    ///     .await?;
464    /// ```
465    pub async fn fetch_all<T>(self) -> Result<Vec<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
466        let rows = self.conn.fetch_all(self.sql, self.args).await?;
467        Ok(rows.iter().map(|r| T::from_row(r).unwrap()).collect())
468    }
469
470    /// Executes the query and returns exactly one row.
471    ///
472    /// # Type Parameters
473    ///
474    /// * `T` - The type to map the row to.
475    ///
476    /// # Example
477    ///
478    /// ```rust,ignore
479    /// let user: User = db.raw("SELECT * FROM users WHERE id = 1")
480    ///     .fetch_one()
481    ///     .await?;
482    /// ```
483    pub async fn fetch_one<T>(self) -> Result<T, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
484        let row = self.conn.fetch_one(self.sql, self.args).await?;
485        Ok(T::from_row(&row)?)
486    }
487
488    /// Executes the query and returns an optional row.
489    ///
490    /// # Type Parameters
491    ///
492    /// * `T` - The type to map the row to.
493    ///
494    /// # Example
495    ///
496    /// ```rust,ignore
497    /// let user: Option<User> = db.raw("SELECT * FROM users WHERE id = 1")
498    ///     .fetch_optional()
499    ///     .await?;
500    /// ```
501    pub async fn fetch_optional<T>(self) -> Result<Option<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
502        let row = self.conn.fetch_optional(self.sql, self.args).await?;
503        Ok(row.map(|r| T::from_row(&r).unwrap()))
504    }
505
506    /// Executes the query and returns the number of affected rows.
507    ///
508    /// Useful for UPDATE, DELETE or INSERT queries.
509    ///
510    /// # Example
511    ///
512    /// ```rust,ignore
513    /// let affected = db.raw("DELETE FROM users WHERE id = 1")
514    ///     .execute()
515    ///     .await?;
516    /// ```
517    pub async fn execute(self) -> Result<u64, Error> {
518        let result = self.conn.execute(self.sql, self.args).await?;
519        Ok(result.rows_affected())
520    }
521}