Skip to main content

bottle_orm/
database.rs

1//! # Database Module
2//!
3//! This module provides the core database connection and management functionality for Bottle ORM.
4//! It handles connection pooling, driver detection, table creation, and foreign key management
5//! across PostgreSQL, MySQL, and SQLite.
6
7// ============================================================================
8// External Crate Imports
9// ============================================================================
10
11use futures::future::BoxFuture;
12use heck::ToSnakeCase;
13use sqlx::{any::AnyArguments, Any, AnyPool, Row, Arguments};
14use std::sync::Arc;
15
16// ============================================================================
17// Internal Crate Imports
18// ============================================================================
19
20use crate::{migration::Migrator, Error, Model, QueryBuilder};
21
22// ============================================================================
23// Database Driver Enum
24// ============================================================================
25
26/// Supported database drivers for Bottle ORM.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Drivers {
29    /// PostgreSQL driver
30    Postgres,
31    /// MySQL driver
32    MySQL,
33    /// SQLite driver
34    SQLite,
35}
36
37// ============================================================================
38// Database Struct
39// ============================================================================
40
41/// The main entry point for Bottle ORM database operations.
42///
43/// `Database` manages a connection pool and provides methods for starting
44/// transactions, creating tables, and building queries for models.
45///
46/// It is designed to be thread-safe and easily shared across an application
47/// (internally uses an `Arc` for the connection pool).
48#[derive(Debug, Clone)]
49pub struct Database {
50    /// The underlying SQLx connection pool
51    pub(crate) pool: AnyPool,
52    /// The detected database driver
53    pub(crate) driver: Drivers,
54}
55
56// ============================================================================
57// Database Implementation
58// ============================================================================
59
60impl Database {
61    /// Creates a new DatabaseBuilder for configuring the connection.
62    pub fn builder() -> DatabaseBuilder {
63        DatabaseBuilder::new()
64    }
65
66    /// Connects to a database using the provided connection string.
67    pub async fn connect(url: &str) -> Result<Self, Error> {
68        DatabaseBuilder::new().connect(url).await
69    }
70
71    /// Returns a new Migrator instance for managing schema changes.
72    pub fn migrator(&self) -> Migrator {
73        Migrator::new(self)
74    }
75
76    /// Starts building a query for the specified model.
77    pub fn model<T: Model + Send + Sync + Unpin>(&self) -> QueryBuilder<T, Self> {
78        let active_columns = T::active_columns();
79        let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
80
81        for col in active_columns {
82            columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
83        }
84
85        QueryBuilder::new(self.clone(), self.driver, T::table_name(), T::columns(), columns)
86    }
87
88    /// Creates a raw SQL query builder.
89    pub fn raw<'a>(&self, sql: &'a str) -> RawQuery<'a, Self> {
90        RawQuery::new(self.clone(), sql)
91    }
92
93    /// Starts a new database transaction.
94    pub async fn begin(&self) -> Result<crate::transaction::Transaction<'_>, Error> {
95        let tx = self.pool.begin().await?;
96        Ok(crate::transaction::Transaction {
97            tx: Arc::new(tokio::sync::Mutex::new(Some(tx))),
98            driver: self.driver,
99        })
100    }
101
102    /// Checks if a table exists in the database.
103    pub async fn table_exists(&self, table_name: &str) -> Result<bool, Error> {
104        let table_name_snake = table_name.to_snake_case();
105        let query = match self.driver {
106            Drivers::Postgres => {
107                "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public')".to_string()
108            }
109            Drivers::MySQL => {
110                "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ? AND table_schema = DATABASE())".to_string()
111            }
112            Drivers::SQLite => {
113                "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?".to_string()
114            }
115        };
116
117        let row = sqlx::query(&query).bind(&table_name_snake).fetch_one(&self.pool).await?;
118
119        match self.driver {
120            Drivers::SQLite => {
121                let count: i64 = row.try_get(0)?;
122                Ok(count > 0)
123            }
124            _ => {
125                let exists: bool = row.try_get(0)?;
126                Ok(exists)
127            }
128        }
129    }
130
131    /// Creates a table based on the provided Model metadata.
132    pub async fn create_table<T: Model>(&self) -> Result<(), Error> {
133        let table_name = T::table_name().to_snake_case();
134        let columns = T::columns();
135
136        let mut query = format!("CREATE TABLE IF NOT EXISTS \"{}\" (", table_name);
137        let mut column_defs = Vec::new();
138        let mut indexes = Vec::new();
139
140        for col in columns {
141            let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
142            let mut def = format!("\"{}\" {}", col_name_clean, col.sql_type);
143
144            if col.is_primary_key {
145                def.push_str(" PRIMARY KEY");
146            } else if !col.is_nullable {
147                def.push_str(" NOT NULL");
148            }
149
150            if col.unique && !col.is_primary_key {
151                def.push_str(" UNIQUE");
152            }
153
154            if col.index && !col.is_primary_key && !col.unique {
155                indexes.push(format!(
156                    "CREATE INDEX IF NOT EXISTS \"idx_{}_{}\" ON \"{}\" (\"{}\")",
157                    table_name, col_name_clean, table_name, col_name_clean
158                ));
159            }
160
161            column_defs.push(def);
162        }
163
164        query.push_str(&column_defs.join(", "));
165        query.push(')');
166
167        sqlx::query(&query).execute(&self.pool).await?;
168
169        for idx_query in indexes {
170            sqlx::query(&idx_query).execute(&self.pool).await?;
171        }
172
173        Ok(())
174    }
175
176    /// Synchronizes a table schema by adding missing columns or indexes.
177    pub async fn sync_table<T: Model>(&self) -> Result<(), Error> {
178        if !self.table_exists(T::table_name()).await? {
179            return self.create_table::<T>().await;
180        }
181
182        let table_name = T::table_name().to_snake_case();
183        let model_columns = T::columns();
184        let existing_columns = self.get_table_columns(&table_name).await?;
185
186        for col in model_columns {
187            let col_name_clean = col.name.strip_prefix("r#").unwrap_or(col.name).to_snake_case();
188            if !existing_columns.contains(&col_name_clean) {
189                let mut alter_query = format!("ALTER TABLE \"{}\" ADD COLUMN \"{}\" {}", table_name, col_name_clean, col.sql_type);
190                if !col.is_nullable {
191                    alter_query.push_str(" DEFAULT ");
192                    match col.sql_type {
193                        "INTEGER" | "INT" | "BIGINT" => alter_query.push('0'),
194                        "BOOLEAN" | "BOOL" => alter_query.push_str("FALSE"),
195                        _ => alter_query.push_str("''"),
196                    }
197                }
198                sqlx::query(&alter_query).execute(&self.pool).await?;
199            }
200
201            if col.index || col.unique {
202                let existing_indexes = self.get_table_indexes(&table_name).await?;
203                let idx_name = format!("idx_{}_{}", table_name, col_name_clean);
204                let uniq_name = format!("unique_{}_{}", table_name, col_name_clean);
205
206                if col.unique && !existing_indexes.contains(&uniq_name) {
207                    let mut query = format!("CREATE UNIQUE INDEX \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
208                    if matches!(self.driver, Drivers::SQLite) {
209                        query = format!("CREATE UNIQUE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", uniq_name, table_name, col_name_clean);
210                    }
211                    sqlx::query(&query).execute(&self.pool).await?;
212                } else if col.index && !existing_indexes.contains(&idx_name) && !col.unique {
213                    let mut query = format!("CREATE INDEX \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
214                    if matches!(self.driver, Drivers::SQLite) {
215                        query = format!("CREATE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (\"{}\")", idx_name, table_name, col_name_clean);
216                    }
217                    sqlx::query(&query).execute(&self.pool).await?;
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    /// Returns the current columns of a table.
226    pub async fn get_table_columns(&self, table_name: &str) -> Result<Vec<String>, Error> {
227        let table_name_snake = table_name.to_snake_case();
228        let query = match self.driver {
229            Drivers::Postgres => "SELECT column_name::TEXT FROM information_schema.columns WHERE table_name = $1 AND table_schema = 'public'".to_string(),
230            Drivers::MySQL => "SELECT column_name FROM information_schema.columns WHERE table_name = ? AND table_schema = DATABASE()".to_string(),
231            Drivers::SQLite => format!("PRAGMA table_info(\"{}\")", table_name_snake),
232        };
233
234        let rows = if let Drivers::SQLite = self.driver {
235            sqlx::query(&query).fetch_all(&self.pool).await?
236        } else {
237            sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
238        };
239
240        let mut columns = Vec::new();
241        for row in rows {
242            let col_name: String = if let Drivers::SQLite = self.driver {
243                row.try_get("name")?
244            } else {
245                row.try_get(0)?
246            };
247            columns.push(col_name);
248        }
249        Ok(columns)
250    }
251
252    /// Returns the current indexes of a table.
253    pub async fn get_table_indexes(&self, table_name: &str) -> Result<Vec<String>, Error> {
254        let table_name_snake = table_name.to_snake_case();
255        let query = match self.driver {
256            Drivers::Postgres => "SELECT indexname::TEXT FROM pg_indexes WHERE tablename = $1 AND schemaname = 'public'".to_string(),
257            Drivers::MySQL => "SELECT INDEX_NAME FROM information_schema.STATISTICS WHERE TABLE_NAME = ? AND TABLE_SCHEMA = DATABASE()".to_string(),
258            Drivers::SQLite => format!("PRAGMA index_list(\"{}\")", table_name_snake),
259        };
260
261        let rows = if let Drivers::SQLite = self.driver {
262            sqlx::query(&query).fetch_all(&self.pool).await?
263        } else {
264            sqlx::query(&query).bind(&table_name_snake).fetch_all(&self.pool).await?
265        };
266
267        let mut indexes = Vec::new();
268        for row in rows {
269            let idx_name: String = if let Drivers::SQLite = self.driver {
270                row.try_get("name")?
271            } else {
272                row.try_get(0)?
273            };
274            indexes.push(idx_name);
275        }
276        Ok(indexes)
277    }
278
279    /// Assigns foreign keys to a table.
280    pub async fn assign_foreign_keys<T: Model>(&self) -> Result<(), Error> {
281        let table_name = T::table_name().to_snake_case();
282        let columns = T::columns();
283
284        for col in columns {
285            if let (Some(f_table), Some(f_key)) = (col.foreign_table, col.foreign_key) {
286                if matches!(self.driver, Drivers::SQLite) { continue; }
287                let constraint_name = format!("fk_{}_{}_{}", table_name, f_table.to_snake_case(), col.name.to_snake_case());
288                let query = format!(
289                    "ALTER TABLE \"{}\" ADD CONSTRAINT \"{}\" FOREIGN KEY (\"{}\") REFERENCES \"{}\"(\"{}\")",
290                    table_name, constraint_name, col.name.to_snake_case(), f_table.to_snake_case(), f_key.to_snake_case()
291                );
292                let _ = sqlx::query(&query).execute(&self.pool).await;
293            }
294        }
295        Ok(())
296    }
297}
298
299// ============================================================================
300// DatabaseBuilder Struct
301// ============================================================================
302
303pub struct DatabaseBuilder {
304    max_connections: u32,
305}
306
307impl DatabaseBuilder {
308    pub fn new() -> Self { Self { max_connections: 5 } }
309    pub fn max_connections(mut self, max: u32) -> Self { self.max_connections = max; self }
310    pub async fn connect(self, url: &str) -> Result<Database, Error> {
311        let pool = sqlx::any::AnyPoolOptions::new().max_connections(self.max_connections).connect(url).await?;
312        let driver = if url.starts_with("postgres") { Drivers::Postgres }
313                    else if url.starts_with("mysql") { Drivers::MySQL }
314                    else { Drivers::SQLite };
315        Ok(Database { pool, driver })
316    }
317}
318
319// ============================================================================
320// Connection Trait
321// ============================================================================
322
323pub trait Connection: Send + Sync {
324    fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>>;
325    fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>>;
326    fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>>;
327    fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>>;
328}
329
330impl Connection for Database {
331    fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
332        Box::pin(async move { sqlx::query_with(sql, args).execute(&self.pool).await })
333    }
334    fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
335        Box::pin(async move { sqlx::query_with(sql, args).fetch_all(&self.pool).await })
336    }
337    fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
338        Box::pin(async move { sqlx::query_with(sql, args).fetch_one(&self.pool).await })
339    }
340    fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
341        Box::pin(async move { sqlx::query_with(sql, args).fetch_optional(&self.pool).await })
342    }
343}
344
345// ============================================================================
346// Raw SQL Query Builder
347// ============================================================================
348
349pub struct RawQuery<'a, C> {
350    conn: C,
351    sql: &'a str,
352    args: AnyArguments<'a>,
353}
354
355impl<'a, C> RawQuery<'a, C> where C: Connection {
356    pub(crate) fn new(conn: C, sql: &'a str) -> Self {
357        Self { conn, sql, args: AnyArguments::default() }
358    }
359    pub fn bind<T>(mut self, value: T) -> Self where T: 'a + sqlx::Encode<'a, sqlx::Any> + sqlx::Type<sqlx::Any> + Send + Sync {
360        let _ = self.args.add(value);
361        self
362    }
363    pub async fn fetch_all<T>(self) -> Result<Vec<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
364        let rows = self.conn.fetch_all(self.sql, self.args).await?;
365        Ok(rows.iter().map(|r| T::from_row(r).unwrap()).collect())
366    }
367    pub async fn fetch_one<T>(self) -> Result<T, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
368        let row = self.conn.fetch_one(self.sql, self.args).await?;
369        Ok(T::from_row(&row)?)
370    }
371    pub async fn fetch_optional<T>(self) -> Result<Option<T>, Error> where T: for<'r> sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin {
372        let row = self.conn.fetch_optional(self.sql, self.args).await?;
373        Ok(row.map(|r| T::from_row(&r).unwrap()))
374    }
375    pub async fn execute(self) -> Result<u64, Error> {
376        let result = self.conn.execute(self.sql, self.args).await?;
377        Ok(result.rows_affected())
378    }
379}