Skip to main content

mcp_postgres/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
8pub enum AccessMode {
9    #[serde(rename = "unrestricted")]
10    Unrestricted,
11    #[serde(rename = "restricted")]
12    Restricted,
13}
14
15impl fmt::Display for AccessMode {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            AccessMode::Unrestricted => write!(f, "unrestricted"),
19            AccessMode::Restricted => write!(f, "restricted"),
20        }
21    }
22}
23
24impl FromStr for AccessMode {
25    type Err = String;
26    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "unrestricted" => Ok(AccessMode::Unrestricted),
29            "restricted" => Ok(AccessMode::Restricted),
30            _ => Err(format!(
31                "Invalid access mode: {s}. Use 'unrestricted' or 'restricted'"
32            )),
33        }
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Config {
39    pub database: DatabaseConfig,
40    pub server: ServerConfig,
41    pub pool: PoolConfig,
42    pub metrics: MetricsConfig,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabaseConfig {
47    pub url: String,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ServerConfig {
52    pub host: String,
53    pub port: u16,
54    pub request_timeout: Duration,
55    pub access_mode: AccessMode,
56    /// Shared secret required for TCP/HTTP transports. `None` means no auth
57    /// (only permitted on loopback binds).
58    #[serde(default, skip_serializing)]
59    pub auth_token: Option<String>,
60    /// Whether the import_from_url tool may make outbound HTTP fetches.
61    #[serde(default)]
62    pub allow_url_import: bool,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct PoolConfig {
67    pub min_size: u32,
68    pub max_size: u32,
69    pub queue_timeout: Duration,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct MetricsConfig {
74    pub enabled: bool,
75    pub port: u16,
76}
77
78impl Config {
79    pub fn from_args(args: &super::Args) -> Result<Self> {
80        let database_url = args
81            .database_url
82            .clone()
83            .or_else(|| std::env::var("DATABASE_URL").ok())
84            .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
85
86        let min_size = args.min_connections.unwrap_or(5);
87        let max_size = args.max_connections.unwrap_or(20);
88
89        let auth_token = args
90            .auth_token
91            .clone()
92            .or_else(|| std::env::var("MCP_AUTH_TOKEN").ok())
93            .filter(|t| !t.is_empty());
94
95        Ok(Config {
96            database: DatabaseConfig { url: database_url },
97            server: ServerConfig {
98                host: args.host.clone(),
99                port: args.port,
100                request_timeout: Duration::from_secs(30),
101                access_mode: args.access_mode,
102                auth_token,
103                allow_url_import: args.allow_url_import,
104            },
105            pool: PoolConfig {
106                min_size,
107                max_size,
108                queue_timeout: Duration::from_secs(10),
109            },
110            metrics: MetricsConfig {
111                enabled: args.enable_metrics,
112                port: args.metrics_port,
113            },
114        })
115    }
116}
117
118impl Default for Config {
119    fn default() -> Self {
120        Self {
121            database: DatabaseConfig {
122                url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
123            },
124            server: ServerConfig {
125                host: "127.0.0.1".to_string(),
126                port: 3000,
127                request_timeout: Duration::from_secs(30),
128                access_mode: AccessMode::Unrestricted,
129                auth_token: None,
130                allow_url_import: false,
131            },
132            pool: PoolConfig {
133                min_size: 5,
134                max_size: 20,
135                queue_timeout: Duration::from_secs(10),
136            },
137            metrics: MetricsConfig {
138                enabled: false,
139                port: 9090,
140            },
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_config_defaults() {
151        let cfg = Config::default();
152        assert_eq!(cfg.server.host, "127.0.0.1");
153        assert_eq!(cfg.server.port, 3000);
154        assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
155    }
156
157    #[test]
158    fn test_database_config_defaults() {
159        let cfg = Config::default();
160        assert_eq!(
161            cfg.database.url,
162            "postgres://postgres:postgres@localhost:5432/postgres"
163        );
164    }
165
166    #[test]
167    fn test_pool_config_defaults() {
168        let cfg = Config::default();
169        assert_eq!(cfg.pool.min_size, 5);
170        assert_eq!(cfg.pool.max_size, 20);
171        assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
172    }
173
174    #[test]
175    fn test_metrics_config_defaults() {
176        let cfg = Config::default();
177        assert!(!cfg.metrics.enabled);
178        assert_eq!(cfg.metrics.port, 9090);
179    }
180
181    #[test]
182    fn test_config_serde() {
183        let cfg = Config::default();
184        let json = serde_json::to_string(&cfg).unwrap();
185        let deserialized: Config = serde_json::from_str(&json).unwrap();
186        assert_eq!(deserialized.server.port, cfg.server.port);
187        assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
188        assert_eq!(deserialized.database.url, cfg.database.url);
189    }
190
191    #[test]
192    fn test_config_from_args_cpu_aware() {
193        let num_cpus = num_cpus::get() as u32;
194
195        // Simulating what from_args does with defaults
196        let min_size = 1;
197        let max_size = num_cpus * 8;
198
199        assert_eq!(min_size, 1);
200        assert!(max_size > 0);
201        assert_eq!(max_size, num_cpus * 8);
202    }
203
204    #[test]
205    fn test_pool_config_values() {
206        let cfg = Config::default();
207        assert!(cfg.pool.min_size > 0);
208        assert!(cfg.pool.max_size >= cfg.pool.min_size);
209    }
210
211    #[test]
212    fn test_server_config_debug() {
213        let cfg = Config::default();
214        let debug = format!("{:?}", cfg);
215        assert!(debug.contains("127.0.0.1"));
216        assert!(debug.contains("3000"));
217    }
218}