1use std::{collections::HashMap, time::Duration};
2
3use sqlx::{
4 any::{AnyPoolOptions, AnyRow},
5 Any, AnyPool, FromRow, Transaction,
6};
7use thiserror::Error;
8
9#[derive(Debug, Clone)]
15pub struct DbConfig {
16 pub url: String,
18 pub max_connections: u32,
20 pub min_connections: u32,
22 pub acquire_timeout: Duration,
24}
25
26impl DbConfig {
27 pub fn new(url: impl Into<String>) -> Self {
31 Self {
32 url: url.into(),
33 max_connections: 10,
34 min_connections: 1,
35 acquire_timeout: Duration::from_secs(10),
36 }
37 }
38
39 pub fn postgres_local(database: &str) -> Self {
43 Self::new(format!("postgres://postgres:postgres@localhost/{database}"))
44 }
45}
46
47#[derive(Debug, Error)]
53pub enum DbError {
54 #[error("Invalid database configuration: {0}")]
55 InvalidConfig(&'static str),
56 #[error("Failed to connect to database")]
57 Connect {
58 #[source]
59 source: sqlx::Error,
60 },
61 #[error("Named connection `{name}` was not configured")]
62 NamedConnectionNotFound { name: String },
63 #[error("Database query failed: {0}")]
64 Query(#[from] sqlx::Error),
65}
66
67#[derive(Clone)]
74pub struct Db {
75 primary: AnyPool,
76 named: HashMap<String, AnyPool>,
77}
78
79impl Db {
80 pub async fn connect(config: DbConfig) -> Result<Self, DbError> {
87 let primary = connect_pool(&config).await?;
88 Ok(Self {
89 primary,
90 named: HashMap::new(),
91 })
92 }
93
94 pub fn connect_lazy(config: DbConfig) -> Result<Self, DbError> {
98 let primary = connect_pool_lazy(&config)?;
99 Ok(Self {
100 primary,
101 named: HashMap::new(),
102 })
103 }
104
105 pub async fn connect_many<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
113 where
114 I: IntoIterator<Item = (String, DbConfig)>,
115 {
116 let primary_pool = connect_pool(&primary).await?;
117 let mut named_pools = HashMap::new();
118
119 for (name, config) in named {
120 let pool = connect_pool(&config).await?;
121 named_pools.insert(name, pool);
122 }
123
124 Ok(Self {
125 primary: primary_pool,
126 named: named_pools,
127 })
128 }
129
130 pub fn connect_many_lazy<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
131 where
132 I: IntoIterator<Item = (String, DbConfig)>,
133 {
134 let primary_pool = connect_pool_lazy(&primary)?;
135 let mut named_pools = HashMap::new();
136
137 for (name, config) in named {
138 let pool = connect_pool_lazy(&config)?;
139 named_pools.insert(name, pool);
140 }
141
142 Ok(Self {
143 primary: primary_pool,
144 named: named_pools,
145 })
146 }
147
148 pub fn pool(&self) -> &AnyPool {
149 &self.primary
150 }
151
152 pub fn pool_named(&self, name: &str) -> Result<&AnyPool, DbError> {
153 self.named
154 .get(name)
155 .ok_or_else(|| DbError::NamedConnectionNotFound {
156 name: name.to_string(),
157 })
158 }
159
160 pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
161 let result = sqlx::query::<Any>(sql).execute(&self.primary).await?;
162 Ok(result.rows_affected())
163 }
164
165 pub async fn execute_script(&self, sql: &str) -> Result<(), DbError> {
166 sqlx::raw_sql(sql).execute(&self.primary).await?;
167 Ok(())
168 }
169
170 pub async fn execute_named(&self, name: &str, sql: &str) -> Result<u64, DbError> {
171 let pool = self.pool_named(name)?;
172 let result = sqlx::query::<Any>(sql).execute(pool).await?;
173 Ok(result.rows_affected())
174 }
175
176 pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
177 where
178 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
179 {
180 let rows = sqlx::query_as::<Any, T>(sql)
181 .fetch_all(&self.primary)
182 .await?;
183 Ok(rows)
184 }
185
186 pub async fn fetch_all_named<T>(&self, name: &str, sql: &str) -> Result<Vec<T>, DbError>
187 where
188 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
189 {
190 let pool = self.pool_named(name)?;
191 let rows = sqlx::query_as::<Any, T>(sql).fetch_all(pool).await?;
192 Ok(rows)
193 }
194
195 pub async fn begin(&self) -> Result<DbTransaction, DbError> {
196 let tx = self.primary.begin().await?;
197 Ok(DbTransaction { tx })
198 }
199
200 pub async fn begin_named(&self, name: &str) -> Result<DbTransaction, DbError> {
201 let pool = self.pool_named(name)?;
202 let tx = pool.begin().await?;
203 Ok(DbTransaction { tx })
204 }
205}
206
207pub struct DbTransaction {
208 tx: Transaction<'static, Any>,
209}
210
211impl DbTransaction {
212 pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
213 let result = sqlx::query::<Any>(sql).execute(&mut *self.tx).await?;
214 Ok(result.rows_affected())
215 }
216
217 pub async fn execute_script(&mut self, sql: &str) -> Result<(), DbError> {
218 sqlx::raw_sql(sql).execute(&mut *self.tx).await?;
219 Ok(())
220 }
221
222 pub async fn fetch_all<T>(&mut self, sql: &str) -> Result<Vec<T>, DbError>
223 where
224 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
225 {
226 let rows = sqlx::query_as::<Any, T>(sql)
227 .fetch_all(&mut *self.tx)
228 .await?;
229 Ok(rows)
230 }
231
232 pub async fn commit(self) -> Result<(), DbError> {
233 self.tx.commit().await?;
234 Ok(())
235 }
236
237 pub async fn rollback(self) -> Result<(), DbError> {
238 self.tx.rollback().await?;
239 Ok(())
240 }
241}
242
243async fn connect_pool(config: &DbConfig) -> Result<AnyPool, DbError> {
244 if config.url.trim().is_empty() {
245 return Err(DbError::InvalidConfig("url cannot be empty"));
246 }
247
248 sqlx::any::install_default_drivers();
249
250 AnyPoolOptions::new()
251 .max_connections(config.max_connections)
252 .min_connections(config.min_connections)
253 .acquire_timeout(config.acquire_timeout)
254 .connect(&config.url)
255 .await
256 .map_err(|source| DbError::Connect { source })
257}
258
259fn connect_pool_lazy(config: &DbConfig) -> Result<AnyPool, DbError> {
260 if config.url.trim().is_empty() {
261 return Err(DbError::InvalidConfig("url cannot be empty"));
262 }
263
264 sqlx::any::install_default_drivers();
265
266 AnyPoolOptions::new()
267 .max_connections(config.max_connections)
268 .min_connections(config.min_connections)
269 .acquire_timeout(config.acquire_timeout)
270 .connect_lazy(&config.url)
271 .map_err(|source| DbError::Connect { source })
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn validates_empty_url_configuration() {
280 let config = DbConfig::new("");
281 let rt = tokio::runtime::Runtime::new().expect("runtime");
282 let err = rt
283 .block_on(connect_pool(&config))
284 .expect_err("config should fail");
285
286 assert!(matches!(err, DbError::InvalidConfig(_)));
287 }
288
289 #[tokio::test]
290 async fn returns_named_connection_error_for_missing_pool() {
291 let db = Db {
292 primary: AnyPoolOptions::new()
293 .connect_lazy("postgres://postgres:postgres@localhost/postgres")
294 .expect("lazy pool"),
295 named: HashMap::new(),
296 };
297
298 let err = db
299 .pool_named("analytics")
300 .expect_err("missing pool should fail");
301 assert!(matches!(err, DbError::NamedConnectionNotFound { .. }));
302 }
303
304 #[tokio::test]
305 async fn creates_lazy_db_for_sync_module_registration() {
306 let db = Db::connect_lazy(DbConfig::postgres_local("postgres"));
307 assert!(db.is_ok(), "lazy db creation should succeed");
308 }
309}