Skip to main content

mockforge_registry_server/
config.rs

1//! Server configuration
2
3use anyhow::{Context, Result};
4use serde::Deserialize;
5
6/// Helper to get a required environment variable with a descriptive error
7fn required_env(name: &str) -> Result<String> {
8    std::env::var(name).with_context(|| {
9        format!(
10            "Required environment variable '{name}' is not set. \
11             Please set it before starting the server."
12        )
13    })
14}
15
16#[derive(Debug, Clone, Deserialize)]
17pub struct Config {
18    /// Server port
19    pub port: u16,
20
21    /// Database connection URL
22    pub database_url: String,
23
24    /// JWT secret for authentication
25    pub jwt_secret: String,
26
27    /// S3 configuration
28    pub s3_bucket: String,
29    pub s3_region: String,
30    pub s3_endpoint: Option<String>, // For MinIO/custom S3
31
32    /// Upload limits
33    pub max_plugin_size: usize, // in bytes (default 50MB)
34
35    /// Rate limiting
36    pub rate_limit_per_minute: u32,
37
38    /// Analytics database path (optional, defaults to "mockforge-analytics.db" in current directory)
39    pub analytics_db_path: Option<String>,
40
41    /// Graceful shutdown timeout in seconds
42    pub shutdown_timeout_secs: u64,
43
44    /// Redis URL for caching and temporary storage (optional)
45    pub redis_url: Option<String>,
46
47    /// Skip running database migrations on startup (default: false)
48    /// Set SKIP_MIGRATIONS=true when running migrations as a separate K8s Job
49    pub skip_migrations: bool,
50
51    /// Whether two-factor authentication is enabled (requires Redis)
52    pub two_factor_enabled: Option<bool>,
53
54    /// Base URL of the application (for OAuth callbacks and email links)
55    pub app_base_url: String,
56
57    /// Stripe secret key for billing
58    pub stripe_secret_key: Option<String>,
59
60    /// Stripe price ID for Pro plan
61    pub stripe_price_id_pro: Option<String>,
62
63    /// Stripe price ID for Team plan
64    pub stripe_price_id_team: Option<String>,
65
66    /// Stripe webhook secret for verifying webhook signatures
67    pub stripe_webhook_secret: Option<String>,
68
69    /// GitHub OAuth client ID
70    pub oauth_github_client_id: Option<String>,
71
72    /// GitHub OAuth client secret
73    pub oauth_github_client_secret: Option<String>,
74
75    /// Google OAuth client ID
76    pub oauth_google_client_id: Option<String>,
77
78    /// Google OAuth client secret
79    pub oauth_google_client_secret: Option<String>,
80}
81
82impl Config {
83    /// Load configuration from environment variables.
84    ///
85    /// Required environment variables:
86    /// - `DATABASE_URL`: Database connection URL
87    /// - `JWT_SECRET`: Secret key for JWT token signing
88    ///
89    /// Optional environment variables (with defaults):
90    /// - `PORT`: Server port (default: 8080)
91    /// - `S3_BUCKET`: S3 bucket name (default: "mockforge-plugins")
92    /// - `S3_REGION`: S3 region (default: "us-east-1")
93    /// - `S3_ENDPOINT`: Custom S3 endpoint for MinIO/compatible storage
94    /// - `MAX_PLUGIN_SIZE`: Maximum plugin size in bytes (default: 52428800 / 50MB)
95    /// - `RATE_LIMIT_PER_MINUTE`: Rate limit per minute (default: 60)
96    /// - `ANALYTICS_DB_PATH`: Path to analytics database
97    /// - `SHUTDOWN_TIMEOUT_SECS`: Graceful shutdown timeout in seconds (default: 30)
98    pub fn load() -> Result<Self> {
99        dotenvy::dotenv().ok();
100
101        // Collect all missing required variables first for better error reporting
102        let mut missing_vars = Vec::new();
103
104        let database_url = match required_env("DATABASE_URL") {
105            Ok(url) => Some(url),
106            Err(_) => {
107                missing_vars.push("DATABASE_URL");
108                None
109            }
110        };
111
112        let jwt_secret = match required_env("JWT_SECRET") {
113            Ok(secret) => Some(secret),
114            Err(_) => {
115                missing_vars.push("JWT_SECRET");
116                None
117            }
118        };
119
120        // Report all missing required variables at once
121        if !missing_vars.is_empty() {
122            anyhow::bail!(
123                "Missing required environment variables: {}. \
124                 Please ensure these are set before starting the server.",
125                missing_vars.join(", ")
126            );
127        }
128
129        let config = Self {
130            port: std::env::var("PORT")
131                .unwrap_or_else(|_| "8080".to_string())
132                .parse()
133                .context("PORT must be a valid port number (0-65535)")?,
134            database_url: database_url.unwrap(),
135            jwt_secret: jwt_secret.unwrap(),
136            // Fall back to the AWS-standard env var names that Fly's
137            // `flyctl storage create` and other tooling auto-set. Lets
138            // Tigris (or any S3-compatible backend) wire up without
139            // manually aliasing secrets.
140            s3_bucket: std::env::var("S3_BUCKET")
141                .or_else(|_| std::env::var("BUCKET_NAME"))
142                .unwrap_or_else(|_| "mockforge-plugins".to_string()),
143            s3_region: std::env::var("S3_REGION")
144                .or_else(|_| std::env::var("AWS_REGION"))
145                .unwrap_or_else(|_| "us-east-1".to_string()),
146            s3_endpoint: std::env::var("S3_ENDPOINT")
147                .ok()
148                .or_else(|| std::env::var("AWS_ENDPOINT_URL_S3").ok()),
149            max_plugin_size: std::env::var("MAX_PLUGIN_SIZE")
150                .unwrap_or_else(|_| "52428800".to_string()) // 50MB
151                .parse()
152                .context("MAX_PLUGIN_SIZE must be a valid number")?,
153            rate_limit_per_minute: std::env::var("RATE_LIMIT_PER_MINUTE")
154                .unwrap_or_else(|_| "60".to_string())
155                .parse()
156                .context("RATE_LIMIT_PER_MINUTE must be a valid number")?,
157            analytics_db_path: std::env::var("ANALYTICS_DB_PATH").ok(),
158            shutdown_timeout_secs: std::env::var("SHUTDOWN_TIMEOUT_SECS")
159                .unwrap_or_else(|_| "30".to_string())
160                .parse()
161                .context("SHUTDOWN_TIMEOUT_SECS must be a valid number")?,
162            skip_migrations: std::env::var("SKIP_MIGRATIONS")
163                .ok()
164                .map(|v| v.to_lowercase() == "true" || v == "1")
165                .unwrap_or(false),
166            redis_url: std::env::var("REDIS_URL").ok(),
167            two_factor_enabled: std::env::var("TWO_FACTOR_ENABLED")
168                .ok()
169                .map(|v| v.to_lowercase() == "true" || v == "1"),
170            app_base_url: std::env::var("APP_BASE_URL")
171                .unwrap_or_else(|_| "http://localhost:3000".to_string()),
172            stripe_secret_key: std::env::var("STRIPE_SECRET_KEY").ok(),
173            stripe_price_id_pro: std::env::var("STRIPE_PRICE_ID_PRO").ok(),
174            stripe_price_id_team: std::env::var("STRIPE_PRICE_ID_TEAM").ok(),
175            stripe_webhook_secret: std::env::var("STRIPE_WEBHOOK_SECRET").ok(),
176            oauth_github_client_id: std::env::var("OAUTH_GITHUB_CLIENT_ID").ok(),
177            oauth_github_client_secret: std::env::var("OAUTH_GITHUB_CLIENT_SECRET").ok(),
178            oauth_google_client_id: std::env::var("OAUTH_GOOGLE_CLIENT_ID").ok(),
179            oauth_google_client_secret: std::env::var("OAUTH_GOOGLE_CLIENT_SECRET").ok(),
180        };
181
182        Ok(config)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::sync::Mutex;
190
191    // Mutex to serialize tests that modify environment variables
192    static ENV_MUTEX: Mutex<()> = Mutex::new(());
193
194    #[test]
195    fn test_config_defaults() {
196        let _guard = ENV_MUTEX.lock().unwrap();
197        // Set required env vars
198        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
199        std::env::set_var("JWT_SECRET", "test-secret");
200
201        let config = Config::load().unwrap();
202
203        // Check defaults
204        assert_eq!(config.s3_bucket, "mockforge-plugins");
205        assert_eq!(config.s3_region, "us-east-1");
206        assert_eq!(config.max_plugin_size, 52428800); // 50MB
207        assert_eq!(config.rate_limit_per_minute, 60);
208        assert!(config.s3_endpoint.is_none());
209        assert!(config.analytics_db_path.is_none());
210        assert_eq!(config.shutdown_timeout_secs, 30); // Default shutdown timeout
211
212        // Clean up
213        std::env::remove_var("DATABASE_URL");
214        std::env::remove_var("JWT_SECRET");
215    }
216
217    #[test]
218    fn test_config_custom_values() {
219        let _guard = ENV_MUTEX.lock().unwrap();
220        // Set all env vars
221        std::env::set_var("PORT", "9090");
222        std::env::set_var("DATABASE_URL", "postgres://custom/db");
223        std::env::set_var("JWT_SECRET", "custom-secret");
224        std::env::set_var("S3_BUCKET", "custom-bucket");
225        std::env::set_var("S3_REGION", "eu-west-1");
226        std::env::set_var("S3_ENDPOINT", "http://localhost:9000");
227        std::env::set_var("MAX_PLUGIN_SIZE", "10485760"); // 10MB
228        std::env::set_var("RATE_LIMIT_PER_MINUTE", "120");
229        std::env::set_var("ANALYTICS_DB_PATH", "/custom/path/analytics.db");
230        std::env::set_var("SHUTDOWN_TIMEOUT_SECS", "60");
231
232        let config = Config::load().unwrap();
233
234        assert_eq!(config.port, 9090);
235        assert_eq!(config.database_url, "postgres://custom/db");
236        assert_eq!(config.jwt_secret, "custom-secret");
237        assert_eq!(config.s3_bucket, "custom-bucket");
238        assert_eq!(config.s3_region, "eu-west-1");
239        assert_eq!(config.s3_endpoint, Some("http://localhost:9000".to_string()));
240        assert_eq!(config.max_plugin_size, 10485760);
241        assert_eq!(config.rate_limit_per_minute, 120);
242        assert_eq!(config.analytics_db_path, Some("/custom/path/analytics.db".to_string()));
243        assert_eq!(config.shutdown_timeout_secs, 60);
244
245        // Clean up
246        std::env::remove_var("PORT");
247        std::env::remove_var("DATABASE_URL");
248        std::env::remove_var("JWT_SECRET");
249        std::env::remove_var("S3_BUCKET");
250        std::env::remove_var("S3_REGION");
251        std::env::remove_var("S3_ENDPOINT");
252        std::env::remove_var("MAX_PLUGIN_SIZE");
253        std::env::remove_var("RATE_LIMIT_PER_MINUTE");
254        std::env::remove_var("ANALYTICS_DB_PATH");
255        std::env::remove_var("SHUTDOWN_TIMEOUT_SECS");
256    }
257
258    #[test]
259    fn test_config_missing_required_database_url() {
260        let _guard = ENV_MUTEX.lock().unwrap();
261        std::env::remove_var("DATABASE_URL");
262        std::env::set_var("JWT_SECRET", "test-secret");
263
264        let result = Config::load();
265
266        assert!(result.is_err());
267        let error_msg = result.unwrap_err().to_string();
268        assert!(
269            error_msg.contains("DATABASE_URL"),
270            "Error should mention DATABASE_URL: {error_msg}"
271        );
272
273        // Clean up
274        std::env::remove_var("JWT_SECRET");
275    }
276
277    #[test]
278    fn test_config_missing_required_jwt_secret() {
279        let _guard = ENV_MUTEX.lock().unwrap();
280        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
281        std::env::remove_var("JWT_SECRET");
282
283        let result = Config::load();
284
285        assert!(result.is_err());
286        let error_msg = result.unwrap_err().to_string();
287        assert!(error_msg.contains("JWT_SECRET"), "Error should mention JWT_SECRET: {error_msg}");
288
289        // Clean up
290        std::env::remove_var("DATABASE_URL");
291    }
292
293    #[test]
294    fn test_config_missing_both_required_vars() {
295        let _guard = ENV_MUTEX.lock().unwrap();
296        std::env::remove_var("DATABASE_URL");
297        std::env::remove_var("JWT_SECRET");
298
299        let result = Config::load();
300
301        assert!(result.is_err());
302        let error_msg = result.unwrap_err().to_string();
303        // Should report both missing variables
304        assert!(
305            error_msg.contains("DATABASE_URL") && error_msg.contains("JWT_SECRET"),
306            "Error should mention both missing variables: {error_msg}"
307        );
308    }
309
310    #[test]
311    fn test_config_invalid_port() {
312        let _guard = ENV_MUTEX.lock().unwrap();
313        std::env::set_var("PORT", "invalid");
314        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
315        std::env::set_var("JWT_SECRET", "test-secret");
316
317        let result = Config::load();
318        assert!(result.is_err());
319
320        // Clean up
321        std::env::remove_var("PORT");
322        std::env::remove_var("DATABASE_URL");
323        std::env::remove_var("JWT_SECRET");
324    }
325
326    #[test]
327    fn test_config_invalid_max_plugin_size() {
328        let _guard = ENV_MUTEX.lock().unwrap();
329        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
330        std::env::set_var("JWT_SECRET", "test-secret");
331        std::env::set_var("MAX_PLUGIN_SIZE", "not-a-number");
332
333        let result = Config::load();
334        assert!(result.is_err());
335
336        // Clean up
337        std::env::remove_var("DATABASE_URL");
338        std::env::remove_var("JWT_SECRET");
339        std::env::remove_var("MAX_PLUGIN_SIZE");
340    }
341
342    #[test]
343    fn test_config_invalid_rate_limit() {
344        let _guard = ENV_MUTEX.lock().unwrap();
345        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
346        std::env::set_var("JWT_SECRET", "test-secret");
347        std::env::set_var("RATE_LIMIT_PER_MINUTE", "not-a-number");
348
349        let result = Config::load();
350        assert!(result.is_err());
351
352        // Clean up
353        std::env::remove_var("DATABASE_URL");
354        std::env::remove_var("JWT_SECRET");
355        std::env::remove_var("RATE_LIMIT_PER_MINUTE");
356    }
357
358    #[test]
359    fn test_config_port_boundary_values() {
360        let _guard = ENV_MUTEX.lock().unwrap();
361        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
362        std::env::set_var("JWT_SECRET", "test-secret");
363
364        // Test port 0
365        std::env::set_var("PORT", "0");
366        let config = Config::load().unwrap();
367        assert_eq!(config.port, 0);
368
369        // Test max port
370        std::env::set_var("PORT", "65535");
371        let config = Config::load().unwrap();
372        assert_eq!(config.port, 65535);
373
374        // Clean up
375        std::env::remove_var("PORT");
376        std::env::remove_var("DATABASE_URL");
377        std::env::remove_var("JWT_SECRET");
378    }
379
380    #[test]
381    fn test_config_clone() {
382        let _guard = ENV_MUTEX.lock().unwrap();
383        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
384        std::env::set_var("JWT_SECRET", "test-secret");
385
386        let config = Config::load().unwrap();
387        let cloned = config.clone();
388
389        assert_eq!(config.database_url, cloned.database_url);
390        assert_eq!(config.jwt_secret, cloned.jwt_secret);
391        assert_eq!(config.port, cloned.port);
392
393        // Clean up
394        std::env::remove_var("DATABASE_URL");
395        std::env::remove_var("JWT_SECRET");
396    }
397
398    #[test]
399    fn test_config_debug() {
400        let _guard = ENV_MUTEX.lock().unwrap();
401        std::env::set_var("DATABASE_URL", "postgres://localhost/test");
402        std::env::set_var("JWT_SECRET", "test-secret");
403
404        let config = Config::load().unwrap();
405        let debug_str = format!("{:?}", config);
406
407        // Should contain field names
408        assert!(debug_str.contains("port"));
409        assert!(debug_str.contains("database_url"));
410
411        // Clean up
412        std::env::remove_var("DATABASE_URL");
413        std::env::remove_var("JWT_SECRET");
414    }
415}