Skip to main content

nestforge_db/
lib.rs

1use std::{collections::HashMap, time::Duration};
2
3use sqlx::{
4    any::{AnyPoolOptions, AnyRow},
5    Any, AnyPool, FromRow, Transaction,
6};
7use thiserror::Error;
8
9#[derive(Debug, Clone)]
10pub struct DbConfig {
11    pub url: String,
12    pub max_connections: u32,
13    pub min_connections: u32,
14    pub acquire_timeout: Duration,
15}
16
17impl DbConfig {
18    pub fn new(url: impl Into<String>) -> Self {
19        Self {
20            url: url.into(),
21            max_connections: 10,
22            min_connections: 1,
23            acquire_timeout: Duration::from_secs(10),
24        }
25    }
26
27    pub fn postgres_local(database: &str) -> Self {
28        Self::new(format!("postgres://postgres:postgres@localhost/{database}"))
29    }
30}
31
32#[derive(Debug, Error)]
33pub enum DbError {
34    #[error("Invalid database configuration: {0}")]
35    InvalidConfig(&'static str),
36    #[error("Failed to connect to database")]
37    Connect {
38        #[source]
39        source: sqlx::Error,
40    },
41    #[error("Named connection `{name}` was not configured")]
42    NamedConnectionNotFound { name: String },
43    #[error("Database query failed: {0}")]
44    Query(#[from] sqlx::Error),
45}
46
47#[derive(Clone)]
48pub struct Db {
49    primary: AnyPool,
50    named: HashMap<String, AnyPool>,
51}
52
53impl Db {
54    pub async fn connect(config: DbConfig) -> Result<Self, DbError> {
55        let primary = connect_pool(&config).await?;
56        Ok(Self {
57            primary,
58            named: HashMap::new(),
59        })
60    }
61
62    pub fn connect_lazy(config: DbConfig) -> Result<Self, DbError> {
63        let primary = connect_pool_lazy(&config)?;
64        Ok(Self {
65            primary,
66            named: HashMap::new(),
67        })
68    }
69
70    pub async fn connect_many<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
71    where
72        I: IntoIterator<Item = (String, DbConfig)>,
73    {
74        let primary_pool = connect_pool(&primary).await?;
75        let mut named_pools = HashMap::new();
76
77        for (name, config) in named {
78            let pool = connect_pool(&config).await?;
79            named_pools.insert(name, pool);
80        }
81
82        Ok(Self {
83            primary: primary_pool,
84            named: named_pools,
85        })
86    }
87
88    pub fn connect_many_lazy<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
89    where
90        I: IntoIterator<Item = (String, DbConfig)>,
91    {
92        let primary_pool = connect_pool_lazy(&primary)?;
93        let mut named_pools = HashMap::new();
94
95        for (name, config) in named {
96            let pool = connect_pool_lazy(&config)?;
97            named_pools.insert(name, pool);
98        }
99
100        Ok(Self {
101            primary: primary_pool,
102            named: named_pools,
103        })
104    }
105
106    pub fn pool(&self) -> &AnyPool {
107        &self.primary
108    }
109
110    pub fn pool_named(&self, name: &str) -> Result<&AnyPool, DbError> {
111        self.named
112            .get(name)
113            .ok_or_else(|| DbError::NamedConnectionNotFound {
114                name: name.to_string(),
115            })
116    }
117
118    pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
119        let result = sqlx::query::<Any>(sql).execute(&self.primary).await?;
120        Ok(result.rows_affected())
121    }
122
123    pub async fn execute_script(&self, sql: &str) -> Result<(), DbError> {
124        sqlx::raw_sql(sql).execute(&self.primary).await?;
125        Ok(())
126    }
127
128    pub async fn execute_named(&self, name: &str, sql: &str) -> Result<u64, DbError> {
129        let pool = self.pool_named(name)?;
130        let result = sqlx::query::<Any>(sql).execute(pool).await?;
131        Ok(result.rows_affected())
132    }
133
134    pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
135    where
136        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
137    {
138        let rows = sqlx::query_as::<Any, T>(sql)
139            .fetch_all(&self.primary)
140            .await?;
141        Ok(rows)
142    }
143
144    pub async fn fetch_all_named<T>(&self, name: &str, sql: &str) -> Result<Vec<T>, DbError>
145    where
146        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
147    {
148        let pool = self.pool_named(name)?;
149        let rows = sqlx::query_as::<Any, T>(sql).fetch_all(pool).await?;
150        Ok(rows)
151    }
152
153    pub async fn begin(&self) -> Result<DbTransaction, DbError> {
154        let tx = self.primary.begin().await?;
155        Ok(DbTransaction { tx })
156    }
157
158    pub async fn begin_named(&self, name: &str) -> Result<DbTransaction, DbError> {
159        let pool = self.pool_named(name)?;
160        let tx = pool.begin().await?;
161        Ok(DbTransaction { tx })
162    }
163}
164
165pub struct DbTransaction {
166    tx: Transaction<'static, Any>,
167}
168
169impl DbTransaction {
170    pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
171        let result = sqlx::query::<Any>(sql).execute(&mut *self.tx).await?;
172        Ok(result.rows_affected())
173    }
174
175    pub async fn execute_script(&mut self, sql: &str) -> Result<(), DbError> {
176        sqlx::raw_sql(sql).execute(&mut *self.tx).await?;
177        Ok(())
178    }
179
180    pub async fn fetch_all<T>(&mut self, sql: &str) -> Result<Vec<T>, DbError>
181    where
182        for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
183    {
184        let rows = sqlx::query_as::<Any, T>(sql)
185            .fetch_all(&mut *self.tx)
186            .await?;
187        Ok(rows)
188    }
189
190    pub async fn commit(self) -> Result<(), DbError> {
191        self.tx.commit().await?;
192        Ok(())
193    }
194
195    pub async fn rollback(self) -> Result<(), DbError> {
196        self.tx.rollback().await?;
197        Ok(())
198    }
199}
200
201async fn connect_pool(config: &DbConfig) -> Result<AnyPool, DbError> {
202    if config.url.trim().is_empty() {
203        return Err(DbError::InvalidConfig("url cannot be empty"));
204    }
205
206    sqlx::any::install_default_drivers();
207
208    AnyPoolOptions::new()
209        .max_connections(config.max_connections)
210        .min_connections(config.min_connections)
211        .acquire_timeout(config.acquire_timeout)
212        .connect(&config.url)
213        .await
214        .map_err(|source| DbError::Connect { source })
215}
216
217fn connect_pool_lazy(config: &DbConfig) -> Result<AnyPool, DbError> {
218    if config.url.trim().is_empty() {
219        return Err(DbError::InvalidConfig("url cannot be empty"));
220    }
221
222    sqlx::any::install_default_drivers();
223
224    AnyPoolOptions::new()
225        .max_connections(config.max_connections)
226        .min_connections(config.min_connections)
227        .acquire_timeout(config.acquire_timeout)
228        .connect_lazy(&config.url)
229        .map_err(|source| DbError::Connect { source })
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn validates_empty_url_configuration() {
238        let config = DbConfig::new("");
239        let rt = tokio::runtime::Runtime::new().expect("runtime");
240        let err = rt
241            .block_on(connect_pool(&config))
242            .expect_err("config should fail");
243
244        assert!(matches!(err, DbError::InvalidConfig(_)));
245    }
246
247    #[tokio::test]
248    async fn returns_named_connection_error_for_missing_pool() {
249        let db = Db {
250            primary: AnyPoolOptions::new()
251                .connect_lazy("postgres://postgres:postgres@localhost/postgres")
252                .expect("lazy pool"),
253            named: HashMap::new(),
254        };
255
256        let err = db
257            .pool_named("analytics")
258            .expect_err("missing pool should fail");
259        assert!(matches!(err, DbError::NamedConnectionNotFound { .. }));
260    }
261
262    #[tokio::test]
263    async fn creates_lazy_db_for_sync_module_registration() {
264        let db = Db::connect_lazy(DbConfig::postgres_local("postgres"));
265        assert!(db.is_ok(), "lazy db creation should succeed");
266    }
267}