Skip to main content

mcp_postgres/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
8pub enum AccessMode {
9    #[serde(rename = "unrestricted")]
10    Unrestricted,
11    #[serde(rename = "restricted")]
12    Restricted,
13}
14
15impl fmt::Display for AccessMode {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            AccessMode::Unrestricted => write!(f, "unrestricted"),
19            AccessMode::Restricted => write!(f, "restricted"),
20        }
21    }
22}
23
24impl FromStr for AccessMode {
25    type Err = String;
26    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27        match s.to_lowercase().as_str() {
28            "unrestricted" => Ok(AccessMode::Unrestricted),
29            "restricted" => Ok(AccessMode::Restricted),
30            _ => Err(format!("Invalid access mode: {s}. Use 'unrestricted' or 'restricted'")),
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct Config {
37    pub database: DatabaseConfig,
38    pub server: ServerConfig,
39    pub pool: PoolConfig,
40    pub metrics: MetricsConfig,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DatabaseConfig {
45    pub url: String,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ServerConfig {
50    pub host: String,
51    pub port: u16,
52    pub request_timeout: Duration,
53    pub access_mode: AccessMode,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct PoolConfig {
58    pub min_size: u32,
59    pub max_size: u32,
60    pub queue_timeout: Duration,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct MetricsConfig {
65    pub enabled: bool,
66    pub port: u16,
67}
68
69impl Config {
70    pub fn from_args(args: &super::Args) -> Result<Self> {
71        let database_url = args.database_url.clone()
72            .or_else(|| std::env::var("DATABASE_URL").ok())
73            .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());
74
75        let num_cpus = num_cpus::get() as u32;
76        let min_size = args.min_connections.unwrap_or(1);
77        let max_size = args.max_connections.unwrap_or(num_cpus * 8);
78
79        Ok(Config {
80            database: DatabaseConfig {
81                url: database_url,
82            },
83            server: ServerConfig {
84                host: args.host.clone(),
85                port: args.port,
86                request_timeout: Duration::from_secs(30),
87                access_mode: args.access_mode,
88            },
89            pool: PoolConfig {
90                min_size,
91                max_size,
92                queue_timeout: Duration::from_secs(10),
93            },
94            metrics: MetricsConfig {
95                enabled: args.enable_metrics,
96                port: args.metrics_port,
97            },
98        })
99    }
100}
101
102impl Default for Config {
103    fn default() -> Self {
104        let num_cpus = num_cpus::get() as u32;
105        let max_size = num_cpus * 8;
106        let min_size = 1;
107
108        Self {
109            database: DatabaseConfig {
110                url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
111            },
112            server: ServerConfig {
113                host: "127.0.0.1".to_string(),
114                port: 3000,
115                request_timeout: Duration::from_secs(30),
116                access_mode: AccessMode::Unrestricted,
117            },
118            pool: PoolConfig {
119                min_size,
120                max_size,
121                queue_timeout: Duration::from_secs(10),
122            },
123            metrics: MetricsConfig {
124                enabled: false,
125                port: 9090,
126            },
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use clap::Parser;
135
136    #[test]
137    fn test_config_defaults() {
138        let cfg = Config::default();
139        assert_eq!(cfg.server.host, "127.0.0.1");
140        assert_eq!(cfg.server.port, 3000);
141        assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
142    }
143
144    #[test]
145    fn test_database_config_defaults() {
146        let cfg = Config::default();
147        assert_eq!(cfg.database.url, "postgres://postgres:postgres@localhost:5432/postgres");
148    }
149
150    #[test]
151    fn test_pool_config_defaults() {
152        let cfg = Config::default();
153        let num_cpus = num_cpus::get() as u32;
154        assert_eq!(cfg.pool.min_size, 1);
155        assert_eq!(cfg.pool.max_size, num_cpus * 8);
156        assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
157    }
158
159    #[test]
160    fn test_metrics_config_defaults() {
161        let cfg = Config::default();
162        assert!(!cfg.metrics.enabled);
163        assert_eq!(cfg.metrics.port, 9090);
164    }
165
166    #[test]
167    fn test_config_serde() {
168        let cfg = Config::default();
169        let json = serde_json::to_string(&cfg).unwrap();
170        let deserialized: Config = serde_json::from_str(&json).unwrap();
171        assert_eq!(deserialized.server.port, cfg.server.port);
172        assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
173        assert_eq!(deserialized.database.url, cfg.database.url);
174    }
175
176    #[test]
177    fn test_config_from_args_cpu_aware() {
178        let num_cpus = num_cpus::get() as u32;
179
180        // Simulating what from_args does with defaults
181        let min_size = 1;
182        let max_size = num_cpus * 8;
183
184        assert_eq!(min_size, 1);
185        assert!(max_size > 0);
186        assert_eq!(max_size, num_cpus * 8);
187    }
188
189    #[test]
190    fn test_pool_config_values() {
191        let cfg = Config::default();
192        assert!(cfg.pool.min_size > 0);
193        assert!(cfg.pool.max_size >= cfg.pool.min_size);
194    }
195
196    #[test]
197    fn test_server_config_debug() {
198        let cfg = Config::default();
199        let debug = format!("{:?}", cfg);
200        assert!(debug.contains("127.0.0.1"));
201        assert!(debug.contains("3000"));
202    }
203}