postrust_core/
config.rs

1//! Configuration for Postrust.
2//!
3//! Mirrors PostgREST's configuration options.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Main application configuration.
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct AppConfig {
11    // ========================================================================
12    // Database Settings
13    // ========================================================================
14    /// PostgreSQL connection URI
15    #[serde(default = "default_db_uri")]
16    pub db_uri: String,
17
18    /// Schemas to expose via the API
19    #[serde(default = "default_db_schemas")]
20    pub db_schemas: Vec<String>,
21
22    /// Role for unauthenticated requests
23    pub db_anon_role: Option<String>,
24
25    /// Connection pool size
26    #[serde(default = "default_pool_size")]
27    pub db_pool_size: u32,
28
29    /// Pool acquisition timeout in seconds
30    #[serde(default = "default_pool_timeout")]
31    pub db_pool_timeout: u64,
32
33    /// Use prepared statements
34    #[serde(default = "default_true")]
35    pub db_prepared_statements: bool,
36
37    /// Extra search path schemas
38    #[serde(default)]
39    pub db_extra_search_path: Vec<String>,
40
41    /// LISTEN/NOTIFY channel for schema reload
42    #[serde(default = "default_db_channel")]
43    pub db_channel: String,
44
45    /// Enable NOTIFY-based schema cache reload
46    #[serde(default)]
47    pub db_channel_enabled: bool,
48
49    /// Pre-request function to call
50    pub db_pre_request: Option<String>,
51
52    /// Maximum rows allowed in a response
53    pub db_max_rows: Option<i64>,
54
55    /// Enable aggregate functions
56    #[serde(default = "default_true")]
57    pub db_aggregates_enabled: bool,
58
59    // ========================================================================
60    // Server Settings
61    // ========================================================================
62    /// Server host to bind
63    #[serde(default = "default_host")]
64    pub server_host: String,
65
66    /// Server port
67    #[serde(default = "default_port")]
68    pub server_port: u16,
69
70    /// Unix socket path (alternative to host/port)
71    pub server_unix_socket: Option<String>,
72
73    /// Admin server port (for health checks)
74    pub admin_server_port: Option<u16>,
75
76    // ========================================================================
77    // JWT Settings
78    // ========================================================================
79    /// JWT secret key (or JWKS URL)
80    pub jwt_secret: Option<String>,
81
82    /// JWT secret as base64
83    #[serde(default)]
84    pub jwt_secret_is_base64: bool,
85
86    /// JWT audience claim to validate
87    pub jwt_aud: Option<String>,
88
89    /// JWT claim that contains the role
90    #[serde(default = "default_jwt_role_claim")]
91    pub jwt_role_claim_key: String,
92
93    /// Cache JWT validations
94    #[serde(default = "default_true")]
95    pub jwt_cache_enabled: bool,
96
97    /// JWT cache max entries
98    #[serde(default = "default_jwt_cache_max")]
99    pub jwt_cache_max_lifetime: u64,
100
101    // ========================================================================
102    // OpenAPI Settings
103    // ========================================================================
104    /// OpenAPI server URL
105    pub openapi_server_proxy_uri: Option<String>,
106
107    /// OpenAPI mode: disabled, follow-privileges, ignore-privileges, security-definer
108    #[serde(default = "default_openapi_mode")]
109    pub openapi_mode: OpenApiMode,
110
111    // ========================================================================
112    // Logging Settings
113    // ========================================================================
114    /// Log level: crit, error, warn, info, debug
115    #[serde(default = "default_log_level")]
116    pub log_level: LogLevel,
117
118    // ========================================================================
119    // Role Settings
120    // ========================================================================
121    /// Per-role settings (isolation level, timeout)
122    #[serde(default)]
123    pub role_settings: HashMap<String, RoleSettings>,
124
125    /// App-level settings to expose via GUC
126    #[serde(default)]
127    pub app_settings: HashMap<String, String>,
128}
129
130impl Default for AppConfig {
131    fn default() -> Self {
132        Self {
133            db_uri: default_db_uri(),
134            db_schemas: default_db_schemas(),
135            db_anon_role: None,
136            db_pool_size: default_pool_size(),
137            db_pool_timeout: default_pool_timeout(),
138            db_prepared_statements: true,
139            db_extra_search_path: vec![],
140            db_channel: default_db_channel(),
141            db_channel_enabled: false,
142            db_pre_request: None,
143            db_max_rows: None,
144            db_aggregates_enabled: true,
145            server_host: default_host(),
146            server_port: default_port(),
147            server_unix_socket: None,
148            admin_server_port: None,
149            jwt_secret: None,
150            jwt_secret_is_base64: false,
151            jwt_aud: None,
152            jwt_role_claim_key: default_jwt_role_claim(),
153            jwt_cache_enabled: true,
154            jwt_cache_max_lifetime: default_jwt_cache_max(),
155            openapi_server_proxy_uri: None,
156            openapi_mode: OpenApiMode::FollowPrivileges,
157            log_level: LogLevel::Error,
158            role_settings: HashMap::new(),
159            app_settings: HashMap::new(),
160        }
161    }
162}
163
164impl AppConfig {
165    /// Load configuration from environment variables.
166    pub fn from_env() -> Self {
167        let mut config = Self::default();
168
169        if let Ok(uri) = std::env::var("PGRST_DB_URI") {
170            config.db_uri = uri;
171        }
172        if let Ok(uri) = std::env::var("DATABASE_URL") {
173            config.db_uri = uri;
174        }
175        if let Ok(schemas) = std::env::var("PGRST_DB_SCHEMAS") {
176            config.db_schemas = schemas.split(',').map(|s| s.trim().to_string()).collect();
177        }
178        if let Ok(role) = std::env::var("PGRST_DB_ANON_ROLE") {
179            config.db_anon_role = Some(role);
180        }
181        if let Ok(size) = std::env::var("PGRST_DB_POOL") {
182            if let Ok(n) = size.parse() {
183                config.db_pool_size = n;
184            }
185        }
186        if let Ok(secret) = std::env::var("PGRST_JWT_SECRET") {
187            config.jwt_secret = Some(secret);
188        }
189        if let Ok(aud) = std::env::var("PGRST_JWT_AUD") {
190            config.jwt_aud = Some(aud);
191        }
192        if let Ok(host) = std::env::var("PGRST_SERVER_HOST") {
193            config.server_host = host;
194        }
195        if let Ok(port) = std::env::var("PGRST_SERVER_PORT") {
196            if let Ok(p) = port.parse() {
197                config.server_port = p;
198            }
199        }
200        if let Ok(port) = std::env::var("PORT") {
201            if let Ok(p) = port.parse() {
202                config.server_port = p;
203            }
204        }
205
206        config
207    }
208
209    /// Get the default schema (first in the list).
210    pub fn default_schema(&self) -> &str {
211        self.db_schemas.first().map(|s| s.as_str()).unwrap_or("public")
212    }
213}
214
215/// Per-role settings.
216#[derive(Clone, Debug, Serialize, Deserialize)]
217pub struct RoleSettings {
218    /// Isolation level for this role
219    pub isolation_level: Option<IsolationLevel>,
220    /// Statement timeout in milliseconds
221    pub statement_timeout: Option<u64>,
222}
223
224/// Transaction isolation levels.
225#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
226pub enum IsolationLevel {
227    ReadCommitted,
228    RepeatableRead,
229    Serializable,
230}
231
232impl IsolationLevel {
233    pub fn to_sql(&self) -> &'static str {
234        match self {
235            Self::ReadCommitted => "READ COMMITTED",
236            Self::RepeatableRead => "REPEATABLE READ",
237            Self::Serializable => "SERIALIZABLE",
238        }
239    }
240}
241
242/// OpenAPI generation mode.
243#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
244pub enum OpenApiMode {
245    Disabled,
246    FollowPrivileges,
247    IgnorePrivileges,
248    SecurityDefiner,
249}
250
251/// Log levels.
252#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
253pub enum LogLevel {
254    Crit,
255    Error,
256    Warn,
257    Info,
258    Debug,
259}
260
261impl LogLevel {
262    pub fn to_tracing(&self) -> tracing::Level {
263        match self {
264            Self::Crit | Self::Error => tracing::Level::ERROR,
265            Self::Warn => tracing::Level::WARN,
266            Self::Info => tracing::Level::INFO,
267            Self::Debug => tracing::Level::DEBUG,
268        }
269    }
270}
271
272// Default value functions
273fn default_db_uri() -> String {
274    "postgresql://localhost/postgres".to_string()
275}
276
277fn default_db_schemas() -> Vec<String> {
278    vec!["public".to_string()]
279}
280
281fn default_pool_size() -> u32 {
282    10
283}
284
285fn default_pool_timeout() -> u64 {
286    10
287}
288
289fn default_db_channel() -> String {
290    "pgrst".to_string()
291}
292
293fn default_host() -> String {
294    "127.0.0.1".to_string()
295}
296
297fn default_port() -> u16 {
298    3000
299}
300
301fn default_jwt_role_claim() -> String {
302    "role".to_string()
303}
304
305fn default_jwt_cache_max() -> u64 {
306    3600
307}
308
309fn default_openapi_mode() -> OpenApiMode {
310    OpenApiMode::FollowPrivileges
311}
312
313fn default_log_level() -> LogLevel {
314    LogLevel::Error
315}
316
317fn default_true() -> bool {
318    true
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_default_config() {
327        let config = AppConfig::default();
328        assert_eq!(config.server_port, 3000);
329        assert_eq!(config.db_pool_size, 10);
330        assert!(config.db_prepared_statements);
331    }
332
333    #[test]
334    fn test_default_schema() {
335        let mut config = AppConfig::default();
336        assert_eq!(config.default_schema(), "public");
337
338        config.db_schemas = vec!["api".to_string(), "public".to_string()];
339        assert_eq!(config.default_schema(), "api");
340    }
341
342    #[test]
343    fn test_isolation_level_sql() {
344        assert_eq!(IsolationLevel::ReadCommitted.to_sql(), "READ COMMITTED");
345        assert_eq!(IsolationLevel::Serializable.to_sql(), "SERIALIZABLE");
346    }
347}