Skip to main content

mockforge_registry_server/
database.rs

1//! Database connection and models
2
3use anyhow::Result;
4use sqlx::{postgres::PgPoolOptions, PgPool};
5
6#[derive(Clone, Debug)]
7pub struct Database {
8    pool: PgPool,
9}
10
11impl Database {
12    pub async fn connect(database_url: &str) -> Result<Self> {
13        // DATABASE_MAX_CONNECTIONS: Maximum number of database connections in the pool
14        // Default: 20
15        let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
16            .ok()
17            .and_then(|s| s.parse().ok())
18            .unwrap_or(20);
19
20        let pool = PgPoolOptions::new()
21            .max_connections(max_connections)
22            .connect(database_url)
23            .await?;
24
25        Ok(Self { pool })
26    }
27
28    pub async fn migrate(&self) -> Result<()> {
29        // Acquire a PostgreSQL advisory lock to prevent concurrent migration runs
30        // across multiple replicas. Lock ID 8675309 is an arbitrary but stable identifier.
31        const MIGRATION_LOCK_ID: i64 = 8675309;
32
33        tracing::info!("Acquiring advisory lock for database migrations...");
34        sqlx::query("SELECT pg_advisory_lock($1)")
35            .bind(MIGRATION_LOCK_ID)
36            .execute(&self.pool)
37            .await?;
38        tracing::info!("Advisory lock acquired, running migrations...");
39
40        // Run migrations. Any error aborts boot — `sqlx::migrate!().run()` bails
41        // on the *first* inconsistency without applying subsequent pending
42        // migrations, so historically-permissive "warn and continue" handling
43        // silently skipped everything past the gap and left the DB multiple
44        // versions behind without surfacing a single error to operators. We'd
45        // rather refuse to boot and have someone repair the `_sqlx_migrations`
46        // table than crash-loop a worker that depends on a table that never
47        // got created.
48        let result =
49            sqlx::migrate!("./migrations")
50                .run(&self.pool)
51                .await
52                .map_err(|e| -> anyhow::Error {
53                    if e.to_string().contains("previously applied but is missing") {
54                        tracing::error!(
55                            "sqlx refused to migrate: the DB's `_sqlx_migrations` table has an \
56                         applied row whose matching file is missing from the repo. This \
57                         blocks ALL subsequent migrations from running. To fix, either \
58                         restore the missing file or remove the orphaned tracking row \
59                         manually (psql: `DELETE FROM _sqlx_migrations WHERE version = …`). \
60                         Full error: {:?}",
61                            e
62                        );
63                    }
64                    e.into()
65                });
66
67        // Release the advisory lock regardless of migration outcome
68        if let Err(unlock_err) = sqlx::query("SELECT pg_advisory_unlock($1)")
69            .bind(MIGRATION_LOCK_ID)
70            .execute(&self.pool)
71            .await
72        {
73            tracing::error!("Failed to release migration advisory lock: {}", unlock_err);
74        } else {
75            tracing::info!("Migration advisory lock released");
76        }
77
78        result
79    }
80
81    pub fn pool(&self) -> &PgPool {
82        &self.pool
83    }
84
85    /// Get total number of plugins
86    pub async fn get_total_plugins(&self) -> Result<i64> {
87        let count: (i64,) =
88            sqlx::query_as("SELECT COUNT(*) FROM plugins").fetch_one(&self.pool).await?;
89        Ok(count.0)
90    }
91
92    /// Get total downloads across all plugins
93    pub async fn get_total_downloads(&self) -> Result<i64> {
94        // downloads_total is NUMERIC in database, so we need to cast it
95        let total: (Option<i64>,) =
96            sqlx::query_as("SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins")
97                .fetch_one(&self.pool)
98                .await?;
99        Ok(total.0.unwrap_or(0))
100    }
101
102    /// Get total number of users
103    pub async fn get_total_users(&self) -> Result<i64> {
104        let count: (i64,) =
105            sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(&self.pool).await?;
106        Ok(count.0)
107    }
108
109    // ==================== Token Revocation Functions ====================
110
111    /// Store a refresh token JTI for tracking (called on token creation)
112    pub async fn store_refresh_token_jti(
113        &self,
114        jti: &str,
115        user_id: uuid::Uuid,
116        expires_at: chrono::DateTime<chrono::Utc>,
117    ) -> Result<()> {
118        sqlx::query(
119            r#"
120            INSERT INTO token_revocations (jti, user_id, expires_at)
121            VALUES ($1, $2, $3)
122            ON CONFLICT (jti) DO NOTHING
123            "#,
124        )
125        .bind(jti)
126        .bind(user_id)
127        .bind(expires_at)
128        .execute(&self.pool)
129        .await?;
130
131        Ok(())
132    }
133
134    /// Check if a refresh token JTI has been revoked
135    pub async fn is_token_revoked(&self, jti: &str) -> Result<bool> {
136        let result: Option<(Option<chrono::DateTime<chrono::Utc>>,)> = sqlx::query_as(
137            r#"
138            SELECT revoked_at FROM token_revocations WHERE jti = $1
139            "#,
140        )
141        .bind(jti)
142        .fetch_optional(&self.pool)
143        .await?;
144
145        match result {
146            // Token found and has a revoked_at timestamp = revoked
147            Some((Some(_),)) => Ok(true),
148            // Token found but no revoked_at timestamp = not revoked (active)
149            Some((None,)) => Ok(false),
150            // Token not found = treat as revoked (unknown tokens should be rejected)
151            None => Ok(true),
152        }
153    }
154
155    /// Revoke a refresh token JTI (called on logout or token refresh)
156    pub async fn revoke_token(&self, jti: &str, reason: &str) -> Result<()> {
157        sqlx::query(
158            r#"
159            UPDATE token_revocations
160            SET revoked_at = NOW(), revocation_reason = $2
161            WHERE jti = $1 AND revoked_at IS NULL
162            "#,
163        )
164        .bind(jti)
165        .bind(reason)
166        .execute(&self.pool)
167        .await?;
168
169        Ok(())
170    }
171
172    /// Revoke all refresh tokens for a user (called on password change, security events)
173    pub async fn revoke_all_user_tokens(&self, user_id: uuid::Uuid, reason: &str) -> Result<u64> {
174        let result = sqlx::query(
175            r#"
176            UPDATE token_revocations
177            SET revoked_at = NOW(), revocation_reason = $2
178            WHERE user_id = $1 AND revoked_at IS NULL
179            "#,
180        )
181        .bind(user_id)
182        .bind(reason)
183        .execute(&self.pool)
184        .await?;
185
186        Ok(result.rows_affected())
187    }
188
189    /// Clean up expired token revocation records (for maintenance)
190    pub async fn cleanup_expired_tokens(&self) -> Result<u64> {
191        let result = sqlx::query(
192            r#"
193            DELETE FROM token_revocations
194            WHERE expires_at < NOW() - INTERVAL '1 day'
195            "#,
196        )
197        .execute(&self.pool)
198        .await?;
199
200        Ok(result.rows_affected())
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_database_clone() {
210        // Verify Database implements Clone
211        fn requires_clone<T: Clone>() {}
212        requires_clone::<Database>();
213    }
214
215    #[tokio::test]
216    async fn test_database_connect() {
217        // This test would require a real Postgres database
218        // We can test that the function exists and has the right signature
219        let database_url = "postgresql://test:test@localhost/test_db";
220
221        // Attempt to connect (will fail without a real database, which is expected)
222        let result = Database::connect(database_url).await;
223
224        // We expect this to fail since we don't have a database running
225        assert!(
226            result.is_err(),
227            "expected connection to fail without a running database, but got: {result:?}"
228        );
229    }
230
231    #[test]
232    fn test_database_pool_type() {
233        // Verify that Database has the expected structure
234        // This ensures the API surface is correct
235        fn check_pool_method(_db: &Database) -> &PgPool {
236            _db.pool()
237        }
238
239        // Verify the function has the expected signature (compile-time check)
240        let _: fn(&Database) -> &PgPool = check_pool_method;
241    }
242
243    // Mock test to verify query structures
244    #[test]
245    fn test_total_plugins_query_structure() {
246        let query = "SELECT COUNT(*) FROM plugins";
247
248        // Verify basic query structure
249        assert!(query.contains("SELECT"));
250        assert!(query.contains("COUNT(*)"));
251        assert!(query.contains("FROM plugins"));
252    }
253
254    #[test]
255    fn test_total_downloads_query_structure() {
256        let query = "SELECT COALESCE(SUM(downloads_total)::BIGINT, 0) FROM plugins";
257
258        // Verify query structure
259        assert!(query.contains("SELECT"));
260        assert!(query.contains("COALESCE"));
261        assert!(query.contains("SUM(downloads_total)"));
262        assert!(query.contains("FROM plugins"));
263        assert!(query.contains("::BIGINT"));
264    }
265
266    #[test]
267    fn test_total_users_query_structure() {
268        let query = "SELECT COUNT(*) FROM users";
269
270        // Verify basic query structure
271        assert!(query.contains("SELECT"));
272        assert!(query.contains("COUNT(*)"));
273        assert!(query.contains("FROM users"));
274    }
275
276    #[test]
277    fn test_migration_error_handling() {
278        // Verify the migration error message patterns
279        let error_msg = "previously applied but is missing";
280
281        assert!(error_msg.contains("previously applied"));
282        assert!(error_msg.contains("missing"));
283    }
284
285    // Integration-style tests (require database, so we make them conditional)
286    // These will be skipped in normal test runs but can be run with a test database
287
288    #[tokio::test]
289    #[ignore] // Requires database
290    async fn test_database_migration() {
291        // This test requires a real Postgres database
292        let database_url = std::env::var("TEST_DATABASE_URL")
293            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
294
295        if let Ok(db) = Database::connect(&database_url).await {
296            let result = db.migrate().await;
297            // Migration should either succeed or fail gracefully
298            assert!(result.is_ok() || result.is_err());
299        }
300    }
301
302    #[tokio::test]
303    #[ignore] // Requires database
304    async fn test_get_total_plugins() {
305        let database_url = std::env::var("TEST_DATABASE_URL")
306            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
307
308        if let Ok(db) = Database::connect(&database_url).await {
309            let _ = db.migrate().await;
310
311            let result = db.get_total_plugins().await;
312            if let Ok(count) = result {
313                assert!(count >= 0);
314            }
315        }
316    }
317
318    #[tokio::test]
319    #[ignore] // Requires database
320    async fn test_get_total_downloads() {
321        let database_url = std::env::var("TEST_DATABASE_URL")
322            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
323
324        if let Ok(db) = Database::connect(&database_url).await {
325            let _ = db.migrate().await;
326
327            let result = db.get_total_downloads().await;
328            if let Ok(count) = result {
329                assert!(count >= 0);
330            }
331        }
332    }
333
334    #[tokio::test]
335    #[ignore] // Requires database
336    async fn test_get_total_users() {
337        let database_url = std::env::var("TEST_DATABASE_URL")
338            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
339
340        if let Ok(db) = Database::connect(&database_url).await {
341            let _ = db.migrate().await;
342
343            let result = db.get_total_users().await;
344            if let Ok(count) = result {
345                assert!(count >= 0);
346            }
347        }
348    }
349
350    #[tokio::test]
351    #[ignore] // Requires database
352    async fn test_pool_reuse() {
353        let database_url = std::env::var("TEST_DATABASE_URL")
354            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
355
356        if let Ok(db) = Database::connect(&database_url).await {
357            // Get pool reference multiple times
358            let pool1 = db.pool();
359            let pool2 = db.pool();
360
361            // Should return the same pool
362            assert!(std::ptr::eq(pool1, pool2));
363        }
364    }
365
366    #[test]
367    fn test_database_connection_string_validation() {
368        // Test various database URL formats
369        let valid_urls = vec![
370            "postgresql://user:pass@localhost/db",
371            "postgresql://user:pass@localhost:5432/db",
372            "postgresql://localhost/db",
373            "postgres://user:pass@host:5432/database?sslmode=require",
374        ];
375
376        for url in valid_urls {
377            assert!(url.starts_with("postgres"));
378            assert!(url.contains("://"));
379        }
380    }
381
382    #[test]
383    fn test_max_connections_config() {
384        // Verify the default max_connections value is reasonable
385        let max_connections = 20; // Default value from DATABASE_MAX_CONNECTIONS env var
386
387        assert!(max_connections > 0);
388        assert!(max_connections <= 100); // Reasonable upper bound
389    }
390
391    #[tokio::test]
392    #[ignore] // Requires database
393    async fn test_migration_idempotency() {
394        let database_url = std::env::var("TEST_DATABASE_URL")
395            .unwrap_or_else(|_| "postgresql://test:test@localhost/test_db".to_string());
396
397        if let Ok(db) = Database::connect(&database_url).await {
398            // Run migrations twice
399            let result1 = db.migrate().await;
400            let result2 = db.migrate().await;
401
402            // Both should succeed (migrations are idempotent)
403            // Or both should handle the "already applied" case gracefully
404            assert!(result1.is_ok() || result1.is_err());
405            assert!(result2.is_ok() || result2.is_err());
406        }
407    }
408
409    #[test]
410    fn test_query_return_types() {
411        // Verify that query return types are correct
412        // This is a compile-time check that the types match expectations
413
414        fn check_total_plugins_type(_: i64) {}
415        fn check_total_downloads_type(_: i64) {}
416        fn check_total_users_type(_: i64) {}
417
418        // Verify the functions accept i64 (compile-time type check)
419        let _: fn(i64) = check_total_plugins_type;
420        let _: fn(i64) = check_total_downloads_type;
421        let _: fn(i64) = check_total_users_type;
422    }
423
424    #[test]
425    fn test_database_error_types() {
426        // Verify error types are appropriate
427        use anyhow::Result;
428
429        fn returns_result() -> Result<()> {
430            Ok(())
431        }
432
433        let result = returns_result();
434        assert!(result.is_ok());
435    }
436}