mockforge_registry_server/
database.rs1use anyhow::Result;
4use sqlx::{postgres::PgPoolOptions, PgPool};
5
6#[derive(Clone, Debug)]
7pub struct Database {
8 pool: PgPool,
9}
10
11impl Database {
12 pub async fn connect(database_url: &str) -> Result<Self> {
13 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
16 .ok()
17 .and_then(|s| s.parse().ok())
18 .unwrap_or(20);
19
20 let pool = PgPoolOptions::new()
21 .max_connections(max_connections)
22 .connect(database_url)
23 .await?;
24
25 Ok(Self { pool })
26 }
27
28 pub async fn migrate(&self) -> Result<()> {
29 const MIGRATION_LOCK_ID: i64 = 8675309;
32
33 tracing::info!("Acquiring advisory lock for database migrations...");
34 sqlx::query("SELECT pg_advisory_lock($1)")
35 .bind(MIGRATION_LOCK_ID)
36 .execute(&self.pool)
37 .await?;
38 tracing::info!("Advisory lock acquired, running migrations...");
39
40 let result =
49 sqlx::migrate!("./migrations")
50 .run(&self.pool)
51 .await
52 .map_err(|e| -> anyhow::Error {
53 if e.to_string().contains("previously applied but is missing") {
54 tracing::error!(
55 "sqlx refused to migrate: the DB's `_sqlx_migrations` table has an \
56 applied row whose matching file is missing from the repo. This \
57 blocks ALL subsequent migrations from running. To fix, either \
58 restore the missing file or remove the orphaned tracking row \
59 manually (psql: `DELETE FROM _sqlx_migrations WHERE version = …`). \
60 Full error: {:?}",
61 e
62 );
63 }
64 e.into()
65 });
66
67 if let Err(unlock_err) = sqlx::query("SELECT pg_advisory_unlock($1)")
69 .bind(MIGRATION_LOCK_ID)
70 .execute(&self.pool)
71 .await
72 {
73 tracing::error!("Failed to release migration advisory lock: {}", unlock_err);
74 } else {
75 tracing::info!("Migration advisory lock released");
76 }
77
78 result
79 }
80
81 pub fn pool(&self) -> &PgPool {
82 &self.pool
83 }
84
85 pub async fn get_total_plugins(&self) -> Result<i64> {
87 let count: (i64,) =
88 sqlx::query_as("SELECT COUNT(*) FROM plugins").fetch_one(&self.pool).await?;
89 Ok(count.0)
90 }
91
92 pub async fn get_total_downloads(&self) -> Result<i64> {
94 let total: (Option<i64>,) =
96 sqlx::query_as("SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins")
97 .fetch_one(&self.pool)
98 .await?;
99 Ok(total.0.unwrap_or(0))
100 }
101
102 pub async fn get_total_users(&self) -> Result<i64> {
104 let count: (i64,) =
105 sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(&self.pool).await?;
106 Ok(count.0)
107 }
108
109 pub async fn store_refresh_token_jti(
113 &self,
114 jti: &str,
115 user_id: uuid::Uuid,
116 expires_at: chrono::DateTime<chrono::Utc>,
117 ) -> Result<()> {
118 sqlx::query(
119 r#"
120 INSERT INTO token_revocations (jti, user_id, expires_at)
121 VALUES ($1, $2, $3)
122 ON CONFLICT (jti) DO NOTHING
123 "#,
124 )
125 .bind(jti)
126 .bind(user_id)
127 .bind(expires_at)
128 .execute(&self.pool)
129 .await?;
130
131 Ok(())
132 }
133
134 pub async fn is_token_revoked(&self, jti: &str) -> Result<bool> {
136 let result: Option<(Option<chrono::DateTime<chrono::Utc>>,)> = sqlx::query_as(
137 r#"
138 SELECT revoked_at FROM token_revocations WHERE jti = $1
139 "#,
140 )
141 .bind(jti)
142 .fetch_optional(&self.pool)
143 .await?;
144
145 match result {
146 Some((Some(_),)) => Ok(true),
148 Some((None,)) => Ok(false),
150 None => Ok(true),
152 }
153 }
154
155 pub async fn revoke_token(&self, jti: &str, reason: &str) -> Result<()> {
157 sqlx::query(
158 r#"
159 UPDATE token_revocations
160 SET revoked_at = NOW(), revocation_reason = $2
161 WHERE jti = $1 AND revoked_at IS NULL
162 "#,
163 )
164 .bind(jti)
165 .bind(reason)
166 .execute(&self.pool)
167 .await?;
168
169 Ok(())
170 }
171
172 pub async fn revoke_all_user_tokens(&self, user_id: uuid::Uuid, reason: &str) -> Result<u64> {
174 let result = sqlx::query(
175 r#"
176 UPDATE token_revocations
177 SET revoked_at = NOW(), revocation_reason = $2
178 WHERE user_id = $1 AND revoked_at IS NULL
179 "#,
180 )
181 .bind(user_id)
182 .bind(reason)
183 .execute(&self.pool)
184 .await?;
185
186 Ok(result.rows_affected())
187 }
188
189 pub async fn cleanup_expired_tokens(&self) -> Result<u64> {
191 let result = sqlx::query(
192 r#"
193 DELETE FROM token_revocations
194 WHERE expires_at < NOW() - INTERVAL '1 day'
195 "#,
196 )
197 .execute(&self.pool)
198 .await?;
199
200 Ok(result.rows_affected())
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_database_clone() {
210 fn requires_clone<T: Clone>() {}
212 requires_clone::<Database>();
213 }
214
215 #[tokio::test]
216 async fn test_database_connect() {
217 let database_url = "postgresql://test:test@localhost/test_db";
220
221 let result = Database::connect(database_url).await;
223
224 assert!(
226 result.is_err(),
227 "expected connection to fail without a running database, but got: {result:?}"
228 );
229 }
230
231 #[test]
232 fn test_database_pool_type() {
233 fn check_pool_method(_db: &Database) -> &PgPool {
236 _db.pool()
237 }
238
239 let _: fn(&Database) -> &PgPool = check_pool_method;
241 }
242
243 #[test]
245 fn test_total_plugins_query_structure() {
246 let query = "SELECT COUNT(*) FROM plugins";
247
248 assert!(query.contains("SELECT"));
250 assert!(query.contains("COUNT(*)"));
251 assert!(query.contains("FROM plugins"));
252 }
253
254 #[test]
255 fn test_total_downloads_query_structure() {
256 let query = "SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins";
257
258 assert!(query.contains("SELECT"));
260 assert!(query.contains("COALESCE"));
261 assert!(query.contains("SUM(downloads_total)"));
262 assert!(query.contains("FROM plugins"));
263 assert!(query.contains("::BIGINT"));
264 }
265
266 #[test]
267 fn test_total_users_query_structure() {
268 let query = "SELECT COUNT(*) FROM users";
269
270 assert!(query.contains("SELECT"));
272 assert!(query.contains("COUNT(*)"));
273 assert!(query.contains("FROM users"));
274 }
275
276 #[test]
277 fn test_migration_error_handling() {
278 let error_msg = "previously applied but is missing";
280
281 assert!(error_msg.contains("previously applied"));
282 assert!(error_msg.contains("missing"));
283 }
284
285 #[tokio::test]
289 #[ignore] async fn test_database_migration() {
291 let database_url = std::env::var("TEST_DATABASE_URL")
293 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
294
295 if let Ok(db) = Database::connect(&database_url).await {
296 let result = db.migrate().await;
297 assert!(result.is_ok() || result.is_err());
299 }
300 }
301
302 #[tokio::test]
303 #[ignore] async fn test_get_total_plugins() {
305 let database_url = std::env::var("TEST_DATABASE_URL")
306 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
307
308 if let Ok(db) = Database::connect(&database_url).await {
309 let _ = db.migrate().await;
310
311 let result = db.get_total_plugins().await;
312 if let Ok(count) = result {
313 assert!(count >= 0);
314 }
315 }
316 }
317
318 #[tokio::test]
319 #[ignore] async fn test_get_total_downloads() {
321 let database_url = std::env::var("TEST_DATABASE_URL")
322 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
323
324 if let Ok(db) = Database::connect(&database_url).await {
325 let _ = db.migrate().await;
326
327 let result = db.get_total_downloads().await;
328 if let Ok(count) = result {
329 assert!(count >= 0);
330 }
331 }
332 }
333
334 #[tokio::test]
335 #[ignore] async fn test_get_total_users() {
337 let database_url = std::env::var("TEST_DATABASE_URL")
338 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
339
340 if let Ok(db) = Database::connect(&database_url).await {
341 let _ = db.migrate().await;
342
343 let result = db.get_total_users().await;
344 if let Ok(count) = result {
345 assert!(count >= 0);
346 }
347 }
348 }
349
350 #[tokio::test]
351 #[ignore] async fn test_pool_reuse() {
353 let database_url = std::env::var("TEST_DATABASE_URL")
354 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
355
356 if let Ok(db) = Database::connect(&database_url).await {
357 let pool1 = db.pool();
359 let pool2 = db.pool();
360
361 assert!(std::ptr::eq(pool1, pool2));
363 }
364 }
365
366 #[test]
367 fn test_database_connection_string_validation() {
368 let valid_urls = vec![
370 "postgresql://user:pass@localhost/db",
371 "postgresql://user:pass@localhost:5432/db",
372 "postgresql://localhost/db",
373 "postgres://user:pass@host:5432/database?sslmode=require",
374 ];
375
376 for url in valid_urls {
377 assert!(url.starts_with("postgres"));
378 assert!(url.contains("://"));
379 }
380 }
381
382 #[test]
383 fn test_max_connections_config() {
384 let max_connections = 20; assert!(max_connections > 0);
388 assert!(max_connections <= 100); }
390
391 #[tokio::test]
392 #[ignore] async fn test_migration_idempotency() {
394 let database_url = std::env::var("TEST_DATABASE_URL")
395 .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
396
397 if let Ok(db) = Database::connect(&database_url).await {
398 let result1 = db.migrate().await;
400 let result2 = db.migrate().await;
401
402 assert!(result1.is_ok() || result1.is_err());
405 assert!(result2.is_ok() || result2.is_err());
406 }
407 }
408
409 #[test]
410 fn test_query_return_types() {
411 fn check_total_plugins_type(_: i64) {}
415 fn check_total_downloads_type(_: i64) {}
416 fn check_total_users_type(_: i64) {}
417
418 let _: fn(i64) = check_total_plugins_type;
420 let _: fn(i64) = check_total_downloads_type;
421 let _: fn(i64) = check_total_users_type;
422 }
423
424 #[test]
425 fn test_database_error_types() {
426 use anyhow::Result;
428
429 fn returns_result() -> Result<()> {
430 Ok(())
431 }
432
433 let result = returns_result();
434 assert!(result.is_ok());
435 }
436}