Skip to main content

mcp_postgres/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
8pub enum AccessMode {
9    #[serde(rename = "unrestricted")]
10    Unrestricted,
11    #[serde(rename = "restricted")]
12    Restricted,
13}
14
15impl fmt::Display for AccessMode {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            AccessMode::Unrestricted => write!(f, "unrestricted"),
19            AccessMode::Restricted => write!(f, "restricted"),
20        }
21    }
22}
23
24impl FromStr for AccessMode {
25    type Err = String;
26    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "unrestricted" => Ok(AccessMode::Unrestricted),
29            "restricted" => Ok(AccessMode::Restricted),
30            _ => Err(format!(
31                "Invalid access mode: {s}. Use 'unrestricted' or 'restricted'"
32            )),
33        }
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Config {
39    pub database: DatabaseConfig,
40    pub server: ServerConfig,
41    pub pool: PoolConfig,
42    pub metrics: MetricsConfig,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DatabaseConfig {
47    pub url: String,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ServerConfig {
52    pub host: String,
53    pub port: u16,
54    pub request_timeout: Duration,
55    pub access_mode: AccessMode,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct PoolConfig {
60    pub min_size: u32,
61    pub max_size: u32,
62    pub queue_timeout: Duration,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MetricsConfig {
67    pub enabled: bool,
68    pub port: u16,
69}
70
71impl Config {
72    pub fn from_args(args: &super::Args) -> Result<Self> {
73        let database_url = args
74            .database_url
75            .clone()
76            .or_else(|| std::env::var("DATABASE_URL").ok())
77            .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
78
79        let min_size = args.min_connections.unwrap_or(5);
80        let max_size = args.max_connections.unwrap_or(20);
81
82        Ok(Config {
83            database: DatabaseConfig { url: database_url },
84            server: ServerConfig {
85                host: args.host.clone(),
86                port: args.port,
87                request_timeout: Duration::from_secs(30),
88                access_mode: args.access_mode,
89            },
90            pool: PoolConfig {
91                min_size,
92                max_size,
93                queue_timeout: Duration::from_secs(10),
94            },
95            metrics: MetricsConfig {
96                enabled: args.enable_metrics,
97                port: args.metrics_port,
98            },
99        })
100    }
101}
102
103impl Default for Config {
104    fn default() -> Self {
105        Self {
106            database: DatabaseConfig {
107                url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
108            },
109            server: ServerConfig {
110                host: "127.0.0.1".to_string(),
111                port: 3000,
112                request_timeout: Duration::from_secs(30),
113                access_mode: AccessMode::Unrestricted,
114            },
115            pool: PoolConfig {
116                min_size: 5,
117                max_size: 20,
118                queue_timeout: Duration::from_secs(10),
119            },
120            metrics: MetricsConfig {
121                enabled: false,
122                port: 9090,
123            },
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_config_defaults() {
134        let cfg = Config::default();
135        assert_eq!(cfg.server.host, "127.0.0.1");
136        assert_eq!(cfg.server.port, 3000);
137        assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
138    }
139
140    #[test]
141    fn test_database_config_defaults() {
142        let cfg = Config::default();
143        assert_eq!(
144            cfg.database.url,
145            "postgres://postgres:postgres@localhost:5432/postgres"
146        );
147    }
148
149    #[test]
150    fn test_pool_config_defaults() {
151        let cfg = Config::default();
152        assert_eq!(cfg.pool.min_size, 5);
153        assert_eq!(cfg.pool.max_size, 20);
154        assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
155    }
156
157    #[test]
158    fn test_metrics_config_defaults() {
159        let cfg = Config::default();
160        assert!(!cfg.metrics.enabled);
161        assert_eq!(cfg.metrics.port, 9090);
162    }
163
164    #[test]
165    fn test_config_serde() {
166        let cfg = Config::default();
167        let json = serde_json::to_string(&cfg).unwrap();
168        let deserialized: Config = serde_json::from_str(&json).unwrap();
169        assert_eq!(deserialized.server.port, cfg.server.port);
170        assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
171        assert_eq!(deserialized.database.url, cfg.database.url);
172    }
173
174    #[test]
175    fn test_config_from_args_cpu_aware() {
176        let num_cpus = num_cpus::get() as u32;
177
178        // Simulating what from_args does with defaults
179        let min_size = 1;
180        let max_size = num_cpus * 8;
181
182        assert_eq!(min_size, 1);
183        assert!(max_size > 0);
184        assert_eq!(max_size, num_cpus * 8);
185    }
186
187    #[test]
188    fn test_pool_config_values() {
189        let cfg = Config::default();
190        assert!(cfg.pool.min_size > 0);
191        assert!(cfg.pool.max_size >= cfg.pool.min_size);
192    }
193
194    #[test]
195    fn test_server_config_debug() {
196        let cfg = Config::default();
197        let debug = format!("{:?}", cfg);
198        assert!(debug.contains("127.0.0.1"));
199        assert!(debug.contains("3000"));
200    }
201}