Skip to main content

forge_core/config/
database.rs

1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{ForgeError, Result};
6
7use super::default_true;
8use super::types::DurationStr;
9
10/// Database configuration. One pool, no per-workload isolation: workload
11/// separation belongs at the worker level, not the connection level. The
12/// single-pool contention model and sizing formula are documented at the
13/// runtime side in `forge_runtime::pg::pool` module docs.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(deny_unknown_fields)]
16#[non_exhaustive]
17pub struct DatabaseConfig {
18    /// PostgreSQL connection URL.
19    #[serde(default)]
20    pub url: String,
21
22    /// Connection pool size. Should be sized as
23    /// `worker.max_concurrent + reactor cap + expected gateway concurrency
24    /// + ~6 for persistent listeners, leader holds, and headroom`. See
25    /// `forge_runtime::pg::pool` module docs.
26    #[serde(default = "default_pool_size")]
27    pub pool_size: u32,
28
29    /// Pool checkout timeout duration (e.g. "30s", "1m").
30    #[serde(default = "default_pool_timeout")]
31    pub pool_timeout: DurationStr,
32
33    /// Statement timeout duration (e.g. "30s", "5m").
34    #[serde(default = "default_statement_timeout")]
35    pub statement_timeout: DurationStr,
36
37    /// Read replica URLs for scaling reads.
38    #[serde(default)]
39    pub replica_urls: Vec<String>,
40
41    /// Whether to route read queries to replicas.
42    #[serde(default)]
43    pub read_from_replica: bool,
44
45    /// Replica pool size. When unset, defaults to `pool_size / 2`.
46    #[serde(default)]
47    pub replica_pool_size: Option<u32>,
48
49    /// Minimum connections to keep alive in the pool (pre-warming).
50    #[serde(default)]
51    pub min_pool_size: u32,
52
53    /// Run a health check query before handing out connections.
54    /// Disabling this halves round-trips for read queries.
55    #[serde(default = "default_true")]
56    pub test_before_acquire: bool,
57}
58
59impl Default for DatabaseConfig {
60    fn default() -> Self {
61        Self {
62            url: String::new(),
63            pool_size: default_pool_size(),
64            pool_timeout: default_pool_timeout(),
65            statement_timeout: default_statement_timeout(),
66            replica_urls: Vec::new(),
67            read_from_replica: false,
68            replica_pool_size: None,
69            min_pool_size: 0,
70            test_before_acquire: true,
71        }
72    }
73}
74
75impl DatabaseConfig {
76    /// Create a config with a database URL.
77    pub fn new(url: impl Into<String>) -> Self {
78        Self {
79            url: url.into(),
80            ..Default::default()
81        }
82    }
83
84    /// Get the database URL.
85    pub fn url(&self) -> &str {
86        &self.url
87    }
88
89    /// Validate the database configuration.
90    pub fn validate(&self) -> Result<()> {
91        if self.url.is_empty() {
92            return Err(ForgeError::config(
93                "database.url is required. \
94                 Set database.url to a PostgreSQL connection string \
95                 (e.g., \"postgres://user:pass@localhost/mydb\").",
96            ));
97        }
98        Ok(())
99    }
100}
101
102fn default_pool_size() -> u32 {
103    // Internal baseline with all defaults:
104    //   14 worker slots (8 default + 4 workflows + 2 cron)
105    //  +64 reactor max-concurrent re-executions
106    //  + 6 persistent listeners, leader holds, health check, migration
107    // = 84 connections consumed before any gateway traffic arrives.
108    // 16 added on top as headroom for light gateway traffic, landing at 100.
109    // Users running at scale should set pool_size explicitly based on their
110    // expected concurrent gateway load; see the sizing formula in
111    // `forge_runtime::pg::pool`.
112    100
113}
114
115fn default_pool_timeout() -> DurationStr {
116    DurationStr::new(Duration::from_secs(30))
117}
118
119fn default_statement_timeout() -> DurationStr {
120    DurationStr::new(Duration::from_secs(30))
121}
122
123#[cfg(test)]
124#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_default_database_config() {
130        let config = DatabaseConfig::default();
131        assert_eq!(config.pool_size, 100);
132        assert_eq!(config.pool_timeout.as_secs(), 30);
133        assert!(config.url.is_empty());
134    }
135
136    #[test]
137    fn test_new_config() {
138        let config = DatabaseConfig::new("postgres://localhost/test");
139        assert_eq!(config.url(), "postgres://localhost/test");
140    }
141
142    #[test]
143    fn test_parse_config() {
144        let toml = r#"
145            url = "postgres://localhost/test"
146            pool_size = 100
147            replica_urls = ["postgres://replica1/test", "postgres://replica2/test"]
148            read_from_replica = true
149        "#;
150
151        let config: DatabaseConfig = toml::from_str(toml).unwrap();
152        assert_eq!(config.pool_size, 100);
153        assert_eq!(config.url(), "postgres://localhost/test");
154        assert_eq!(config.replica_urls.len(), 2);
155        assert!(config.read_from_replica);
156    }
157
158    #[test]
159    fn test_validate_with_url() {
160        let config = DatabaseConfig::new("postgres://localhost/test");
161        assert!(config.validate().is_ok());
162    }
163
164    #[test]
165    fn test_validate_empty_url() {
166        let config = DatabaseConfig::default();
167        let result = config.validate();
168        assert!(result.is_err());
169        let err_msg = result.unwrap_err().to_string();
170        assert!(err_msg.contains("database.url is required"));
171    }
172
173    #[test]
174    fn test_rejects_legacy_pools_blocks() {
175        let toml = r#"
176            url = "postgres://localhost/test"
177            [pools.jobs]
178            size = 10
179        "#;
180        let err = toml::from_str::<DatabaseConfig>(toml).unwrap_err();
181        let msg = err.to_string();
182        assert!(
183            msg.contains("unknown field"),
184            "expected unknown-field error, got: {msg}"
185        );
186    }
187}