1use std::{collections::HashMap, time::Duration};
2
3use sqlx::{
4 any::{AnyPoolOptions, AnyRow},
5 Any, AnyPool, FromRow, Transaction,
6};
7use thiserror::Error;
8
9#[derive(Debug, Clone)]
10pub struct DbConfig {
11 pub url: String,
12 pub max_connections: u32,
13 pub min_connections: u32,
14 pub acquire_timeout: Duration,
15}
16
17impl DbConfig {
18 pub fn new(url: impl Into<String>) -> Self {
19 Self {
20 url: url.into(),
21 max_connections: 10,
22 min_connections: 1,
23 acquire_timeout: Duration::from_secs(10),
24 }
25 }
26
27 pub fn postgres_local(database: &str) -> Self {
28 Self::new(format!("postgres://postgres:postgres@localhost/{database}"))
29 }
30}
31
32#[derive(Debug, Error)]
33pub enum DbError {
34 #[error("Invalid database configuration: {0}")]
35 InvalidConfig(&'static str),
36 #[error("Failed to connect to database")]
37 Connect {
38 #[source]
39 source: sqlx::Error,
40 },
41 #[error("Named connection `{name}` was not configured")]
42 NamedConnectionNotFound { name: String },
43 #[error("Database query failed: {0}")]
44 Query(#[from] sqlx::Error),
45}
46
47#[derive(Clone)]
48pub struct Db {
49 primary: AnyPool,
50 named: HashMap<String, AnyPool>,
51}
52
53impl Db {
54 pub async fn connect(config: DbConfig) -> Result<Self, DbError> {
55 let primary = connect_pool(&config).await?;
56 Ok(Self {
57 primary,
58 named: HashMap::new(),
59 })
60 }
61
62 pub fn connect_lazy(config: DbConfig) -> Result<Self, DbError> {
63 let primary = connect_pool_lazy(&config)?;
64 Ok(Self {
65 primary,
66 named: HashMap::new(),
67 })
68 }
69
70 pub async fn connect_many<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
71 where
72 I: IntoIterator<Item = (String, DbConfig)>,
73 {
74 let primary_pool = connect_pool(&primary).await?;
75 let mut named_pools = HashMap::new();
76
77 for (name, config) in named {
78 let pool = connect_pool(&config).await?;
79 named_pools.insert(name, pool);
80 }
81
82 Ok(Self {
83 primary: primary_pool,
84 named: named_pools,
85 })
86 }
87
88 pub fn connect_many_lazy<I>(primary: DbConfig, named: I) -> Result<Self, DbError>
89 where
90 I: IntoIterator<Item = (String, DbConfig)>,
91 {
92 let primary_pool = connect_pool_lazy(&primary)?;
93 let mut named_pools = HashMap::new();
94
95 for (name, config) in named {
96 let pool = connect_pool_lazy(&config)?;
97 named_pools.insert(name, pool);
98 }
99
100 Ok(Self {
101 primary: primary_pool,
102 named: named_pools,
103 })
104 }
105
106 pub fn pool(&self) -> &AnyPool {
107 &self.primary
108 }
109
110 pub fn pool_named(&self, name: &str) -> Result<&AnyPool, DbError> {
111 self.named
112 .get(name)
113 .ok_or_else(|| DbError::NamedConnectionNotFound {
114 name: name.to_string(),
115 })
116 }
117
118 pub async fn execute(&self, sql: &str) -> Result<u64, DbError> {
119 let result = sqlx::query::<Any>(sql).execute(&self.primary).await?;
120 Ok(result.rows_affected())
121 }
122
123 pub async fn execute_script(&self, sql: &str) -> Result<(), DbError> {
124 sqlx::raw_sql(sql).execute(&self.primary).await?;
125 Ok(())
126 }
127
128 pub async fn execute_named(&self, name: &str, sql: &str) -> Result<u64, DbError> {
129 let pool = self.pool_named(name)?;
130 let result = sqlx::query::<Any>(sql).execute(pool).await?;
131 Ok(result.rows_affected())
132 }
133
134 pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, DbError>
135 where
136 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
137 {
138 let rows = sqlx::query_as::<Any, T>(sql)
139 .fetch_all(&self.primary)
140 .await?;
141 Ok(rows)
142 }
143
144 pub async fn fetch_all_named<T>(&self, name: &str, sql: &str) -> Result<Vec<T>, DbError>
145 where
146 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
147 {
148 let pool = self.pool_named(name)?;
149 let rows = sqlx::query_as::<Any, T>(sql).fetch_all(pool).await?;
150 Ok(rows)
151 }
152
153 pub async fn begin(&self) -> Result<DbTransaction, DbError> {
154 let tx = self.primary.begin().await?;
155 Ok(DbTransaction { tx })
156 }
157
158 pub async fn begin_named(&self, name: &str) -> Result<DbTransaction, DbError> {
159 let pool = self.pool_named(name)?;
160 let tx = pool.begin().await?;
161 Ok(DbTransaction { tx })
162 }
163}
164
165pub struct DbTransaction {
166 tx: Transaction<'static, Any>,
167}
168
169impl DbTransaction {
170 pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
171 let result = sqlx::query::<Any>(sql).execute(&mut *self.tx).await?;
172 Ok(result.rows_affected())
173 }
174
175 pub async fn execute_script(&mut self, sql: &str) -> Result<(), DbError> {
176 sqlx::raw_sql(sql).execute(&mut *self.tx).await?;
177 Ok(())
178 }
179
180 pub async fn fetch_all<T>(&mut self, sql: &str) -> Result<Vec<T>, DbError>
181 where
182 for<'r> T: FromRow<'r, AnyRow> + Send + Unpin,
183 {
184 let rows = sqlx::query_as::<Any, T>(sql)
185 .fetch_all(&mut *self.tx)
186 .await?;
187 Ok(rows)
188 }
189
190 pub async fn commit(self) -> Result<(), DbError> {
191 self.tx.commit().await?;
192 Ok(())
193 }
194
195 pub async fn rollback(self) -> Result<(), DbError> {
196 self.tx.rollback().await?;
197 Ok(())
198 }
199}
200
201async fn connect_pool(config: &DbConfig) -> Result<AnyPool, DbError> {
202 if config.url.trim().is_empty() {
203 return Err(DbError::InvalidConfig("url cannot be empty"));
204 }
205
206 sqlx::any::install_default_drivers();
207
208 AnyPoolOptions::new()
209 .max_connections(config.max_connections)
210 .min_connections(config.min_connections)
211 .acquire_timeout(config.acquire_timeout)
212 .connect(&config.url)
213 .await
214 .map_err(|source| DbError::Connect { source })
215}
216
217fn connect_pool_lazy(config: &DbConfig) -> Result<AnyPool, DbError> {
218 if config.url.trim().is_empty() {
219 return Err(DbError::InvalidConfig("url cannot be empty"));
220 }
221
222 sqlx::any::install_default_drivers();
223
224 AnyPoolOptions::new()
225 .max_connections(config.max_connections)
226 .min_connections(config.min_connections)
227 .acquire_timeout(config.acquire_timeout)
228 .connect_lazy(&config.url)
229 .map_err(|source| DbError::Connect { source })
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn validates_empty_url_configuration() {
238 let config = DbConfig::new("");
239 let rt = tokio::runtime::Runtime::new().expect("runtime");
240 let err = rt
241 .block_on(connect_pool(&config))
242 .expect_err("config should fail");
243
244 assert!(matches!(err, DbError::InvalidConfig(_)));
245 }
246
247 #[tokio::test]
248 async fn returns_named_connection_error_for_missing_pool() {
249 let db = Db {
250 primary: AnyPoolOptions::new()
251 .connect_lazy("postgres://postgres:postgres@localhost/postgres")
252 .expect("lazy pool"),
253 named: HashMap::new(),
254 };
255
256 let err = db
257 .pool_named("analytics")
258 .expect_err("missing pool should fail");
259 assert!(matches!(err, DbError::NamedConnectionNotFound { .. }));
260 }
261
262 #[tokio::test]
263 async fn creates_lazy_db_for_sync_module_registration() {
264 let db = Db::connect_lazy(DbConfig::postgres_local("postgres"));
265 assert!(db.is_ok(), "lazy db creation should succeed");
266 }
267}