Skip to main content

nestforge_db/
lib.rs

1use std::{collections::HashMap, time::Duration};
2
3use sqlx::{
4    any::{AnyPoolOptions, AnyRow},
5    Any, AnyPool, FromRow, Transaction,
6};
7use thiserror::Error;
8
9/**
10 * DbConfig
11 *
12 * Configuration for database connections.
13 */
14#[derive(Debug, Clone)]
15pub struct DbConfig {
16    /** The database connection URL */
17    pub url: String,
18    /** Maximum number of connections in the pool */
19    pub max_connections: u32,
20    /** Minimum number of connections to maintain */
21    pub min_connections: u32,
22    /** Timeout for acquiring a connection */
23    pub acquire_timeout: Duration,
24}
25
26impl DbConfig {
27    /**
28     * Creates a new DbConfig with the given connection URL.
29     */
30    pub fn new(url: impl Into<String>) -> Self {
31        Self {
32            url: url.into(),
33            max_connections: 10,
34            min_connections: 1,
35            acquire_timeout: Duration::from_secs(10),
36        }
37    }
38
39    /**
40     * Creates a config for a local PostgreSQL database.
41     */
42    pub fn postgres_local(database: &str) -> Self {
43        Self::new(format!("postgres://postgres:postgres@localhost/{database}"))
44    }
45}
46
47/**
48 * DbError
49 *
50 * Error types that can occur during database operations.
51 */
52#[derive(Debug, Error)]
53pub enum DbError {
54    #[error("Invalid database configuration: {0}")]
55    InvalidConfig(&'static str),
56    #[error("Failed to connect to database")]
57    Connect {
58        #[source]
59        source: sqlx::Error,
60    },
61    #[error("Named connection `{name}` was not configured")]
62    NamedConnectionNotFound { name: String },
63    #[error("Database query failed: {0}")]
64    Query(#[from] sqlx::Error),
65}
66
67/**
68 * Db
69 *
70 * The database connection pool manager in NestForge.
71 * Provides access to one or more database connections.
72 */
73#[derive(Clone)]
74pub struct Db {
75    primary: AnyPool,
76    named: HashMap<String, AnyPool>,
77}
78
79impl Db {
80    /**
81     * Connects to the database synchronously.
82     *
83     * # Arguments
84     * - `config`: Database configuration
85     */
86    pub async fn connect(config: DbConfig) -> Result<Self, DbError> {
87        let primary = connect_pool(&config).await?;
88        Ok(Self {
89            primary,
90            named: HashMap::new(),
91        })
92    }
93
94    /**
95     * Connects to the database lazily (without verifying connection).
96     */
97    pub fn connect_lazy(config: DbConfig) -> Result<Self, DbError> {
98        let primary = connect_pool_lazy(&config)?;
99        Ok(Self {
100            primary,
101            named: HashMap::new(),
102        })
103    }
104
105    /**
106     * Connects to multiple databases simultaneously.
107     *
108     * # Arguments
109     * - `primary`: Configuration for the primary database
110     * - `named`: Iterator of (name, config) pairs for additional databases
111     */
112    pub async fn connect_many<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
113    where
114        I: IntoIterator<Item = (String, DbConfig)>,
115    {
116        let primary_pool = connect_pool(&primary).await?;
117        let mut named_pools = HashMap::new();
118
119        for (name, config) in named {
120            let pool = connect_pool(&config).await?;
121            named_pools.insert(name, pool);
122        }
123
124        Ok(Self {
125            primary: primary_pool,
126            named: named_pools,
127        })
128    }
129
130    pub fn connect_many_lazy<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
131    where
132        I: IntoIterator<Item = (String, DbConfig)>,
133    {
134        let primary_pool = connect_pool_lazy(&primary)?;
135        let mut named_pools = HashMap::new();
136
137        for (name, config) in named {
138            let pool = connect_pool_lazy(&config)?;
139            named_pools.insert(name, pool);
140        }
141
142        Ok(Self {
143            primary: primary_pool,
144            named: named_pools,
145        })
146    }
147
148    pub fn pool(&self) -> &AnyPool {
149        &self.primary
150    }
151
152    pub fn pool_named(&self, name: &str) -> Result<&AnyPool, DbError> {
153        self.named
154            .get(name)
155            .ok_or_else(|| DbError::NamedConnectionNotFound {
156                name: name.to_string(),
157            })
158    }
159
160    pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
161        let result = sqlx::query::<Any>(sql).execute(&self.primary).await?;
162        Ok(result.rows_affected())
163    }
164
165    pub async fn execute_script(&self, sql: &str) -> Result<(), DbError> {
166        sqlx::raw_sql(sql).execute(&self.primary).await?;
167        Ok(())
168    }
169
170    pub async fn execute_named(&self, name: &str, sql: &str) -> Result<u64, DbError> {
171        let pool = self.pool_named(name)?;
172        let result = sqlx::query::<Any>(sql).execute(pool).await?;
173        Ok(result.rows_affected())
174    }
175
176    pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
177    where
178        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
179    {
180        let rows = sqlx::query_as::<Any, T>(sql)
181            .fetch_all(&self.primary)
182            .await?;
183        Ok(rows)
184    }
185
186    pub async fn fetch_all_named<T>(&self, name: &str, sql: &str) -> Result<Vec<T>, DbError>
187    where
188        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
189    {
190        let pool = self.pool_named(name)?;
191        let rows = sqlx::query_as::<Any, T>(sql).fetch_all(pool).await?;
192        Ok(rows)
193    }
194
195    pub async fn begin(&self) -> Result<DbTransaction, DbError> {
196        let tx = self.primary.begin().await?;
197        Ok(DbTransaction { tx })
198    }
199
200    pub async fn begin_named(&self, name: &str) -> Result<DbTransaction, DbError> {
201        let pool = self.pool_named(name)?;
202        let tx = pool.begin().await?;
203        Ok(DbTransaction { tx })
204    }
205}
206
207pub struct DbTransaction {
208    tx: Transaction<'static, Any>,
209}
210
211impl DbTransaction {
212    pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
213        let result = sqlx::query::<Any>(sql).execute(&mut *self.tx).await?;
214        Ok(result.rows_affected())
215    }
216
217    pub async fn execute_script(&mut self, sql: &str) -> Result<(), DbError> {
218        sqlx::raw_sql(sql).execute(&mut *self.tx).await?;
219        Ok(())
220    }
221
222    pub async fn fetch_all<T>(&mut self, sql: &str) -> Result<Vec<T>, DbError>
223    where
224        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
225    {
226        let rows = sqlx::query_as::<Any, T>(sql)
227            .fetch_all(&mut *self.tx)
228            .await?;
229        Ok(rows)
230    }
231
232    pub async fn commit(self) -> Result<(), DbError> {
233        self.tx.commit().await?;
234        Ok(())
235    }
236
237    pub async fn rollback(self) -> Result<(), DbError> {
238        self.tx.rollback().await?;
239        Ok(())
240    }
241}
242
243async fn connect_pool(config: &DbConfig) -> Result<AnyPool, DbError> {
244    if config.url.trim().is_empty() {
245        return Err(DbError::InvalidConfig("url cannot be empty"));
246    }
247
248    sqlx::any::install_default_drivers();
249
250    AnyPoolOptions::new()
251        .max_connections(config.max_connections)
252        .min_connections(config.min_connections)
253        .acquire_timeout(config.acquire_timeout)
254        .connect(&config.url)
255        .await
256        .map_err(|source| DbError::Connect { source })
257}
258
259fn connect_pool_lazy(config: &DbConfig) -> Result<AnyPool, DbError> {
260    if config.url.trim().is_empty() {
261        return Err(DbError::InvalidConfig("url cannot be empty"));
262    }
263
264    sqlx::any::install_default_drivers();
265
266    AnyPoolOptions::new()
267        .max_connections(config.max_connections)
268        .min_connections(config.min_connections)
269        .acquire_timeout(config.acquire_timeout)
270        .connect_lazy(&config.url)
271        .map_err(|source| DbError::Connect { source })
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn validates_empty_url_configuration() {
280        let config = DbConfig::new("");
281        let rt = tokio::runtime::Runtime::new().expect("runtime");
282        let err = rt
283            .block_on(connect_pool(&config))
284            .expect_err("config should fail");
285
286        assert!(matches!(err, DbError::InvalidConfig(_)));
287    }
288
289    #[tokio::test]
290    async fn returns_named_connection_error_for_missing_pool() {
291        let db = Db {
292            primary: AnyPoolOptions::new()
293                .connect_lazy("postgres://postgres:postgres@localhost/postgres")
294                .expect("lazy pool"),
295            named: HashMap::new(),
296        };
297
298        let err = db
299            .pool_named("analytics")
300            .expect_err("missing pool should fail");
301        assert!(matches!(err, DbError::NamedConnectionNotFound { .. }));
302    }
303
304    #[tokio::test]
305    async fn creates_lazy_db_for_sync_module_registration() {
306        let db = Db::connect_lazy(DbConfig::postgres_local("postgres"));
307        assert!(db.is_ok(), "lazy db creation should succeed");
308    }
309}